Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
040838a0
Commit
040838a0
authored
Jun 10, 2025
by
dongcl
Browse files
moe a2a overlap support self-attention; fix bug when use_shared_expert is false
parent
2a0c4358
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
23 additions
and
6 deletions
+23
-6
dcu_megatron/adaptor/features_manager.py
dcu_megatron/adaptor/features_manager.py
+4
-0
dcu_megatron/core/models/gpt/fine_grained_schedule.py
dcu_megatron/core/models/gpt/fine_grained_schedule.py
+9
-3
dcu_megatron/core/pipeline_parallel/combined_1f1b.py
dcu_megatron/core/pipeline_parallel/combined_1f1b.py
+4
-2
dcu_megatron/core/transformer/attention.py
dcu_megatron/core/transformer/attention.py
+4
-0
dcu_megatron/core/transformer/moe/moe_layer.py
dcu_megatron/core/transformer/moe/moe_layer.py
+2
-1
No files found.
dcu_megatron/adaptor/features_manager.py
View file @
040838a0
...
...
@@ -16,6 +16,7 @@ def a2a_overlap_adaptation(patches_manager):
TELayerNormColumnParallelLinear
,
)
from
..core.transformer.multi_latent_attention
import
MLASelfAttention
from
..core.transformer.attention
import
SelfAttention
from
..core.transformer.mlp
import
MLP
from
..core.transformer.moe.experts
import
TEGroupedMLP
from
..core.transformer.moe.moe_layer
import
MoELayer
...
...
@@ -61,6 +62,9 @@ def a2a_overlap_adaptation(patches_manager):
patches_manager
.
register_patch
(
'megatron.core.transformer.multi_latent_attention.MLASelfAttention.backward_dw'
,
MLASelfAttention
.
backward_dw
,
create_dummy
=
True
)
patches_manager
.
register_patch
(
'megatron.core.transformer.attention.SelfAttention.backward_dw'
,
SelfAttention
.
backward_dw
,
create_dummy
=
True
)
patches_manager
.
register_patch
(
'megatron.core.transformer.mlp.MLP.backward_dw'
,
MLP
.
backward_dw
,
create_dummy
=
True
)
...
...
dcu_megatron/core/models/gpt/fine_grained_schedule.py
View file @
040838a0
...
...
@@ -307,7 +307,8 @@ class MoeAttnNode(TransformerLayerNode):
# detached here
self
.
common_state
.
probs
=
self
.
detach
(
probs
)
self
.
common_state
.
residual
=
self
.
detach
(
hidden_states
)
self
.
common_state
.
pre_mlp_layernorm_output
=
self
.
detach
(
pre_mlp_layernorm_output
)
if
self
.
layer
.
mlp
.
use_shared_expert
:
self
.
common_state
.
pre_mlp_layernorm_output
=
self
.
detach
(
pre_mlp_layernorm_output
)
return
permutated_local_input_tokens
...
...
@@ -333,7 +334,10 @@ class MoeDispatchNode(TransformerLayerNode):
class
MoeMlPNode
(
TransformerLayerNode
):
def
forward_impl
(
self
,
global_input_tokens
):
pre_mlp_layernorm_output
=
self
.
common_state
.
pre_mlp_layernorm_output
if
self
.
layer
.
mlp
.
use_shared_expert
:
pre_mlp_layernorm_output
=
self
.
common_state
.
pre_mlp_layernorm_output
else
:
pre_mlp_layernorm_output
=
None
token_dispatcher
=
self
.
layer
.
mlp
.
token_dispatcher
with
token_dispatcher
.
per_batch_state_context
(
self
.
common_state
):
expert_output
,
shared_expert_output
,
mlp_bias
=
self
.
layer
.
_submodule_moe_forward
(
...
...
@@ -343,6 +347,8 @@ class MoeMlPNode(TransformerLayerNode):
# pre_mlp_layernorm_output used
self
.
common_state
.
pre_mlp_layernorm_output
=
None
if
shared_expert_output
is
None
:
return
expert_output
return
expert_output
,
shared_expert_output
def
dw
(
self
):
...
...
@@ -351,7 +357,7 @@ class MoeMlPNode(TransformerLayerNode):
class
MoeCombineNode
(
TransformerLayerNode
):
def
forward_impl
(
self
,
expert_output
,
shared_expert_output
):
def
forward_impl
(
self
,
expert_output
,
shared_expert_output
=
None
):
# TODO(lhb): if dw use grad of residual and probs, necessary synchronization should be add
residual
=
self
.
common_state
.
residual
token_dispatcher
=
self
.
layer
.
mlp
.
token_dispatcher
...
...
dcu_megatron/core/pipeline_parallel/combined_1f1b.py
View file @
040838a0
...
...
@@ -67,6 +67,7 @@ class ScheduleNode:
allow_unreachable
=
True
,
accumulate_grad
=
True
,
)
return
output_grad
def
forward
(
self
,
inputs
=
(),
stream_wait_event
=
None
,
stream_record_event
=
None
):
...
...
@@ -105,8 +106,9 @@ class ScheduleNode:
if
self
.
free_inputs
:
for
input
in
inputs
:
input
.
record_stream
(
self
.
stream
)
input
.
untyped_storage
().
resize_
(
0
)
if
input
is
not
None
:
input
.
record_stream
(
self
.
stream
)
input
.
untyped_storage
().
resize_
(
0
)
return
self
.
output
...
...
dcu_megatron/core/transformer/attention.py
0 → 100644
View file @
040838a0
class
SelfAttention
():
def
backward_dw
(
self
):
self
.
linear_qkv
.
backward_dw
()
self
.
linear_proj
.
backward_dw
()
dcu_megatron/core/transformer/moe/moe_layer.py
View file @
040838a0
class
MoELayer
():
def
backward_dw
(
self
):
self
.
experts
.
backward_dw
()
self
.
shared_experts
.
backward_dw
()
if
self
.
use_shared_expert
and
not
self
.
shared_expert_overlap
:
self
.
shared_experts
.
backward_dw
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment