"tests/distributed/synced_batchnorm/unit_test.sh" did not exist on "e12c1ec300c6b7369e17ba733996eacefae462a9"
Commit 7a088067 authored by dongcl's avatar dongcl
Browse files

1f1b overlap only supports MoEAlltoAllTokenDispatcher

parent 6dcd0fb8
......@@ -12,8 +12,7 @@ from megatron.core.utils import (
)
from megatron.core.transformer.moe.moe_layer import MoELayer
from megatron.core.transformer.transformer_layer import TransformerLayer as MegatronCoreTransformerLayer
from dcu_megatron.core.transformer.utils import SubmoduleCallables, TransformerLayerSubmoduleCallables
from megatron.core.transformer.moe.token_dispatcher import MoEAlltoAllTokenDispatcher
class TransformerLayer(MegatronCoreTransformerLayer):
......@@ -34,7 +33,10 @@ class TransformerLayer(MegatronCoreTransformerLayer):
inference_params: Optional[Any] = None,
):
if not isinstance(self.mlp, MoELayer):
if (
not isinstance(self.mlp, MoELayer)
or not isinstance(self.mlp.token_dispatcher, MoEAlltoAllTokenDispatcher)
):
return super().forward(
hidden_states=hidden_states,
context=context,
......@@ -55,7 +57,7 @@ class TransformerLayer(MegatronCoreTransformerLayer):
pre_mlp_layernorm_output,
tokens_per_expert,
permutated_local_input_tokens,
probs,
_,
) = self._submodule_attention_router_compound_forward(
hidden_states,
attention_mask,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment