Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
1d497357
Commit
1d497357
authored
Jun 19, 2025
by
dongcl
Browse files
split routed experts and shared experts
parent
2ceeaafd
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
121 additions
and
26 deletions
+121
-26
dcu_megatron/adaptor/features_manager.py
dcu_megatron/adaptor/features_manager.py
+6
-0
dcu_megatron/core/models/gpt/fine_grained_schedule.py
dcu_megatron/core/models/gpt/fine_grained_schedule.py
+78
-7
dcu_megatron/core/transformer/moe/moe_layer.py
dcu_megatron/core/transformer/moe/moe_layer.py
+7
-0
dcu_megatron/core/transformer/moe/token_dispatcher.py
dcu_megatron/core/transformer/moe/token_dispatcher.py
+2
-6
dcu_megatron/core/transformer/transformer_layer.py
dcu_megatron/core/transformer/transformer_layer.py
+28
-13
No files found.
dcu_megatron/adaptor/features_manager.py
View file @
1d497357
...
...
@@ -74,3 +74,9 @@ def a2a_overlap_adaptation(patches_manager):
patches_manager
.
register_patch
(
'megatron.core.transformer.moe.moe_layer.MoELayer.backward_dw'
,
MoELayer
.
backward_dw
,
create_dummy
=
True
)
patches_manager
.
register_patch
(
'megatron.core.transformer.moe.moe_layer.MoELayer.backward_routed_expert_dw'
,
MoELayer
.
backward_routed_expert_dw
,
create_dummy
=
True
)
patches_manager
.
register_patch
(
'megatron.core.transformer.moe.moe_layer.MoELayer.backward_shared_expert_dw'
,
MoELayer
.
backward_shared_expert_dw
,
create_dummy
=
True
)
dcu_megatron/core/models/gpt/fine_grained_schedule.py
View file @
1d497357
...
...
@@ -356,10 +356,54 @@ class MoeMlPNode(TransformerLayerNode):
self
.
layer
.
_submodule_mlp_dw
()
class
MoeSharedExpertNode
(
TransformerLayerNode
):
def
forward_impl
(
self
):
if
self
.
layer
.
mlp
.
use_shared_expert
:
pre_mlp_layernorm_output
=
self
.
common_state
.
pre_mlp_layernorm_output
else
:
pre_mlp_layernorm_output
=
None
token_dispatcher
=
self
.
layer
.
mlp
.
token_dispatcher
with
token_dispatcher
.
per_batch_state_context
(
self
.
common_state
):
shared_expert_output
=
self
.
layer
.
_submodule_shared_expert_forward
(
pre_mlp_layernorm_output
)
# pre_mlp_layernorm_output used
self
.
common_state
.
pre_mlp_layernorm_output
=
None
return
shared_expert_output
# self.common_state.shared_expert_output = self.detach(shared_expert_output)
def
dw
(
self
):
with
torch
.
cuda
.
nvtx
.
range
(
f
"
{
self
.
name
}
wgrad"
):
self
.
layer
.
_submodule_shared_expert_dw
()
class
MoeRoutedExpertNode
(
TransformerLayerNode
):
def
forward_impl
(
self
,
global_input_tokens
):
token_dispatcher
=
self
.
layer
.
mlp
.
token_dispatcher
with
token_dispatcher
.
per_batch_state_context
(
self
.
common_state
):
expert_output
,
mlp_bias
=
self
.
layer
.
_submodule_routed_expert_forward
(
self
.
common_state
.
tokens_per_expert
,
global_input_tokens
)
assert
mlp_bias
is
None
return
expert_output
def
dw
(
self
):
with
torch
.
cuda
.
nvtx
.
range
(
f
"
{
self
.
name
}
wgrad"
):
self
.
layer
.
_submodule_routed_expert_dw
()
class
MoeCombineNode
(
TransformerLayerNode
):
def
forward_impl
(
self
,
expert_output
,
shared_expert_output
=
None
):
# TODO(lhb): if dw use grad of residual and probs, necessary synchronization should be add
residual
=
self
.
common_state
.
residual
# shared_expert_output = None
# if self.layer.mlp.use_shared_expert:
# shared_expert_output = self.common_state.shared_expert_output
token_dispatcher
=
self
.
layer
.
mlp
.
token_dispatcher
with
token_dispatcher
.
per_batch_state_context
(
self
.
common_state
):
permutated_local_input_tokens
=
token_dispatcher
.
combine_all_to_all
(
...
...
@@ -371,8 +415,10 @@ class MoeCombineNode(TransformerLayerNode):
cur_stream
=
torch
.
cuda
.
current_stream
()
self
.
common_state
.
residual
.
record_stream
(
cur_stream
)
self
.
common_state
.
probs
.
record_stream
(
cur_stream
)
# self.common_state.shared_expert_output.record_stream(cur_stream)
self
.
common_state
.
residual
=
None
self
.
common_state
.
probs
=
None
# self.common_state.shared_expert_output = None
return
output
...
...
@@ -443,13 +489,22 @@ def build_layer_schedule_plan(layer, event, chunk_state, comp_stream, com_stream
common_state
=
TransformerLayerState
()
attn
=
MoeAttnNode
(
chunk_state
,
common_state
,
layer
,
comp_stream
,
event
)
attn
.
name
=
"attn"
dispatch
=
MoeDispatchNode
(
chunk_state
,
common_state
,
layer
,
com_stream
,
event
,
True
)
dispatch
.
name
=
"dispatch"
mlp
=
MoeMlPNode
(
chunk_state
,
common_state
,
layer
,
comp_stream
,
event
,
True
)
mlp
.
name
=
"mlp"
routed_expert
=
MoeRoutedExpertNode
(
chunk_state
,
common_state
,
layer
,
comp_stream
,
event
,
True
)
routed_expert
.
name
=
"routed_expert"
shared_expert
=
MoeSharedExpertNode
(
chunk_state
,
common_state
,
layer
,
comp_stream
,
event
,
True
)
shared_expert
.
name
=
"shared_expert"
combine
=
MoeCombineNode
(
chunk_state
,
common_state
,
layer
,
com_stream
,
event
,
True
)
combine
.
name
=
"combine"
return
TransformerLayerSchedulePlan
(
attn
,
dispatch
,
mlp
,
combine
)
return
TransformerLayerSchedulePlan
(
attn
,
dispatch
,
mlp
,
combine
,
shared_expert
=
shared_expert
,
routed_expert
=
routed_expert
)
class
TransformerLayerState
(
MoEAlltoAllPerBatchState
):
...
...
@@ -462,11 +517,13 @@ class ModelChunkSate:
class
TransformerLayerSchedulePlan
:
def
__init__
(
self
,
attn
,
dispatch
,
mlp
,
combine
):
def
__init__
(
self
,
attn
,
dispatch
,
mlp
,
combine
,
shared_expert
=
None
,
routed_expert
=
None
):
self
.
attn
=
attn
self
.
dispatch
=
dispatch
self
.
mlp
=
mlp
self
.
combine
=
combine
self
.
shared_expert
=
shared_expert
self
.
routed_expert
=
routed_expert
class
ModelChunkSchedulePlan
(
AbstractSchedulePlan
):
...
...
@@ -577,7 +634,7 @@ def schedule_layer_1f1b(
if
b_layer
is
not
None
:
with
b_context
:
b
_grad
=
b_layer
.
combine
.
backward
(
b_grad
)
routed_expert_output_grad
,
shared_expert_output
_grad
=
b_layer
.
combine
.
backward
(
b_grad
)
if
pre_backward_dw
is
not
None
:
pre_backward_dw
()
...
...
@@ -593,22 +650,36 @@ def schedule_layer_1f1b(
if
f_layer
is
not
None
:
with
f_context
:
shared_expert_output
=
f_layer
.
shared_expert
.
forward
()
f_input
=
f_layer
.
dispatch
.
forward
(
f_input
,
stream_record_event
=
f_dispatch_b_mlp_sync_event
)
# if f_layer is not None:
# with f_context:
# f_input = f_layer.dispatch.forward(f_input, stream_record_event=f_dispatch_b_mlp_sync_event)
if
b_layer
is
not
None
:
with
b_context
:
b_grad
=
b_layer
.
mlp
.
backward
(
b_grad
,
stream_wait_event
=
f_dispatch_b_mlp_sync_event
)
# routed_expert_output_grad, shared_expert_output_grad = b_grad
b_grad
=
b_layer
.
routed_expert
.
backward
(
routed_expert_output_grad
,
stream_wait_event
=
f_dispatch_b_mlp_sync_event
)
b_layer
.
shared_expert
.
backward
(
shared_expert_output_grad
)
b_grad
=
b_layer
.
dispatch
.
backward
(
b_grad
)
b_layer
.
mlp
.
dw
()
b_layer
.
routed_expert
.
dw
()
if
f_layer
is
not
None
:
with
f_context
:
f_input
=
f_layer
.
mlp
.
forward
(
f_input
)
f_input
=
f_layer
.
routed_expert
.
forward
(
f_input
)
# if b_layer is not None:
# with b_context:
# # b_grad = b_layer.dispatch.backward(b_grad)
# b_layer.shared_expert.backward(shared_expert_output_grad)
# b_layer.routed_expert.dw()
def
next_iter_pre_forward
():
if
f_layer
is
not
None
:
with
f_context
:
output
=
f_layer
.
combine
.
forward
(
f_input
)
output
=
f_layer
.
combine
.
forward
(
(
f_input
,
shared_expert_output
)
)
return
output
def
next_iter_pre_backward
():
...
...
dcu_megatron/core/transformer/moe/moe_layer.py
View file @
1d497357
...
...
@@ -3,3 +3,10 @@ class MoELayer():
self
.
experts
.
backward_dw
()
if
self
.
use_shared_expert
and
not
self
.
shared_expert_overlap
:
self
.
shared_experts
.
backward_dw
()
def
backward_routed_expert_dw
(
self
):
self
.
experts
.
backward_dw
()
def
backward_shared_expert_dw
(
self
):
if
self
.
use_shared_expert
and
not
self
.
shared_expert_overlap
:
self
.
shared_experts
.
backward_dw
()
dcu_megatron/core/transformer/moe/token_dispatcher.py
View file @
1d497357
...
...
@@ -91,7 +91,7 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
self
.
collect_per_batch_state
(
state
)
self
.
apply_per_batch_state
(
origin_state
)
def
meta_prepare
(
def
dispatch_preprocess
(
self
,
hidden_states
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
,
routing_map
:
torch
.
Tensor
):
self
.
hidden_shape
=
hidden_states
.
shape
...
...
@@ -103,9 +103,6 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
tokens_per_expert
=
self
.
preprocess
(
self
.
routing_map
)
return
tokens_per_expert
def
dispatch_preprocess
(
self
,
hidden_states
:
torch
.
Tensor
,
routing_map
:
torch
.
Tensor
,
tokens_per_expert
:
torch
.
Tensor
):
hidden_states
=
hidden_states
.
view
(
-
1
,
self
.
hidden_shape
[
-
1
])
if
self
.
shared_experts
is
not
None
:
self
.
shared_experts
.
pre_forward_comm
(
hidden_states
.
view
(
self
.
hidden_shape
))
...
...
@@ -206,8 +203,7 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
"""
# Preprocess: Get the metadata for communication, permutation and computation operations.
# Permutation 1: input to AlltoAll input
tokens_per_expert
=
self
.
meta_prepare
(
hidden_states
,
probs
,
routing_map
)
tokens_per_expert
,
permutated_local_input_tokens
=
self
.
dispatch_preprocess
(
hidden_states
,
routing_map
,
tokens_per_expert
)
tokens_per_expert
,
permutated_local_input_tokens
=
self
.
dispatch_preprocess
(
hidden_states
,
probs
,
routing_map
)
# Perform expert parallel AlltoAll communication
tokens_per_expert
,
global_input_tokens
=
self
.
dispatch_all_to_all
(
tokens_per_expert
,
permutated_local_input_tokens
)
...
...
dcu_megatron/core/transformer/transformer_layer.py
View file @
1d497357
...
...
@@ -189,11 +189,8 @@ class TransformerLayer(MegatronCoreTransformerLayer):
pre_mlp_layernorm_output
=
self
.
pre_mlp_layernorm
(
hidden_states
)
probs
,
routing_map
=
self
.
mlp
.
router
(
pre_mlp_layernorm_output
)
tokens_per_expert
=
self
.
mlp
.
token_dispatcher
.
meta_prepare
(
pre_mlp_layernorm_output
,
probs
,
routing_map
)
tokens_per_expert
,
permutated_local_input_tokens
=
self
.
mlp
.
token_dispatcher
.
dispatch_preprocess
(
pre_mlp_layernorm_output
,
routing_map
,
tokens_per_expert
pre_mlp_layernorm_output
,
probs
,
routing_map
)
outputs
=
[
...
...
@@ -205,15 +202,6 @@ class TransformerLayer(MegatronCoreTransformerLayer):
]
return
tuple
(
outputs
)
def
_submodule_shared_expert_forward
(
self
,
pre_mlp_layernorm_output
):
"""
Performs a forward pass for shared experts.
"""
shared_expert_output
=
None
if
self
.
mlp
.
use_shared_expert
and
not
self
.
mlp
.
shared_expert_overlap
:
shared_expert_output
=
self
.
mlp
.
shared_experts
(
pre_mlp_layernorm_output
)
return
shared_expert_output
def
_submodule_dispatch_forward
(
self
,
tokens_per_expert
,
permutated_local_input_tokens
):
"""
Dispatches tokens to the appropriate experts based on the router output.
...
...
@@ -253,6 +241,27 @@ class TransformerLayer(MegatronCoreTransformerLayer):
expert_output
=
self
.
mlp
.
token_dispatcher
.
combine_preprocess
(
expert_output
)
return
expert_output
,
shared_expert_output
,
mlp_bias
def
_submodule_shared_expert_forward
(
self
,
pre_mlp_layernorm_output
):
"""
Performs a forward pass for shared experts.
"""
shared_expert_output
=
None
if
self
.
mlp
.
use_shared_expert
and
not
self
.
mlp
.
shared_expert_overlap
:
shared_expert_output
=
self
.
mlp
.
shared_experts
(
pre_mlp_layernorm_output
)
return
shared_expert_output
def
_submodule_routed_expert_forward
(
self
,
tokens_per_expert
,
global_input_tokens
):
"""
Performs a forward pass for the MLP submodule, including both expert-based
and optional shared-expert computations.
"""
(
dispatched_input
,
tokens_per_expert
)
=
(
self
.
mlp
.
token_dispatcher
.
dispatch_postprocess
(
tokens_per_expert
,
global_input_tokens
)
)
expert_output
,
mlp_bias
=
self
.
mlp
.
experts
(
dispatched_input
,
tokens_per_expert
)
expert_output
=
self
.
mlp
.
token_dispatcher
.
combine_preprocess
(
expert_output
)
return
expert_output
,
mlp_bias
def
_submodule_combine_forward
(
self
,
hidden_states
):
return
[
self
.
mlp
.
token_dispatcher
.
combine_all_to_all
(
hidden_states
)]
...
...
@@ -295,3 +304,9 @@ class TransformerLayer(MegatronCoreTransformerLayer):
def
_submodule_mlp_dw
(
self
):
self
.
mlp
.
backward_dw
()
def
_submodule_routed_expert_dw
(
self
):
self
.
mlp
.
backward_routed_expert_dw
()
def
_submodule_shared_expert_dw
(
self
):
self
.
mlp
.
backward_shared_expert_dw
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment