Unverified Commit 15f34013 authored by Charles Chen's avatar Charles Chen Committed by GitHub
Browse files

Fix MTP with Deepseek R1 Fp4 (#7376)

parent d04163b3
...@@ -330,6 +330,12 @@ class FusedMoE(torch.nn.Module): ...@@ -330,6 +330,12 @@ class FusedMoE(torch.nn.Module):
self.tp_rank = get_tensor_model_parallel_rank() self.tp_rank = get_tensor_model_parallel_rank()
self.num_experts = num_experts self.num_experts = num_experts
self.expert_map = None self.expert_map = None
if enable_flashinfer_moe and quant_config is None:
logger.warning("Disable flashinfer MoE when quantization config is None.")
enable_flashinfer_moe = False
enable_ep_moe = False
self.enable_flashinfer_moe = enable_flashinfer_moe self.enable_flashinfer_moe = enable_flashinfer_moe
if enable_ep_moe: if enable_ep_moe:
assert ( assert (
......
...@@ -44,6 +44,12 @@ class DeepseekModelNextN(nn.Module): ...@@ -44,6 +44,12 @@ class DeepseekModelNextN(nn.Module):
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
if quant_config is not None and quant_config.get_name() == "modelopt_fp4":
logger.warning(
"Overriding DeepseekV3ForCausalLMNextN quant config for modelopt_fp4 Deepseek model."
)
quant_config = None
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
......
...@@ -2201,7 +2201,7 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -2201,7 +2201,7 @@ class DeepseekV2ForCausalLM(nn.Module):
q_a_proj_weight = cached_a_proj[q_a_proj_name] q_a_proj_weight = cached_a_proj[q_a_proj_name]
kv_a_proj_weight = cached_a_proj[kv_a_proj_name] kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
cat_dim = 0 cat_dim = 0
if ( if self.quant_config is not None and (
self.quant_config.get_name() == "awq" self.quant_config.get_name() == "awq"
or self.quant_config.get_name() == "moe_wna16" or self.quant_config.get_name() == "moe_wna16"
): ):
...@@ -2232,6 +2232,13 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -2232,6 +2232,13 @@ class DeepseekV2ForCausalLM(nn.Module):
for scale in ["k_scale", "v_scale"]: for scale in ["k_scale", "v_scale"]:
if scale in name: if scale in name:
name = name.replace(f"{scale[0]}_proj", "attn_mqa") name = name.replace(f"{scale[0]}_proj", "attn_mqa")
break
if name not in params_dict:
# modelopt ckpt contains not needed weights for MTP module:
# model.decoder.self_attn.attn_mqa.v_scale and
# model.decoder.self_attn.attn_mqa.k_scale
logger.warning(f"{name} not found in params_dict.")
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr( weight_loader = getattr(
param, "weight_loader", default_weight_loader param, "weight_loader", default_weight_loader
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment