Commit 040838a0 authored by dongcl's avatar dongcl
Browse files

moe a2a overlap support self-attention; fix bug when use_shared_expert is false

parent 2a0c4358
......@@ -16,6 +16,7 @@ def a2a_overlap_adaptation(patches_manager):
TELayerNormColumnParallelLinear,
)
from ..core.transformer.multi_latent_attention import MLASelfAttention
from ..core.transformer.attention import SelfAttention
from ..core.transformer.mlp import MLP
from ..core.transformer.moe.experts import TEGroupedMLP
from ..core.transformer.moe.moe_layer import MoELayer
......@@ -61,6 +62,9 @@ def a2a_overlap_adaptation(patches_manager):
patches_manager.register_patch('megatron.core.transformer.multi_latent_attention.MLASelfAttention.backward_dw',
MLASelfAttention.backward_dw,
create_dummy=True)
patches_manager.register_patch('megatron.core.transformer.attention.SelfAttention.backward_dw',
SelfAttention.backward_dw,
create_dummy=True)
patches_manager.register_patch('megatron.core.transformer.mlp.MLP.backward_dw',
MLP.backward_dw,
create_dummy=True)
......
......@@ -307,7 +307,8 @@ class MoeAttnNode(TransformerLayerNode):
# detached here
self.common_state.probs = self.detach(probs)
self.common_state.residual = self.detach(hidden_states)
self.common_state.pre_mlp_layernorm_output = self.detach(pre_mlp_layernorm_output)
if self.layer.mlp.use_shared_expert:
self.common_state.pre_mlp_layernorm_output = self.detach(pre_mlp_layernorm_output)
return permutated_local_input_tokens
......@@ -333,7 +334,10 @@ class MoeDispatchNode(TransformerLayerNode):
class MoeMlPNode(TransformerLayerNode):
def forward_impl(self, global_input_tokens):
pre_mlp_layernorm_output = self.common_state.pre_mlp_layernorm_output
if self.layer.mlp.use_shared_expert:
pre_mlp_layernorm_output = self.common_state.pre_mlp_layernorm_output
else:
pre_mlp_layernorm_output = None
token_dispatcher = self.layer.mlp.token_dispatcher
with token_dispatcher.per_batch_state_context(self.common_state):
expert_output, shared_expert_output, mlp_bias = self.layer._submodule_moe_forward(
......@@ -343,6 +347,8 @@ class MoeMlPNode(TransformerLayerNode):
# pre_mlp_layernorm_output used
self.common_state.pre_mlp_layernorm_output = None
if shared_expert_output is None:
return expert_output
return expert_output, shared_expert_output
def dw(self):
......@@ -351,7 +357,7 @@ class MoeMlPNode(TransformerLayerNode):
class MoeCombineNode(TransformerLayerNode):
def forward_impl(self, expert_output, shared_expert_output):
def forward_impl(self, expert_output, shared_expert_output=None):
# TODO(lhb): if dw use grad of residual and probs, necessary synchronization should be add
residual = self.common_state.residual
token_dispatcher = self.layer.mlp.token_dispatcher
......
......@@ -67,6 +67,7 @@ class ScheduleNode:
allow_unreachable=True,
accumulate_grad=True,
)
return output_grad
def forward(self, inputs=(), stream_wait_event=None, stream_record_event=None):
......@@ -105,8 +106,9 @@ class ScheduleNode:
if self.free_inputs:
for input in inputs:
input.record_stream(self.stream)
input.untyped_storage().resize_(0)
if input is not None:
input.record_stream(self.stream)
input.untyped_storage().resize_(0)
return self.output
......
class SelfAttention():
def backward_dw(self):
self.linear_qkv.backward_dw()
self.linear_proj.backward_dw()
class MoELayer():
def backward_dw(self):
self.experts.backward_dw()
self.shared_experts.backward_dw()
if self.use_shared_expert and not self.shared_expert_overlap:
self.shared_experts.backward_dw()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment