Unverified Commit c3948ba6 authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Reorder loop in shared expert weight loading (#5719)

parent 269c457e
...@@ -215,11 +215,11 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM): ...@@ -215,11 +215,11 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
"up_proj.weight_scale_inv", "up_proj.weight_scale_inv",
] ]
names_to_remove = [] names_to_remove = []
for num_repeat in range(self.n_share_experts_fusion): for suffix in suffix_list:
for suffix in suffix_list: shared_expert_weight_name = (
shared_expert_weight_name = ( f"model.layers.0.mlp.shared_experts.{suffix}"
f"model.layers.0.mlp.shared_experts.{suffix}" )
) for num_repeat in range(self.n_share_experts_fusion):
weights_list.append( weights_list.append(
( (
f"model.layers.0." f"model.layers.0."
...@@ -229,7 +229,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM): ...@@ -229,7 +229,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
weights_dict[shared_expert_weight_name], weights_dict[shared_expert_weight_name],
) )
) )
names_to_remove += [shared_expert_weight_name] names_to_remove += [shared_expert_weight_name]
weights = [w for w in weights_list if w[0] not in names_to_remove] weights = [w for w in weights_list if w[0] not in names_to_remove]
# Params for weights, fp8 weight scales, fp8 activation scales # Params for weights, fp8 weight scales, fp8 activation scales
......
...@@ -1650,11 +1650,11 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -1650,11 +1650,11 @@ class DeepseekV2ForCausalLM(nn.Module):
desc=f"Cloning {self.n_share_experts_fusion} " desc=f"Cloning {self.n_share_experts_fusion} "
"replicas of the shared expert into MoE", "replicas of the shared expert into MoE",
): ):
for num_repeat in range(self.n_share_experts_fusion): for suffix in suffix_list:
for suffix in suffix_list: shared_expert_weight_name = (
shared_expert_weight_name = ( f"model.layers.{moe_layer}.mlp.shared_experts.{suffix}"
f"model.layers.{moe_layer}.mlp.shared_experts.{suffix}" )
) for num_repeat in range(self.n_share_experts_fusion):
weights_list.append( weights_list.append(
( (
f"model.layers.{moe_layer}." f"model.layers.{moe_layer}."
...@@ -1664,7 +1664,7 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -1664,7 +1664,7 @@ class DeepseekV2ForCausalLM(nn.Module):
weights_dict[shared_expert_weight_name], weights_dict[shared_expert_weight_name],
) )
) )
names_to_remove += [shared_expert_weight_name] names_to_remove += [shared_expert_weight_name]
weights = [w for w in weights_list if w[0] not in names_to_remove] weights = [w for w in weights_list if w[0] not in names_to_remove]
# Params for weights, fp8 weight scales, fp8 activation scales # Params for weights, fp8 weight scales, fp8 activation scales
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment