Unverified Commit 8617f867 authored by Yongye Zhu's avatar Yongye Zhu Committed by GitHub
Browse files

[Bugfix] Fix DSV32 weight loading (#38870)


Signed-off-by: default avatarYongye Zhu <zyy1102000@gmail.com>
parent 06fd9ffc
...@@ -184,11 +184,16 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts): ...@@ -184,11 +184,16 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
self.config = vllm_config.model_config.hf_config self.config = vllm_config.model_config.hf_config
self.quant_config = vllm_config.quant_config
self.model = DeepSeekMultiTokenPredictor( self.model = DeepSeekMultiTokenPredictor(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
) )
# Set MoE hyperparameters # Set MoE hyperparameters
self.set_moe_parameters() self.set_moe_parameters()
self.is_fp4_ckpt = (
self.quant_config is not None
and self.quant_config.get_name() == "modelopt_fp4"
)
def set_moe_parameters(self): def set_moe_parameters(self):
self.expert_weights = [] self.expert_weights = []
...@@ -241,10 +246,15 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts): ...@@ -241,10 +246,15 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts):
("gate_up_proj", "up_proj", 1), ("gate_up_proj", "up_proj", 1),
("fused_qkv_a_proj", "q_a_proj", 0), ("fused_qkv_a_proj", "q_a_proj", 0),
("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1), ("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
# Fused indexer wk + weights_proj ]
if self.is_fp4_ckpt:
# Fused indexer wk + weights_proj (shard 0 = wk, shard 1 = weights_proj)
indexer_fused_mapping = [
("wk_weights_proj", "wk", 0), ("wk_weights_proj", "wk", 0),
("wk_weights_proj", "weights_proj", 1), ("wk_weights_proj", "weights_proj", 1),
] ]
stacked_params_mapping.extend(indexer_fused_mapping)
expert_params_mapping = SharedFusedMoE.make_expert_params_mapping( expert_params_mapping = SharedFusedMoE.make_expert_params_mapping(
self, self,
......
...@@ -625,6 +625,11 @@ class Indexer(nn.Module): ...@@ -625,6 +625,11 @@ class Indexer(nn.Module):
super().__init__() super().__init__()
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.config = config self.config = config
self.quant_config = quant_config
self.is_fp4_ckpt = (
self.quant_config is not None
and self.quant_config.get_name() == "modelopt_fp4"
)
# self.indexer_cfg = config.attn_module_list_cfg[0]["attn_index"] # self.indexer_cfg = config.attn_module_list_cfg[0]["attn_index"]
self.topk_tokens = config.index_topk self.topk_tokens = config.index_topk
self.n_head = config.index_n_heads # 64 self.n_head = config.index_n_heads # 64
...@@ -639,10 +644,13 @@ class Indexer(nn.Module): ...@@ -639,10 +644,13 @@ class Indexer(nn.Module):
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.wq_b", prefix=f"{prefix}.wq_b",
) )
if self.is_fp4_ckpt:
# Fused wk + weights_proj: single GEMM producing [head_dim + n_head]. # Fused wk + weights_proj: single GEMM producing [head_dim + n_head].
# weights_proj does not get quantized, so we run both with quant_config=None # weights_proj does not get quantized,
# wk may be upcasted from the default quant; experiments show fusion is always # so we run both with quant_config=None
# faster unless WK proj is in FP4, which is not the case for all known quants. # wk may be upcasted from the default quant;
# experiments show fusion is always faster unless WK proj is in FP4,
# which is not the case for all known quants.
self.wk_weights_proj = MergedColumnParallelLinear( self.wk_weights_proj = MergedColumnParallelLinear(
hidden_size, hidden_size,
[self.head_dim, self.n_head], [self.head_dim, self.n_head],
...@@ -651,6 +659,21 @@ class Indexer(nn.Module): ...@@ -651,6 +659,21 @@ class Indexer(nn.Module):
disable_tp=True, disable_tp=True,
prefix=f"{prefix}.wk_weights_proj", prefix=f"{prefix}.wk_weights_proj",
) )
else:
self.wk = ReplicatedLinear(
hidden_size,
self.head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.wk",
)
self.weights_proj = ReplicatedLinear(
hidden_size,
self.n_head,
bias=False,
quant_config=None,
prefix=f"{prefix}.weights_proj",
)
self.k_norm = LayerNorm(self.head_dim, eps=1e-6) self.k_norm = LayerNorm(self.head_dim, eps=1e-6)
self.softmax_scale = self.head_dim**-0.5 self.softmax_scale = self.head_dim**-0.5
...@@ -691,11 +714,14 @@ class Indexer(nn.Module): ...@@ -691,11 +714,14 @@ class Indexer(nn.Module):
q_pe, q_nope = torch.split( q_pe, q_nope = torch.split(
q, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1 q, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1
) )
if self.is_fp4_ckpt:
# Fused wk + weights_proj: one GEMM, then split # Fused wk + weights_proj: one GEMM, then split
kw, _ = self.wk_weights_proj(hidden_states) kw, _ = self.wk_weights_proj(hidden_states)
k = kw[:, : self.head_dim] k = kw[:, : self.head_dim]
weights_raw = kw[:, self.head_dim :] weights = kw[:, self.head_dim :]
else:
k, _ = self.wk(hidden_states)
weights, _ = self.weights_proj(hidden_states)
k = self.k_norm(k) k = self.k_norm(k)
k_pe, k_nope = torch.split( k_pe, k_nope = torch.split(
...@@ -726,7 +752,7 @@ class Indexer(nn.Module): ...@@ -726,7 +752,7 @@ class Indexer(nn.Module):
q_scale = q_scale.view(-1, self.n_head, 1) q_scale = q_scale.view(-1, self.n_head, 1)
weights = ( weights = (
weights_raw.unsqueeze(-1) * q_scale * self.softmax_scale * self.n_head**-0.5 weights.unsqueeze(-1) * q_scale * self.softmax_scale * self.n_head**-0.5
) )
weights = weights.squeeze(-1) weights = weights.squeeze(-1)
...@@ -1314,6 +1340,10 @@ class DeepseekV2ForCausalLM( ...@@ -1314,6 +1340,10 @@ class DeepseekV2ForCausalLM(
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.is_fp4_ckpt = (
self.quant_config is not None
and self.quant_config.get_name() == "modelopt_fp4"
)
qk_nope_head_dim = getattr(config, "qk_nope_head_dim", 0) qk_nope_head_dim = getattr(config, "qk_nope_head_dim", 0)
qk_rope_head_dim = getattr(config, "qk_rope_head_dim", 0) qk_rope_head_dim = getattr(config, "qk_rope_head_dim", 0)
...@@ -1439,6 +1469,7 @@ class DeepseekV2ForCausalLM( ...@@ -1439,6 +1469,7 @@ class DeepseekV2ForCausalLM(
("qkv_proj", "k_proj", "k"), ("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"), ("qkv_proj", "v_proj", "v"),
] ]
if self.is_fp4_ckpt:
# Fused indexer wk + weights_proj (shard 0 = wk, shard 1 = weights_proj) # Fused indexer wk + weights_proj (shard 0 = wk, shard 1 = weights_proj)
indexer_fused_mapping = [ indexer_fused_mapping = [
("wk_weights_proj", "wk", 0), ("wk_weights_proj", "wk", 0),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment