Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
4863ddcf
Commit
4863ddcf
authored
Jun 17, 2025
by
dongcl
Browse files
add tokens_per_expert to common_state
parent
6a579b17
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
47 additions
and
31 deletions
+47
-31
dcu_megatron/core/models/gpt/fine_grained_schedule.py
dcu_megatron/core/models/gpt/fine_grained_schedule.py
+11
-7
dcu_megatron/core/pipeline_parallel/schedules.py
dcu_megatron/core/pipeline_parallel/schedules.py
+4
-2
dcu_megatron/core/transformer/transformer_layer.py
dcu_megatron/core/transformer/transformer_layer.py
+16
-15
dcu_megatron/core/transformer/utils.py
dcu_megatron/core/transformer/utils.py
+16
-7
No files found.
dcu_megatron/core/models/gpt/fine_grained_schedule.py
View file @
4863ddcf
...
...
@@ -7,16 +7,15 @@ import torch
from
torch
import
Tensor
from
megatron.core.config_logger
import
has_config_logger_enabled
,
log_config_to_disk
from
megatron.core.inference.contexts
import
BaseInferenceContext
from
megatron.core.packed_seq_params
import
PackedSeqParams
from
megatron.core.transformer
import
transformer_layer
from
megatron.core.transformer.moe.moe_layer
import
MoELayer
from
megatron.core.utils
import
WrappedTensor
,
deprecate_inference_params
from
megatron.core.inference.contexts
import
BaseInferenceContext
from
dcu_megatron.core.transformer.moe.token_dispatcher
import
MoEAlltoAllPerBatchState
from
dcu_megatron.core.pipeline_parallel.combined_1f1b
import
(
AbstractSchedulePlan
,
FakeScheduleNode
,
FreeInputsMemoryStrategy
,
NoOpMemoryStrategy
,
ScheduleNode
,
get_com_stream
,
get_comp_stream
,
...
...
@@ -776,10 +775,12 @@ def build_model_chunk_schedule_plan(
attention_mask
:
Tensor
,
decoder_input
:
Tensor
=
None
,
labels
:
Tensor
=
None
,
inference_
params
=
None
,
inference_
context
:
BaseInferenceContext
=
None
,
packed_seq_params
=
None
,
extra_block_kwargs
=
None
,
runtime_gather_output
:
Optional
[
bool
]
=
None
,
inference_params
=
None
,
loss_mask
=
None
,
):
"""Builds a schedule plan for a model chunk.
...
...
@@ -797,6 +798,7 @@ def build_model_chunk_schedule_plan(
packed_seq_params: Parameters for packed sequences.
extra_block_kwargs: Additional keyword arguments for blocks.
runtime_gather_output: Whether to gather output at runtime.
loss_mask: Loss mask
Returns:
The model chunk schedule plan.
...
...
@@ -812,10 +814,12 @@ def build_model_chunk_schedule_plan(
state
.
attention_mask
=
attention_mask
state
.
decoder_input
=
decoder_input
state
.
labels
=
labels
state
.
inference_
params
=
inference_
params
state
.
inference_
context
=
inference_
context
state
.
packed_seq_params
=
packed_seq_params
state
.
extra_block_kwargs
=
extra_block_kwargs
state
.
runtime_gather_output
=
runtime_gather_output
state
.
inference_params
=
inference_params
state
.
loss_mask
=
loss_mask
state
.
context
=
None
state
.
context_mask
=
None
state
.
attention_bias
=
None
...
...
dcu_megatron/core/pipeline_parallel/schedules.py
View file @
4863ddcf
...
...
@@ -117,8 +117,10 @@ def forward_backward_pipelining_with_interleaving(
config
=
get_model_config
(
model
[
0
])
set_streams
()
if
not
forward_only
:
forward_step_func
=
wrap_forward_func
(
config
,
forward_step_func
)
if
config
.
combined_1f1b
and
not
forward_only
:
# in combined_1f1b, we need to wrap the forward_step_func
# to return a schedule plan instead of the forward output tensor
forward_step_func
=
wrap_forward_func
(
forward_step_func
)
if
config
.
overlap_p2p_comm
and
config
.
batch_p2p_comm
:
raise
ValueError
(
"Can not use both overlap_p2p_comm and batch_p2p_comm"
)
...
...
dcu_megatron/core/transformer/transformer_layer.py
View file @
4863ddcf
from
typing
import
Any
,
Optional
from
functools
import
partial
import
torch
from
torch
import
Tensor
from
megatron.training
import
get_args
from
megatron.core
import
tensor_parallel
,
parallel_state
from
megatron.core
import
parallel_state
from
megatron.core.packed_seq_params
import
PackedSeqParams
from
megatron.core.utils
import
(
deprecate_inference_params
,
make_viewless_tensor
,
nvtx_range_pop
,
nvtx_range_push
,
)
from
megatron.core.transformer.moe.moe_layer
import
MoELayer
from
megatron.core.utils
import
make_viewless_tensor
from
megatron.core.transformer.transformer_layer
import
TransformerLayer
as
MegatronCoreTransformerLayer
from
megatron.core.transformer.moe.token_dispatcher
import
MoEAlltoAllTokenDispatcher
from
megatron.core.transformer.transformer_config
import
TransformerConfig
from
dcu_megatron.core.transformer.utils
import
SubmoduleCallables
,
TransformerLayerSubmoduleCallables
def
get_transformer_layer_offset
(
config
:
TransformerConfig
,
vp_stage
:
Optional
[
int
]
=
None
):
"""Get the index offset of current pipeline stage, given the level of pipelining."""
...
...
@@ -244,29 +241,33 @@ class TransformerLayer(MegatronCoreTransformerLayer):
residual
,
context
,
):
node
.
common_state
.
tokens_per_expert
=
tokens_per_expert
node
.
common_state
.
residual
=
node
.
detach
(
residual
)
if
self
.
mlp
.
use_shared_expert
:
node
.
common_state
.
pre_mlp_layernorm_output
=
node
.
detach
(
pre_mlp_layernorm_output
)
return
tokens_per_expert
,
permutated_local_input_tokens
,
permuted_probs
return
permutated_local_input_tokens
,
permuted_probs
def
_submodule_dispatch_forward
(
self
,
tokens_per_expert
,
permutated_local_input_tokens
,
permuted_probs
,
state
=
None
):
def
_submodule_dispatch_forward
(
self
,
permutated_local_input_tokens
,
permuted_probs
,
state
=
None
):
"""
Dispatches tokens to the appropriate experts based on the router output.
"""
tokens_per_expert
=
state
.
tokens_per_expert
token_dispatcher
=
self
.
mlp
.
token_dispatcher
tokens_per_expert
,
global_input_tokens
,
global_probs
=
token_dispatcher
.
dispatch_all_to_all
(
tokens_per_expert
,
permutated_local_input_tokens
,
permuted_probs
)
return
tokens_per_expert
,
global_input_tokens
,
global_probs
def
_submodule_dispatch_postprocess
(
self
,
node
,
tokens_per_expert
,
global_input_tokens
,
global_probs
):
return
tokens_per_expert
,
global_input_tokens
,
global_probs
node
.
common_state
.
tokens_per_expert
=
tokens_per_expert
return
global_input_tokens
,
global_probs
def
_submodule_moe_forward
(
self
,
tokens_per_expert
,
global_input_tokens
,
global_probs
,
state
=
None
):
def
_submodule_moe_forward
(
self
,
global_input_tokens
,
global_probs
,
state
=
None
):
"""
Performs a forward pass for the MLP submodule, including both expert-based
and optional shared-expert computations.
"""
tokens_per_expert
=
state
.
tokens_per_expert
shared_expert_output
=
None
token_dispatcher
=
self
.
mlp
.
token_dispatcher
...
...
@@ -275,7 +276,7 @@ class TransformerLayer(MegatronCoreTransformerLayer):
)
expert_output
,
mlp_bias
=
self
.
mlp
.
experts
(
dispatched_
tokens
,
tokens_per_expert
,
permuted_probs
dispatched_
input
,
tokens_per_expert
,
permuted_probs
)
assert
mlp_bias
is
None
,
f
"Bias is not supported in
{
token_dispatcher
.
__class__
.
__name__
}
"
if
self
.
mlp
.
use_shared_expert
and
not
self
.
mlp
.
shared_expert_overlap
:
...
...
@@ -371,7 +372,7 @@ class TransformerLayer(MegatronCoreTransformerLayer):
return
attn_func
(
hidden_states
=
hidden_states
,
attention_mask
=
chunk_state
.
attention_mask
,
conte
n
t
=
chunk_state
.
context
,
conte
x
t
=
chunk_state
.
context
,
context_mask
=
chunk_state
.
context_mask
,
rotary_pos_emb
=
chunk_state
.
rotary_pos_emb
,
rotary_pos_cos
=
chunk_state
.
rotary_pos_cos
,
...
...
dcu_megatron/core/transformer/utils.py
View file @
4863ddcf
from
dataclasses
import
dataclass
from
typing
import
Callable
,
Optional
from
functools
import
partial
@
dataclass
class
SubmoduleCallables
:
...
...
@@ -9,10 +9,13 @@ class SubmoduleCallables:
for a particular submodule.
"""
forward
:
Optional
[
Callable
]
=
None
backward
:
Optional
[
Callable
]
=
None
dgrad
:
Optional
[
Callable
]
=
None
dw
:
Optional
[
Callable
]
=
None
def
raise_not_implemented
(
name
:
str
):
raise
NotImplementedError
(
f
"
{
name
}
not implemented."
)
forward
:
Optional
[
Callable
]
=
partial
(
raise_not_implemented
,
"forward"
)
dw
:
Optional
[
Callable
]
=
partial
(
raise_not_implemented
,
"dw"
)
is_moe
:
bool
=
False
is_deepep
:
bool
=
False
@
dataclass
...
...
@@ -26,7 +29,13 @@ class TransformerLayerSubmoduleCallables:
dispatch
:
SubmoduleCallables
mlp
:
SubmoduleCallables
combine
:
SubmoduleCallables
post_combine
:
SubmoduleCallables
is_moe
:
bool
=
False
is_deepep
:
bool
=
False
def
as_array
(
self
):
return
[
self
.
attention
,
self
.
dispatch
,
self
.
mlp
,
self
.
combine
,
self
.
post_combine
]
return
[
self
.
attention
,
self
.
dispatch
,
self
.
mlp
,
self
.
combine
]
def
__post_init__
(
self
):
for
submodule
in
self
.
as_array
():
submodule
.
is_moe
=
self
.
is_moe
submodule
.
is_deepep
=
self
.
is_deepep
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment