You need to sign in or sign up before continuing.
Unverified Commit 492143bf authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

fix: resolve qwen2 moe weight loader (#1252)

parent 0a97d796
...@@ -401,24 +401,12 @@ class Qwen2MoeForCausalLM(nn.Module): ...@@ -401,24 +401,12 @@ class Qwen2MoeForCausalLM(nn.Module):
("gate_up_proj", "up_proj", 1), ("gate_up_proj", "up_proj", 1),
] ]
expert_params_mapping = [ expert_params_mapping = FusedMoE.make_expert_params_mapping(
# These are the weights for the experts ckpt_gate_proj_name="gate_proj",
# (param_name, weight_name, expert_id, shard_id) ckpt_down_proj_name="down_proj",
( ckpt_up_proj_name="up_proj",
( num_experts=self.config.num_experts,
"experts.w13_weight" )
if weight_name in ["gate_proj", "up_proj"]
else "experts.w2_weight"
),
f"experts.{expert_id}.{weight_name}.weight",
expert_id,
shard_id,
)
for expert_id in range(self.config.num_experts)
for shard_id, weight_name in enumerate(
["gate_proj", "down_proj", "up_proj"]
)
]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
for name, loaded_weight in weights: for name, loaded_weight in weights:
...@@ -458,7 +446,7 @@ class Qwen2MoeForCausalLM(nn.Module): ...@@ -458,7 +446,7 @@ class Qwen2MoeForCausalLM(nn.Module):
weight_loader( weight_loader(
param, param,
loaded_weight, loaded_weight,
weight_name, name,
shard_id=shard_id, shard_id=shard_id,
expert_id=expert_id, expert_id=expert_id,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment