Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
12b56c98
Commit
12b56c98
authored
Apr 30, 2025
by
dongcl
Browse files
support a2a overlap
parent
8551c38e
Changes
6
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
2139 additions
and
125 deletions
+2139
-125
dcu_megatron/core/models/gpt/fine_grained_schedule.py
dcu_megatron/core/models/gpt/fine_grained_schedule.py
+778
-0
dcu_megatron/core/models/gpt/gpt_model.py
dcu_megatron/core/models/gpt/gpt_model.py
+191
-125
dcu_megatron/core/pipeline_parallel/combined_1f1b.py
dcu_megatron/core/pipeline_parallel/combined_1f1b.py
+625
-0
dcu_megatron/core/transformer/moe/token_dispatcher.py
dcu_megatron/core/transformer/moe/token_dispatcher.py
+301
-0
dcu_megatron/core/transformer/transformer_block.py
dcu_megatron/core/transformer/transformer_block.py
+20
-0
dcu_megatron/core/transformer/transformer_layer.py
dcu_megatron/core/transformer/transformer_layer.py
+224
-0
No files found.
dcu_megatron/core/models/gpt/fine_grained_schedule.py
0 → 100644
View file @
12b56c98
This diff is collapsed.
Click to expand it.
dcu_megatron/core/models/gpt/gpt_model.py
View file @
12b56c98
...
@@ -10,6 +10,7 @@ from torch import Tensor
...
@@ -10,6 +10,7 @@ from torch import Tensor
from
megatron.core
import
InferenceParams
,
tensor_parallel
from
megatron.core
import
InferenceParams
,
tensor_parallel
from
megatron.core.config_logger
import
has_config_logger_enabled
,
log_config_to_disk
from
megatron.core.config_logger
import
has_config_logger_enabled
,
log_config_to_disk
from
megatron.core.packed_seq_params
import
PackedSeqParams
from
megatron.core.packed_seq_params
import
PackedSeqParams
from
megatron.core.models.gpt
import
GPTModel
as
MegatronCoreGPTModel
from
dcu_megatron.core.tensor_parallel
import
FluxColumnParallelLinear
from
dcu_megatron.core.tensor_parallel
import
FluxColumnParallelLinear
...
@@ -45,100 +46,143 @@ def gpt_model_init_wrapper(fn):
...
@@ -45,100 +46,143 @@ def gpt_model_init_wrapper(fn):
return
wrapper
return
wrapper
def
gpt_model_forward
(
class
GPTModel
(
MegatronCoreGPTModel
):
self
,
input_ids
:
Tensor
,
position_ids
:
Tensor
,
attention_mask
:
Tensor
,
decoder_input
:
Tensor
=
None
,
labels
:
Tensor
=
None
,
inference_params
:
InferenceParams
=
None
,
packed_seq_params
:
PackedSeqParams
=
None
,
extra_block_kwargs
:
dict
=
None
,
runtime_gather_output
:
Optional
[
bool
]
=
None
,
loss_mask
:
Optional
[
Tensor
]
=
None
,
)
->
Tensor
:
"""Forward function of the GPT Model This function passes the input tensors
through the embedding layer, and then the decoder and finally into the post
processing layer (optional).
It either returns the Loss values if labels are given or the final hidden units
Args:
runtime_gather_output (bool): Gather output at runtime. Default None means
`parallel_output` arg in the constructor will be used.
"""
"""
# If decoder_input is provided (not None), then input_ids and position_ids are ignored.
patch megatron GPTModel
# Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input.
"""
# Decoder embedding.
def
get_transformer_callables_by_layer
(
self
,
layer_number
:
int
):
if
decoder_input
is
not
None
:
"""
pass
Get the callables for the layer at the given transformer layer number.
elif
self
.
pre_process
:
"""
decoder_input
=
self
.
embedding
(
input_ids
=
input_ids
,
position_ids
=
position_ids
)
return
self
.
decoder
.
get_layer_callables
(
layer_number
)
else
:
# intermediate stage of pipeline
def
build_schedule_plan
(
# decoder will get hidden_states from encoder.input_tensor
self
,
decoder_input
=
None
input_ids
:
Tensor
,
position_ids
:
Tensor
,
# Rotary positional embeddings (embedding is None for PP intermediate devices)
attention_mask
:
Tensor
,
rotary_pos_emb
=
None
decoder_input
:
Tensor
=
None
,
rotary_pos_cos
=
None
labels
:
Tensor
=
None
,
rotary_pos_sin
=
None
inference_params
:
InferenceParams
=
None
,
if
self
.
position_embedding_type
==
'rope'
and
not
self
.
config
.
multi_latent_attention
:
packed_seq_params
:
PackedSeqParams
=
None
,
if
not
self
.
training
and
self
.
config
.
flash_decode
and
inference_params
:
extra_block_kwargs
:
dict
=
None
,
# Flash decoding uses precomputed cos and sin for RoPE
runtime_gather_output
:
Optional
[
bool
]
=
None
,
rotary_pos_cos
,
rotary_pos_sin
=
self
.
rotary_pos_emb_cache
.
setdefault
(
loss_mask
:
Optional
[
Tensor
]
=
None
,
inference_params
.
max_sequence_length
,
self
.
rotary_pos_emb
.
get_cos_sin
(
inference_params
.
max_sequence_length
),
)
else
:
rotary_seq_len
=
self
.
rotary_pos_emb
.
get_rotary_seq_len
(
inference_params
,
self
.
decoder
,
decoder_input
,
self
.
config
,
packed_seq_params
)
rotary_pos_emb
=
self
.
rotary_pos_emb
(
rotary_seq_len
,
packed_seq
=
packed_seq_params
is
not
None
and
packed_seq_params
.
qkv_format
==
'thd'
,
)
if
(
(
self
.
config
.
enable_cuda_graph
or
self
.
config
.
flash_decode
)
and
rotary_pos_cos
is
not
None
and
inference_params
):
):
sequence_len_offset
=
torch
.
tensor
(
"""Builds a computation schedule plan for the model.
[
inference_params
.
sequence_len_offset
]
*
inference_params
.
current_batch_size
,
dtype
=
torch
.
int32
,
This function creates a schedule plan for a model chunk, including
device
=
rotary_pos_cos
.
device
,
# Co-locate this with the rotary tensors
preprocessing, transformer layers, and postprocessing.
)
The schedule plan is used to optimize computation and memory usage
else
:
in distributed environments.
sequence_len_offset
=
None
Args:
# Run decoder.
input_ids (Tensor): Input token IDs.
hidden_states
=
self
.
decoder
(
position_ids (Tensor): Position IDs.
hidden_states
=
decoder_input
,
attention_mask (Tensor): Attention mask.
attention_mask
=
attention_mask
,
decoder_input (Tensor, optional): Decoder input tensor. Defaults to None.
inference_params
=
inference_params
,
labels (Tensor, optional): Labels for loss computation. Defaults to None.
rotary_pos_emb
=
rotary_pos_emb
,
inference_params (InferenceParams, optional):
rotary_pos_cos
=
rotary_pos_cos
,
Parameters for inference. Defaults to None.
rotary_pos_sin
=
rotary_pos_sin
,
packed_seq_params (PackedSeqParams, optional):
packed_seq_params
=
packed_seq_params
,
Parameters for packed sequences. Defaults to None.
sequence_len_offset
=
sequence_len_offset
,
extra_block_kwargs (dict, optional):
**
(
extra_block_kwargs
or
{}),
Additional keyword arguments for blocks. Defaults to None.
)
runtime_gather_output (Optional[bool], optional):
Whether to gather output at runtime. Defaults to None.
# logits and loss
loss_mask (Optional[Tensor], optional): Loss mask. Defaults to None.
output_weight
=
None
if
self
.
share_embeddings_and_output_weights
:
Returns:
output_weight
=
self
.
shared_embedding_or_output_weight
()
ModelChunkSchedulePlan: The model chunk schedule plan.
"""
if
self
.
mtp_process
:
from
.fine_grained_schedule
import
build_model_chunk_schedule_plan
hidden_states
=
self
.
mtp
(
input_ids
=
input_ids
,
return
build_model_chunk_schedule_plan
(
position_ids
=
position_ids
,
self
,
input_ids
,
position_ids
,
attention_mask
,
decoder_input
=
decoder_input
,
labels
=
labels
,
labels
=
labels
,
loss_mask
=
loss_mask
,
inference_params
=
inference_params
,
hidden_states
=
hidden_states
,
packed_seq_params
=
packed_seq_params
,
extra_block_kwargs
=
extra_block_kwargs
,
runtime_gather_output
=
runtime_gather_output
,
)
def
forward
(
self
,
input_ids
:
Tensor
,
position_ids
:
Tensor
,
attention_mask
:
Tensor
,
decoder_input
:
Tensor
=
None
,
labels
:
Tensor
=
None
,
inference_params
:
InferenceParams
=
None
,
packed_seq_params
:
PackedSeqParams
=
None
,
extra_block_kwargs
:
dict
=
None
,
runtime_gather_output
:
Optional
[
bool
]
=
None
,
loss_mask
:
Optional
[
Tensor
]
=
None
,
)
->
Tensor
:
"""Forward function of the GPT Model This function passes the input tensors
through the embedding layer, and then the decoder and finally into the post
processing layer (optional).
It either returns the Loss values if labels are given or the final hidden units
Args:
runtime_gather_output (bool): Gather output at runtime. Default None means
`parallel_output` arg in the constructor will be used.
"""
# If decoder_input is provided (not None), then input_ids and position_ids are ignored.
# Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input.
# Decoder embedding.
if
decoder_input
is
not
None
:
pass
elif
self
.
pre_process
:
decoder_input
=
self
.
embedding
(
input_ids
=
input_ids
,
position_ids
=
position_ids
)
else
:
# intermediate stage of pipeline
# decoder will get hidden_states from encoder.input_tensor
decoder_input
=
None
# Rotary positional embeddings (embedding is None for PP intermediate devices)
rotary_pos_emb
=
None
rotary_pos_cos
=
None
rotary_pos_sin
=
None
if
self
.
position_embedding_type
==
'rope'
and
not
self
.
config
.
multi_latent_attention
:
if
not
self
.
training
and
self
.
config
.
flash_decode
and
inference_params
:
# Flash decoding uses precomputed cos and sin for RoPE
rotary_pos_cos
,
rotary_pos_sin
=
self
.
rotary_pos_emb_cache
.
setdefault
(
inference_params
.
max_sequence_length
,
self
.
rotary_pos_emb
.
get_cos_sin
(
inference_params
.
max_sequence_length
),
)
else
:
rotary_seq_len
=
self
.
rotary_pos_emb
.
get_rotary_seq_len
(
inference_params
,
self
.
decoder
,
decoder_input
,
self
.
config
,
packed_seq_params
)
rotary_pos_emb
=
self
.
rotary_pos_emb
(
rotary_seq_len
,
packed_seq
=
packed_seq_params
is
not
None
and
packed_seq_params
.
qkv_format
==
'thd'
,
)
if
(
(
self
.
config
.
enable_cuda_graph
or
self
.
config
.
flash_decode
)
and
rotary_pos_cos
is
not
None
and
inference_params
):
sequence_len_offset
=
torch
.
tensor
(
[
inference_params
.
sequence_len_offset
]
*
inference_params
.
current_batch_size
,
dtype
=
torch
.
int32
,
device
=
rotary_pos_cos
.
device
,
# Co-locate this with the rotary tensors
)
else
:
sequence_len_offset
=
None
# Run decoder.
hidden_states
=
self
.
decoder
(
hidden_states
=
decoder_input
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
inference_params
=
inference_params
,
inference_params
=
inference_params
,
rotary_pos_emb
=
rotary_pos_emb
,
rotary_pos_emb
=
rotary_pos_emb
,
...
@@ -146,44 +190,66 @@ def gpt_model_forward(
...
@@ -146,44 +190,66 @@ def gpt_model_forward(
rotary_pos_sin
=
rotary_pos_sin
,
rotary_pos_sin
=
rotary_pos_sin
,
packed_seq_params
=
packed_seq_params
,
packed_seq_params
=
packed_seq_params
,
sequence_len_offset
=
sequence_len_offset
,
sequence_len_offset
=
sequence_len_offset
,
embedding
=
self
.
embedding
,
output_layer
=
self
.
output_layer
,
output_weight
=
output_weight
,
runtime_gather_output
=
runtime_gather_output
,
compute_language_model_loss
=
self
.
compute_language_model_loss
,
**
(
extra_block_kwargs
or
{}),
**
(
extra_block_kwargs
or
{}),
)
)
if
(
# logits and loss
self
.
mtp_process
is
not
None
output_weight
=
None
and
getattr
(
self
.
decoder
,
"main_final_layernorm"
,
None
)
is
not
None
if
self
.
share_embeddings_and_output_weights
:
):
output_weight
=
self
.
shared_embedding_or_output_weight
()
# move block main model final norms here
hidden_states
=
self
.
decoder
.
main_final_layernorm
(
hidden_states
)
if
self
.
mtp_process
:
hidden_states
=
self
.
mtp
(
if
not
self
.
post_process
:
input_ids
=
input_ids
,
return
hidden_states
position_ids
=
position_ids
,
labels
=
labels
,
logits
,
_
=
self
.
output_layer
(
loss_mask
=
loss_mask
,
hidden_states
,
weight
=
output_weight
,
runtime_gather_output
=
runtime_gather_output
hidden_states
=
hidden_states
,
)
attention_mask
=
attention_mask
,
inference_params
=
inference_params
,
if
has_config_logger_enabled
(
self
.
config
):
rotary_pos_emb
=
rotary_pos_emb
,
payload
=
OrderedDict
(
rotary_pos_cos
=
rotary_pos_cos
,
{
rotary_pos_sin
=
rotary_pos_sin
,
'input_ids'
:
input_ids
,
packed_seq_params
=
packed_seq_params
,
'position_ids'
:
position_ids
,
sequence_len_offset
=
sequence_len_offset
,
'attention_mask'
:
attention_mask
,
embedding
=
self
.
embedding
,
'decoder_input'
:
decoder_input
,
output_layer
=
self
.
output_layer
,
'logits'
:
logits
,
output_weight
=
output_weight
,
}
runtime_gather_output
=
runtime_gather_output
,
compute_language_model_loss
=
self
.
compute_language_model_loss
,
**
(
extra_block_kwargs
or
{}),
)
if
(
self
.
mtp_process
is
not
None
and
getattr
(
self
.
decoder
,
"main_final_layernorm"
,
None
)
is
not
None
):
# move block main model final norms here
hidden_states
=
self
.
decoder
.
main_final_layernorm
(
hidden_states
)
if
not
self
.
post_process
:
return
hidden_states
logits
,
_
=
self
.
output_layer
(
hidden_states
,
weight
=
output_weight
,
runtime_gather_output
=
runtime_gather_output
)
)
log_config_to_disk
(
self
.
config
,
payload
,
prefix
=
'input_and_logits'
)
if
labels
is
None
:
if
has_config_logger_enabled
(
self
.
config
):
# [s b h] => [b s h]
payload
=
OrderedDict
(
return
logits
.
transpose
(
0
,
1
).
contiguous
()
{
'input_ids'
:
input_ids
,
'position_ids'
:
position_ids
,
'attention_mask'
:
attention_mask
,
'decoder_input'
:
decoder_input
,
'logits'
:
logits
,
}
)
log_config_to_disk
(
self
.
config
,
payload
,
prefix
=
'input_and_logits'
)
if
labels
is
None
:
# [s b h] => [b s h]
return
logits
.
transpose
(
0
,
1
).
contiguous
()
loss
=
self
.
compute_language_model_loss
(
labels
,
logits
)
loss
=
self
.
compute_language_model_loss
(
labels
,
logits
)
return
loss
return
loss
dcu_megatron/core/pipeline_parallel/combined_1f1b.py
0 → 100644
View file @
12b56c98
This diff is collapsed.
Click to expand it.
dcu_megatron/core/transformer/moe/token_dispatcher.py
0 → 100644
View file @
12b56c98
from
megatron.core.transformer.moe.token_dispatcher
import
_DeepepManager
as
MegatronCoreDeepepManager
class
MoEAlltoAllTokenDispatcher
(
MoETokenDispatcher
):
def
token_permutation
(
self
,
hidden_states
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
,
routing_map
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Dispatch tokens to local experts using AlltoAll communication.
This method performs the following steps:
1. Preprocess the routing map to get metadata for communication and permutation.
2. Permute input tokens for AlltoAll communication.
3. Perform expert parallel AlltoAll communication.
4. Sort tokens by local expert (if multiple local experts exist).
Args:
hidden_states (torch.Tensor): Input token embeddings.
probs (torch.Tensor): The probabilities of token to experts assignment.
routing_map (torch.Tensor): The mapping of token to experts assignment.
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- Permuted token embeddings for local experts.
- Number of tokens per expert.
- Permuted probs of each token produced by the router.
"""
# Preprocess: Get the metadata for communication, permutation and computation operations.
self
.
hidden_shape
=
hidden_states
.
shape
self
.
probs
=
probs
self
.
routing_map
=
routing_map
assert
probs
.
dim
()
==
2
,
"Expected 2D tensor for probs"
assert
routing_map
.
dim
()
==
2
,
"Expected 2D tensor for token2expert mask"
assert
routing_map
.
dtype
==
torch
.
bool
,
"Expected bool tensor for mask"
hidden_states
=
hidden_states
.
view
(
-
1
,
self
.
hidden_shape
[
-
1
])
tokens_per_expert
=
self
.
preprocess
(
self
.
routing_map
)
if
self
.
shared_experts
is
not
None
:
self
.
shared_experts
.
pre_forward_comm
(
hidden_states
.
view
(
self
.
hidden_shape
))
# Permutation 1: input to AlltoAll input
tokens_per_expert
=
self
.
_maybe_dtoh_and_synchronize
(
"before_permutation_1"
,
tokens_per_expert
)
self
.
hidden_shape_before_permute
=
hidden_states
.
shape
(
permutated_local_input_tokens
,
permuted_probs
,
self
.
reversed_local_input_permutation_mapping
,
)
=
permute
(
hidden_states
,
routing_map
,
probs
=
probs
,
num_out_tokens
=
self
.
num_out_tokens
,
fused
=
self
.
config
.
moe_permute_fusion
,
drop_and_pad
=
self
.
drop_and_pad
,
)
# Perform expert parallel AlltoAll communication
tokens_per_expert
=
self
.
_maybe_dtoh_and_synchronize
(
"before_ep_alltoall"
,
tokens_per_expert
)
global_input_tokens
=
all_to_all
(
self
.
ep_group
,
permutated_local_input_tokens
,
self
.
output_splits
,
self
.
input_splits
)
global_probs
=
all_to_all
(
self
.
ep_group
,
permuted_probs
,
self
.
output_splits
,
self
.
input_splits
)
if
self
.
shared_experts
is
not
None
:
self
.
shared_experts
.
linear_fc1_forward_and_act
(
global_input_tokens
)
if
self
.
tp_size
>
1
:
if
self
.
output_splits_tp
is
None
:
output_split_sizes
=
None
else
:
output_split_sizes
=
self
.
output_splits_tp
.
tolist
()
global_input_tokens
=
gather_from_sequence_parallel_region
(
global_input_tokens
,
group
=
self
.
tp_group
,
output_split_sizes
=
output_split_sizes
)
global_probs
=
gather_from_sequence_parallel_region
(
global_probs
,
group
=
self
.
tp_group
,
output_split_sizes
=
output_split_sizes
)
# Permutation 2: Sort tokens by local expert.
tokens_per_expert
=
self
.
_maybe_dtoh_and_synchronize
(
"before_permutation_2"
,
tokens_per_expert
)
if
self
.
num_local_experts
>
1
:
if
self
.
drop_and_pad
:
global_input_tokens
=
(
global_input_tokens
.
view
(
self
.
tp_size
*
self
.
ep_size
,
self
.
num_local_experts
,
self
.
capacity
,
*
global_input_tokens
.
size
()[
1
:],
)
.
transpose
(
0
,
1
)
.
contiguous
()
.
flatten
(
start_dim
=
0
,
end_dim
=
2
)
)
global_probs
=
(
global_probs
.
view
(
self
.
tp_size
*
self
.
ep_size
,
self
.
num_local_experts
,
self
.
capacity
,
*
global_probs
.
size
()[
1
:],
)
.
transpose
(
0
,
1
)
.
contiguous
()
.
flatten
(
start_dim
=
0
,
end_dim
=
2
)
)
else
:
global_input_tokens
,
global_probs
=
sort_chunks_by_idxs
(
global_input_tokens
,
self
.
num_global_tokens_per_local_expert
.
ravel
(),
self
.
sort_input_by_local_experts
,
probs
=
global_probs
,
fused
=
self
.
config
.
moe_permute_fusion
,
)
tokens_per_expert
=
self
.
_maybe_dtoh_and_synchronize
(
"before_finish"
,
tokens_per_expert
)
return
global_input_tokens
,
tokens_per_expert
,
global_probs
class
_DeepepManager
(
MegatronCoreDeepepManager
):
"""
patch megatron _DeepepManager. async
"""
def
dispatch
(
self
,
hidden_states
:
torch
.
Tensor
,
async_finish
:
bool
=
False
,
allocate_on_comm_stream
:
bool
=
False
,
)
->
torch
.
Tensor
:
# DeepEP only supports float32 probs
if
self
.
token_probs
.
dtype
!=
torch
.
float32
:
if
self
.
token_probs
.
dtype
in
[
torch
.
bfloat16
,
torch
.
float16
]:
print
(
"DeepEP only supports float32 probs, please set --moe-router-dtype=fp32"
)
self
.
token_probs
=
self
.
token_probs
.
float
()
# downcast or upcast
hidden_states
,
dispatched_indices
,
dispatched_probs
,
num_tokens_per_expert
,
handle
=
(
fused_dispatch
(
hidden_states
,
self
.
token_indices
,
self
.
token_probs
,
self
.
num_experts
,
self
.
group
,
async_finish
=
async_finish
,
allocate_on_comm_stream
=
allocate_on_comm_stream
,
)
)
self
.
handle
=
handle
self
.
tokens_per_expert
=
num_tokens_per_expert
self
.
dispatched_indices
=
dispatched_indices
self
.
dispatched_probs
=
dispatched_probs
return
hidden_states
def
combine
(
self
,
hidden_states
:
torch
.
Tensor
,
async_finish
:
bool
=
False
,
allocate_on_comm_stream
:
bool
=
False
,
)
->
torch
.
Tensor
:
hidden_states
,
_
=
fused_combine
(
hidden_states
,
self
.
group
,
self
.
handle
,
async_finish
=
async_finish
,
allocate_on_comm_stream
=
allocate_on_comm_stream
,
)
# Release the handle after combine operation
self
.
handle
=
None
return
hidden_states
class
MoEFlexTokenDispatcher
(
MoETokenDispatcher
):
"""
Flex token dispatcher using DeepEP.
"""
def
dispatch_preprocess
(
self
,
hidden_states
:
torch
.
Tensor
,
routing_map
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
):
"""
Preprocesses the hidden states and routing information before dispatching tokens to experts.
Args:
hidden_states (torch.Tensor): Input hidden states to be processed
routing_map (torch.Tensor): Map indicating which expert each token should be routed to
probs (torch.Tensor): Routing probabilities for each token-expert pair
Returns:
Tuple containing:
- torch.Tensor: Reshaped hidden states
- torch.Tensor: Token probabilities from the communication manager
- None: Placeholder for compatibility
"""
self
.
hidden_shape
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
self
.
hidden_shape
[
-
1
])
# Initialize metadata
routing_map
,
probs
=
self
.
_initialize_metadata
(
routing_map
,
probs
)
self
.
_comm_manager
.
setup_metadata
(
routing_map
,
probs
)
return
hidden_states
,
self
.
_comm_manager
.
token_probs
,
None
def
dispatch_all_to_all
(
self
,
hidden_states
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
=
None
,
async_finish
:
bool
=
True
,
allocate_on_comm_stream
:
bool
=
True
,
):
"""
Performs all-to-all communication to dispatch tokens across expert parallel ranks.
"""
return
(
self
.
_comm_manager
.
dispatch
(
hidden_states
,
async_finish
,
allocate_on_comm_stream
),
self
.
_comm_manager
.
dispatched_probs
,
)
def
dispatch_postprocess
(
self
,
hidden_states
:
torch
.
Tensor
):
"""
Post-processes the dispatched hidden states after all-to-all communication.
This method retrieves the permuted hidden states by experts, calculates the number of tokens
per expert, and returns the processed data ready for expert processing.
"""
global_input_tokens
,
permuted_probs
=
(
self
.
_comm_manager
.
get_permuted_hidden_states_by_experts
(
hidden_states
)
)
tokens_per_expert
=
self
.
_comm_manager
.
get_number_of_tokens_per_expert
()
return
global_input_tokens
,
tokens_per_expert
,
permuted_probs
def
token_permutation
(
self
,
hidden_states
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
,
routing_map
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Permutes tokens according to the routing map and dispatches them to experts.
This method implements the token permutation process in three steps:
1. Preprocess the hidden states and routing information
2. Perform all-to-all communication to dispatch tokens
3. Post-process the dispatched tokens for expert processing
"""
hidden_states
,
_
,
_
=
self
.
dispatch_preprocess
(
hidden_states
,
routing_map
,
probs
)
hidden_states
,
_
=
self
.
dispatch_all_to_all
(
hidden_states
,
async_finish
=
False
,
allocate_on_comm_stream
=
False
)
global_input_tokens
,
tokens_per_expert
,
permuted_probs
=
self
.
dispatch_postprocess
(
hidden_states
)
return
global_input_tokens
,
tokens_per_expert
,
permuted_probs
def
combine_preprocess
(
self
,
hidden_states
:
torch
.
Tensor
):
"""
Pre-processes the hidden states before combining them after expert processing.
This method restores the hidden states to their original ordering before expert processing
by using the communication manager's restoration function.
"""
hidden_states
=
self
.
_comm_manager
.
get_restored_hidden_states_by_experts
(
hidden_states
)
return
hidden_states
def
combine_all_to_all
(
self
,
hidden_states
:
torch
.
Tensor
,
async_finish
:
bool
=
True
,
allocate_on_comm_stream
:
bool
=
True
,
):
"""
Performs all-to-all communication to combine tokens after expert processing.
"""
return
self
.
_comm_manager
.
combine
(
hidden_states
,
async_finish
,
allocate_on_comm_stream
)
def
combine_postprocess
(
self
,
hidden_states
:
torch
.
Tensor
):
"""
Post-processes the combined hidden states after all-to-all communication.
This method reshapes the combined hidden states to match the original input shape.
"""
return
hidden_states
.
view
(
self
.
hidden_shape
)
def
token_unpermutation
(
self
,
hidden_states
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
"""
Reverses the token permutation process to restore the original token order.
This method implements the token unpermutation process in three steps:
1. Pre-process the hidden states to restore their original ordering
2. Perform all-to-all communication to combine tokens
3. Post-process the combined tokens to match the original input shape
"""
assert
bias
is
None
,
"Bias is not supported in MoEFlexTokenDispatcher"
hidden_states
=
self
.
combine_preprocess
(
hidden_states
)
hidden_states
=
self
.
combine_all_to_all
(
hidden_states
,
False
,
False
)
hidden_states
=
self
.
combine_postprocess
(
hidden_states
)
return
hidden_states
,
None
dcu_megatron/core/transformer/transformer_block.py
View file @
12b56c98
from
functools
import
wraps
from
functools
import
wraps
from
megatron.core.transformer.transformer_block
import
TransformerBlock
as
MegatronCoreTransformerBlock
def
transformer_block_init_wrapper
(
fn
):
def
transformer_block_init_wrapper
(
fn
):
@
wraps
(
fn
)
@
wraps
(
fn
)
...
@@ -13,3 +14,22 @@ def transformer_block_init_wrapper(fn):
...
@@ -13,3 +14,22 @@ def transformer_block_init_wrapper(fn):
self
.
final_layernorm
=
None
self
.
final_layernorm
=
None
return
wrapper
return
wrapper
class
TransformerBlock
(
MegatronCoreTransformerBlock
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
# mtp require seperate layernorms for main model and mtp modules, thus move finalnorm out of block
config
=
args
[
0
]
if
len
(
args
)
>
1
else
kwargs
[
'config'
]
if
getattr
(
config
,
"mtp_num_layers"
,
0
)
>
0
:
self
.
main_final_layernorm
=
self
.
final_layernorm
self
.
final_layernorm
=
None
def
get_layer_callables
(
self
,
layer_number
:
int
):
"""
Get the callables for the layer at the given layer number.
"""
return
self
.
layers
[
layer_number
].
get_submodule_callables
()
dcu_megatron/core/transformer/transformer_layer.py
0 → 100644
View file @
12b56c98
from
megatron.core.transformer.transformer_layer
import
TransformerLayer
as
MegatronCoreTransformerLayer
class
TransformerLayer
(
MegatronCoreTransformerLayer
):
def
_submodule_attn_router_forward
(
self
,
hidden_states
,
attention_mask
=
None
,
inference_params
=
None
,
rotary_pos_emb
=
None
,
rotary_pos_cos
=
None
,
rotary_pos_sin
=
None
,
attention_bias
=
None
,
packed_seq_params
=
None
,
sequence_len_offset
=
None
,
state
=
None
,
):
"""
Performs a combined forward pass that includes self-attention and MLP routing logic.
"""
hidden_states
,
_
=
self
.
_forward_attention
(
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
rotary_pos_emb
=
rotary_pos_emb
,
rotary_pos_cos
=
rotary_pos_cos
,
rotary_pos_sin
=
rotary_pos_sin
,
attention_bias
=
attention_bias
,
packed_seq_params
=
packed_seq_params
,
sequence_len_offset
=
sequence_len_offset
,
inference_params
=
inference_params
,
)
pre_mlp_layernorm_output
=
self
.
pre_mlp_layernorm
(
hidden_states
)
probs
,
routing_map
=
self
.
mlp
.
router
(
pre_mlp_layernorm_output
)
local_tokens
,
probs
,
tokens_per_expert
=
self
.
mlp
.
token_dispatcher
.
dispatch_preprocess
(
pre_mlp_layernorm_output
,
routing_map
,
probs
)
return
(
local_tokens
,
probs
,
hidden_states
,
pre_mlp_layernorm_output
,
tokens_per_expert
)
def
_submodule_dispatch_forward
(
self
,
local_tokens
,
probs
,
state
=
None
):
"""
Dispatches tokens to the appropriate experts based on the router output.
"""
token_dispatcher
=
self
.
mlp
.
token_dispatcher
if
self
.
is_deepep
:
token_dispatcher
.
_comm_manager
.
token_probs
=
probs
return
token_dispatcher
.
dispatch_all_to_all
(
local_tokens
,
probs
)
def
_submodule_moe_forward
(
self
,
dispatched_tokens
,
probs
=
None
,
state
=
None
):
"""
Performs a forward pass for the MLP submodule, including both expert-based
and optional shared-expert computations.
"""
shared_expert_output
=
None
token_dispatcher
=
self
.
mlp
.
token_dispatcher
if
self
.
is_deepep
:
token_dispatcher
.
_comm_manager
.
dispatched_probs
=
state
.
dispatched_probs
dispatched_tokens
,
tokens_per_expert
,
permuted_probs
=
(
token_dispatcher
.
dispatch_postprocess
(
dispatched_tokens
)
)
else
:
dispatched_tokens
,
permuted_probs
=
token_dispatcher
.
dispatch_postprocess
(
dispatched_tokens
,
probs
)
tokens_per_expert
=
state
.
tokens_per_expert
expert_output
,
mlp_bias
=
self
.
mlp
.
experts
(
dispatched_tokens
,
tokens_per_expert
,
permuted_probs
)
assert
mlp_bias
is
None
,
f
"Bias is not supported in
{
token_dispatcher
.
__class__
.
__name__
}
"
if
self
.
mlp
.
use_shared_expert
and
not
self
.
mlp
.
shared_expert_overlap
:
shared_expert_output
=
self
.
mlp
.
shared_experts
(
state
.
pre_mlp_layernorm_output
)
expert_output
=
self
.
mlp
.
token_dispatcher
.
combine_preprocess
(
expert_output
)
return
expert_output
,
shared_expert_output
,
mlp_bias
def
_submodule_combine_forward
(
self
,
output
,
shared_expert_output
=
None
,
state
=
None
):
residual
=
state
.
residual
token_dispatcher
=
self
.
mlp
.
token_dispatcher
output
=
token_dispatcher
.
combine_all_to_all
(
output
)
output
=
token_dispatcher
.
combine_postprocess
(
output
)
if
shared_expert_output
is
not
None
:
output
=
output
+
shared_expert_output
mlp_output_with_bias
=
(
output
,
None
)
with
self
.
bias_dropout_add_exec_handler
():
hidden_states
=
self
.
mlp_bda
(
self
.
training
,
self
.
config
.
bias_dropout_fusion
)(
mlp_output_with_bias
,
residual
,
self
.
hidden_dropout
)
output
=
make_viewless_tensor
(
inp
=
hidden_states
,
requires_grad
=
hidden_states
.
requires_grad
,
keep_graph
=
True
)
return
output
def
_submodule_attn_router_dw
(
self
):
self
.
self_attention
.
backward_dw
()
def
_submodule_mlp_dw
(
self
):
self
.
mlp
.
backward_dw
()
def
_submodule_attn_router_postprocess
(
self
,
node
,
local_tokens
,
probs
,
residual
,
pre_mlp_layernorm_output
,
tokens_per_expert
):
node
.
common_state
.
residual
=
node
.
detach
(
residual
)
if
self
.
mlp
.
use_shared_expert
:
node
.
common_state
.
pre_mlp_layernorm_output
=
node
.
detach
(
pre_mlp_layernorm_output
)
if
not
self
.
is_deepep
:
node
.
common_state
.
tokens_per_expert
=
tokens_per_expert
return
local_tokens
,
probs
def
_submodule_dispatch_postprocess
(
self
,
node
,
dispatched_tokens
,
probs
):
if
self
.
is_deepep
:
node
.
common_state
.
dispatched_probs
=
node
.
detach
(
probs
)
return
dispatched_tokens
else
:
return
dispatched_tokens
,
probs
def
_submodule_mlp_postprocess
(
self
,
node
,
expert_output
,
shared_expert_output
,
mlp_bias
):
assert
mlp_bias
is
None
node
.
common_state
.
pre_mlp_layernorm_output
=
None
if
shared_expert_output
is
None
:
return
expert_output
return
expert_output
,
shared_expert_output
def
_submodule_combine_postprocess
(
self
,
node
,
output
):
cur_stream
=
torch
.
cuda
.
current_stream
()
node
.
common_state
.
residual
.
record_stream
(
cur_stream
)
node
.
common_state
.
residual
=
None
return
output
def
_submodule_attn_postprocess
(
self
,
node
,
hidden_states
,
context
):
return
hidden_states
def
_submodule_dense_postprocess
(
self
,
node
,
hidden_states
):
return
hidden_states
def
_submodule_not_implemented
(
self
,
*
args
):
raise
NotImplementedError
(
"This callable is not implemented."
)
def
get_submodule_callables
(
self
,
chunk_state
):
"""
The forward callables take 2 parts of inputs:
1. The ScheduleNode object.
2. The input tensors.
"""
from
megatron.core.transformer.moe.moe_layer
import
MoELayer
from
megatron.core.transformer.moe.token_dispatcher
import
MoEFlexTokenDispatcher
self
.
is_moe
=
isinstance
(
self
.
mlp
,
MoELayer
)
self
.
is_deepep
=
False
if
self
.
is_moe
:
self
.
is_deepep
=
isinstance
(
self
.
mlp
.
token_dispatcher
,
MoEFlexTokenDispatcher
)
def
get_func_with_default
(
func
,
default_func
):
if
self
.
is_moe
:
return
func
return
default_func
def
callable_wrapper
(
forward_func
,
postprocess_func
,
node
,
*
args
):
state
=
getattr
(
node
,
'common_state'
,
None
)
callable_outputs
=
forward_func
(
*
args
,
state
=
state
)
if
isinstance
(
callable_outputs
,
tuple
):
outputs
=
postprocess_func
(
node
,
*
callable_outputs
)
else
:
outputs
=
postprocess_func
(
node
,
callable_outputs
)
return
outputs
attn_func
=
get_func_with_default
(
self
.
_submodule_attn_router_forward
,
self
.
_forward_attention
)
def
attn_wrapper
(
hidden_states
,
state
=
None
):
return
attn_func
(
hidden_states
=
hidden_states
,
attention_mask
=
chunk_state
.
attention_mask
,
attention_bias
=
chunk_state
.
attention_bias
,
inference_params
=
chunk_state
.
inference_params
,
packed_seq_params
=
chunk_state
.
packed_seq_params
,
sequence_len_offset
=
chunk_state
.
sequence_len_offset
,
rotary_pos_emb
=
chunk_state
.
rotary_pos_emb
,
rotary_pos_cos
=
chunk_state
.
rotary_pos_cos
,
rotary_pos_sin
=
chunk_state
.
rotary_pos_sin
,
state
=
state
,
)
attn_postprocess_func
=
get_func_with_default
(
self
.
_submodule_attn_router_postprocess
,
self
.
_submodule_attn_postprocess
)
dispatch_func
=
get_func_with_default
(
self
.
_submodule_dispatch_forward
,
self
.
_submodule_not_implemented
)
dispatch_postprocess_func
=
get_func_with_default
(
self
.
_submodule_dispatch_postprocess
,
self
.
_submodule_not_implemented
)
mlp_func
=
get_func_with_default
(
self
.
_submodule_moe_forward
,
self
.
_forward_mlp
)
mlp_postprocess_func
=
get_func_with_default
(
self
.
_submodule_mlp_postprocess
,
self
.
_submodule_dense_postprocess
)
combine_func
=
get_func_with_default
(
self
.
_submodule_combine_forward
,
self
.
_submodule_not_implemented
)
combine_postprocess_func
=
get_func_with_default
(
self
.
_submodule_combine_postprocess
,
self
.
_submodule_not_implemented
)
attn_forward
=
partial
(
callable_wrapper
,
attn_wrapper
,
attn_postprocess_func
)
dispatch_forward
=
partial
(
callable_wrapper
,
dispatch_func
,
dispatch_postprocess_func
)
mlp_forward
=
partial
(
callable_wrapper
,
mlp_func
,
mlp_postprocess_func
)
combine_forward
=
partial
(
callable_wrapper
,
combine_func
,
combine_postprocess_func
)
callables
=
TransformerLayerSubmoduleCallables
(
attention
=
SubmoduleCallables
(
forward
=
attn_forward
,
dw
=
self
.
_submodule_attn_router_dw
),
dispatch
=
SubmoduleCallables
(
forward
=
dispatch_forward
),
mlp
=
SubmoduleCallables
(
forward
=
mlp_forward
,
dw
=
self
.
_submodule_mlp_dw
),
combine
=
SubmoduleCallables
(
forward
=
combine_forward
),
is_moe
=
self
.
is_moe
,
is_deepep
=
self
.
is_deepep
,
)
return
callables
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment