"docs/vscode:/vscode.git/clone" did not exist on "62cbde8d41ac39e4b3a1f5bbbbc546cc93f1d84d"
Unverified Commit 21176b00 authored by Pavani Majety's avatar Pavani Majety Committed by GitHub
Browse files

[Bugfix] Fix Weightloading for the original nvidia/Deepseek-R1-FP4 checkpoint (#9940)


Signed-off-by: default avatarPavani Majety <pmajety@nvidia.com>
Co-authored-by: default avatarYineng Zhang <me@zhyncs.com>
Co-authored-by: default avatarfzyzcjy <5236035+fzyzcjy@users.noreply.github.com>
parent 94100294
...@@ -642,10 +642,22 @@ class ModelOptFp4Config(QuantizationConfig): ...@@ -642,10 +642,22 @@ class ModelOptFp4Config(QuantizationConfig):
def is_layer_excluded(self, prefix: str, exclude_modules: list): def is_layer_excluded(self, prefix: str, exclude_modules: list):
import regex as re import regex as re
fused_patterns = ["q_a_proj", "q_b_proj", "kv_a_proj_with_mqa", "kv_b_proj"]
prefix_split = prefix.split(".")
for pattern in exclude_modules: for pattern in exclude_modules:
regex_str = pattern.replace(".", r"\.").replace("*", r".*") regex_str = pattern.replace(".", r"\.").replace("*", r".*")
pattern_split = pattern.split(".")
if re.fullmatch(regex_str, prefix): if re.fullmatch(regex_str, prefix):
return True return True
elif (
pattern_split[-1] in fused_patterns
and pattern_split[-1] in prefix_split[-1]
):
# Check if the last part of the excluded pattern is contained in the last part of the prefix
# This handles fused modules like fused_qkv_a_proj_with_mqa that contain q_a_proj and kv_a_proj_with_mqa
# e.g., model.layers.{i}.self_attn.{fused_weight_name}
assert len(prefix_split) == 5 and len(pattern_split) == 5
return True
return False return False
def get_quant_method( def get_quant_method(
...@@ -1250,8 +1262,6 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): ...@@ -1250,8 +1262,6 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
layer.w13_weight_scale, layer.w13_weight_scale,
) )
logger.info_once("Applied flashinfer weight processing for both w13 and w2")
else: else:
# CUTLASS processing - handle w13 and w2 separately # CUTLASS processing - handle w13 and w2 separately
...@@ -1268,7 +1278,6 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): ...@@ -1268,7 +1278,6 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False) layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
# Both flashinfer cutlass and regular cutlass use same processing for w2 # Both flashinfer cutlass and regular cutlass use same processing for w2
logger.info_once("Applied weight processing for both w13 and w2")
# Set up CUTLASS MoE parameters # Set up CUTLASS MoE parameters
device = layer.w13_weight.device device = layer.w13_weight.device
......
...@@ -654,11 +654,13 @@ class ServerArgs: ...@@ -654,11 +654,13 @@ class ServerArgs:
], "The expert parallel size must be 1 or the same as the tensor parallel size" ], "The expert parallel size must be 1 or the same as the tensor parallel size"
if self.moe_runner_backend == "flashinfer_trtllm": if self.moe_runner_backend == "flashinfer_trtllm":
if not self.disable_shared_experts_fusion: assert (
self.disable_shared_experts_fusion = True self.quantization == "modelopt_fp4" or self.quantization == "fp8"
logger.warning( ), "modelopt_fp4 quantization is required for Flashinfer TRTLLM MoE"
"FlashInfer TRTLLM MoE is enabled. --disable-shared-experts-fusion is automatically set." self.disable_shared_experts_fusion = True
) logger.warning(
"FlashInfer TRTLLM MoE is enabled. --disable-shared-experts-fusion is automatically set."
)
# DeepEP MoE # DeepEP MoE
if self.moe_a2a_backend == "deepep": if self.moe_a2a_backend == "deepep":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment