Unverified Commit 2dff2e21 authored by Benjamin Chislett's avatar Benjamin Chislett Committed by GitHub
Browse files

[Bugfix] Fix MTP weight loading (#21941)

parent 71470bc4
...@@ -182,6 +182,8 @@ class DeepSeekMTP(nn.Module, SupportsPP): ...@@ -182,6 +182,8 @@ class DeepSeekMTP(nn.Module, SupportsPP):
stacked_params_mapping = [ stacked_params_mapping = [
("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1), ("gate_up_proj", "up_proj", 1),
("fused_qkv_a_proj", "q_a_proj", 0),
("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
] ]
expert_params_mapping = FusedMoE.make_expert_params_mapping( expert_params_mapping = FusedMoE.make_expert_params_mapping(
...@@ -212,6 +214,13 @@ class DeepSeekMTP(nn.Module, SupportsPP): ...@@ -212,6 +214,13 @@ class DeepSeekMTP(nn.Module, SupportsPP):
if (("mlp.experts." in name) and name not in params_dict): if (("mlp.experts." in name) and name not in params_dict):
continue continue
name = name.replace(weight_name, param_name) name = name.replace(weight_name, param_name)
# QKV fusion is optional, fall back to normal
# weight loading if it's not enabled
if ((param_name == "fused_qkv_a_proj")
and name not in params_dict):
continue
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment