Unverified Commit c2c4f57f authored by Pavani Majety's avatar Pavani Majety Committed by GitHub
Browse files

[DeepseekR1-FP4] Add Support for nvidia/DeepSeekR1-FP4 model (#6853)


Signed-off-by: default avatarPavani Majety <pmajety@nvidia.com>
parent 23881fa6
......@@ -556,7 +556,8 @@ class FusedMoE(torch.nn.Module):
loaded_weight = loaded_weight.to(param.data.device)
if (
param.data[expert_id] != 1
"compressed" in self.quant_method.__class__.__name__.lower()
and param.data[expert_id] != 1
and (param.data[expert_id] - loaded_weight).abs() > 1e-5
):
raise ValueError(
......@@ -580,6 +581,23 @@ class FusedMoE(torch.nn.Module):
tp_rank=tp_rank,
)
return
if "ModelOpt" in self.quant_method.__class__.__name__:
if "weight_scale_2" in weight_name or "input_scale" in weight_name:
self._load_per_tensor_weight_scale(
shard_id=shard_id,
param=param,
loaded_weight=loaded_weight,
expert_id=expert_id,
)
elif "weight" in weight_name:
self._load_model_weight_or_group_weight_scale(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank,
)
return
# Case weight scales and zero_points
if "scale" in weight_name or "zero" in weight_name:
......
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/modelopt.py
import logging
from typing import Any, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional
import torch
from torch.nn.parameter import Parameter
from sglang.srt.layers.linear import LinearBase, LinearMethodBase
from sglang.srt.layers.linear import (
LinearBase,
LinearMethodBase,
UnquantizedLinearMethod,
)
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
......@@ -15,10 +20,12 @@ from sglang.srt.layers.quantization.base_config import (
from sglang.srt.layers.quantization.fp8_utils import (
apply_fp8_linear,
cutlass_fp8_supported,
is_sm100_supported,
)
from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
from sglang.srt.layers.quantization.utils import (
convert_to_channelwise,
is_layer_skipped,
requantize_with_max_scale,
)
from sglang.srt.layers.radix_attention import RadixAttention
......@@ -270,9 +277,16 @@ class ModelOptFp4Config(QuantizationConfig):
)
is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method
kv_cache_quant_algo = quant_config["kv_cache_quant_algo"]
if not kv_cache_quant_algo:
kv_cache_quant_algo = "auto"
group_size = quant_config["group_size"]
exclude_modules = quant_config["exclude_modules"]
if not (group_size and kv_cache_quant_algo and exclude_modules):
logger.warning(
f"group_size: {group_size},"
f"kv_cache_quant_algo: {kv_cache_quant_algo},"
f"exclude_modules: {exclude_modules}"
)
raise ValueError(
"NVFP4 quantization requires group size and "
"kv_cache_quant_algo specified in "
......@@ -285,19 +299,30 @@ class ModelOptFp4Config(QuantizationConfig):
exclude_modules,
)
def is_layer_excluded(self, prefix: str, exclude_modules: list):
import regex as re
for pattern in exclude_modules:
regex_str = pattern.replace(".", r"\.").replace("*", r".*")
if re.fullmatch(regex_str, prefix):
return True
return False
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
if self.exclude_modules and any(
module in prefix for module in self.exclude_modules
):
return None
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
if isinstance(layer, LinearBase):
if is_layer_skipped(prefix, self.exclude_modules) or self.is_layer_excluded(
prefix, self.exclude_modules
):
return UnquantizedLinearMethod()
return ModelOptFp4LinearMethod(self)
if self.kv_cache_quant_algo and isinstance(layer, RadixAttention):
return ModelOptFp8KVCacheMethod(self)
elif isinstance(layer, FusedMoE):
return ModelOptNvFp4FusedMoEMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
......@@ -461,3 +486,305 @@ class ModelOptFp4LinearMethod(LinearMethodBase):
if bias is not None:
out = out + bias
return out.view(*output_shape)
class ModelOptNvFp4FusedMoEMethod:
"""
MoE Method for FP4 Quantization with Blockscales and PerTensorScales
Args:
quant_config: NVFP4 Quant Config
"""
def __new__(cls, *args, **kwargs):
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
if not hasattr(cls, "_initialized"):
original_init = cls.__init__
new_cls = type(
cls.__name__,
(FusedMoEMethodBase,),
{
"__init__": original_init,
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
},
)
obj = super(new_cls, new_cls).__new__(new_cls)
obj.__init__(*args, **kwargs)
return obj
return super().__new__(cls)
def __init__(self, quant_config: ModelOptFp4Config):
self.quant_config = quant_config
if not is_sm100_supported():
raise ValueError(
"Current platform does not support NVFP4"
" quantization. Please use Blackwell and"
" above."
)
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
if not self.quant_config.is_checkpoint_nvfp4_serialized:
raise ValueError(
"NVFP4 quantization was selected, "
" dynamic quantization is not supported."
)
layer.num_experts = num_experts
layer.params_dtype = params_dtype
layer.quant_config = self.quant_config
weight_dtype = torch.uint8
weight_scale_dtype = torch.float8_e4m3fn
weight_loader = extra_weight_attrs.get("weight_loader")
# GEMM 1
w13_weight = ModelWeightParameter(
data=torch.empty(
num_experts,
2 * intermediate_size_per_partition,
# 2 fp4 items are packed in the input dimension
hidden_size // 2,
dtype=weight_dtype,
),
input_dim=1,
output_dim=2,
weight_loader=weight_loader,
)
layer.register_parameter("w13_weight", w13_weight)
# GEMM 2
w2_weight = ModelWeightParameter(
data=torch.empty(
num_experts,
hidden_size,
# 2 fp4 items are packed in the input dimension
intermediate_size_per_partition // 2,
dtype=weight_dtype,
),
input_dim=1,
output_dim=2,
weight_loader=weight_loader,
)
layer.register_parameter("w2_weight", w2_weight)
w13_weight_scale = ModelWeightParameter(
data=torch.empty(
num_experts,
2 * intermediate_size_per_partition,
# 2 fp4 items are packed in the input dimension
hidden_size // self.quant_config.group_size,
dtype=weight_scale_dtype,
),
input_dim=1,
output_dim=2,
weight_loader=weight_loader,
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
w2_weight_scale = ModelWeightParameter(
data=torch.empty(
num_experts,
hidden_size,
# 2 fp4 items are packed in the input dimension
intermediate_size_per_partition // self.quant_config.group_size,
dtype=weight_scale_dtype,
),
input_dim=1,
output_dim=2,
weight_loader=weight_loader,
)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
)
w13_weight_scale_2 = PerTensorScaleParameter(
data=torch.empty(num_experts, 2, dtype=torch.float32),
weight_loader=weight_loader,
)
layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2)
w2_weight_scale_2 = PerTensorScaleParameter(
data=torch.empty(num_experts, dtype=torch.float32),
weight_loader=weight_loader,
)
layer.register_parameter("w2_weight_scale_2", w2_weight_scale_2)
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
)
w13_input_scale = PerTensorScaleParameter(
data=torch.empty(num_experts, 2, dtype=torch.float32),
weight_loader=weight_loader,
)
layer.register_parameter("w13_input_scale", w13_input_scale)
w2_input_scale = PerTensorScaleParameter(
data=torch.empty(num_experts, dtype=torch.float32),
weight_loader=weight_loader,
)
layer.register_parameter("w2_input_scale", w2_input_scale)
def swizzle_blockscale(self, scale: torch.tensor):
assert scale.dtype == torch.float8_e4m3fn
# Pad and blockwise interleave weight_scale
scale_ndim = scale.ndim
if scale.ndim == 2:
scale = scale.unsqueeze(0)
assert scale.ndim == 3
B, M, K = scale.shape
round_up_multiple = lambda x, m: (x + m - 1) // m * m
M_padded = round_up_multiple(M, 128)
K_padded = round_up_multiple(K, 4)
padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype)
padded_scale[:B, :M, :K] = scale
batches, rows, cols = padded_scale.shape
assert rows % 128 == 0
assert cols % 4 == 0
padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32, cols // 4, 4)
swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5))
swizzled_scale = swizzled_scale.contiguous().cuda()
return (
swizzled_scale.reshape(M, K)
if scale_ndim == 2
else swizzled_scale.reshape(B, M, K)
)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# GEMM 1
if not torch.allclose(
layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]
):
logger.warning_once(
"w1_weight_scale_2 must match w3_weight_scale_2. "
"Accuracy may be affected."
)
w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0]
layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False)
w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
layer.g1_alphas = Parameter(
(w13_input_scale * w13_weight_scale_2).to(torch.float32),
requires_grad=False,
)
assert (
layer.w13_weight_scale.shape[2] % 16 == 0
), "Expected weight_scale.dim(1) to be divisible by 16"
assert (
layer.w13_weight_scale.dtype == torch.float8_e4m3fn
), "Weight Blockscale must be represented as FP8-E4M3"
w13_blockscale_swizzled = self.swizzle_blockscale(layer.w13_weight_scale)
layer.w13_blockscale_swizzled = Parameter(
w13_blockscale_swizzled, requires_grad=False
)
# This is for quantization, so we need to invert it.
layer.w13_input_scale_quant = Parameter(
(1 / w13_input_scale).to(torch.float32), requires_grad=False
)
layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False)
# GEMM 2
layer.g2_alphas = Parameter(
(layer.w2_input_scale * layer.w2_weight_scale_2).to(torch.float32),
requires_grad=False,
)
# This is for quantization, so we need to invert it.
layer.w2_input_scale_quant = Parameter(
(1 / layer.w2_input_scale).to(torch.float32), requires_grad=False
)
assert (
layer.w2_weight_scale.shape[2] % 16 == 0
), "Expected weight_scale.dim(1) to be divisible by 16"
assert (
layer.w2_weight_scale.dtype == torch.float8_e4m3fn
), "Weight Blockscale must be represented as FP8-E4M3"
w2_blockscale_swizzled = self.swizzle_blockscale(layer.w2_weight_scale)
layer.w2_blockscale_swizzled = Parameter(
w2_blockscale_swizzled, requires_grad=False
)
layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
device = layer.w13_weight.device
layer.cutlass_moe_params = CutlassMoEParams(
CutlassMoEType.BlockscaledFP4,
device,
num_experts=layer.num_experts,
intermediate_size_per_partition=layer.w2_weight.shape[2] * 2, # n
hidden_size=layer.w13_weight.shape[2] * 2,
) # k
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor:
assert activation == "silu", "Only SiLU activation is supported."
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
from sglang.srt.layers.moe.topk import select_experts
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
)
from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
return cutlass_moe_fp4(
a=x,
a1_gscale=layer.w13_input_scale_quant,
w1_fp4=layer.w13_weight,
w1_blockscale=layer.w13_blockscale_swizzled,
w1_alphas=layer.g1_alphas,
a2_gscale=layer.w2_input_scale_quant,
w2_fp4=layer.w2_weight,
w2_blockscale=layer.w2_blockscale_swizzled,
w2_alphas=layer.g2_alphas,
topk_weights=topk_weights,
topk_ids=topk_ids,
params=layer.cutlass_moe_params,
apply_router_weight_on_input=apply_router_weight_on_input,
).to(x.dtype)
......@@ -1746,7 +1746,7 @@ class DeepseekV2ForCausalLM(nn.Module):
global_server_args_dict["disable_shared_experts_fusion"] = False
log_info_on_rank0(
logger,
"Deepseek V3/R1 with fp8 can use shared experts fusion optimization when SM version >=90. Shared experts fusion optimization is enabled.",
"Deepseek V3/R1 with fp8/fp4 can use shared experts fusion optimization when SM version >=90. Shared experts fusion optimization is enabled.",
)
def get_input_embeddings(self) -> nn.Embedding:
......@@ -1926,6 +1926,7 @@ class DeepseekV2ForCausalLM(nn.Module):
self_attn.use_deep_gemm_bmm = True
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
if is_nextn:
if hasattr(self.config, "num_nextn_predict_layers"):
num_nextn_layers = self.config.num_nextn_predict_layers
......@@ -1982,6 +1983,21 @@ class DeepseekV2ForCausalLM(nn.Module):
"up_proj.qzeros",
"up_proj.scales",
]
elif self.quant_config.get_name() == "modelopt_fp4":
suffix_list = [
"down_proj.weight",
"down_proj.weight_scale",
"down_proj.weight_scale_2",
"down_proj.input_scale",
"gate_proj.weight",
"gate_proj.weight_scale",
"gate_proj.weight_scale_2",
"gate_proj.input_scale",
"up_proj.weight",
"up_proj.weight_scale",
"up_proj.weight_scale_2",
"up_proj.input_scale",
]
else:
raise ValueError(
f"Unsupported shared expert fusion for quantization: {self.quant_config.get_name()}."
......@@ -2125,7 +2141,6 @@ class DeepseekV2ForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if fuse_qkv_a_proj and (
"q_a_proj" in name or "kv_a_proj_with_mqa" in name
):
......@@ -2151,9 +2166,12 @@ class DeepseekV2ForCausalLM(nn.Module):
fused_weight = torch.cat(
[q_a_proj_weight, kv_a_proj_weight], dim=0
)
param_name = name.replace(
"q_a_proj", "fused_qkv_a_proj_with_mqa"
param_name = (
name.replace("q_a_proj", "fused_qkv_a_proj_with_mqa")
if "q_a_proj" in name
else name.replace(
"kv_a_proj_with_mqa", "fused_qkv_a_proj_with_mqa"
)
)
param = params_dict[param_name]
......@@ -2164,6 +2182,16 @@ class DeepseekV2ForCausalLM(nn.Module):
cached_a_proj.pop(q_a_proj_name)
cached_a_proj.pop(kv_a_proj_name)
else:
if (
"k_scale" in name or "v_scale" in name
) and name not in params_dict:
# modelopt attn kv scale is named differently
if any(scale in name for scale in ["k_scale", "v_scale"]):
name = name.replace("_proj", "attn_mqa")
else:
logger.warning(
f"Unknown scale found in checkpoint: {name}"
)
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment