Commit 2ceeaafd authored by dongcl's avatar dongcl
Browse files

fix bug caused by mtp > 1 when using moe a2a overlap

parent 040838a0
...@@ -205,6 +205,15 @@ class TransformerLayer(MegatronCoreTransformerLayer): ...@@ -205,6 +205,15 @@ class TransformerLayer(MegatronCoreTransformerLayer):
] ]
return tuple(outputs) return tuple(outputs)
def _submodule_shared_expert_forward(self, pre_mlp_layernorm_output):
"""
Performs a forward pass for shared experts.
"""
shared_expert_output = None
if self.mlp.use_shared_expert and not self.mlp.shared_expert_overlap:
shared_expert_output = self.mlp.shared_experts(pre_mlp_layernorm_output)
return shared_expert_output
def _submodule_dispatch_forward(self, tokens_per_expert, permutated_local_input_tokens): def _submodule_dispatch_forward(self, tokens_per_expert, permutated_local_input_tokens):
""" """
Dispatches tokens to the appropriate experts based on the router output. Dispatches tokens to the appropriate experts based on the router output.
...@@ -234,13 +243,14 @@ class TransformerLayer(MegatronCoreTransformerLayer): ...@@ -234,13 +243,14 @@ class TransformerLayer(MegatronCoreTransformerLayer):
and optional shared-expert computations. and optional shared-expert computations.
""" """
shared_expert_output = None shared_expert_output = None
if self.mlp.use_shared_expert and not self.mlp.shared_expert_overlap:
shared_expert_output = self.mlp.shared_experts(pre_mlp_layernorm_output)
(dispatched_input, tokens_per_expert) = ( (dispatched_input, tokens_per_expert) = (
self.mlp.token_dispatcher.dispatch_postprocess(tokens_per_expert, global_input_tokens) self.mlp.token_dispatcher.dispatch_postprocess(tokens_per_expert, global_input_tokens)
) )
expert_output, mlp_bias = self.mlp.experts(dispatched_input, tokens_per_expert) expert_output, mlp_bias = self.mlp.experts(dispatched_input, tokens_per_expert)
expert_output = self.mlp.token_dispatcher.combine_preprocess(expert_output) expert_output = self.mlp.token_dispatcher.combine_preprocess(expert_output)
if self.mlp.use_shared_expert and not self.mlp.shared_expert_overlap:
shared_expert_output = self.mlp.shared_experts(pre_mlp_layernorm_output)
return expert_output, shared_expert_output, mlp_bias return expert_output, shared_expert_output, mlp_bias
def _submodule_combine_forward(self, hidden_states): def _submodule_combine_forward(self, hidden_states):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment