Unverified Commit a14654dd authored by Baizhou Zhang's avatar Baizhou Zhang Committed by GitHub
Browse files

Fix weight loading bug for Deepseek v3+nextn (#5684)

parent 5d93a950
......@@ -242,6 +242,12 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
num_experts=self.config.n_routed_experts + self.n_share_experts_fusion,
)
# Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
fuse_qkv_a_proj = hasattr(self.config, "q_lora_rank") and (
self.config.q_lora_rank is not None
)
cached_a_proj = {} if fuse_qkv_a_proj else None
nextn_layer_prefix = "model.layers.0"
nextn_spec_weight_names = [
"shared_head.norm",
......@@ -313,11 +319,51 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
# Handle fused_qkv_a_proj
if fuse_qkv_a_proj and (
"q_a_proj" in name or "kv_a_proj_with_mqa" in name
):
cached_a_proj[name] = loaded_weight
q_a_proj_name = (
name
if "q_a_proj" in name
else name.replace("kv_a_proj_with_mqa", "q_a_proj")
)
kv_a_proj_name = (
name
if "kv_a_proj_with_mqa" in name
else name.replace("q_a_proj", "kv_a_proj_with_mqa")
)
# When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter
if (
q_a_proj_name in cached_a_proj
and kv_a_proj_name in cached_a_proj
):
q_a_proj_weight = cached_a_proj[q_a_proj_name]
kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
fused_weight = torch.cat(
[q_a_proj_weight, kv_a_proj_weight], dim=0
)
param_name = name.replace(
"q_a_proj", "fused_qkv_a_proj_with_mqa"
)
param = params_dict[param_name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, fused_weight)
cached_a_proj.pop(q_a_proj_name)
cached_a_proj.pop(kv_a_proj_name)
else:
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
self_attn = self.model.decoder.self_attn
if hasattr(self_attn.kv_b_proj, "qweight"):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment