Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
32ee381a
Commit
32ee381a
authored
May 06, 2025
by
dongcl
Browse files
a2a overlap
parent
12b56c98
Changes
6
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
1025 additions
and
879 deletions
+1025
-879
dcu_megatron/core/models/gpt/fine_grained_schedule.py
dcu_megatron/core/models/gpt/fine_grained_schedule.py
+342
-320
dcu_megatron/core/models/gpt/gpt_model.py
dcu_megatron/core/models/gpt/gpt_model.py
+59
-11
dcu_megatron/core/pipeline_parallel/combined_1f1b.py
dcu_megatron/core/pipeline_parallel/combined_1f1b.py
+122
-216
dcu_megatron/core/transformer/moe/token_dispatcher.py
dcu_megatron/core/transformer/moe/token_dispatcher.py
+191
-182
dcu_megatron/core/transformer/transformer_layer.py
dcu_megatron/core/transformer/transformer_layer.py
+279
-150
dcu_megatron/core/transformer/utils.py
dcu_megatron/core/transformer/utils.py
+32
-0
No files found.
dcu_megatron/core/models/gpt/fine_grained_schedule.py
View file @
32ee381a
This diff is collapsed.
Click to expand it.
dcu_megatron/core/models/gpt/gpt_model.py
View file @
32ee381a
...
...
@@ -9,6 +9,7 @@ from torch import Tensor
from
megatron.core
import
InferenceParams
,
tensor_parallel
from
megatron.core.config_logger
import
has_config_logger_enabled
,
log_config_to_disk
from
megatron.core.inference.contexts
import
BaseInferenceContext
from
megatron.core.packed_seq_params
import
PackedSeqParams
from
megatron.core.models.gpt
import
GPTModel
as
MegatronCoreGPTModel
...
...
@@ -64,11 +65,14 @@ class GPTModel(MegatronCoreGPTModel):
attention_mask
:
Tensor
,
decoder_input
:
Tensor
=
None
,
labels
:
Tensor
=
None
,
inference_
params
:
Inference
Params
=
None
,
inference_
context
:
Base
Inference
Context
=
None
,
packed_seq_params
:
PackedSeqParams
=
None
,
extra_block_kwargs
:
dict
=
None
,
runtime_gather_output
:
Optional
[
bool
]
=
None
,
*
,
inference_params
:
Optional
[
BaseInferenceContext
]
=
None
,
loss_mask
:
Optional
[
Tensor
]
=
None
,
):
"""Builds a computation schedule plan for the model.
...
...
@@ -105,10 +109,12 @@ class GPTModel(MegatronCoreGPTModel):
attention_mask
,
decoder_input
=
decoder_input
,
labels
=
labels
,
inference_
params
=
inference_
params
,
inference_
context
=
inference_
context
,
packed_seq_params
=
packed_seq_params
,
extra_block_kwargs
=
extra_block_kwargs
,
runtime_gather_output
=
runtime_gather_output
,
inference_params
=
inference_params
,
loss_mask
=
loss_mask
,
)
def
forward
(
...
...
@@ -118,14 +124,16 @@ class GPTModel(MegatronCoreGPTModel):
attention_mask
:
Tensor
,
decoder_input
:
Tensor
=
None
,
labels
:
Tensor
=
None
,
inference_
params
:
Inference
Params
=
None
,
inference_
context
:
Base
Inference
Context
=
None
,
packed_seq_params
:
PackedSeqParams
=
None
,
extra_block_kwargs
:
dict
=
None
,
runtime_gather_output
:
Optional
[
bool
]
=
None
,
*
,
inference_params
:
Optional
[
BaseInferenceContext
]
=
None
,
loss_mask
:
Optional
[
Tensor
]
=
None
,
)
->
Tensor
:
"""Forward function of the GPT Model This function passes the input tensors
through the embedding layer, and then the decoder and finally into the post
through the embedding layer, and then the deco
e
der and finally into the post
processing layer (optional).
It either returns the Loss values if labels are given or the final hidden units
...
...
@@ -137,6 +145,8 @@ class GPTModel(MegatronCoreGPTModel):
# If decoder_input is provided (not None), then input_ids and position_ids are ignored.
# Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input.
inference_context
=
deprecate_inference_params
(
inference_context
,
inference_params
)
# Decoder embedding.
if
decoder_input
is
not
None
:
pass
...
...
@@ -152,39 +162,64 @@ class GPTModel(MegatronCoreGPTModel):
rotary_pos_cos
=
None
rotary_pos_sin
=
None
if
self
.
position_embedding_type
==
'rope'
and
not
self
.
config
.
multi_latent_attention
:
if
not
self
.
training
and
self
.
config
.
flash_decode
and
inference_params
:
if
not
self
.
training
and
self
.
config
.
flash_decode
and
inference_context
:
assert
(
inference_context
.
is_static_batching
()
),
"GPTModel currently only supports static inference batching."
# Flash decoding uses precomputed cos and sin for RoPE
rotary_pos_cos
,
rotary_pos_sin
=
self
.
rotary_pos_emb_cache
.
setdefault
(
inference_
params
.
max_sequence_length
,
self
.
rotary_pos_emb
.
get_cos_sin
(
inference_
params
.
max_sequence_length
),
inference_
context
.
max_sequence_length
,
self
.
rotary_pos_emb
.
get_cos_sin
(
inference_
context
.
max_sequence_length
),
)
else
:
rotary_seq_len
=
self
.
rotary_pos_emb
.
get_rotary_seq_len
(
inference_
params
,
self
.
decoder
,
decoder_input
,
self
.
config
,
packed_seq_params
inference_
context
,
self
.
decoder
,
decoder_input
,
self
.
config
,
packed_seq_params
)
rotary_pos_emb
=
self
.
rotary_pos_emb
(
rotary_seq_len
,
packed_seq
=
packed_seq_params
is
not
None
and
packed_seq_params
.
qkv_format
==
'thd'
,
)
elif
self
.
position_embedding_type
==
'mrope'
and
not
self
.
config
.
multi_latent_attention
:
if
self
.
training
or
not
self
.
config
.
flash_decode
:
rotary_pos_emb
=
self
.
rotary_pos_emb
(
position_ids
,
self
.
mrope_section
)
else
:
# Flash decoding uses precomputed cos and sin for RoPE
raise
NotImplementedError
(
"Flash decoding uses precomputed cos and sin for RoPE, not implmented in "
"MultimodalRotaryEmbedding yet."
)
if
(
(
self
.
config
.
enable_cuda_graph
or
self
.
config
.
flash_decode
)
and
rotary_pos_cos
is
not
None
and
inference_params
and
inference_context
and
inference_context
.
is_static_batching
()
and
not
self
.
training
):
sequence_len_offset
=
torch
.
tensor
(
[
inference_
params
.
sequence_len_offset
]
*
inference_
params
.
current_batch_size
,
[
inference_
context
.
sequence_len_offset
]
*
inference_
context
.
current_batch_size
,
dtype
=
torch
.
int32
,
device
=
rotary_pos_cos
.
device
,
# Co-locate this with the rotary tensors
)
else
:
sequence_len_offset
=
None
# Wrap decoder_input to allow the decoder (TransformerBlock) to delete the
# reference held by this caller function, enabling early garbage collection for
# inference. Skip wrapping if decoder_input is logged after decoder completion.
if
(
inference_context
is
not
None
and
not
self
.
training
and
not
has_config_logger_enabled
(
self
.
config
)
):
decoder_input
=
WrappedTensor
(
decoder_input
)
# Run decoder.
hidden_states
=
self
.
decoder
(
hidden_states
=
decoder_input
,
attention_mask
=
attention_mask
,
inference_
params
=
inference_
params
,
inference_
context
=
inference_
context
,
rotary_pos_emb
=
rotary_pos_emb
,
rotary_pos_cos
=
rotary_pos_cos
,
rotary_pos_sin
=
rotary_pos_sin
,
...
...
@@ -193,6 +228,12 @@ class GPTModel(MegatronCoreGPTModel):
**
(
extra_block_kwargs
or
{}),
)
# Process inference output.
if
inference_context
and
not
inference_context
.
is_static_batching
():
hidden_states
=
inference_context
.
last_token_logits
(
hidden_states
.
squeeze
(
1
).
unsqueeze
(
0
)
).
unsqueeze
(
1
)
# logits and loss
output_weight
=
None
if
self
.
share_embeddings_and_output_weights
:
...
...
@@ -230,6 +271,13 @@ class GPTModel(MegatronCoreGPTModel):
if
not
self
.
post_process
:
return
hidden_states
if
(
not
self
.
training
and
inference_context
is
not
None
and
inference_context
.
is_static_batching
()
and
inference_context
.
materialize_only_last_token_logits
):
hidden_states
=
hidden_states
[
-
1
:,
:,
:]
logits
,
_
=
self
.
output_layer
(
hidden_states
,
weight
=
output_weight
,
runtime_gather_output
=
runtime_gather_output
)
...
...
dcu_megatron/core/pipeline_parallel/combined_1f1b.py
View file @
32ee381a
This diff is collapsed.
Click to expand it.
dcu_megatron/core/transformer/moe/token_dispatcher.py
View file @
32ee381a
from
megatron.core.transformer.moe.token_dispatcher
import
_DeepepManager
as
MegatronCoreDeepepManager
class
MoEAlltoAllTokenDispatcher
(
MoETokenDispatcher
):
def
token_permutation
(
from
megatron.core.transformer.moe.token_dispatcher
import
MoEAlltoAllTokenDispatcher
as
MegatronCoreMoEAlltoAllTokenDispatcher
# decouple perbatch state from MoEAlltoAllTokenDispatcher
class
MoEAlltoAllPerBatchState
:
def
__init__
(
self
,
build_event
=
False
):
self
.
num_global_tokens_per_local_expert
=
None
self
.
output_splits_tp
=
None
self
.
output_splits
=
None
self
.
input_splits
=
None
self
.
num_out_tokens
=
None
self
.
capacity
=
None
self
.
preprocess_event
=
None
self
.
hidden_shape
=
None
self
.
probs
=
None
self
.
routing_map
=
None
self
.
reversed_local_input_permutation_mapping
=
None
self
.
cuda_sync_point
=
None
self
.
hidden_shape_before_permute
=
None
class
MoEAlltoAllTokenDispatcher
(
MegatronCoreMoEAlltoAllTokenDispatcher
):
def
collect_per_batch_state
(
self
,
state
:
MoEAlltoAllPerBatchState
):
state
.
num_global_tokens_per_local_expert
=
getattr
(
self
,
"num_global_tokens_per_local_expert"
,
None
)
state
.
output_splits_tp
=
getattr
(
self
,
"output_splits_tp"
,
None
)
state
.
output_splits
=
getattr
(
self
,
"output_splits"
,
None
)
state
.
input_splits
=
getattr
(
self
,
"input_splits"
,
None
)
state
.
num_out_tokens
=
getattr
(
self
,
"num_out_tokens"
,
None
)
state
.
capacity
=
getattr
(
self
,
"capacity"
,
None
)
state
.
preprocess_event
=
getattr
(
self
,
"preprocess_event"
,
None
)
state
.
hidden_shape
=
getattr
(
self
,
"hidden_shape"
,
None
)
state
.
probs
=
getattr
(
self
,
"probs"
,
None
)
state
.
routing_map
=
getattr
(
self
,
"routing_map"
,
None
)
state
.
reversed_local_input_permutation_mapping
=
getattr
(
self
,
"reversed_local_input_permutation_mapping"
,
None
)
state
.
hidden_shape_before_permute
=
getattr
(
self
,
"hidden_shape_before_permute"
,
None
)
state
.
cuda_sync_point
=
getattr
(
self
,
"cuda_sync_point"
,
None
)
def
apply_per_batch_state
(
self
,
state
:
MoEAlltoAllPerBatchState
):
self
.
num_global_tokens_per_local_expert
=
state
.
num_global_tokens_per_local_expert
self
.
output_splits_tp
=
state
.
output_splits_tp
self
.
output_splits
=
state
.
output_splits
self
.
input_splits
=
state
.
input_splits
self
.
num_out_tokens
=
state
.
num_out_tokens
self
.
capacity
=
state
.
capacity
self
.
preprocess_event
=
state
.
preprocess_event
self
.
hidden_shape
=
state
.
hidden_shape
self
.
probs
=
state
.
probs
self
.
routing_map
=
state
.
routing_map
self
.
reversed_local_input_permutation_mapping
=
(
state
.
reversed_local_input_permutation_mapping
)
self
.
hidden_shape_before_permute
=
state
.
hidden_shape_before_permute
self
.
cuda_sync_point
=
state
.
cuda_sync_point
@
contextmanager
def
per_batch_state_context
(
self
,
state
:
MoEAlltoAllPerBatchState
):
origin_state
=
MoEAlltoAllPerBatchState
()
self
.
collect_per_batch_state
(
origin_state
)
try
:
self
.
apply_per_batch_state
(
state
)
yield
finally
:
self
.
collect_per_batch_state
(
state
)
self
.
apply_per_batch_state
(
origin_state
)
def
meta_prepare
(
self
,
hidden_states
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
,
routing_map
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Dispatch tokens to local experts using AlltoAll communication.
This method performs the following steps:
1. Preprocess the routing map to get metadata for communication and permutation.
2. Permute input tokens for AlltoAll communication.
3. Perform expert parallel AlltoAll communication.
4. Sort tokens by local expert (if multiple local experts exist).
Args:
hidden_states (torch.Tensor): Input token embeddings.
probs (torch.Tensor): The probabilities of token to experts assignment.
routing_map (torch.Tensor): The mapping of token to experts assignment.
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- Permuted token embeddings for local experts.
- Number of tokens per expert.
- Permuted probs of each token produced by the router.
"""
# Preprocess: Get the metadata for communication, permutation and computation operations.
):
self
.
hidden_shape
=
hidden_states
.
shape
self
.
probs
=
probs
self
.
routing_map
=
routing_map
assert
probs
.
dim
()
==
2
,
"Expected 2D tensor for probs"
assert
routing_map
.
dim
()
==
2
,
"Expected 2D tensor for token2expert mask"
assert
routing_map
.
dtype
==
torch
.
bool
,
"Expected bool tensor for mask"
hidden_states
=
hidden_states
.
view
(
-
1
,
self
.
hidden_shape
[
-
1
])
tokens_per_expert
=
self
.
preprocess
(
self
.
routing_map
)
return
tokens_per_expert
def
dispatch_preprocess
(
self
,
hidden_states
:
torch
.
Tensor
,
routing_map
:
torch
.
Tensor
,
tokens_per_expert
:
torch
.
Tensor
):
hidden_states
=
hidden_states
.
view
(
-
1
,
self
.
hidden_shape
[
-
1
])
if
self
.
shared_experts
is
not
None
:
self
.
shared_experts
.
pre_forward_comm
(
hidden_states
.
view
(
self
.
hidden_shape
))
...
...
@@ -49,12 +98,15 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
)
=
permute
(
hidden_states
,
routing_map
,
probs
=
probs
,
probs
=
self
.
probs
,
num_out_tokens
=
self
.
num_out_tokens
,
fused
=
self
.
config
.
moe_permute_fusion
,
drop_and_pad
=
self
.
drop_and_pad
,
)
return
tokens_per_expert
,
permutated_local_input_tokens
,
permuted_probs
def
dispatch_all_to_all
(
self
,
tokens_per_expert
,
permutated_local_input_tokens
,
permuted_probs
):
# Perform expert parallel AlltoAll communication
tokens_per_expert
=
self
.
_maybe_dtoh_and_synchronize
(
"before_ep_alltoall"
,
tokens_per_expert
...
...
@@ -65,6 +117,10 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
global_probs
=
all_to_all
(
self
.
ep_group
,
permuted_probs
,
self
.
output_splits
,
self
.
input_splits
)
return
tokens_per_expert
,
global_input_tokens
,
global_probs
def
dispatch_postprocess
(
self
,
tokens_per_expert
,
global_input_tokens
,
global_probs
):
if
self
.
shared_experts
is
not
None
:
self
.
shared_experts
.
linear_fc1_forward_and_act
(
global_input_tokens
)
...
...
@@ -118,184 +174,137 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
)
tokens_per_expert
=
self
.
_maybe_dtoh_and_synchronize
(
"before_finish"
,
tokens_per_expert
)
return
global_input_tokens
,
tokens_per_expert
,
global_probs
return
global_input_tokens
,
tokens_per_expert
,
global_probs
class
_DeepepManager
(
MegatronCoreDeepepManager
):
"""
patch megatron _DeepepManager. async
"""
def
dispatch
(
self
,
hidden_states
:
torch
.
Tensor
,
async_finish
:
bool
=
False
,
allocate_on_comm_stream
:
bool
=
False
,
)
->
torch
.
Tensor
:
# DeepEP only supports float32 probs
if
self
.
token_probs
.
dtype
!=
torch
.
float32
:
if
self
.
token_probs
.
dtype
in
[
torch
.
bfloat16
,
torch
.
float16
]:
print
(
"DeepEP only supports float32 probs, please set --moe-router-dtype=fp32"
)
self
.
token_probs
=
self
.
token_probs
.
float
()
# downcast or upcast
hidden_states
,
dispatched_indices
,
dispatched_probs
,
num_tokens_per_expert
,
handle
=
(
fused_dispatch
(
hidden_states
,
self
.
token_indices
,
self
.
token_probs
,
self
.
num_experts
,
self
.
group
,
async_finish
=
async_finish
,
allocate_on_comm_stream
=
allocate_on_comm_stream
,
)
)
self
.
handle
=
handle
self
.
tokens_per_expert
=
num_tokens_per_expert
self
.
dispatched_indices
=
dispatched_indices
self
.
dispatched_probs
=
dispatched_probs
return
hidden_states
def
combine
(
self
,
hidden_states
:
torch
.
Tensor
,
async_finish
:
bool
=
False
,
allocate_on_comm_stream
:
bool
=
False
,
)
->
torch
.
Tensor
:
hidden_states
,
_
=
fused_combine
(
hidden_states
,
self
.
group
,
self
.
handle
,
async_finish
=
async_finish
,
allocate_on_comm_stream
=
allocate_on_comm_stream
,
)
# Release the handle after combine operation
self
.
handle
=
None
return
hidden_states
def
token_permutation
(
self
,
hidden_states
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
,
routing_map
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Dispatch tokens to local experts using AlltoAll communication.
class
MoEFlexTokenDispatcher
(
MoETokenDispatcher
):
"""
Flex token dispatcher using DeepEP.
"""
This method performs the following steps:
1. Preprocess the routing map to get metadata for communication and permutation.
2. Permute input tokens for AlltoAll communication.
3. Perform expert parallel AlltoAll communication.
4. Sort tokens by local expert (if multiple local experts exist).
def
dispatch_preprocess
(
self
,
hidden_states
:
torch
.
Tensor
,
routing_map
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
):
"""
Preprocesses the hidden states and routing information before dispatching tokens to experts.
Args:
hidden_states (torch.Tensor): Input
hidden states to be processed
ro
uting_map
(torch.Tensor):
Map indicating which expert each token should be routed to
p
ro
bs
(torch.Tensor):
Routing probabilities for each
token
-
expert
pair
hidden_states (torch.Tensor): Input
token embeddings.
p
ro
bs
(torch.Tensor):
The probabilities of token to experts assignment.
ro
uting_map
(torch.Tensor):
The mapping of
token
to
expert
s assignment.
Returns:
Tuple
containing
:
- torch.Tensor: Reshaped hidden states
- torch.Tensor: T
oken p
robabilities from the communication manager
- None: Placeholder for compatibility
Tuple
[torch.Tensor, torch.Tensor, torch.Tensor]
:
- Permuted token embeddings for local experts.
- Number of t
oken
s
p
er expert.
- Permuted probs of each token produced by the router.
"""
self
.
hidden_shape
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
self
.
hidden_shape
[
-
1
])
# Preprocess: Get the metadata for communication, permutation and computation operations.
# Permutation 1: input to AlltoAll input
tokens_per_expert
=
self
.
meta_prepare
(
hidden_states
,
probs
,
routing_map
)
tokens_per_expert
,
permutated_local_input_tokens
,
permuted_probs
=
self
.
dispatch_preprocess
(
hidden_states
,
routing_map
,
tokens_per_expert
)
#
Initialize metadata
routing_map
,
probs
=
self
.
_initialize_metadata
(
routing_map
,
probs
)
#
Perform expert parallel AlltoAll communication
tokens_per_expert
,
global_input_tokens
,
global_probs
=
self
.
dispatch_all_to_all
(
tokens_per_expert
,
permutated_local_input_tokens
,
permuted_
probs
)
self
.
_comm_manager
.
setup_metadata
(
routing_map
,
probs
)
return
hidden_states
,
self
.
_comm_manager
.
token_p
robs
,
None
# Permutation 2: Sort tokens by local expert.
global_input_tokens
,
tokens_per_expert
,
global_probs
=
self
.
dispatch_postprocess
(
token
s
_p
er_expert
,
global_input_tokens
,
global_probs
)
def
dispatch_all_to_all
(
self
,
hidden_states
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
=
None
,
async_finish
:
bool
=
True
,
allocate_on_comm_stream
:
bool
=
True
,
):
"""
Performs all-to-all communication to dispatch tokens across expert parallel ranks.
"""
return
(
self
.
_comm_manager
.
dispatch
(
hidden_states
,
async_finish
,
allocate_on_comm_stream
),
self
.
_comm_manager
.
dispatched_probs
,
)
return
global_input_tokens
,
tokens_per_expert
,
global_probs
def
dispatch_postprocess
(
self
,
hidden_states
:
torch
.
Tensor
):
"""
Post-processes the dispatched hidden states after all-to-all communication.
def
combine_preprocess
(
self
,
hidden_states
):
# Unpermutation 2: Unsort tokens by local expert.
if
self
.
num_local_experts
>
1
:
if
self
.
drop_and_pad
:
hidden_states
=
(
hidden_states
.
view
(
self
.
num_local_experts
,
self
.
tp_size
*
self
.
ep_size
,
self
.
capacity
,
*
hidden_states
.
size
()[
1
:],
)
.
transpose
(
0
,
1
)
.
contiguous
()
.
flatten
(
start_dim
=
0
,
end_dim
=
2
)
)
else
:
hidden_states
,
_
=
sort_chunks_by_idxs
(
hidden_states
,
self
.
num_global_tokens_per_local_expert
.
T
.
ravel
(),
self
.
restore_output_by_local_experts
,
fused
=
self
.
config
.
moe_permute_fusion
,
)
This method retrieves the permuted hidden states by experts, calculates the number of tokens
per expert, and returns the processed data ready for expert processing.
"""
global_input_tokens
,
permuted_probs
=
(
self
.
_comm_manager
.
get_permuted_hidden_states_by_experts
(
hidden_states
)
)
tokens_per_expert
=
self
.
_comm_manager
.
get_number_of_tokens_per_expert
()
return
global_input_tokens
,
tokens_per_expert
,
permuted_probs
if
self
.
tp_size
>
1
:
if
self
.
output_splits_tp
is
None
:
input_split_sizes
=
None
else
:
input_split_sizes
=
self
.
output_splits_tp
.
tolist
()
# The precision of TP reduce_scatter should be the same as the router_dtype
hidden_states
=
reduce_scatter_to_sequence_parallel_region
(
hidden_states
.
to
(
self
.
probs
.
dtype
),
group
=
self
.
tp_group
,
input_split_sizes
=
input_split_sizes
,
).
to
(
hidden_states
.
dtype
)
def
token_permutation
(
self
,
hidden_states
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
,
routing_map
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Permutes tokens according to the routing map and dispatches them to experts.
return
hidden_states
This method implements the token permutation process in three steps:
1. Preprocess the hidden states and routing information
2. Perform all-to-all communication to dispatch tokens
3. Post-process the dispatched tokens for expert processing
"""
hidden_states
,
_
,
_
=
self
.
dispatch_preprocess
(
hidden_states
,
routing_map
,
probs
)
hidden_states
,
_
=
self
.
dispatch_all_to_all
(
hidden_states
,
async_finish
=
False
,
allocate_on_comm_stream
=
False
)
global_input_tokens
,
tokens_per_expert
,
permuted_probs
=
self
.
dispatch_postprocess
(
hidden_states
def
combine_all_to_all
(
self
,
hidden_states
):
# Perform expert parallel AlltoAll communication
# hidden_states: [SEQL, H] -> [SEQL, H/TP]
permutated_local_input_tokens
=
all_to_all
(
self
.
ep_group
,
hidden_states
,
self
.
input_splits
,
self
.
output_splits
)
return
permutated_local_input_tokens
return
global_input_tokens
,
tokens_per_expert
,
permuted_probs
def
combine_preprocess
(
self
,
hidden_states
:
torch
.
Tensor
):
"""
Pre-processes the hidden states before combining them after expert processing.
This method restores the hidden states to their original ordering before expert processing
by using the communication manager's restoration function.
"""
hidden_states
=
self
.
_comm_manager
.
get_restored_hidden_states_by_experts
(
hidden_states
)
return
hidden_states
def
combine_postprocess
(
self
,
permutated_local_input_tokens
):
if
self
.
shared_experts
is
not
None
:
self
.
shared_experts
.
linear_fc2_forward
(
permutated_local_input_tokens
)
self
.
shared_experts
.
post_forward_comm
()
def
combine_all_to_all
(
self
,
hidden_states
:
torch
.
Tensor
,
async_finish
:
bool
=
True
,
allocate_on_comm_stream
:
bool
=
True
,
):
"""
Performs all-to-all communication to combine tokens after expert processing.
"""
return
self
.
_comm_manager
.
combine
(
hidden_states
,
async_finish
,
allocate_on_comm_stream
)
# Unpermutation 1: AlltoAll output to output
output
=
unpermute
(
permutated_local_input_tokens
,
self
.
reversed_local_input_permutation_mapping
,
restore_shape
=
self
.
hidden_shape_before_permute
,
routing_map
=
self
.
routing_map
,
fused
=
self
.
config
.
moe_permute_fusion
,
drop_and_pad
=
self
.
drop_and_pad
,
)
def
combine_postprocess
(
self
,
hidden_states
:
torch
.
Tensor
):
"""
Post-processes the combined hidden states after all-to-all communication.
# Reshape the output tensor
output
=
output
.
view
(
self
.
hidden_shape
)
This method reshapes the combined hidden states to match the original input shape.
"""
return
hidden_states
.
view
(
self
.
hidden_shape
)
# Add shared experts output
if
self
.
shared_experts
is
not
None
:
shared_expert_output
=
self
.
shared_experts
.
get_output
()
output
+=
shared_expert_output
return
output
def
token_unpermutation
(
self
,
hidden_states
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
"""
Reverses the token permutation process to restore the original token order.
Reverse the token permutation to restore the original order.
This method performs the following steps:
1. Unsort tokens by local expert (if multiple local experts exist).
2. Perform expert parallel AlltoAll communication to restore the original order.
3. Unpermute tokens to restore the original order.
This method implements the token unpermutation process in three steps:
1. Pre-process the hidden states to restore their original ordering
2. Perform all-to-all communication to combine tokens
3. Post-process the combined tokens to match the original input shape
Args:
hidden_states (torch.Tensor): Output from local experts.
bias (torch.Tensor, optional): Bias tensor (not supported).
Returns:
Tuple[torch.Tensor, Optional[torch.Tensor]]:
- Unpermuted token embeddings in the original order.
- None (bias is not supported).
"""
assert
bias
is
None
,
"Bias is not supported in MoEFlexTokenDispatcher"
hidden_states
=
self
.
combine_preprocess
(
hidden_states
)
hidden_states
=
self
.
combine_all_to_all
(
hidden_states
,
False
,
False
)
hidden_states
=
self
.
combine_postprocess
(
hidden_states
)
assert
bias
is
None
,
"Bias is not supported in MoEAlltoAllTokenDispatcher"
hidden_states
=
self
.
combine_preprocess
(
hidden_states
)
permutated_local_input_tokens
=
self
.
combine_all_to_all
(
hidden_states
)
output
=
self
.
combine_postprocess
(
permutated_local_input_tokens
)
return
hidden_states
,
None
return
output
,
None
dcu_megatron/core/transformer/transformer_layer.py
View file @
32ee381a
This diff is collapsed.
Click to expand it.
dcu_megatron/core/transformer/utils.py
0 → 100644
View file @
32ee381a
from
dataclasses
import
dataclass
from
typing
import
Callable
,
Optional
@
dataclass
class
SubmoduleCallables
:
"""
Holds references to forward, dgrad, and dw (weight-grad) callables
for a particular submodule.
"""
forward
:
Optional
[
Callable
]
=
None
backward
:
Optional
[
Callable
]
=
None
dgrad
:
Optional
[
Callable
]
=
None
dw
:
Optional
[
Callable
]
=
None
@
dataclass
class
TransformerLayerSubmoduleCallables
:
"""
Collects the SubmoduleMethods for each of the submodules:
'attention', 'dispatch', 'mlp', 'combine'.
"""
attention
:
SubmoduleCallables
dispatch
:
SubmoduleCallables
mlp
:
SubmoduleCallables
combine
:
SubmoduleCallables
post_combine
:
SubmoduleCallables
def
as_array
(
self
):
return
[
self
.
attention
,
self
.
dispatch
,
self
.
mlp
,
self
.
combine
,
self
.
post_combine
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment