Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
ab2a8334
Commit
ab2a8334
authored
May 16, 2025
by
dongcl
Browse files
modify MoEAlltoAllPerBatchState, add tokens_per_expert attr
parent
5890bb4c
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
14 additions
and
9 deletions
+14
-9
dcu_megatron/core/models/gpt/fine_grained_schedule.py
dcu_megatron/core/models/gpt/fine_grained_schedule.py
+4
-3
dcu_megatron/core/pipeline_parallel/combined_1f1b.py
dcu_megatron/core/pipeline_parallel/combined_1f1b.py
+4
-1
dcu_megatron/core/transformer/moe/token_dispatcher.py
dcu_megatron/core/transformer/moe/token_dispatcher.py
+4
-3
dcu_megatron/core/transformer/transformer_layer.py
dcu_megatron/core/transformer/transformer_layer.py
+2
-2
No files found.
dcu_megatron/core/models/gpt/fine_grained_schedule.py
View file @
ab2a8334
...
@@ -6,6 +6,7 @@ from typing import Optional
...
@@ -6,6 +6,7 @@ from typing import Optional
import
torch
import
torch
from
torch
import
Tensor
from
torch
import
Tensor
from
megatron.core
import
parallel_state
from
megatron.core.config_logger
import
has_config_logger_enabled
,
log_config_to_disk
from
megatron.core.config_logger
import
has_config_logger_enabled
,
log_config_to_disk
from
megatron.core.inference.contexts
import
BaseInferenceContext
from
megatron.core.inference.contexts
import
BaseInferenceContext
...
@@ -19,6 +20,7 @@ from dcu_megatron.core.pipeline_parallel.combined_1f1b import (
...
@@ -19,6 +20,7 @@ from dcu_megatron.core.pipeline_parallel.combined_1f1b import (
AbstractSchedulePlan
,
AbstractSchedulePlan
,
ScheduleNode
,
ScheduleNode
,
get_com_stream
,
get_com_stream
,
get_comp_stream
,
make_viewless
,
make_viewless
,
)
)
...
@@ -620,7 +622,6 @@ def schedule_chunk_1f1b(
...
@@ -620,7 +622,6 @@ def schedule_chunk_1f1b(
f_context
=
f_context
if
f_context
is
not
None
else
contextlib
.
nullcontext
()
f_context
=
f_context
if
f_context
is
not
None
else
contextlib
.
nullcontext
()
b_context
=
b_context
if
b_context
is
not
None
else
contextlib
.
nullcontext
()
b_context
=
b_context
if
b_context
is
not
None
else
contextlib
.
nullcontext
()
if
f_schedule_plan
:
if
f_schedule_plan
:
# pp output send/receive sync
# pp output send/receive sync
if
pre_forward
is
not
None
:
if
pre_forward
is
not
None
:
...
@@ -709,7 +710,7 @@ def schedule_chunk_1f1b(
...
@@ -709,7 +710,7 @@ def schedule_chunk_1f1b(
if
f_schedule_plan
is
not
None
and
post_forward
is
not
None
:
if
f_schedule_plan
is
not
None
and
post_forward
is
not
None
:
with
f_context
:
with
f_context
:
f_schedule_plan
.
wait_current_stream
()
f_schedule_plan
.
wait_current_stream
()
post_forward
(
f_input
)
post_forward
(
None
if
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
False
)
else
f_input
)
# pp grad send / receive, overlapped with attn dw of cur micro-batch and forward attn of next micro-batch
# pp grad send / receive, overlapped with attn dw of cur micro-batch and forward attn of next micro-batch
if
b_schedule_plan
is
not
None
and
post_backward
is
not
None
:
if
b_schedule_plan
is
not
None
and
post_backward
is
not
None
:
...
@@ -744,7 +745,7 @@ def build_model_chunk_schedule_plan(
...
@@ -744,7 +745,7 @@ def build_model_chunk_schedule_plan(
loss_mask
:
Optional
[
Tensor
]
=
None
loss_mask
:
Optional
[
Tensor
]
=
None
):
):
comp_stream
=
torch
.
cuda
.
current
_stream
()
comp_stream
=
get_comp
_stream
()
com_stream
=
get_com_stream
()
com_stream
=
get_com_stream
()
model_chunk_schedule_plan
=
ModelChunkSchedulePlan
()
model_chunk_schedule_plan
=
ModelChunkSchedulePlan
()
event
=
model_chunk_schedule_plan
.
event
event
=
model_chunk_schedule_plan
.
event
...
...
dcu_megatron/core/pipeline_parallel/combined_1f1b.py
View file @
ab2a8334
...
@@ -472,7 +472,10 @@ def forward_backward_step(
...
@@ -472,7 +472,10 @@ def forward_backward_step(
else
torch
.
tensor
(
1.0
)
else
torch
.
tensor
(
1.0
)
)
)
# Set the loss scale
# Set the loss scale
MoEAuxLossAutoScaler
.
set_loss_scale
(
loss_scale
/
num_microbatches
)
if
config
.
calculate_per_token_loss
:
MoEAuxLossAutoScaler
.
set_loss_scale
(
loss_scale
)
else
:
MoEAuxLossAutoScaler
.
set_loss_scale
(
loss_scale
/
num_microbatches
)
if
not
unwrap_output_tensor
:
if
not
unwrap_output_tensor
:
output_tensor
,
num_tokens
=
[
output_tensor
],
num_tokens
output_tensor
,
num_tokens
=
[
output_tensor
],
num_tokens
...
...
dcu_megatron/core/transformer/moe/token_dispatcher.py
View file @
ab2a8334
...
@@ -25,13 +25,13 @@ class MoEAlltoAllPerBatchState:
...
@@ -25,13 +25,13 @@ class MoEAlltoAllPerBatchState:
self
.
input_splits
=
None
self
.
input_splits
=
None
self
.
num_out_tokens
=
None
self
.
num_out_tokens
=
None
self
.
capacity
=
None
self
.
capacity
=
None
self
.
preprocess_event
=
None
self
.
hidden_shape
=
None
self
.
hidden_shape
=
None
self
.
probs
=
None
self
.
probs
=
None
self
.
routing_map
=
None
self
.
routing_map
=
None
self
.
reversed_local_input_permutation_mapping
=
None
self
.
reversed_local_input_permutation_mapping
=
None
self
.
cuda_sync_point
=
None
self
.
cuda_sync_point
=
None
self
.
hidden_shape_before_permute
=
None
self
.
hidden_shape_before_permute
=
None
self
.
tokens_per_expert
=
None
class
MoEAlltoAllTokenDispatcher
(
MegatronCoreMoEAlltoAllTokenDispatcher
):
class
MoEAlltoAllTokenDispatcher
(
MegatronCoreMoEAlltoAllTokenDispatcher
):
...
@@ -44,7 +44,6 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
...
@@ -44,7 +44,6 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
state
.
input_splits
=
getattr
(
self
,
"input_splits"
,
None
)
state
.
input_splits
=
getattr
(
self
,
"input_splits"
,
None
)
state
.
num_out_tokens
=
getattr
(
self
,
"num_out_tokens"
,
None
)
state
.
num_out_tokens
=
getattr
(
self
,
"num_out_tokens"
,
None
)
state
.
capacity
=
getattr
(
self
,
"capacity"
,
None
)
state
.
capacity
=
getattr
(
self
,
"capacity"
,
None
)
state
.
preprocess_event
=
getattr
(
self
,
"preprocess_event"
,
None
)
state
.
hidden_shape
=
getattr
(
self
,
"hidden_shape"
,
None
)
state
.
hidden_shape
=
getattr
(
self
,
"hidden_shape"
,
None
)
state
.
probs
=
getattr
(
self
,
"probs"
,
None
)
state
.
probs
=
getattr
(
self
,
"probs"
,
None
)
state
.
routing_map
=
getattr
(
self
,
"routing_map"
,
None
)
state
.
routing_map
=
getattr
(
self
,
"routing_map"
,
None
)
...
@@ -53,6 +52,7 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
...
@@ -53,6 +52,7 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
)
)
state
.
hidden_shape_before_permute
=
getattr
(
self
,
"hidden_shape_before_permute"
,
None
)
state
.
hidden_shape_before_permute
=
getattr
(
self
,
"hidden_shape_before_permute"
,
None
)
state
.
cuda_sync_point
=
getattr
(
self
,
"cuda_sync_point"
,
None
)
state
.
cuda_sync_point
=
getattr
(
self
,
"cuda_sync_point"
,
None
)
state
.
tokens_per_expert
=
getattr
(
self
,
"tokens_per_expert"
,
None
)
def
apply_per_batch_state
(
self
,
state
:
MoEAlltoAllPerBatchState
):
def
apply_per_batch_state
(
self
,
state
:
MoEAlltoAllPerBatchState
):
self
.
num_global_tokens_per_local_expert
=
state
.
num_global_tokens_per_local_expert
self
.
num_global_tokens_per_local_expert
=
state
.
num_global_tokens_per_local_expert
...
@@ -61,7 +61,6 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
...
@@ -61,7 +61,6 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
self
.
input_splits
=
state
.
input_splits
self
.
input_splits
=
state
.
input_splits
self
.
num_out_tokens
=
state
.
num_out_tokens
self
.
num_out_tokens
=
state
.
num_out_tokens
self
.
capacity
=
state
.
capacity
self
.
capacity
=
state
.
capacity
self
.
preprocess_event
=
state
.
preprocess_event
self
.
hidden_shape
=
state
.
hidden_shape
self
.
hidden_shape
=
state
.
hidden_shape
self
.
probs
=
state
.
probs
self
.
probs
=
state
.
probs
self
.
routing_map
=
state
.
routing_map
self
.
routing_map
=
state
.
routing_map
...
@@ -70,6 +69,7 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
...
@@ -70,6 +69,7 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
)
)
self
.
hidden_shape_before_permute
=
state
.
hidden_shape_before_permute
self
.
hidden_shape_before_permute
=
state
.
hidden_shape_before_permute
self
.
cuda_sync_point
=
state
.
cuda_sync_point
self
.
cuda_sync_point
=
state
.
cuda_sync_point
self
.
tokens_per_expert
=
state
.
tokens_per_expert
@
contextmanager
@
contextmanager
def
per_batch_state_context
(
self
,
state
:
MoEAlltoAllPerBatchState
):
def
per_batch_state_context
(
self
,
state
:
MoEAlltoAllPerBatchState
):
...
@@ -144,6 +144,7 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
...
@@ -144,6 +144,7 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
output_split_sizes
=
None
output_split_sizes
=
None
else
:
else
:
output_split_sizes
=
self
.
output_splits_tp
.
tolist
()
output_split_sizes
=
self
.
output_splits_tp
.
tolist
()
global_input_tokens
=
gather_from_sequence_parallel_region
(
global_input_tokens
=
gather_from_sequence_parallel_region
(
global_input_tokens
,
group
=
self
.
tp_group
,
output_split_sizes
=
output_split_sizes
global_input_tokens
,
group
=
self
.
tp_group
,
output_split_sizes
=
output_split_sizes
)
)
...
...
dcu_megatron/core/transformer/transformer_layer.py
View file @
ab2a8334
...
@@ -182,7 +182,7 @@ class TransformerLayer(MegatronCoreTransformerLayer):
...
@@ -182,7 +182,7 @@ class TransformerLayer(MegatronCoreTransformerLayer):
return
output
return
output
def
_submodule_moe_forward
(
self
,
tokens_per_expert
,
global_input_tokens
,
global_prob
,
hidden_states
):
def
_submodule_moe_forward
(
self
,
tokens_per_expert
,
global_input_tokens
,
global_prob
,
pre_mlp_layernorm_output
):
"""
"""
Performs a forward pass for the MLP submodule, including both expert-based
Performs a forward pass for the MLP submodule, including both expert-based
and optional shared-expert computations.
and optional shared-expert computations.
...
@@ -194,7 +194,7 @@ class TransformerLayer(MegatronCoreTransformerLayer):
...
@@ -194,7 +194,7 @@ class TransformerLayer(MegatronCoreTransformerLayer):
expert_output
,
mlp_bias
=
self
.
mlp
.
experts
(
dispatched_input
,
tokens_per_expert
,
permuted_probs
)
expert_output
,
mlp_bias
=
self
.
mlp
.
experts
(
dispatched_input
,
tokens_per_expert
,
permuted_probs
)
expert_output
=
self
.
mlp
.
token_dispatcher
.
combine_preprocess
(
expert_output
)
expert_output
=
self
.
mlp
.
token_dispatcher
.
combine_preprocess
(
expert_output
)
if
self
.
mlp
.
use_shared_expert
and
not
self
.
mlp
.
shared_expert_overlap
:
if
self
.
mlp
.
use_shared_expert
and
not
self
.
mlp
.
shared_expert_overlap
:
shared_expert_output
=
self
.
mlp
.
shared_experts
(
hidden_states
)
shared_expert_output
=
self
.
mlp
.
shared_experts
(
pre_mlp_layernorm_output
)
return
expert_output
,
shared_expert_output
,
mlp_bias
return
expert_output
,
shared_expert_output
,
mlp_bias
def
_submodule_combine_forward
(
self
,
hidden_states
):
def
_submodule_combine_forward
(
self
,
hidden_states
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment