Unverified Commit c2c4f57f authored by Pavani Majety's avatar Pavani Majety Committed by GitHub
Browse files

[DeepseekR1-FP4] Add Support for nvidia/DeepSeekR1-FP4 model (#6853)


Signed-off-by: default avatarPavani Majety <pmajety@nvidia.com>
parent 23881fa6
...@@ -556,7 +556,8 @@ class FusedMoE(torch.nn.Module): ...@@ -556,7 +556,8 @@ class FusedMoE(torch.nn.Module):
loaded_weight = loaded_weight.to(param.data.device) loaded_weight = loaded_weight.to(param.data.device)
if ( if (
param.data[expert_id] != 1 "compressed" in self.quant_method.__class__.__name__.lower()
and param.data[expert_id] != 1
and (param.data[expert_id] - loaded_weight).abs() > 1e-5 and (param.data[expert_id] - loaded_weight).abs() > 1e-5
): ):
raise ValueError( raise ValueError(
...@@ -580,6 +581,23 @@ class FusedMoE(torch.nn.Module): ...@@ -580,6 +581,23 @@ class FusedMoE(torch.nn.Module):
tp_rank=tp_rank, tp_rank=tp_rank,
) )
return return
if "ModelOpt" in self.quant_method.__class__.__name__:
if "weight_scale_2" in weight_name or "input_scale" in weight_name:
self._load_per_tensor_weight_scale(
shard_id=shard_id,
param=param,
loaded_weight=loaded_weight,
expert_id=expert_id,
)
elif "weight" in weight_name:
self._load_model_weight_or_group_weight_scale(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank,
)
return
# Case weight scales and zero_points # Case weight scales and zero_points
if "scale" in weight_name or "zero" in weight_name: if "scale" in weight_name or "zero" in weight_name:
......
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/modelopt.py # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/modelopt.py
import logging import logging
from typing import Any, Dict, List, Optional from typing import Any, Callable, Dict, List, Optional
import torch import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from sglang.srt.layers.linear import LinearBase, LinearMethodBase from sglang.srt.layers.linear import (
LinearBase,
LinearMethodBase,
UnquantizedLinearMethod,
)
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
from sglang.srt.layers.quantization.base_config import ( from sglang.srt.layers.quantization.base_config import (
QuantizationConfig, QuantizationConfig,
...@@ -15,10 +20,12 @@ from sglang.srt.layers.quantization.base_config import ( ...@@ -15,10 +20,12 @@ from sglang.srt.layers.quantization.base_config import (
from sglang.srt.layers.quantization.fp8_utils import ( from sglang.srt.layers.quantization.fp8_utils import (
apply_fp8_linear, apply_fp8_linear,
cutlass_fp8_supported, cutlass_fp8_supported,
is_sm100_supported,
) )
from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
from sglang.srt.layers.quantization.utils import ( from sglang.srt.layers.quantization.utils import (
convert_to_channelwise, convert_to_channelwise,
is_layer_skipped,
requantize_with_max_scale, requantize_with_max_scale,
) )
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
...@@ -270,9 +277,16 @@ class ModelOptFp4Config(QuantizationConfig): ...@@ -270,9 +277,16 @@ class ModelOptFp4Config(QuantizationConfig):
) )
is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method
kv_cache_quant_algo = quant_config["kv_cache_quant_algo"] kv_cache_quant_algo = quant_config["kv_cache_quant_algo"]
if not kv_cache_quant_algo:
kv_cache_quant_algo = "auto"
group_size = quant_config["group_size"] group_size = quant_config["group_size"]
exclude_modules = quant_config["exclude_modules"] exclude_modules = quant_config["exclude_modules"]
if not (group_size and kv_cache_quant_algo and exclude_modules): if not (group_size and kv_cache_quant_algo and exclude_modules):
logger.warning(
f"group_size: {group_size},"
f"kv_cache_quant_algo: {kv_cache_quant_algo},"
f"exclude_modules: {exclude_modules}"
)
raise ValueError( raise ValueError(
"NVFP4 quantization requires group size and " "NVFP4 quantization requires group size and "
"kv_cache_quant_algo specified in " "kv_cache_quant_algo specified in "
...@@ -285,19 +299,30 @@ class ModelOptFp4Config(QuantizationConfig): ...@@ -285,19 +299,30 @@ class ModelOptFp4Config(QuantizationConfig):
exclude_modules, exclude_modules,
) )
def is_layer_excluded(self, prefix: str, exclude_modules: list):
import regex as re
for pattern in exclude_modules:
regex_str = pattern.replace(".", r"\.").replace("*", r".*")
if re.fullmatch(regex_str, prefix):
return True
return False
def get_quant_method( def get_quant_method(
self, layer: torch.nn.Module, prefix: str self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]: ) -> Optional["QuantizeMethodBase"]:
if self.exclude_modules and any( from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
module in prefix for module in self.exclude_modules
):
return None
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
if is_layer_skipped(prefix, self.exclude_modules) or self.is_layer_excluded(
prefix, self.exclude_modules
):
return UnquantizedLinearMethod()
return ModelOptFp4LinearMethod(self) return ModelOptFp4LinearMethod(self)
if self.kv_cache_quant_algo and isinstance(layer, RadixAttention): if self.kv_cache_quant_algo and isinstance(layer, RadixAttention):
return ModelOptFp8KVCacheMethod(self) return ModelOptFp8KVCacheMethod(self)
elif isinstance(layer, FusedMoE):
return ModelOptNvFp4FusedMoEMethod(self)
return None return None
def get_scaled_act_names(self) -> List[str]: def get_scaled_act_names(self) -> List[str]:
...@@ -461,3 +486,305 @@ class ModelOptFp4LinearMethod(LinearMethodBase): ...@@ -461,3 +486,305 @@ class ModelOptFp4LinearMethod(LinearMethodBase):
if bias is not None: if bias is not None:
out = out + bias out = out + bias
return out.view(*output_shape) return out.view(*output_shape)
class ModelOptNvFp4FusedMoEMethod:
"""
MoE Method for FP4 Quantization with Blockscales and PerTensorScales
Args:
quant_config: NVFP4 Quant Config
"""
def __new__(cls, *args, **kwargs):
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
if not hasattr(cls, "_initialized"):
original_init = cls.__init__
new_cls = type(
cls.__name__,
(FusedMoEMethodBase,),
{
"__init__": original_init,
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
},
)
obj = super(new_cls, new_cls).__new__(new_cls)
obj.__init__(*args, **kwargs)
return obj
return super().__new__(cls)
def __init__(self, quant_config: ModelOptFp4Config):
self.quant_config = quant_config
if not is_sm100_supported():
raise ValueError(
"Current platform does not support NVFP4"
" quantization. Please use Blackwell and"
" above."
)
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
if not self.quant_config.is_checkpoint_nvfp4_serialized:
raise ValueError(
"NVFP4 quantization was selected, "
" dynamic quantization is not supported."
)
layer.num_experts = num_experts
layer.params_dtype = params_dtype
layer.quant_config = self.quant_config
weight_dtype = torch.uint8
weight_scale_dtype = torch.float8_e4m3fn
weight_loader = extra_weight_attrs.get("weight_loader")
# GEMM 1
w13_weight = ModelWeightParameter(
data=torch.empty(
num_experts,
2 * intermediate_size_per_partition,
# 2 fp4 items are packed in the input dimension
hidden_size // 2,
dtype=weight_dtype,
),
input_dim=1,
output_dim=2,
weight_loader=weight_loader,
)
layer.register_parameter("w13_weight", w13_weight)
# GEMM 2
w2_weight = ModelWeightParameter(
data=torch.empty(
num_experts,
hidden_size,
# 2 fp4 items are packed in the input dimension
intermediate_size_per_partition // 2,
dtype=weight_dtype,
),
input_dim=1,
output_dim=2,
weight_loader=weight_loader,
)
layer.register_parameter("w2_weight", w2_weight)
w13_weight_scale = ModelWeightParameter(
data=torch.empty(
num_experts,
2 * intermediate_size_per_partition,
# 2 fp4 items are packed in the input dimension
hidden_size // self.quant_config.group_size,
dtype=weight_scale_dtype,
),
input_dim=1,
output_dim=2,
weight_loader=weight_loader,
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
w2_weight_scale = ModelWeightParameter(
data=torch.empty(
num_experts,
hidden_size,
# 2 fp4 items are packed in the input dimension
intermediate_size_per_partition // self.quant_config.group_size,
dtype=weight_scale_dtype,
),
input_dim=1,
output_dim=2,
weight_loader=weight_loader,
)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
)
w13_weight_scale_2 = PerTensorScaleParameter(
data=torch.empty(num_experts, 2, dtype=torch.float32),
weight_loader=weight_loader,
)
layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2)
w2_weight_scale_2 = PerTensorScaleParameter(
data=torch.empty(num_experts, dtype=torch.float32),
weight_loader=weight_loader,
)
layer.register_parameter("w2_weight_scale_2", w2_weight_scale_2)
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
)
w13_input_scale = PerTensorScaleParameter(
data=torch.empty(num_experts, 2, dtype=torch.float32),
weight_loader=weight_loader,
)
layer.register_parameter("w13_input_scale", w13_input_scale)
w2_input_scale = PerTensorScaleParameter(
data=torch.empty(num_experts, dtype=torch.float32),
weight_loader=weight_loader,
)
layer.register_parameter("w2_input_scale", w2_input_scale)
def swizzle_blockscale(self, scale: torch.tensor):
assert scale.dtype == torch.float8_e4m3fn
# Pad and blockwise interleave weight_scale
scale_ndim = scale.ndim
if scale.ndim == 2:
scale = scale.unsqueeze(0)
assert scale.ndim == 3
B, M, K = scale.shape
round_up_multiple = lambda x, m: (x + m - 1) // m * m
M_padded = round_up_multiple(M, 128)
K_padded = round_up_multiple(K, 4)
padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype)
padded_scale[:B, :M, :K] = scale
batches, rows, cols = padded_scale.shape
assert rows % 128 == 0
assert cols % 4 == 0
padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32, cols // 4, 4)
swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5))
swizzled_scale = swizzled_scale.contiguous().cuda()
return (
swizzled_scale.reshape(M, K)
if scale_ndim == 2
else swizzled_scale.reshape(B, M, K)
)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# GEMM 1
if not torch.allclose(
layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]
):
logger.warning_once(
"w1_weight_scale_2 must match w3_weight_scale_2. "
"Accuracy may be affected."
)
w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0]
layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False)
w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
layer.g1_alphas = Parameter(
(w13_input_scale * w13_weight_scale_2).to(torch.float32),
requires_grad=False,
)
assert (
layer.w13_weight_scale.shape[2] % 16 == 0
), "Expected weight_scale.dim(1) to be divisible by 16"
assert (
layer.w13_weight_scale.dtype == torch.float8_e4m3fn
), "Weight Blockscale must be represented as FP8-E4M3"
w13_blockscale_swizzled = self.swizzle_blockscale(layer.w13_weight_scale)
layer.w13_blockscale_swizzled = Parameter(
w13_blockscale_swizzled, requires_grad=False
)
# This is for quantization, so we need to invert it.
layer.w13_input_scale_quant = Parameter(
(1 / w13_input_scale).to(torch.float32), requires_grad=False
)
layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False)
# GEMM 2
layer.g2_alphas = Parameter(
(layer.w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
requires_grad=False,
)
# This is for quantization, so we need to invert it.
layer.w2_input_scale_quant = Parameter(
(1 / layer.w2_input_scale).to(torch.float32), requires_grad=False
)
assert (
layer.w2_weight_scale.shape[2] % 16 == 0
), "Expected weight_scale.dim(1) to be divisible by 16"
assert (
layer.w2_weight_scale.dtype == torch.float8_e4m3fn
), "Weight Blockscale must be represented as FP8-E4M3"
w2_blockscale_swizzled = self.swizzle_blockscale(layer.w2_weight_scale)
layer.w2_blockscale_swizzled = Parameter(
w2_blockscale_swizzled, requires_grad=False
)
layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
device = layer.w13_weight.device
layer.cutlass_moe_params = CutlassMoEParams(
CutlassMoEType.BlockscaledFP4,
device,
num_experts=layer.num_experts,
intermediate_size_per_partition=layer.w2_weight.shape[2] * 2, # n
hidden_size=layer.w13_weight.shape[2] * 2,
) # k
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor:
assert activation == "silu", "Only SiLU activation is supported."
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
from sglang.srt.layers.moe.topk import select_experts
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
)
from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
return cutlass_moe_fp4(
a=x,
a1_gscale=layer.w13_input_scale_quant,
w1_fp4=layer.w13_weight,
w1_blockscale=layer.w13_blockscale_swizzled,
w1_alphas=layer.g1_alphas,
a2_gscale=layer.w2_input_scale_quant,
w2_fp4=layer.w2_weight,
w2_blockscale=layer.w2_blockscale_swizzled,
w2_alphas=layer.g2_alphas,
topk_weights=topk_weights,
topk_ids=topk_ids,
params=layer.cutlass_moe_params,
apply_router_weight_on_input=apply_router_weight_on_input,
).to(x.dtype)
...@@ -1746,7 +1746,7 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -1746,7 +1746,7 @@ class DeepseekV2ForCausalLM(nn.Module):
global_server_args_dict["disable_shared_experts_fusion"] = False global_server_args_dict["disable_shared_experts_fusion"] = False
log_info_on_rank0( log_info_on_rank0(
logger, logger,
"Deepseek V3/R1 with fp8 can use shared experts fusion optimization when SM version >=90. Shared experts fusion optimization is enabled.", "Deepseek V3/R1 with fp8/fp4 can use shared experts fusion optimization when SM version >=90. Shared experts fusion optimization is enabled.",
) )
def get_input_embeddings(self) -> nn.Embedding: def get_input_embeddings(self) -> nn.Embedding:
...@@ -1926,6 +1926,7 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -1926,6 +1926,7 @@ class DeepseekV2ForCausalLM(nn.Module):
self_attn.use_deep_gemm_bmm = True self_attn.use_deep_gemm_bmm = True
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
if is_nextn: if is_nextn:
if hasattr(self.config, "num_nextn_predict_layers"): if hasattr(self.config, "num_nextn_predict_layers"):
num_nextn_layers = self.config.num_nextn_predict_layers num_nextn_layers = self.config.num_nextn_predict_layers
...@@ -1982,6 +1983,21 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -1982,6 +1983,21 @@ class DeepseekV2ForCausalLM(nn.Module):
"up_proj.qzeros", "up_proj.qzeros",
"up_proj.scales", "up_proj.scales",
] ]
elif self.quant_config.get_name() == "modelopt_fp4":
suffix_list = [
"down_proj.weight",
"down_proj.weight_scale",
"down_proj.weight_scale_2",
"down_proj.input_scale",
"gate_proj.weight",
"gate_proj.weight_scale",
"gate_proj.weight_scale_2",
"gate_proj.input_scale",
"up_proj.weight",
"up_proj.weight_scale",
"up_proj.weight_scale_2",
"up_proj.input_scale",
]
else: else:
raise ValueError( raise ValueError(
f"Unsupported shared expert fusion for quantization: {self.quant_config.get_name()}." f"Unsupported shared expert fusion for quantization: {self.quant_config.get_name()}."
...@@ -2125,7 +2141,6 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -2125,7 +2141,6 @@ class DeepseekV2ForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if fuse_qkv_a_proj and ( if fuse_qkv_a_proj and (
"q_a_proj" in name or "kv_a_proj_with_mqa" in name "q_a_proj" in name or "kv_a_proj_with_mqa" in name
): ):
...@@ -2151,9 +2166,12 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -2151,9 +2166,12 @@ class DeepseekV2ForCausalLM(nn.Module):
fused_weight = torch.cat( fused_weight = torch.cat(
[q_a_proj_weight, kv_a_proj_weight], dim=0 [q_a_proj_weight, kv_a_proj_weight], dim=0
) )
param_name = (
param_name = name.replace( name.replace("q_a_proj", "fused_qkv_a_proj_with_mqa")
"q_a_proj", "fused_qkv_a_proj_with_mqa" if "q_a_proj" in name
else name.replace(
"kv_a_proj_with_mqa", "fused_qkv_a_proj_with_mqa"
)
) )
param = params_dict[param_name] param = params_dict[param_name]
...@@ -2164,6 +2182,16 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -2164,6 +2182,16 @@ class DeepseekV2ForCausalLM(nn.Module):
cached_a_proj.pop(q_a_proj_name) cached_a_proj.pop(q_a_proj_name)
cached_a_proj.pop(kv_a_proj_name) cached_a_proj.pop(kv_a_proj_name)
else: else:
if (
"k_scale" in name or "v_scale" in name
) and name not in params_dict:
# modelopt attn kv scale is named differently
if any(scale in name for scale in ["k_scale", "v_scale"]):
name = name.replace("_proj", "attn_mqa")
else:
logger.warning(
f"Unknown scale found in checkpoint: {name}"
)
param = params_dict[name] param = params_dict[name]
weight_loader = getattr( weight_loader = getattr(
param, "weight_loader", default_weight_loader param, "weight_loader", default_weight_loader
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment