Commit 040838a0 authored by dongcl's avatar dongcl
Browse files

moe a2a overlap support self-attention; fix bug when use_shared_expert is false

parent 2a0c4358
...@@ -16,6 +16,7 @@ def a2a_overlap_adaptation(patches_manager): ...@@ -16,6 +16,7 @@ def a2a_overlap_adaptation(patches_manager):
TELayerNormColumnParallelLinear, TELayerNormColumnParallelLinear,
) )
from ..core.transformer.multi_latent_attention import MLASelfAttention from ..core.transformer.multi_latent_attention import MLASelfAttention
from ..core.transformer.attention import SelfAttention
from ..core.transformer.mlp import MLP from ..core.transformer.mlp import MLP
from ..core.transformer.moe.experts import TEGroupedMLP from ..core.transformer.moe.experts import TEGroupedMLP
from ..core.transformer.moe.moe_layer import MoELayer from ..core.transformer.moe.moe_layer import MoELayer
...@@ -61,6 +62,9 @@ def a2a_overlap_adaptation(patches_manager): ...@@ -61,6 +62,9 @@ def a2a_overlap_adaptation(patches_manager):
patches_manager.register_patch('megatron.core.transformer.multi_latent_attention.MLASelfAttention.backward_dw', patches_manager.register_patch('megatron.core.transformer.multi_latent_attention.MLASelfAttention.backward_dw',
MLASelfAttention.backward_dw, MLASelfAttention.backward_dw,
create_dummy=True) create_dummy=True)
patches_manager.register_patch('megatron.core.transformer.attention.SelfAttention.backward_dw',
SelfAttention.backward_dw,
create_dummy=True)
patches_manager.register_patch('megatron.core.transformer.mlp.MLP.backward_dw', patches_manager.register_patch('megatron.core.transformer.mlp.MLP.backward_dw',
MLP.backward_dw, MLP.backward_dw,
create_dummy=True) create_dummy=True)
......
...@@ -307,6 +307,7 @@ class MoeAttnNode(TransformerLayerNode): ...@@ -307,6 +307,7 @@ class MoeAttnNode(TransformerLayerNode):
# detached here # detached here
self.common_state.probs = self.detach(probs) self.common_state.probs = self.detach(probs)
self.common_state.residual = self.detach(hidden_states) self.common_state.residual = self.detach(hidden_states)
if self.layer.mlp.use_shared_expert:
self.common_state.pre_mlp_layernorm_output = self.detach(pre_mlp_layernorm_output) self.common_state.pre_mlp_layernorm_output = self.detach(pre_mlp_layernorm_output)
return permutated_local_input_tokens return permutated_local_input_tokens
...@@ -333,7 +334,10 @@ class MoeDispatchNode(TransformerLayerNode): ...@@ -333,7 +334,10 @@ class MoeDispatchNode(TransformerLayerNode):
class MoeMlPNode(TransformerLayerNode): class MoeMlPNode(TransformerLayerNode):
def forward_impl(self, global_input_tokens): def forward_impl(self, global_input_tokens):
if self.layer.mlp.use_shared_expert:
pre_mlp_layernorm_output = self.common_state.pre_mlp_layernorm_output pre_mlp_layernorm_output = self.common_state.pre_mlp_layernorm_output
else:
pre_mlp_layernorm_output = None
token_dispatcher = self.layer.mlp.token_dispatcher token_dispatcher = self.layer.mlp.token_dispatcher
with token_dispatcher.per_batch_state_context(self.common_state): with token_dispatcher.per_batch_state_context(self.common_state):
expert_output, shared_expert_output, mlp_bias = self.layer._submodule_moe_forward( expert_output, shared_expert_output, mlp_bias = self.layer._submodule_moe_forward(
...@@ -343,6 +347,8 @@ class MoeMlPNode(TransformerLayerNode): ...@@ -343,6 +347,8 @@ class MoeMlPNode(TransformerLayerNode):
# pre_mlp_layernorm_output used # pre_mlp_layernorm_output used
self.common_state.pre_mlp_layernorm_output = None self.common_state.pre_mlp_layernorm_output = None
if shared_expert_output is None:
return expert_output
return expert_output, shared_expert_output return expert_output, shared_expert_output
def dw(self): def dw(self):
...@@ -351,7 +357,7 @@ class MoeMlPNode(TransformerLayerNode): ...@@ -351,7 +357,7 @@ class MoeMlPNode(TransformerLayerNode):
class MoeCombineNode(TransformerLayerNode): class MoeCombineNode(TransformerLayerNode):
def forward_impl(self, expert_output, shared_expert_output): def forward_impl(self, expert_output, shared_expert_output=None):
# TODO(lhb): if dw use grad of residual and probs, necessary synchronization should be add # TODO(lhb): if dw use grad of residual and probs, necessary synchronization should be add
residual = self.common_state.residual residual = self.common_state.residual
token_dispatcher = self.layer.mlp.token_dispatcher token_dispatcher = self.layer.mlp.token_dispatcher
......
...@@ -67,6 +67,7 @@ class ScheduleNode: ...@@ -67,6 +67,7 @@ class ScheduleNode:
allow_unreachable=True, allow_unreachable=True,
accumulate_grad=True, accumulate_grad=True,
) )
return output_grad return output_grad
def forward(self, inputs=(), stream_wait_event=None, stream_record_event=None): def forward(self, inputs=(), stream_wait_event=None, stream_record_event=None):
...@@ -105,6 +106,7 @@ class ScheduleNode: ...@@ -105,6 +106,7 @@ class ScheduleNode:
if self.free_inputs: if self.free_inputs:
for input in inputs: for input in inputs:
if input is not None:
input.record_stream(self.stream) input.record_stream(self.stream)
input.untyped_storage().resize_(0) input.untyped_storage().resize_(0)
......
class SelfAttention():
def backward_dw(self):
self.linear_qkv.backward_dw()
self.linear_proj.backward_dw()
class MoELayer(): class MoELayer():
def backward_dw(self): def backward_dw(self):
self.experts.backward_dw() self.experts.backward_dw()
if self.use_shared_expert and not self.shared_expert_overlap:
self.shared_experts.backward_dw() self.shared_experts.backward_dw()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment