Unverified Commit 1b769dcc authored by Jee Jee Li's avatar Jee Jee Li Committed by GitHub
Browse files

[Bugfix] Fix Ernie4_5_MoeForCausalLM shared experts (#21717)


Signed-off-by: default avatarJee Jee Li <pandaleefree@gmail.com>
parent 2cc57119
...@@ -109,8 +109,8 @@ class Ernie4_5_MoeMoE(nn.Module): ...@@ -109,8 +109,8 @@ class Ernie4_5_MoeMoE(nn.Module):
layer_idx = extract_layer_index(prefix) layer_idx = extract_layer_index(prefix)
self.layer_idx = layer_idx self.layer_idx = layer_idx
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.moe_num_shared_experts = getattr(config, "moe_num_shared_experts", self.has_shared_experts = (getattr(config, "moe_num_shared_experts", 0)
None) > 0)
if self.tp_size > config.moe_num_experts: if self.tp_size > config.moe_num_experts:
raise ValueError( raise ValueError(
...@@ -137,7 +137,7 @@ class Ernie4_5_MoeMoE(nn.Module): ...@@ -137,7 +137,7 @@ class Ernie4_5_MoeMoE(nn.Module):
prefix=f"{prefix}.experts", prefix=f"{prefix}.experts",
e_score_correction_bias=self.gate.e_score_correction_bias) e_score_correction_bias=self.gate.e_score_correction_bias)
if self.moe_num_shared_experts is not None: if self.has_shared_experts:
intermediate_size = (config.moe_intermediate_size * intermediate_size = (config.moe_intermediate_size *
config.moe_num_shared_experts) config.moe_num_shared_experts)
self.shared_experts = Ernie4_5_MoeMLP( self.shared_experts = Ernie4_5_MoeMLP(
...@@ -153,7 +153,8 @@ class Ernie4_5_MoeMoE(nn.Module): ...@@ -153,7 +153,8 @@ class Ernie4_5_MoeMoE(nn.Module):
orig_shape = hidden_states.shape orig_shape = hidden_states.shape
hidden_dim = hidden_states.shape[-1] hidden_dim = hidden_states.shape[-1]
hidden_states = hidden_states.view(-1, hidden_dim) hidden_states = hidden_states.view(-1, hidden_dim)
if self.moe_num_shared_experts is not None: shared_output = None
if self.has_shared_experts:
shared_output = self.shared_experts(hidden_states) shared_output = self.shared_experts(hidden_states)
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
...@@ -161,7 +162,7 @@ class Ernie4_5_MoeMoE(nn.Module): ...@@ -161,7 +162,7 @@ class Ernie4_5_MoeMoE(nn.Module):
final_hidden_states = self.experts(hidden_states=hidden_states, final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits) router_logits=router_logits)
if self.moe_num_shared_experts is not None and \ if self.has_shared_experts and \
shared_output is not None: shared_output is not None:
final_hidden_states = final_hidden_states + shared_output final_hidden_states = final_hidden_states + shared_output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment