Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
32ee381a
Commit
32ee381a
authored
May 06, 2025
by
dongcl
Browse files
a2a overlap
parent
12b56c98
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
1025 additions
and
879 deletions
+1025
-879
dcu_megatron/core/models/gpt/fine_grained_schedule.py
dcu_megatron/core/models/gpt/fine_grained_schedule.py
+342
-320
dcu_megatron/core/models/gpt/gpt_model.py
dcu_megatron/core/models/gpt/gpt_model.py
+59
-11
dcu_megatron/core/pipeline_parallel/combined_1f1b.py
dcu_megatron/core/pipeline_parallel/combined_1f1b.py
+122
-216
dcu_megatron/core/transformer/moe/token_dispatcher.py
dcu_megatron/core/transformer/moe/token_dispatcher.py
+191
-182
dcu_megatron/core/transformer/transformer_layer.py
dcu_megatron/core/transformer/transformer_layer.py
+279
-150
dcu_megatron/core/transformer/utils.py
dcu_megatron/core/transformer/utils.py
+32
-0
No files found.
dcu_megatron/core/models/gpt/fine_grained_schedule.py
View file @
32ee381a
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
import
contextlib
import
contextlib
import
weakref
import
weakref
from
typing
import
Optional
from
typing
import
Any
,
Callable
,
Optional
,
Tuple
,
Union
import
torch
import
torch
from
torch
import
Tensor
from
torch
import
Tensor
from
megatron.core.pipeline_parallel.combined_1f1b
import
(
from
megatron.core.pipeline_parallel.combined_1f1b
import
(
AbstractSchedulePlan
,
AbstractSchedulePlan
,
FakeScheduleNode
,
FreeInputsMemoryStrategy
,
NoOpMemoryStrategy
,
ScheduleNode
,
ScheduleNode
,
get_com_stream
,
get_com_stream
,
get_comp_stream
,
get_comp_stream
,
...
@@ -19,15 +14,11 @@ from megatron.core.pipeline_parallel.combined_1f1b import (
...
@@ -19,15 +14,11 @@ from megatron.core.pipeline_parallel.combined_1f1b import (
)
)
from
megatron.core.transformer
import
transformer_layer
from
megatron.core.transformer
import
transformer_layer
from
megatron.core.transformer.module
import
float16_to_fp32
from
megatron.core.transformer.module
import
float16_to_fp32
from
megatron.core.transformer.moe.moe_layer
import
MoELayer
from
megatron.core.transformer.moe.token_dispatcher
import
MoEAlltoAllPerBatchState
def
weak_method
(
method
):
def
weak_method
(
method
):
"""Creates a weak reference to a method to prevent circular references.
This function creates a weak reference to a method and returns a wrapper function
that calls the method when invoked. This helps prevent memory leaks from circular
references.
"""
method_ref
=
weakref
.
WeakMethod
(
method
)
method_ref
=
weakref
.
WeakMethod
(
method
)
del
method
del
method
...
@@ -38,78 +29,24 @@ def weak_method(method):
...
@@ -38,78 +29,24 @@ def weak_method(method):
return
wrapped_func
return
wrapped_func
class
MemoryStrategyRegistry
:
"""Registry for memory management strategies based on node names.
This class centralizes the definition of which memory strategy
should be used for each type of node in the computation graph.
"""
@
classmethod
def
get_strategy_by_name
(
cls
,
name
,
is_moe
,
is_deepep
):
"""Gets the appropriate memory strategy for a node based on its name and MoE status.
Args:
name: The name of the node, which determines which strategy to use.
is_moe: Whether the node is part of a Mixture of Experts model.
Returns:
The memory strategy to use for the node.
"""
strategies
=
{
"default"
:
NoOpMemoryStrategy
(),
"attn"
:
NoOpMemoryStrategy
(),
# Attention nodes keep their inputs
"dispatch"
:
(
FreeInputsMemoryStrategy
()
if
not
is_deepep
else
NoOpMemoryStrategy
()
),
# deepep dispatch inputs share same storage with moe inputs
"mlp"
:
FreeInputsMemoryStrategy
(),
# MLP nodes free inputs after use
"combine"
:
FreeInputsMemoryStrategy
(),
# Combine nodes free inputs after use
}
if
is_moe
:
return
strategies
.
get
(
name
,
strategies
[
"default"
])
# For dense layers [attn, fake, mlp, fake], the inputs of mlp are required for backward
return
NoOpMemoryStrategy
()
class
PreProcessNode
(
ScheduleNode
):
class
PreProcessNode
(
ScheduleNode
):
"""Node responsible for preprocessing operations in the model.
This node handles embedding and rotary positional embedding computations
before the main transformer layers.
"""
def
__init__
(
self
,
gpt_model
,
model_chunk_state
,
event
,
stream
):
def
__init__
(
self
,
gpt_model
,
model_chunk_state
,
event
,
stream
):
"""Initializes a preprocessing node.
super
().
__init__
(
weak_method
(
self
.
forward_impl
),
stream
,
event
)
Args:
gpt_model: The GPT model instance.
model_chunk_state: State shared across the model chunk.
event: CUDA event for synchronization.
stream: CUDA stream for execution.
"""
super
().
__init__
(
weak_method
(
self
.
forward_impl
),
stream
,
event
,
name
=
"pre_process"
)
self
.
gpt_model
=
gpt_model
self
.
gpt_model
=
gpt_model
self
.
model_chunk_state
=
model_chunk_state
self
.
model_chunk_state
=
model_chunk_state
def
forward_impl
(
self
):
def
forward_impl
(
self
):
"""Implements the forward pass for preprocessing.
This method handles:
1. Decoder embedding computation
2. Rotary positional embedding computation
3. Sequence length offset computation for flash decoding
Returns:
The processed decoder input tensor.
"""
gpt_model
=
self
.
gpt_model
gpt_model
=
self
.
gpt_model
decoder_input
=
self
.
model_chunk_state
.
decoder_input
decoder_input
=
self
.
model_chunk_state
.
decoder_input
input_ids
=
self
.
model_chunk_state
.
input_ids
input_ids
=
self
.
model_chunk_state
.
input_ids
position_ids
=
self
.
model_chunk_state
.
position_ids
position_ids
=
self
.
model_chunk_state
.
position_ids
inference_
params
=
self
.
model_chunk_state
.
inference_
params
inference_
context
=
self
.
model_chunk_state
.
inference_
context
packed_seq_params
=
self
.
model_chunk_state
.
packed_seq_params
packed_seq_params
=
self
.
model_chunk_state
.
packed_seq_params
inference_context
=
deprecate_inference_params
(
inference_context
,
inference_params
)
# Decoder embedding.
# Decoder embedding.
if
decoder_input
is
not
None
:
if
decoder_input
is
not
None
:
pass
pass
...
@@ -118,42 +55,51 @@ class PreProcessNode(ScheduleNode):
...
@@ -118,42 +55,51 @@ class PreProcessNode(ScheduleNode):
else
:
else
:
# intermediate stage of pipeline
# intermediate stage of pipeline
# decoder will get hidden_states from encoder.input_tensor
# decoder will get hidden_states from encoder.input_tensor
# TODO(dongcl)
decoder_input
=
gpt_model
.
decoder
.
input_tensor
decoder_input
=
gpt_model
.
decoder
.
input_tensor
# Rotary positional embeddings (embedding is None for PP intermediate devices)
# Rotary positional embeddings (embedding is None for PP intermediate devices)
rotary_pos_emb
=
None
rotary_pos_emb
=
None
rotary_pos_cos
=
None
rotary_pos_cos
=
None
rotary_pos_sin
=
None
rotary_pos_sin
=
None
if
(
if
gpt_model
.
position_embedding_type
==
'rope'
and
not
gpt_model
.
config
.
multi_latent_attention
:
gpt_model
.
position_embedding_type
==
'rope'
if
not
gpt_model
.
training
and
gpt_model
.
config
.
flash_decode
and
inference_context
:
and
not
gpt_model
.
config
.
multi_latent_attention
assert
(
)
:
inference_context
.
is_static_batching
(
)
if
not
gpt_model
.
training
and
gpt_model
.
config
.
flash_decode
and
inference_params
:
),
"GPTModel currently only supports static inference batching."
# Flash decoding uses precomputed cos and sin for RoPE
# Flash decoding uses precomputed cos and sin for RoPE
rotary_pos_cos
,
rotary_pos_sin
=
gpt_model
.
rotary_pos_emb_cache
.
setdefault
(
rotary_pos_cos
,
rotary_pos_sin
=
gpt_model
.
rotary_pos_emb_cache
.
setdefault
(
inference_
params
.
max_sequence_length
,
inference_
context
.
max_sequence_length
,
gpt_model
.
rotary_pos_emb
.
get_cos_sin
(
inference_
params
.
max_sequence_length
),
gpt_model
.
rotary_pos_emb
.
get_cos_sin
(
inference_
context
.
max_sequence_length
),
)
)
else
:
else
:
rotary_seq_len
=
gpt_model
.
rotary_pos_emb
.
get_rotary_seq_len
(
rotary_seq_len
=
gpt_model
.
rotary_pos_emb
.
get_rotary_seq_len
(
inference_params
,
inference_context
,
gpt_model
.
decoder
,
decoder_input
,
gpt_model
.
config
,
packed_seq_params
gpt_model
.
decoder
,
decoder_input
,
gpt_model
.
config
,
packed_seq_params
,
)
)
rotary_pos_emb
=
gpt_model
.
rotary_pos_emb
(
rotary_pos_emb
=
gpt_model
.
rotary_pos_emb
(
rotary_seq_len
,
rotary_seq_len
,
packed_seq
=
packed_seq_params
is
not
None
packed_seq
=
packed_seq_params
is
not
None
and
packed_seq_params
.
qkv_format
==
'thd'
,
and
packed_seq_params
.
qkv_format
==
'thd'
,
)
)
elif
gpt_model
.
position_embedding_type
==
'mrope'
and
not
gpt_model
.
config
.
multi_latent_attention
:
if
gpt_model
.
training
or
not
gpt_model
.
config
.
flash_decode
:
rotary_pos_emb
=
gpt_model
.
rotary_pos_emb
(
position_ids
,
gpt_model
.
mrope_section
)
else
:
# Flash decoding uses precomputed cos and sin for RoPE
raise
NotImplementedError
(
"Flash decoding uses precomputed cos and sin for RoPE, not implmented in "
"MultimodalRotaryEmbedding yet."
)
if
(
if
(
(
gpt_model
.
config
.
enable_cuda_graph
or
gpt_model
.
config
.
flash_decode
)
(
gpt_model
.
config
.
enable_cuda_graph
or
gpt_model
.
config
.
flash_decode
)
and
rotary_pos_cos
is
not
None
and
rotary_pos_cos
is
not
None
and
inference_params
and
inference_context
and
inference_context
.
is_static_batching
()
and
not
gpt_model
.
training
):
):
sequence_len_offset
=
torch
.
tensor
(
sequence_len_offset
=
torch
.
tensor
(
[
inference_
params
.
sequence_len_offset
]
*
inference_
params
.
current_batch_size
,
[
inference_
context
.
sequence_len_offset
]
*
inference_
context
.
current_batch_size
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
rotary_pos_cos
.
device
,
# Co-locate this with the rotary tensors
device
=
rotary_pos_cos
.
device
,
# Co-locate this with the rotary tensors
)
)
...
@@ -169,42 +115,48 @@ class PreProcessNode(ScheduleNode):
...
@@ -169,42 +115,48 @@ class PreProcessNode(ScheduleNode):
class
PostProcessNode
(
ScheduleNode
):
class
PostProcessNode
(
ScheduleNode
):
"""Node responsible for postprocessing operations in the model.
This node handles final layer normalization and output layer computation
after the main transformer layers.
"""
def
__init__
(
self
,
gpt_model
,
model_chunk_state
,
event
,
stream
):
def
__init__
(
self
,
gpt_model
,
model_chunk_state
,
event
,
stream
):
"""Initializes a postprocessing node.
super
().
__init__
(
weak_method
(
self
.
forward_impl
),
stream
,
event
)
Args:
gpt_model: The GPT model instance.
model_chunk_state: State shared across the model chunk.
event: CUDA event for synchronization.
stream: CUDA stream for execution.
"""
super
().
__init__
(
weak_method
(
self
.
forward_impl
),
stream
,
event
,
name
=
"post_process"
)
self
.
gpt_model
=
gpt_model
self
.
gpt_model
=
gpt_model
self
.
model_chunk_state
=
model_chunk_state
self
.
model_chunk_state
=
model_chunk_state
def
forward_impl
(
self
,
hidden_states
):
"""Implements the forward pass for postprocessing.
This method handles:
state
.
input_ids
=
input_ids
1. Final layer normalization
state
.
position_ids
=
position_ids
2. Output layer computation
state
.
attention_mask
=
attention_mask
3. Loss computation if labels are provided
state
.
decoder_input
=
decoder_input
state
.
labels
=
labels
state
.
inference_context
=
inference_context
state
.
packed_seq_params
=
packed_seq_params
state
.
extra_block_kwargs
=
extra_block_kwargs
state
.
runtime_gather_output
=
runtime_gather_output
state
.
inference_params
=
inference_params
state
.
loss_mask
=
loss_mask
state
.
context
=
None
state
.
context_mask
=
None
state
.
attention_bias
=
None
Args:
def
forward_impl
(
self
,
hidden_states
):
hidden_states: The hidden states from the transformer layers.
gpt_model
=
self
.
gpt_model
input_ids
=
self
.
model_chunk_state
.
input_ids
position_ids
=
self
.
model_chunk_state
.
position_ids
labels
=
self
.
model_chunk_state
.
labels
loss_mask
=
self
.
model_chunk_state
.
loss_mask
attention_mask
=
self
.
model_chunk_state
.
attention_mask
inference_params
=
self
.
model_chunk_state
.
inference_params
rotary_pos_emb
=
self
.
model_chunk_state
.
rotary_pos_emb
rotary_pos_cos
=
self
.
model_chunk_state
.
rotary_pos_cos
rotary_pos_sin
=
self
.
model_chunk_state
.
rotary_pos_sin
packed_seq_params
=
self
.
model_chunk_state
.
packed_seq_params
sequence_len_offset
=
self
.
model_chunk_state
.
sequence_len_offset
runtime_gather_output
=
self
.
model_chunk_state
.
runtime_gather_output
inference_context
=
self
.
model_chunk_state
.
inference_context
Returns:
The logits or loss depending on whether labels are provided.
"""
# Final layer norm.
# Final layer norm.
if
self
.
gpt_model
.
decoder
.
final_layernorm
is
not
None
:
if
gpt_model
.
decoder
.
final_layernorm
is
not
None
:
hidden_states
=
self
.
gpt_model
.
decoder
.
final_layernorm
(
hidden_states
)
hidden_states
=
gpt_model
.
decoder
.
final_layernorm
(
hidden_states
)
# TENorm produces a "viewed" tensor. This will result in schedule.py's
# TENorm produces a "viewed" tensor. This will result in schedule.py's
# deallocate_output_tensor() throwing an error, so a viewless tensor is
# deallocate_output_tensor() throwing an error, so a viewless tensor is
# created to prevent this.
# created to prevent this.
...
@@ -212,73 +164,108 @@ class PostProcessNode(ScheduleNode):
...
@@ -212,73 +164,108 @@ class PostProcessNode(ScheduleNode):
inp
=
hidden_states
,
requires_grad
=
True
,
keep_graph
=
True
inp
=
hidden_states
,
requires_grad
=
True
,
keep_graph
=
True
)
)
gpt_model
=
self
.
gpt_model
# Process inference output.
runtime_gather_output
=
self
.
model_chunk_state
.
runtime_gather_output
if
inference_context
and
not
inference_context
.
is_static_batching
():
labels
=
self
.
model_chunk_state
.
labels
hidden_states
=
inference_context
.
last_token_logits
(
hidden_states
.
squeeze
(
1
).
unsqueeze
(
0
)
).
unsqueeze
(
1
)
# logits and loss
output_weight
=
None
output_weight
=
None
if
gpt_model
.
share_embeddings_and_output_weights
:
if
gpt_model
.
share_embeddings_and_output_weights
:
output_weight
=
gpt_model
.
shared_embedding_or_output_weight
()
output_weight
=
gpt_model
.
shared_embedding_or_output_weight
()
if
gpt_model
.
mtp_process
:
hidden_states
=
gpt_model
.
mtp
(
input_ids
=
input_ids
,
position_ids
=
position_ids
,
labels
=
labels
,
loss_mask
=
loss_mask
,
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
inference_params
=
inference_params
,
rotary_pos_emb
=
rotary_pos_emb
,
rotary_pos_cos
=
rotary_pos_cos
,
rotary_pos_sin
=
rotary_pos_sin
,
packed_seq_params
=
packed_seq_params
,
sequence_len_offset
=
sequence_len_offset
,
embedding
=
gpt_model
.
embedding
,
output_layer
=
gpt_model
.
output_layer
,
output_weight
=
output_weight
,
runtime_gather_output
=
runtime_gather_output
,
compute_language_model_loss
=
gpt_model
.
compute_language_model_loss
,
**
(
extra_block_kwargs
or
{}),
)
if
(
gpt_model
.
mtp_process
is
not
None
and
getattr
(
gpt_model
.
decoder
,
"main_final_layernorm"
,
None
)
is
not
None
):
# move block main model final norms here
hidden_states
=
gpt_model
.
decoder
.
main_final_layernorm
(
hidden_states
)
if
not
gpt_model
.
post_process
:
return
hidden_states
if
(
not
gpt_model
.
training
and
inference_context
is
not
None
and
inference_context
.
is_static_batching
()
and
inference_context
.
materialize_only_last_token_logits
):
hidden_states
=
hidden_states
[
-
1
:,
:,
:]
logits
,
_
=
gpt_model
.
output_layer
(
logits
,
_
=
gpt_model
.
output_layer
(
hidden_states
,
weight
=
output_weight
,
runtime_gather_output
=
runtime_gather_output
hidden_states
,
weight
=
output_weight
,
runtime_gather_output
=
runtime_gather_output
)
)
if
has_config_logger_enabled
(
gpt_model
.
config
):
payload
=
OrderedDict
(
{
'input_ids'
:
input_ids
,
'position_ids'
:
position_ids
,
'attention_mask'
:
attention_mask
,
'decoder_input'
:
decoder_input
,
'logits'
:
logits
,
}
)
log_config_to_disk
(
gpt_model
.
config
,
payload
,
prefix
=
'input_and_logits'
)
if
labels
is
None
:
if
labels
is
None
:
# [s b h] => [b s h]
# [s b h] => [b s h]
return
float16_to_fp32
(
logits
.
transpose
(
0
,
1
).
contiguous
())
return
logits
.
transpose
(
0
,
1
).
contiguous
()
loss
=
float16_to_fp32
(
gpt_model
.
compute_language_model_loss
(
labels
,
logits
))
loss
=
gpt_model
.
compute_language_model_loss
(
labels
,
logits
)
return
loss
return
loss
class
TransformerLayerNode
(
ScheduleNode
):
class
TransformerLayerNode
(
ScheduleNode
):
"""Base class for transformer layer computation nodes.
This class provides common functionality for different types of
transformer layer nodes (attention, MLP, etc.)
"""
def
__init__
(
self
,
stream
,
event
,
state
,
callables
,
name
=
"default"
):
"""Initialize a transformer layer node.
Args:
stream (torch.cuda.Stream): CUDA stream for execution
event (torch.cuda.Event): Synchronization event
common_state (TransformerLayerState): State shared within a transformer layer
callables (Callable): The callables contain forward and dw function
it's the per_batch_state_context, o.w. nullcontext
name (str): Node name, also used to determine memory strategy
"""
# Get memory strategy based on node name
memory_strategy
=
MemoryStrategyRegistry
.
get_strategy_by_name
(
name
,
callables
.
is_moe
,
callables
.
is_deepep
)
def
__init__
(
self
,
chunk_state
,
common_state
,
layer
,
stream
,
event
,
free_inputs
=
False
):
super
().
__init__
(
super
().
__init__
(
weak_method
(
self
.
forward_impl
),
weak_method
(
self
.
forward_impl
),
stream
,
stream
,
event
,
event
,
weak_method
(
self
.
backward_impl
),
weak_method
(
self
.
backward_impl
),
memory_strategy
=
memory_strategy
,
free_inputs
=
free_inputs
,
name
=
name
,
)
)
self
.
common_state
=
state
# layer state
self
.
callables
=
callables
self
.
common_state
=
common_state
# model chunk state
self
.
chunk_state
=
chunk_state
self
.
layer
=
layer
self
.
detached
=
tuple
()
self
.
detached
=
tuple
()
self
.
before_detached
=
tuple
()
self
.
before_detached
=
tuple
()
def
detach
(
self
,
t
):
def
detach
(
self
,
t
):
"""Detaches a tensor and stores it for backward computation."""
detached
=
make_viewless
(
t
).
detach
()
detached
=
make_viewless
(
t
).
detach
()
detached
.
requires_grad
=
t
.
requires_grad
detached
.
requires_grad
=
t
.
requires_grad
self
.
before_detached
=
self
.
before_detached
+
(
t
,)
self
.
before_detached
=
self
.
before_detached
+
(
t
,)
self
.
detached
=
self
.
detached
+
(
detached
,)
self
.
detached
=
self
.
detached
+
(
detached
,)
return
detached
return
detached
def
forward_impl
(
self
,
*
args
):
"""Implements the forward pass for the transformer layer node."""
return
self
.
callables
.
forward
(
self
,
*
args
)
def
backward_impl
(
self
,
outputs
,
output_grad
):
def
backward_impl
(
self
,
outputs
,
output_grad
):
"""Implements the backward pass for the transformer layer node."""
detached_grad
=
tuple
([
e
.
grad
for
e
in
self
.
detached
])
detached_grad
=
tuple
([
e
.
grad
for
e
in
self
.
detached
])
grads
=
output_grad
+
detached_grad
grads
=
output_grad
+
detached_grad
self
.
default_backward_func
(
outputs
+
self
.
before_detached
,
grads
)
self
.
default_backward_func
(
outputs
+
self
.
before_detached
,
grads
)
...
@@ -287,84 +274,197 @@ class TransformerLayerNode(ScheduleNode):
...
@@ -287,84 +274,197 @@ class TransformerLayerNode(ScheduleNode):
# return grads for record stream
# return grads for record stream
return
grads
return
grads
def
dw
(
self
):
"""Computes the weight gradients for the transformer layer node."""
with
torch
.
cuda
.
nvtx
.
range
(
f
"
{
self
.
name
}
wgrad"
):
self
.
callables
.
dw
()
class
MoeAttnNode
(
TransformerLayerNode
):
class
TransformerLayerState
:
def
forward_impl
(
self
,
hidden_states
):
"""State shared within a transformer layer.
attention_mask
=
self
.
chunk_state
.
attention_mask
context
=
self
.
chunk_state
.
context
rotary_pos_emb
=
self
.
chunk_state
.
rotary_pos_emb
rotary_pos_cos
=
self
.
chunk_state
.
rotary_pos_cos
rotary_pos_sin
=
self
.
chunk_state
.
rotary_pos_sin
attention_bias
=
self
.
chunk_state
.
attention_bias
inference_context
=
self
.
chunk_state
.
inference_context
packed_seq_params
=
self
.
chunk_state
.
packed_seq_params
sequence_len_offset
=
self
.
chunk_state
.
sequence_len_offset
inference_params
=
self
.
chunk_state
.
inference_params
token_dispatcher
=
self
.
layer
.
mlp
.
token_dispatcher
with
token_dispatcher
.
per_batch_state_context
(
self
.
common_state
):
(
hidden_states
,
pre_mlp_layernorm_output
,
tokens_per_expert
,
permutated_local_input_tokens
,
permuted_probs
,
probs
,
)
=
self
.
layer
.
_submodule_attention_router_compound_forward
(
hidden_states
,
attention_mask
=
attention_mask
,
rotary_pos_emb
=
rotary_pos_emb
,
rotary_pos_cos
=
rotary_pos_cos
,
rotary_pos_sin
=
rotary_pos_sin
,
attention_bias
=
attention_bias
,
inference_context
=
inference_context
,
packed_seq_params
=
packed_seq_params
,
sequence_len_offset
=
sequence_len_offset
,
inference_params
=
inference_params
,
)
self
.
common_state
.
tokens_per_expert
=
tokens_per_expert
This class holds state that is shared between different nodes
# detached here
within a transformer layer.
self
.
common_state
.
probs
=
self
.
detach
(
probs
)
"""
self
.
common_state
.
residual
=
self
.
detach
(
hidden_states
)
self
.
common_state
.
pre_mlp_layernorm_output
=
self
.
detach
(
pre_mlp_layernorm_output
)
pas
s
return
permutated_local_input_tokens
,
permuted_prob
s
def
dw
(
self
):
with
torch
.
cuda
.
nvtx
.
range
(
f
"
{
self
.
name
}
wgrad"
):
self
.
layer
.
_submodule_attention_router_compound_dw
()
class
ModelChunkSate
:
"""State shared across a model chunk.
This class holds state that is shared between different components
class
MoeDispatchNode
(
TransformerLayerNode
):
of a model chunk, such as input tensors, parameters, and configuration.
"""
pass
def
forward_impl
(
self
,
permutated_local_input_tokens
,
permuted_probs
):
token_dispatcher
=
self
.
layer
.
mlp
.
token_dispatcher
with
token_dispatcher
.
per_batch_state_context
(
self
.
common_state
):
inputs
=
permutated_local_input_tokens
tokens_per_expert
,
global_input_tokens
,
global_probs
=
token_dispatcher
.
dispatch_all_to_all
(
self
.
common_state
.
tokens_per_expert
,
permutated_local_input_tokens
,
permuted_probs
)
# release tensor not used by backward
# inputs.untyped_storage().resize_(0)
self
.
common_state
.
tokens_per_expert
=
=
tokens_per_expert
return
global_input_tokens
,
global_probs
class
TransformerLayerSchedulePlan
:
"""Schedule plan for a transformer layer.
This class organizes the computation nodes for a transformer layer,
including attention, MLP, dispatch, and combine nodes.
"""
def
__init__
(
self
,
layer
,
event
,
chunk_state
,
comp_stream
,
com_stream
):
"""Initializes a transformer layer schedule plan.
Args:
layer (TransformerLayer): The transformer layer to schedule.
event (torch.cuda.Event): CUDA event for synchronization.
chunk_state (ModelChunkState): State shared across the model chunk.
comp_stream (torch.cuda.Stream): CUDA stream for computation.
com_stream (torch.cuda.Stream): CUDA stream for communication.
"""
self
.
common_state
=
TransformerLayerState
()
# get callables for transformer layer
attn_callable
,
dispatch_callable
,
mlp_callable
,
combine_callable
=
(
layer
.
get_submodule_callables
(
chunk_state
).
as_array
()
)
# Create nodes for different operations in the layer
class
MoeMlPNode
(
TransformerLayerNode
):
# Each node type has a predefined name that determines its memory strategy
def
forward_impl
(
self
,
global_input_tokens
,
global_probs
):
self
.
attn
=
TransformerLayerNode
(
pre_mlp_layernorm_output
=
self
.
common_state
.
pre_mlp_layernorm_output
comp_stream
,
event
,
self
.
common_state
,
attn_callable
,
name
=
"attn"
token_dispatcher
=
self
.
layer
.
mlp
.
token_dispatcher
with
token_dispatcher
.
per_batch_state_context
(
self
.
common_state
):
expert_output
,
shared_expert_output
,
mlp_bias
=
self
.
layer
.
_submodule_moe_forward
(
self
.
common_state
.
tokens_per_expert
,
global_input_tokens
,
global_prob
,
pre_mlp_layernorm_output
)
)
self
.
mlp
=
TransformerLayerNode
(
assert
mlp_bias
is
None
comp_stream
,
event
,
self
.
common_state
,
mlp_callable
,
name
=
"mlp"
# pre_mlp_layernorm_output used
self
.
common_state
.
pre_mlp_layernorm_output
=
None
return
expert_output
,
shared_expert_output
def
dw
(
self
):
with
torch
.
cuda
.
nvtx
.
range
(
f
"
{
self
.
name
}
wgrad"
):
self
.
layer
.
_submodule_mlp_dw
()
class
MoeCombineNode
(
TransformerLayerNode
):
def
forward_impl
(
self
,
expert_output
,
shared_expert_output
):
# TODO(lhb): if dw use grad of residual and probs, necessary synchronization should be add
residual
=
self
.
common_state
.
residual
token_dispatcher
=
self
.
layer
.
mlp
.
token_dispatcher
with
token_dispatcher
.
per_batch_state_context
(
self
.
common_state
):
permutated_local_input_tokens
=
token_dispatcher
.
combine_all_to_all
(
expert_output
)
)
if
attn_callable
.
is_moe
:
output
=
self
.
layer
.
_submodule_post_combine_forward
(
self
.
dispatch
=
TransformerLayerNode
(
permutated_local_input_tokens
,
shared_expert_output
,
None
,
residual
com_stream
,
event
,
self
.
common_state
,
dispatch_callable
,
name
=
"dispatch"
)
)
self
.
combine
=
TransformerLayerNode
(
cur_stream
=
torch
.
cuda
.
current_stream
()
com_stream
,
event
,
self
.
common_state
,
combine_callable
,
name
=
"combine"
self
.
common_state
.
residual
.
record_stream
(
cur_stream
)
self
.
common_state
.
probs
.
record_stream
(
cur_stream
)
self
.
common_state
.
residual
=
None
self
.
common_state
.
probs
=
None
return
output
class
DenseAttnNode
(
TransformerLayerNode
):
def
forward_impl
(
self
,
hidden_states
):
attention_mask
=
self
.
chunk_state
.
attention_mask
rotary_pos_emb
=
self
.
chunk_state
.
rotary_pos_emb
rotary_pos_cos
=
self
.
chunk_state
.
rotary_pos_cos
rotary_pos_sin
=
self
.
chunk_state
.
rotary_pos_sin
attention_bias
=
self
.
chunk_state
.
attention_bias
inference_context
=
self
.
chunk_state
.
inference_context
packed_seq_params
=
self
.
chunk_state
.
packed_seq_params
sequence_len_offset
=
self
.
chunk_state
.
sequence_len_offset
inference_params
=
self
.
chunk_state
.
inference_params
hidden_states
=
self
.
layer
.
_submodule_attention_forward
(
hidden_states
,
attention_mask
,
rotary_pos_emb
,
rotary_pos_cos
,
rotary_pos_sin
,
attention_bias
,
inference_context
,
packed_seq_params
,
sequence_len_offset
,
inference_params
=
inference_params
,
)
)
else
:
return
hidden_states
self
.
dispatch
=
FakeScheduleNode
()
self
.
combine
=
FakeScheduleNode
()
class
ModelChunkSchedulePlan
(
AbstractSchedulePlan
):
class
FakeScheduleNode
:
"""Schedule plan for a model chunk.
def
forward
(
self
,
inputs
):
return
inputs
This class organizes the computation nodes for a model chunk,
def
backward
(
self
,
outgrads
):
including preprocessing, transformer layers, and postprocessing.
return
outgrads
"""
class
DenseMlpNode
(
TransformerLayerNode
):
def
forward_impl
(
self
,
hidden_states
):
return
self
.
layer
.
_submodule_dense_forward
(
hidden_states
)
def
build_non_moe_layer_plan
(
layer
,
event
,
chunk_state
,
comp_stream
,
com_stream
):
common_state
=
TransformerLayerState
()
attn
=
DenseAttnNode
(
chunk_state
,
common_state
,
layer
,
comp_stream
,
event
)
attn
.
name
=
"attn"
dispatch
=
FakeScheduleNode
()
mlp
=
DenseMlpNode
(
chunk_state
,
common_state
,
layer
,
comp_stream
,
event
)
combine
=
FakeScheduleNode
()
return
TransformerLayerSchedulePlan
(
attn
,
dispatch
,
mlp
,
combine
)
def
build_layer_schedule_plan
(
layer
,
event
,
chunk_state
,
comp_stream
,
com_stream
):
if
not
isinstance
(
layer
.
mlp
,
MoELayer
):
return
build_non_moe_layer_plan
(
layer
,
event
,
chunk_state
,
comp_stream
,
com_stream
)
common_state
=
TransformerLayerState
()
attn
=
MoeAttnNode
(
chunk_state
,
common_state
,
layer
,
comp_stream
,
event
)
attn
.
name
=
"attn"
dispatch
=
MoeDispatchNode
(
chunk_state
,
common_state
,
layer
,
com_stream
,
event
,
True
)
dispatch
.
name
=
"dispatch"
mlp
=
MoeMlPNode
(
chunk_state
,
common_state
,
layer
,
comp_stream
,
event
,
True
)
mlp
.
name
=
"mlp"
combine
=
MoeCombineNode
(
chunk_state
,
common_state
,
layer
,
com_stream
,
event
,
True
)
combine
.
name
=
"combine"
return
TransformerLayerSchedulePlan
(
attn
,
dispatch
,
mlp
,
combine
)
class
TransformerLayerState
(
MoEAlltoAllPerBatchState
):
pass
class
ModelChunkSate
:
pass
class
TransformerLayerSchedulePlan
:
def
__init__
(
self
,
attn
,
dispatch
,
mlp
,
combine
):
self
.
attn
=
attn
self
.
dispatch
=
dispatch
self
.
mlp
=
mlp
self
.
combine
=
combine
class
ModelChunkSchedulePlan
(
AbstractSchedulePlan
):
def
__init__
(
self
):
def
__init__
(
self
):
"""Initializes a model chunk schedule plan."""
super
().
__init__
()
super
().
__init__
()
self
.
_pre_process
=
None
self
.
_pre_process
=
None
self
.
_post_process
=
None
self
.
_post_process
=
None
...
@@ -385,22 +485,7 @@ class ModelChunkSchedulePlan(AbstractSchedulePlan):
...
@@ -385,22 +485,7 @@ class ModelChunkSchedulePlan(AbstractSchedulePlan):
post_forward
=
None
,
post_forward
=
None
,
post_backward
=
None
,
post_backward
=
None
,
):
):
"""Schedules forward and backward passes for model chunks.
Args:
f_schedule_plan (ModelChunkSchedulePlan): Forward schedule plan.
b_schedule_plan (ModelChunkSchedulePlan): Backward schedule plan.
grad (Tensor): Gradient for backward computation.
f_context (VppContextManager or None): The VppContextManager for the forward pass.
b_context (VppContextManager or None): The VppContextManager for the backward pass
pre_forward (Callable): Callback for preprocessing in forward pass.
pre_backward (Callable): Callback for preprocessing in backward pass.
post_forward (Callable): Callback for postprocessing in forward pass.
post_backward (Callable): Callback for postprocessing in backward pass.
Returns:
The output of the forward pass.
"""
return
schedule_chunk_1f1b
(
return
schedule_chunk_1f1b
(
f_schedule_plan
,
f_schedule_plan
,
b_schedule_plan
,
b_schedule_plan
,
...
@@ -415,55 +500,44 @@ class ModelChunkSchedulePlan(AbstractSchedulePlan):
...
@@ -415,55 +500,44 @@ class ModelChunkSchedulePlan(AbstractSchedulePlan):
@
property
@
property
def
event
(
self
):
def
event
(
self
):
"""Gets the CUDA event for synchronization."""
return
self
.
_event
return
self
.
_event
def
record_current_stream
(
self
):
def
record_current_stream
(
self
):
"""Records the current CUDA stream in the event."""
stream
=
torch
.
cuda
.
current_stream
()
stream
=
torch
.
cuda
.
current_stream
()
self
.
event
.
record
(
stream
)
self
.
event
.
record
(
stream
)
def
wait_current_stream
(
self
):
def
wait_current_stream
(
self
):
"""Waits for the event to complete on the current CUDA stream."""
stream
=
torch
.
cuda
.
current_stream
()
stream
=
torch
.
cuda
.
current_stream
()
self
.
event
.
wait
(
stream
)
self
.
event
.
wait
(
stream
)
@
property
@
property
def
pre_process
(
self
):
def
pre_process
(
self
):
"""Gets the preprocessing node."""
return
self
.
_pre_process
return
self
.
_pre_process
@
pre_process
.
setter
@
pre_process
.
setter
def
pre_process
(
self
,
value
):
def
pre_process
(
self
,
value
):
"""Sets the preprocessing node."""
self
.
_pre_process
=
value
self
.
_pre_process
=
value
@
property
@
property
def
post_process
(
self
):
def
post_process
(
self
):
"""Gets the postprocessing node."""
return
self
.
_post_process
return
self
.
_post_process
@
post_process
.
setter
@
post_process
.
setter
def
post_process
(
self
,
value
):
def
post_process
(
self
,
value
):
"""Sets the postprocessing node."""
self
.
_post_process
=
value
self
.
_post_process
=
value
def
get_layer
(
self
,
i
):
def
get_layer
(
self
,
i
):
"""Gets the transformer layer at the specified index."""
assert
i
<
self
.
num_layers
()
assert
i
<
self
.
num_layers
()
return
self
.
_transformer_layers
[
i
]
return
self
.
_transformer_layers
[
i
]
def
num_layers
(
self
):
def
num_layers
(
self
):
"""Gets the number of transformer layers."""
return
len
(
self
.
_transformer_layers
)
return
len
(
self
.
_transformer_layers
)
def
add_layer
(
self
,
layer
):
def
add_layer
(
self
,
layer
):
"""Adds a transformer layer to the schedule plan."""
self
.
_transformer_layers
.
append
(
layer
)
self
.
_transformer_layers
.
append
(
layer
)
@
property
@
property
def
state
(
self
):
def
state
(
self
):
"""Gets the model chunk state."""
return
self
.
_model_chunk_state
return
self
.
_model_chunk_state
...
@@ -478,40 +552,24 @@ def schedule_layer_1f1b(
...
@@ -478,40 +552,24 @@ def schedule_layer_1f1b(
f_context
=
None
,
f_context
=
None
,
b_context
=
None
,
b_context
=
None
,
):
):
"""Schedule one-forward-one-backward operations for a single layer.
This function interleaves forward and backward operations to maximize
parallelism and efficiency.
Args:
f_layer (TransformerLayerSchedulePlan): Forward layer (for current microbatch)
b_layer (TransformerLayerSchedulePlan): Backward layer (for previous microbatch)
f_input (Tensor): Input for forward computation
b_grad (Tensor): Gradient for backward computation
pre_forward (Callable): Callback to get forward input if not provided
pre_backward (Callable): Callback to get backward gradient if not provided
pre_backward_dw (Callable): Callback for weight gradient computation
f_context (VppContextManager or None): The VppContextManager for the forward pass.
b_context (VppContextManager or None): The VppContextManager for the backward pass
Returns:
Functions or values for next iteration's computation
"""
f_context
=
f_context
if
f_context
is
not
None
else
contextlib
.
nullcontext
()
f_context
=
f_context
if
f_context
is
not
None
else
contextlib
.
nullcontext
()
b_context
=
b_context
if
b_context
is
not
None
else
contextlib
.
nullcontext
()
b_context
=
b_context
if
b_context
is
not
None
else
contextlib
.
nullcontext
()
if
pre_forward
is
not
None
:
if
pre_forward
is
not
None
:
assert
f_input
is
None
assert
f_input
is
None
# combine from last iter
# combine from last iter
f_input
=
pre_forward
()
f_input
=
pre_forward
()
del
pre_forward
del
pre_forward
if
pre_backward
is
not
None
:
if
pre_backward
is
not
None
:
# attn backward from last iter
# attn backward from last iter
assert
b_grad
is
None
assert
b_grad
is
None
b_grad
=
pre_backward
()
b_grad
=
pre_backward
()
del
pre_backward
del
pre_backward
if
b_layer
is
not
None
:
if
b_layer
is
not
None
:
with
b_context
:
with
b_context
:
b_grad
=
b_layer
.
combine
.
backward
(
b_grad
)
b_grad
=
b_layer
.
combine
.
backward
(
b_grad
)
...
@@ -520,6 +578,7 @@ def schedule_layer_1f1b(
...
@@ -520,6 +578,7 @@ def schedule_layer_1f1b(
pre_backward_dw
()
pre_backward_dw
()
del
pre_backward_dw
del
pre_backward_dw
if
f_layer
is
not
None
:
if
f_layer
is
not
None
:
with
f_context
:
with
f_context
:
f_input
=
f_layer
.
attn
.
forward
(
f_input
)
f_input
=
f_layer
.
attn
.
forward
(
f_input
)
...
@@ -534,10 +593,13 @@ def schedule_layer_1f1b(
...
@@ -534,10 +593,13 @@ def schedule_layer_1f1b(
b_grad
=
b_layer
.
dispatch
.
backward
(
b_grad
)
b_grad
=
b_layer
.
dispatch
.
backward
(
b_grad
)
b_layer
.
mlp
.
dw
()
b_layer
.
mlp
.
dw
()
if
f_layer
is
not
None
:
if
f_layer
is
not
None
:
with
f_context
:
with
f_context
:
f_input
=
f_layer
.
mlp
.
forward
(
f_input
)
f_input
=
f_layer
.
mlp
.
forward
(
f_input
)
def
next_iter_pre_forward
():
def
next_iter_pre_forward
():
if
f_layer
is
not
None
:
if
f_layer
is
not
None
:
with
f_context
:
with
f_context
:
...
@@ -555,6 +617,7 @@ def schedule_layer_1f1b(
...
@@ -555,6 +617,7 @@ def schedule_layer_1f1b(
with
b_context
:
with
b_context
:
b_layer
.
attn
.
dw
()
b_layer
.
attn
.
dw
()
if
f_layer
and
b_layer
:
if
f_layer
and
b_layer
:
return
next_iter_pre_forward
,
next_iter_pre_backward
,
next_iter_pre_backward_dw
return
next_iter_pre_forward
,
next_iter_pre_backward
,
next_iter_pre_backward_dw
else
:
else
:
...
@@ -572,32 +635,14 @@ def schedule_chunk_1f1b(
...
@@ -572,32 +635,14 @@ def schedule_chunk_1f1b(
post_forward
=
None
,
post_forward
=
None
,
post_backward
=
None
,
post_backward
=
None
,
):
):
"""Schedules one-forward-one-backward operations for a model chunk.
This function interleaves forward and backward operations across multiple layers
to maximize parallelism and efficiency.
Args:
f_schedule_plan: Forward schedule plan.
b_schedule_plan: Backward schedule plan.
grad: Gradient for backward computation.
f_context: Context for forward computation.
b_context: Context for backward computation.
pre_forward: Callback for preprocessing in forward pass.
pre_backward: Callback for preprocessing in backward pass.
post_forward: Callback for postprocessing in forward pass.
post_backward: Callback for postprocessing in backward pass.
Returns:
The output of the forward pass.
"""
f_context
=
f_context
if
f_context
is
not
None
else
contextlib
.
nullcontext
()
f_context
=
f_context
if
f_context
is
not
None
else
contextlib
.
nullcontext
()
b_context
=
b_context
if
b_context
is
not
None
else
contextlib
.
nullcontext
()
b_context
=
b_context
if
b_context
is
not
None
else
contextlib
.
nullcontext
()
if
f_schedule_plan
:
if
f_schedule_plan
:
# pp output send/receive sync
# pp output send/receive sync
if
pre_forward
is
not
None
:
if
pre_forward
is
not
None
:
with
f_context
:
# virtual pipeline parallel context
with
f_context
:
pre_forward
()
pre_forward
()
f_schedule_plan
.
record_current_stream
()
f_schedule_plan
.
record_current_stream
()
...
@@ -617,14 +662,14 @@ def schedule_chunk_1f1b(
...
@@ -617,14 +662,14 @@ def schedule_chunk_1f1b(
if
b_schedule_plan
is
not
None
:
if
b_schedule_plan
is
not
None
:
assert
grad
is
not
None
assert
grad
is
not
None
if
b_schedule_plan
.
post_process
is
not
None
:
if
b_schedule_plan
.
post_process
is
not
None
:
with
b_context
:
# virtual pipeline parallel context
with
b_context
:
tmp
=
b_schedule_plan
.
post_process
.
backward
(
grad
)
tmp
=
b_schedule_plan
.
post_process
.
backward
(
grad
)
if
pre_backward
is
not
None
:
if
pre_backward
is
not
None
:
# pp grad send receive sync here, safe for now, maybe not safe in the future
# pp grad send receive sync here, safe for now, maybe not safe in the future
with
torch
.
cuda
.
stream
(
get_com_stream
()):
with
torch
.
cuda
.
stream
(
get_com_stream
()):
b_schedule_plan
.
wait_current_stream
()
b_schedule_plan
.
wait_current_stream
()
with
b_context
:
# virtual pipeline parallel context
with
b_context
:
pre_backward
()
pre_backward
()
b_schedule_plan
.
record_current_stream
()
b_schedule_plan
.
record_current_stream
()
...
@@ -652,9 +697,6 @@ def schedule_chunk_1f1b(
...
@@ -652,9 +697,6 @@ def schedule_chunk_1f1b(
)
)
torch
.
cuda
.
nvtx
.
range_pop
()
torch
.
cuda
.
nvtx
.
range_pop
()
# tail forward
f_input
=
layer_pre_forward
()
del
layer_pre_forward
# tail backward
# tail backward
grad
=
layer_pre_backward
()
grad
=
layer_pre_backward
()
del
layer_pre_backward
del
layer_pre_backward
...
@@ -665,12 +707,12 @@ def schedule_chunk_1f1b(
...
@@ -665,12 +707,12 @@ def schedule_chunk_1f1b(
tmp
,
grad
,
_
=
schedule_layer_1f1b
(
None
,
b_layer
,
b_grad
=
grad
)
tmp
,
grad
,
_
=
schedule_layer_1f1b
(
None
,
b_layer
,
b_grad
=
grad
)
torch
.
cuda
.
nvtx
.
range_pop
()
torch
.
cuda
.
nvtx
.
range_pop
()
#
if b_schedule_plan is not None:
if
b_schedule_plan
is
not
None
:
#
b_schedule_plan.pre_process.backward(grad)
b_schedule_plan
.
pre_process
.
backward
(
grad
)
#
#
tail forward
# tail forward
#
f_input = layer_pre_forward()
f_input
=
layer_pre_forward
()
#
del layer_pre_forward
del
layer_pre_forward
with
f_context
:
with
f_context
:
for
i
in
range
(
overlaped_layers
,
f_num_layers
):
for
i
in
range
(
overlaped_layers
,
f_num_layers
):
f_layer
=
f_schedule_plan
.
get_layer
(
i
)
f_layer
=
f_schedule_plan
.
get_layer
(
i
)
...
@@ -678,8 +720,8 @@ def schedule_chunk_1f1b(
...
@@ -678,8 +720,8 @@ def schedule_chunk_1f1b(
f_input
,
tmp
,
_
=
schedule_layer_1f1b
(
f_layer
,
None
,
f_input
=
f_input
)
f_input
,
tmp
,
_
=
schedule_layer_1f1b
(
f_layer
,
None
,
f_input
=
f_input
)
torch
.
cuda
.
nvtx
.
range_pop
()
torch
.
cuda
.
nvtx
.
range_pop
()
#
if f_schedule_plan is not None and f_schedule_plan.post_process is not None:
if
f_schedule_plan
is
not
None
and
f_schedule_plan
.
post_process
is
not
None
:
#
f_input = f_schedule_plan.post_process.forward(f_input)
f_input
=
f_schedule_plan
.
post_process
.
forward
(
f_input
)
# output pp send receive, overlapped with attn backward
# output pp send receive, overlapped with attn backward
if
f_schedule_plan
is
not
None
and
post_forward
is
not
None
:
if
f_schedule_plan
is
not
None
and
post_forward
is
not
None
:
...
@@ -687,8 +729,7 @@ def schedule_chunk_1f1b(
...
@@ -687,8 +729,7 @@ def schedule_chunk_1f1b(
f_schedule_plan
.
wait_current_stream
()
f_schedule_plan
.
wait_current_stream
()
post_forward
(
f_input
)
post_forward
(
f_input
)
# pp grad send / receive, overlapped with attn dw of cur micro-batch
# pp grad send / receive, overlapped with attn dw of cur micro-batch and forward attn of next micro-batch
# and forward attn of next micro-batch
if
b_schedule_plan
is
not
None
and
post_backward
is
not
None
:
if
b_schedule_plan
is
not
None
and
post_backward
is
not
None
:
with
b_context
:
with
b_context
:
b_schedule_plan
.
wait_current_stream
()
b_schedule_plan
.
wait_current_stream
()
...
@@ -698,13 +739,6 @@ def schedule_chunk_1f1b(
...
@@ -698,13 +739,6 @@ def schedule_chunk_1f1b(
layer_pre_backward_dw
()
layer_pre_backward_dw
()
del
layer_pre_backward_dw
del
layer_pre_backward_dw
with
f_context
:
if
f_schedule_plan
is
not
None
and
f_schedule_plan
.
post_process
is
not
None
:
f_input
=
f_schedule_plan
.
post_process
.
forward
(
f_input
)
with
b_context
:
if
b_schedule_plan
is
not
None
:
b_schedule_plan
.
pre_process
.
backward
(
grad
)
if
f_schedule_plan
:
if
f_schedule_plan
:
f_schedule_plan
.
wait_current_stream
()
f_schedule_plan
.
wait_current_stream
()
if
b_schedule_plan
:
if
b_schedule_plan
:
...
@@ -720,32 +754,15 @@ def build_model_chunk_schedule_plan(
...
@@ -720,32 +754,15 @@ def build_model_chunk_schedule_plan(
attention_mask
:
Tensor
,
attention_mask
:
Tensor
,
decoder_input
:
Tensor
=
None
,
decoder_input
:
Tensor
=
None
,
labels
:
Tensor
=
None
,
labels
:
Tensor
=
None
,
inference_
params
=
None
,
inference_
context
:
BaseInferenceContext
=
None
,
packed_seq_params
=
None
,
packed_seq_params
:
PackedSeqParams
=
None
,
extra_block_kwargs
=
None
,
extra_block_kwargs
:
dict
=
None
,
runtime_gather_output
:
Optional
[
bool
]
=
None
,
runtime_gather_output
:
Optional
[
bool
]
=
None
,
inference_params
:
Optional
[
BaseInferenceContext
]
=
None
,
loss_mask
:
Optional
[
Tensor
]
=
None
):
):
"""Builds a schedule plan for a model chunk.
comp_stream
=
torch
.
cuda
.
current_stream
()
This function creates a schedule plan for a model chunk, including
preprocessing, transformer layers, and postprocessing.
Args:
model: The model to build a schedule plan for.
input_ids: Input token IDs.
position_ids: Position IDs.
attention_mask: Attention mask.
decoder_input: Decoder input tensor.
labels: Labels for loss computation.
inference_params: Parameters for inference.
packed_seq_params: Parameters for packed sequences.
extra_block_kwargs: Additional keyword arguments for blocks.
runtime_gather_output: Whether to gather output at runtime.
Returns:
The model chunk schedule plan.
"""
comp_stream
=
get_comp_stream
()
com_stream
=
get_com_stream
()
com_stream
=
get_com_stream
()
model_chunk_schedule_plan
=
ModelChunkSchedulePlan
()
model_chunk_schedule_plan
=
ModelChunkSchedulePlan
()
event
=
model_chunk_schedule_plan
.
event
event
=
model_chunk_schedule_plan
.
event
...
@@ -756,23 +773,28 @@ def build_model_chunk_schedule_plan(
...
@@ -756,23 +773,28 @@ def build_model_chunk_schedule_plan(
state
.
attention_mask
=
attention_mask
state
.
attention_mask
=
attention_mask
state
.
decoder_input
=
decoder_input
state
.
decoder_input
=
decoder_input
state
.
labels
=
labels
state
.
labels
=
labels
state
.
inference_
params
=
inference_
params
state
.
inference_
context
=
inference_
context
state
.
packed_seq_params
=
packed_seq_params
state
.
packed_seq_params
=
packed_seq_params
state
.
extra_block_kwargs
=
extra_block_kwargs
state
.
extra_block_kwargs
=
extra_block_kwargs
state
.
runtime_gather_output
=
runtime_gather_output
state
.
runtime_gather_output
=
runtime_gather_output
state
.
inference_params
=
inference_params
state
.
loss_mask
=
loss_mask
state
.
context
=
None
state
.
context
=
None
state
.
context_mask
=
None
state
.
context_mask
=
None
state
.
attention_bias
=
None
state
.
attention_bias
=
None
# build preprocess
# build preprocess
model_chunk_schedule_plan
.
pre_process
=
PreProcessNode
(
model
,
state
,
event
,
comp_stream
)
model_chunk_schedule_plan
.
pre_process
=
PreProcessNode
(
model
,
state
,
event
,
comp_stream
)
model_chunk_schedule_plan
.
pre_process
.
name
=
"pre_process"
# build for layers
# build for layers
for
layer_idx
in
range
(
model
.
decoder
.
num_layers_per_pipeline_rank
):
for
layer_idx
in
range
(
model
.
decoder
.
num_layers_per_pipeline_rank
):
layer
=
model
.
decoder
.
_get_layer
(
layer_idx
)
layer
=
model
.
decoder
.
_get_layer
(
layer_idx
)
layer_plan
=
TransformerL
ayer
S
chedule
P
lan
(
layer
,
event
,
state
,
comp_stream
,
com_stream
)
layer_plan
=
build_l
ayer
_s
chedule
_p
lan
(
layer
,
event
,
state
,
comp_stream
,
com_stream
)
model_chunk_schedule_plan
.
add_layer
(
layer_plan
)
model_chunk_schedule_plan
.
add_layer
(
layer_plan
)
# build post process
# build post process
if
model
.
post_process
:
if
model
.
post_process
:
model_chunk_schedule_plan
.
post_process
=
PostProcessNode
(
model
,
state
,
event
,
comp_stream
)
model_chunk_schedule_plan
.
post_process
=
PostProcessNode
(
model
,
state
,
event
,
comp_stream
)
model_chunk_schedule_plan
.
post_process
.
name
=
"post_process"
return
model_chunk_schedule_plan
return
model_chunk_schedule_plan
dcu_megatron/core/models/gpt/gpt_model.py
View file @
32ee381a
...
@@ -9,6 +9,7 @@ from torch import Tensor
...
@@ -9,6 +9,7 @@ from torch import Tensor
from
megatron.core
import
InferenceParams
,
tensor_parallel
from
megatron.core
import
InferenceParams
,
tensor_parallel
from
megatron.core.config_logger
import
has_config_logger_enabled
,
log_config_to_disk
from
megatron.core.config_logger
import
has_config_logger_enabled
,
log_config_to_disk
from
megatron.core.inference.contexts
import
BaseInferenceContext
from
megatron.core.packed_seq_params
import
PackedSeqParams
from
megatron.core.packed_seq_params
import
PackedSeqParams
from
megatron.core.models.gpt
import
GPTModel
as
MegatronCoreGPTModel
from
megatron.core.models.gpt
import
GPTModel
as
MegatronCoreGPTModel
...
@@ -64,11 +65,14 @@ class GPTModel(MegatronCoreGPTModel):
...
@@ -64,11 +65,14 @@ class GPTModel(MegatronCoreGPTModel):
attention_mask
:
Tensor
,
attention_mask
:
Tensor
,
decoder_input
:
Tensor
=
None
,
decoder_input
:
Tensor
=
None
,
labels
:
Tensor
=
None
,
labels
:
Tensor
=
None
,
inference_
params
:
Inference
Params
=
None
,
inference_
context
:
Base
Inference
Context
=
None
,
packed_seq_params
:
PackedSeqParams
=
None
,
packed_seq_params
:
PackedSeqParams
=
None
,
extra_block_kwargs
:
dict
=
None
,
extra_block_kwargs
:
dict
=
None
,
runtime_gather_output
:
Optional
[
bool
]
=
None
,
runtime_gather_output
:
Optional
[
bool
]
=
None
,
*
,
inference_params
:
Optional
[
BaseInferenceContext
]
=
None
,
loss_mask
:
Optional
[
Tensor
]
=
None
,
loss_mask
:
Optional
[
Tensor
]
=
None
,
):
):
"""Builds a computation schedule plan for the model.
"""Builds a computation schedule plan for the model.
...
@@ -105,10 +109,12 @@ class GPTModel(MegatronCoreGPTModel):
...
@@ -105,10 +109,12 @@ class GPTModel(MegatronCoreGPTModel):
attention_mask
,
attention_mask
,
decoder_input
=
decoder_input
,
decoder_input
=
decoder_input
,
labels
=
labels
,
labels
=
labels
,
inference_
params
=
inference_
params
,
inference_
context
=
inference_
context
,
packed_seq_params
=
packed_seq_params
,
packed_seq_params
=
packed_seq_params
,
extra_block_kwargs
=
extra_block_kwargs
,
extra_block_kwargs
=
extra_block_kwargs
,
runtime_gather_output
=
runtime_gather_output
,
runtime_gather_output
=
runtime_gather_output
,
inference_params
=
inference_params
,
loss_mask
=
loss_mask
,
)
)
def
forward
(
def
forward
(
...
@@ -118,14 +124,16 @@ class GPTModel(MegatronCoreGPTModel):
...
@@ -118,14 +124,16 @@ class GPTModel(MegatronCoreGPTModel):
attention_mask
:
Tensor
,
attention_mask
:
Tensor
,
decoder_input
:
Tensor
=
None
,
decoder_input
:
Tensor
=
None
,
labels
:
Tensor
=
None
,
labels
:
Tensor
=
None
,
inference_
params
:
Inference
Params
=
None
,
inference_
context
:
Base
Inference
Context
=
None
,
packed_seq_params
:
PackedSeqParams
=
None
,
packed_seq_params
:
PackedSeqParams
=
None
,
extra_block_kwargs
:
dict
=
None
,
extra_block_kwargs
:
dict
=
None
,
runtime_gather_output
:
Optional
[
bool
]
=
None
,
runtime_gather_output
:
Optional
[
bool
]
=
None
,
*
,
inference_params
:
Optional
[
BaseInferenceContext
]
=
None
,
loss_mask
:
Optional
[
Tensor
]
=
None
,
loss_mask
:
Optional
[
Tensor
]
=
None
,
)
->
Tensor
:
)
->
Tensor
:
"""Forward function of the GPT Model This function passes the input tensors
"""Forward function of the GPT Model This function passes the input tensors
through the embedding layer, and then the decoder and finally into the post
through the embedding layer, and then the deco
e
der and finally into the post
processing layer (optional).
processing layer (optional).
It either returns the Loss values if labels are given or the final hidden units
It either returns the Loss values if labels are given or the final hidden units
...
@@ -137,6 +145,8 @@ class GPTModel(MegatronCoreGPTModel):
...
@@ -137,6 +145,8 @@ class GPTModel(MegatronCoreGPTModel):
# If decoder_input is provided (not None), then input_ids and position_ids are ignored.
# If decoder_input is provided (not None), then input_ids and position_ids are ignored.
# Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input.
# Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input.
inference_context
=
deprecate_inference_params
(
inference_context
,
inference_params
)
# Decoder embedding.
# Decoder embedding.
if
decoder_input
is
not
None
:
if
decoder_input
is
not
None
:
pass
pass
...
@@ -152,39 +162,64 @@ class GPTModel(MegatronCoreGPTModel):
...
@@ -152,39 +162,64 @@ class GPTModel(MegatronCoreGPTModel):
rotary_pos_cos
=
None
rotary_pos_cos
=
None
rotary_pos_sin
=
None
rotary_pos_sin
=
None
if
self
.
position_embedding_type
==
'rope'
and
not
self
.
config
.
multi_latent_attention
:
if
self
.
position_embedding_type
==
'rope'
and
not
self
.
config
.
multi_latent_attention
:
if
not
self
.
training
and
self
.
config
.
flash_decode
and
inference_params
:
if
not
self
.
training
and
self
.
config
.
flash_decode
and
inference_context
:
assert
(
inference_context
.
is_static_batching
()
),
"GPTModel currently only supports static inference batching."
# Flash decoding uses precomputed cos and sin for RoPE
# Flash decoding uses precomputed cos and sin for RoPE
rotary_pos_cos
,
rotary_pos_sin
=
self
.
rotary_pos_emb_cache
.
setdefault
(
rotary_pos_cos
,
rotary_pos_sin
=
self
.
rotary_pos_emb_cache
.
setdefault
(
inference_
params
.
max_sequence_length
,
inference_
context
.
max_sequence_length
,
self
.
rotary_pos_emb
.
get_cos_sin
(
inference_
params
.
max_sequence_length
),
self
.
rotary_pos_emb
.
get_cos_sin
(
inference_
context
.
max_sequence_length
),
)
)
else
:
else
:
rotary_seq_len
=
self
.
rotary_pos_emb
.
get_rotary_seq_len
(
rotary_seq_len
=
self
.
rotary_pos_emb
.
get_rotary_seq_len
(
inference_
params
,
self
.
decoder
,
decoder_input
,
self
.
config
,
packed_seq_params
inference_
context
,
self
.
decoder
,
decoder_input
,
self
.
config
,
packed_seq_params
)
)
rotary_pos_emb
=
self
.
rotary_pos_emb
(
rotary_pos_emb
=
self
.
rotary_pos_emb
(
rotary_seq_len
,
rotary_seq_len
,
packed_seq
=
packed_seq_params
is
not
None
packed_seq
=
packed_seq_params
is
not
None
and
packed_seq_params
.
qkv_format
==
'thd'
,
and
packed_seq_params
.
qkv_format
==
'thd'
,
)
)
elif
self
.
position_embedding_type
==
'mrope'
and
not
self
.
config
.
multi_latent_attention
:
if
self
.
training
or
not
self
.
config
.
flash_decode
:
rotary_pos_emb
=
self
.
rotary_pos_emb
(
position_ids
,
self
.
mrope_section
)
else
:
# Flash decoding uses precomputed cos and sin for RoPE
raise
NotImplementedError
(
"Flash decoding uses precomputed cos and sin for RoPE, not implmented in "
"MultimodalRotaryEmbedding yet."
)
if
(
if
(
(
self
.
config
.
enable_cuda_graph
or
self
.
config
.
flash_decode
)
(
self
.
config
.
enable_cuda_graph
or
self
.
config
.
flash_decode
)
and
rotary_pos_cos
is
not
None
and
rotary_pos_cos
is
not
None
and
inference_params
and
inference_context
and
inference_context
.
is_static_batching
()
and
not
self
.
training
):
):
sequence_len_offset
=
torch
.
tensor
(
sequence_len_offset
=
torch
.
tensor
(
[
inference_
params
.
sequence_len_offset
]
*
inference_
params
.
current_batch_size
,
[
inference_
context
.
sequence_len_offset
]
*
inference_
context
.
current_batch_size
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
rotary_pos_cos
.
device
,
# Co-locate this with the rotary tensors
device
=
rotary_pos_cos
.
device
,
# Co-locate this with the rotary tensors
)
)
else
:
else
:
sequence_len_offset
=
None
sequence_len_offset
=
None
# Wrap decoder_input to allow the decoder (TransformerBlock) to delete the
# reference held by this caller function, enabling early garbage collection for
# inference. Skip wrapping if decoder_input is logged after decoder completion.
if
(
inference_context
is
not
None
and
not
self
.
training
and
not
has_config_logger_enabled
(
self
.
config
)
):
decoder_input
=
WrappedTensor
(
decoder_input
)
# Run decoder.
# Run decoder.
hidden_states
=
self
.
decoder
(
hidden_states
=
self
.
decoder
(
hidden_states
=
decoder_input
,
hidden_states
=
decoder_input
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
inference_
params
=
inference_
params
,
inference_
context
=
inference_
context
,
rotary_pos_emb
=
rotary_pos_emb
,
rotary_pos_emb
=
rotary_pos_emb
,
rotary_pos_cos
=
rotary_pos_cos
,
rotary_pos_cos
=
rotary_pos_cos
,
rotary_pos_sin
=
rotary_pos_sin
,
rotary_pos_sin
=
rotary_pos_sin
,
...
@@ -193,6 +228,12 @@ class GPTModel(MegatronCoreGPTModel):
...
@@ -193,6 +228,12 @@ class GPTModel(MegatronCoreGPTModel):
**
(
extra_block_kwargs
or
{}),
**
(
extra_block_kwargs
or
{}),
)
)
# Process inference output.
if
inference_context
and
not
inference_context
.
is_static_batching
():
hidden_states
=
inference_context
.
last_token_logits
(
hidden_states
.
squeeze
(
1
).
unsqueeze
(
0
)
).
unsqueeze
(
1
)
# logits and loss
# logits and loss
output_weight
=
None
output_weight
=
None
if
self
.
share_embeddings_and_output_weights
:
if
self
.
share_embeddings_and_output_weights
:
...
@@ -230,6 +271,13 @@ class GPTModel(MegatronCoreGPTModel):
...
@@ -230,6 +271,13 @@ class GPTModel(MegatronCoreGPTModel):
if
not
self
.
post_process
:
if
not
self
.
post_process
:
return
hidden_states
return
hidden_states
if
(
not
self
.
training
and
inference_context
is
not
None
and
inference_context
.
is_static_batching
()
and
inference_context
.
materialize_only_last_token_logits
):
hidden_states
=
hidden_states
[
-
1
:,
:,
:]
logits
,
_
=
self
.
output_layer
(
logits
,
_
=
self
.
output_layer
(
hidden_states
,
weight
=
output_weight
,
runtime_gather_output
=
runtime_gather_output
hidden_states
,
weight
=
output_weight
,
runtime_gather_output
=
runtime_gather_output
)
)
...
...
dcu_megatron/core/pipeline_parallel/combined_1f1b.py
View file @
32ee381a
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
import
contextlib
import
contextlib
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
typing
import
List
,
Union
from
typing
import
Any
,
List
,
Tuple
,
Union
import
torch
import
torch
from
torch
import
Tensor
from
torch
import
Tensor
...
@@ -11,24 +9,24 @@ from torch.autograd.variable import Variable
...
@@ -11,24 +9,24 @@ from torch.autograd.variable import Variable
from
megatron.core
import
parallel_state
from
megatron.core
import
parallel_state
from
megatron.core.distributed
import
DistributedDataParallel
from
megatron.core.distributed
import
DistributedDataParallel
from
megatron.core.models.gpt.gpt_model
import
GPTModel
from
megatron.core.transformer.module
import
Float16Module
from
megatron.core.transformer.module
import
Float16Module
from
megatron.core.transformer.moe.router
import
MoEAuxLossAutoScaler
from
megatron.core.transformer.moe.router
import
MoEAuxLossAutoScaler
from
megatron.core.utils
import
get_attr_wrapped_model
,
make_viewless_tensor
from
megatron.core.utils
import
get_attr_wrapped_model
,
make_viewless_tensor
# Types
# Types
Shape
=
Union
[
List
[
int
],
torch
.
Size
]
Shape
=
Union
[
List
[
int
],
torch
.
Size
]
def
make_viewless
(
e
):
def
make_viewless
(
e
):
"""
M
ake_viewless util func"""
"""
m
ake_viewless util func"""
e
=
make_viewless_tensor
(
inp
=
e
,
requires_grad
=
e
.
requires_grad
,
keep_graph
=
True
)
e
=
make_viewless_tensor
(
inp
=
e
,
requires_grad
=
e
.
requires_grad
,
keep_graph
=
True
)
return
e
return
e
@
contextmanager
@
contextmanager
def
stream_acquire_context
(
stream
,
event
):
def
stream_acquire_context
(
stream
,
event
):
"""Stream acquire context"""
event
.
wait
(
stream
)
event
.
wait
(
stream
)
try
:
try
:
yield
yield
...
@@ -36,29 +34,8 @@ def stream_acquire_context(stream, event):
...
@@ -36,29 +34,8 @@ def stream_acquire_context(stream, event):
event
.
record
(
stream
)
event
.
record
(
stream
)
class
FakeScheduleNode
:
"""A placeholder node in the computation graph that simply passes through inputs and outputs.
This class is used as a no-op node in the scheduling system when a real computation node
is not needed but the interface must be maintained. It simply returns its inputs unchanged
in both forward and backward passes.
"""
def
forward
(
self
,
inputs
):
"""Passes through inputs unchanged in the forward pass."""
return
inputs
def
backward
(
self
,
outgrads
):
"""Passes through gradients unchanged in the backward pass."""
return
outgrads
class
ScheduleNode
:
class
ScheduleNode
:
"""Base node for fine-grained scheduling.
"""base node for fine-grained schedule"""
This class represents a computational node in the pipeline schedule.
It handles the execution of forward and backward operations on a stream.
"""
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -66,30 +43,19 @@ class ScheduleNode:
...
@@ -66,30 +43,19 @@ class ScheduleNode:
stream
,
stream
,
event
,
event
,
backward_func
=
None
,
backward_func
=
None
,
memory_strategy
=
Non
e
,
free_inputs
=
Fals
e
,
name
=
"schedule_node"
,
name
=
"schedule_node"
,
):
):
"""Initialize a schedule node.
Args:
forward_func (callable): Function to execute during forward pass
stream (torch.cuda.Stream): CUDA stream for computation
event (torch.cuda.Event): Event for synchronization
backward_func (callable, optional): Function for backward pass
memory_strategy (MemoryManagementStrategy, optional): Strategy for memory management
name (str): Name of the node for debugging
"""
self
.
name
=
name
self
.
name
=
name
self
.
forward_func
=
forward_func
self
.
forward_func
=
forward_func
self
.
backward_func
=
backward_func
if
backward_func
else
self
.
default_backward_func
self
.
backward_func
=
backward_func
self
.
stream
=
stream
self
.
stream
=
stream
self
.
event
=
event
self
.
event
=
event
self
.
memory_strategy
=
memory_strategy
or
NoOpMemoryStrategy
()
self
.
free_inputs
=
free_inputs
self
.
inputs
=
None
self
.
inputs
=
None
self
.
outputs
=
None
self
.
outputs
=
None
def
default_backward_func
(
self
,
outputs
,
output_grad
):
def
default_backward_func
(
self
,
outputs
,
output_grad
):
"""Default backward function"""
Variable
.
_execution_engine
.
run_backward
(
Variable
.
_execution_engine
.
run_backward
(
tensors
=
outputs
,
tensors
=
outputs
,
grad_tensors
=
output_grad
,
grad_tensors
=
output_grad
,
...
@@ -102,7 +68,8 @@ class ScheduleNode:
...
@@ -102,7 +68,8 @@ class ScheduleNode:
return
output_grad
return
output_grad
def
forward
(
self
,
inputs
=
()):
def
forward
(
self
,
inputs
=
()):
"""Schedule node forward"""
"""schedule node forward"""
if
not
isinstance
(
inputs
,
tuple
):
if
not
isinstance
(
inputs
,
tuple
):
inputs
=
(
inputs
,)
inputs
=
(
inputs
,)
return
self
.
_forward
(
*
inputs
)
return
self
.
_forward
(
*
inputs
)
...
@@ -127,17 +94,19 @@ class ScheduleNode:
...
@@ -127,17 +94,19 @@ class ScheduleNode:
self
.
output
=
data
self
.
output
=
data
torch
.
cuda
.
nvtx
.
range_pop
()
torch
.
cuda
.
nvtx
.
range_pop
()
# Handle inputs using the memory strategy
if
self
.
free_inputs
:
self
.
memory_strategy
.
handle_inputs
(
inputs
,
self
.
stream
)
for
input
in
inputs
:
input
.
record_stream
(
self
.
stream
)
input
.
untyped_storage
().
resize_
(
0
)
return
self
.
output
return
self
.
output
def
get_output
(
self
):
def
get_output
(
self
):
"""
G
et the forward output"""
"""
g
et the forward output"""
return
self
.
output
return
self
.
output
def
backward
(
self
,
output_grad
):
def
backward
(
self
,
output_grad
):
"""
S
chedule node backward"""
"""
s
chedule node backward"""
if
not
isinstance
(
output_grad
,
tuple
):
if
not
isinstance
(
output_grad
,
tuple
):
output_grad
=
(
output_grad
,)
output_grad
=
(
output_grad
,)
return
self
.
_backward
(
*
output_grad
)
return
self
.
_backward
(
*
output_grad
)
...
@@ -149,11 +118,13 @@ class ScheduleNode:
...
@@ -149,11 +118,13 @@ class ScheduleNode:
outputs
=
self
.
output
outputs
=
self
.
output
if
not
isinstance
(
outputs
,
tuple
):
if
not
isinstance
(
outputs
,
tuple
):
outputs
=
(
outputs
,)
outputs
=
(
outputs
,)
assert
len
(
outputs
)
==
len
(
output_grad
),
(
assert
len
(
outputs
)
==
len
(
f
"
{
len
(
outputs
)
}
of
{
type
(
outputs
[
0
])
}
is not equal to "
output_grad
f
"
{
len
(
output_grad
)
}
of
{
type
(
output_grad
[
0
])
}
"
),
f
"
{
len
(
outputs
)
}
of
{
type
(
outputs
[
0
])
}
vs
{
len
(
output_grad
)
}
of
{
type
(
output_grad
[
0
])
}
"
)
if
self
.
backward_func
is
not
None
:
output_grad
=
self
.
backward_func
(
outputs
,
output_grad
)
output_grad
=
self
.
backward_func
(
outputs
,
output_grad
)
else
:
output_grad
=
self
.
default_backward_func
(
outputs
,
output_grad
)
torch
.
cuda
.
nvtx
.
range_pop
()
torch
.
cuda
.
nvtx
.
range_pop
()
# output_grad maybe from another stream
# output_grad maybe from another stream
...
@@ -163,7 +134,7 @@ class ScheduleNode:
...
@@ -163,7 +134,7 @@ class ScheduleNode:
return
self
.
get_grad
()
return
self
.
get_grad
()
def
get_grad
(
self
):
def
get_grad
(
self
):
"""
G
et the grad of inputs"""
"""
g
et the grad of inputs"""
grad
=
tuple
([
e
.
grad
if
e
is
not
None
else
None
for
e
in
self
.
inputs
])
grad
=
tuple
([
e
.
grad
if
e
is
not
None
else
None
for
e
in
self
.
inputs
])
# clear state
# clear state
self
.
inputs
=
None
self
.
inputs
=
None
...
@@ -175,7 +146,7 @@ class ScheduleNode:
...
@@ -175,7 +146,7 @@ class ScheduleNode:
class
AbstractSchedulePlan
(
ABC
):
class
AbstractSchedulePlan
(
ABC
):
"""
T
o use combined 1f1b, model must implement build_schedule_plan while take the same
"""
t
o use combined 1f1b, model must implement build_schedule_plan while take the same
signature as model forward but return an instance of AbstractSchedulePlan"""
signature as model forward but return an instance of AbstractSchedulePlan"""
@
classmethod
@
classmethod
...
@@ -207,29 +178,7 @@ def schedule_chunk_1f1b(
...
@@ -207,29 +178,7 @@ def schedule_chunk_1f1b(
post_forward
=
None
,
post_forward
=
None
,
post_backward
=
None
,
post_backward
=
None
,
):
):
"""Model level 1f1b fine-grained schedule
"""model level 1f1b fine-grained schedule"""
This function schedules the forward and backward passes for a chunk of the model.
It takes in the forward schedule plan, backward schedule plan, gradient, and optional
context managers for the forward and backward passes.
Args:
f_schedule_plan (subclass of AbstractSchedulePlan): The forward schedule plan
b_schedule_plan (subclass of AbstractSchedulePlan): The backward schedule plan
grad (Tensor or None): The gradient of the loss function
f_context (VppContextManager or None): The VppContextManager for the forward pass
b_context (VppContextManager or None): The VppContextManager for the backward pass
pre_forward (callable or None): The function to call before the forward pass
pre_backward (callable or None): The function to call before the backward pass
post_forward (callable or None): The function to call after the forward pass
post_backward (callable or None): The function to call after the backward pass
Returns:
The output of the forward pass
"""
# Calls fine_grained_schedule.py::ModelChunkSchedulePlan.forward_backward(),
# which calls fine_grained_schedule.py::schedule_chunk_1f1b()
return
type
(
f_schedule_plan
or
b_schedule_plan
).
forward_backward
(
return
type
(
f_schedule_plan
or
b_schedule_plan
).
forward_backward
(
f_schedule_plan
,
f_schedule_plan
,
b_schedule_plan
,
b_schedule_plan
,
...
@@ -243,19 +192,30 @@ def schedule_chunk_1f1b(
...
@@ -243,19 +192,30 @@ def schedule_chunk_1f1b(
)
)
def
schedule_chunk_forward
(
schedule_plan
):
"""model level fine-grained forward schedule"""
f_input
=
schedule_chunk_1f1b
(
schedule_plan
,
None
,
None
)
return
f_input
def
schedule_chunk_backward
(
schedule_plan
,
grad
):
"""model level fine-grained backward schedule"""
tmp
=
schedule_chunk_1f1b
(
None
,
schedule_plan
,
grad
)
_COMP_STREAM
=
None
_COMP_STREAM
=
None
_COM_STREAM
=
None
_COM_STREAM
=
None
def
set_streams
(
comp_stream
=
None
,
com_stream
=
None
):
def
set_streams
(
comp_stream
=
None
,
com_stream
=
None
):
"""
S
et the streams for communication and computation"""
"""
s
et the streams for communication and computation"""
global
_COMP_STREAM
global
_COMP_STREAM
global
_COM_STREAM
global
_COM_STREAM
if
_COMP_STREAM
is
not
None
:
if
_COMP_STREAM
is
not
None
:
return
return
if
comp_stream
is
None
:
if
comp_stream
is
None
:
comp_stream
=
torch
.
cuda
.
current_stream
(
)
comp_stream
=
torch
.
cuda
.
Stream
(
device
=
"cuda"
)
if
com_stream
is
None
:
if
com_stream
is
None
:
com_stream
=
torch
.
cuda
.
Stream
(
device
=
"cuda"
)
com_stream
=
torch
.
cuda
.
Stream
(
device
=
"cuda"
)
...
@@ -266,19 +226,19 @@ def set_streams(comp_stream=None, com_stream=None):
...
@@ -266,19 +226,19 @@ def set_streams(comp_stream=None, com_stream=None):
def
get_comp_stream
():
def
get_comp_stream
():
"""
G
et the stream for computation"""
"""
g
et the stream for computation"""
global
_COMP_STREAM
global
_COMP_STREAM
return
_COMP_STREAM
return
_COMP_STREAM
def
get_com_stream
():
def
get_com_stream
():
"""
G
et the stream for communication"""
"""
g
et the stream for communication"""
global
_COM_STREAM
global
_COM_STREAM
return
_COM_STREAM
return
_COM_STREAM
class
VppContextManager
:
class
VppContextManager
:
"""
A
reusable context manager for switch vpp stage"""
"""
a
reusable context manager for switch vpp stage"""
def
__init__
(
self
,
vpp_rank
):
def
__init__
(
self
,
vpp_rank
):
self
.
vpp_rank
=
vpp_rank
self
.
vpp_rank
=
vpp_rank
...
@@ -316,58 +276,75 @@ def forward_backward_step(
...
@@ -316,58 +276,75 @@ def forward_backward_step(
current_microbatch
=
None
,
current_microbatch
=
None
,
encoder_decoder_xattn
=
False
,
encoder_decoder_xattn
=
False
,
):
):
"""Merged forward and backward step for combined_1f1b.
"""Forward step for passed-in model.
If it is the first stage, the input tensor is obtained from the data_iterator.
Otherwise, the passed-in input_tensor is used.
Args:
Args:
Need to accept the argument of both forward_step() and backward_step().
forward_step_func (callable):
forward_step_func (callable): is wrapped by wrap_forward_func() which is now returning
The forward step function for the model that takes the
a forward schedule plan which is an input of schedule_chunk_1f1b function.
data iterator as the first argument, and model as the second.
f_context (VppContextManager or nullcontext): The context manager for setting vpp ranks.
This user's forward step is expected to output a tuple of two elements:
b_context (VppContextManager or nullcontext): The context manager for setting vpp ranks.
1. The output object from the forward step. This output object needs to be a
Only exists in 1f1b steady state with p2p overlap.
tensor or some kind of collection of tensors. The only hard requirement
pre_forward (callable): The function to call before the forward_step.
for this object is that it needs to be acceptible as input into the second
pre_backward (callable): The function to call before the backward_step.
function.
post_forward (callable): The function to call after the forward_step.
2. A function to reduce (optionally) the output from the forward step. This
post_backward (callable): The function to call after the backward_step.
could be a reduction over the loss from the model, it could be a function that
grabs the output from the model and reformats, it could be a function that just
passes through the model output. This function must have one of the following
patterns, and depending on the pattern different things happen internally:
a. A tuple of reduced loss and some other data. Note that in this case
the first argument is divided by the number of global microbatches,
assuming it is a loss, so that the loss is stable as a function of
the number of devices the step is split across.
b. A triple of reduced loss, number of tokens, and some other data. This
is similar to case (a), but the loss is further averaged across the
number of tokens in the batch. If the user is not already averaging
across the number of tokens, this pattern is useful to use.
c. Any arbitrary data the user wants (eg a dictionary of tensors, a list
of tensors, etc in the case of inference). To trigger case 3 you need
to specify `collect_non_loss_data=True` and you may also want to
specify `forward_only=True` in the call to the parent forward_backward
function.
data_iterator (iterator):
The data iterator.
model (nn.Module):
The model to perform the forward step on.
num_microbatches (int):
The number of microbatches.
input_tensor (Tensor or list[Tensor]):
The input tensor(s) for the forward step.
forward_data_store (list):
The list to store the forward data. If you go down path 2.a or
2.b for the return of your forward reduction function then this will store only the
final dimension of the output, for example the metadata output by the loss function.
If you go down the path of 2.c then this will store the entire output of the forward
reduction function applied to the model output.
config (object):
The configuration object.
collect_non_loss_data (bool, optional):
Whether to collect non-loss data. Defaults to False.
This is the path to use if you want to collect arbitrary output from the model forward,
such as with inference use cases. Defaults to False.
checkpoint_activations_microbatch (int, optional):
The microbatch to checkpoint activations.
Defaults to None.
is_first_microbatch (bool, optional):
Whether it is the first microbatch. Defaults to False.
current_microbatch (int, optional):
The current microbatch. Defaults to None.
Returns:
Returns:
forward_output_tensor (Tensor or list[Tensor]): The output object(s) from the forward step.
Tensor or list[Tensor]: The output object(s) from the forward step.
forward_num_tokens (Tensor): The number of tokens.
Tensor: The number of tokens.
backward_input_tensor_grad (Tensor): The grad of the input tensor.
Descriptions:
This method merges the forward_step() and backward_step() methods in the schedules.py file.
Assuming that:
def forward_step():
# forward_preprocess()
# forward_compute()
# forward_postprocess()
def backward_step():
# backward_preprocess()
# backward_compute()
# backward_postprocess()
Then the forward_backward_step() method will be:
def forward_backward_step():
# forward_preprocess() // the same as the forward_step()
# GENERATE f_schedule_plan // schedule happens in schedule_chunk_1f1b()
# backward_preprocess() // the same as the backward_step()
# COMBINED_FORWARD_BACKWARD_COMPUTE() // by calling schedule_chunk_1f1b()
# forward_postprocess() // the same as the forward_step()
# backward_postprocess() // the same as the backward_step()
"""
"""
assert
(
checkpoint_activations_microbatch
is
None
),
"checkpoint_activations_microbatch is not supported for combined_1f1b"
if
config
.
combined_1f1b_recipe
!=
"ep_a2a"
:
raise
NotImplementedError
(
f
"combined_1f1b_recipe
{
config
.
combined_1f1b_recipe
}
not supported yet"
)
from
.schedules
import
set_current_microbatch
from
.schedules
import
set_current_microbatch
if
f_model
is
not
None
and
config
.
timers
is
not
None
:
if
config
.
timers
is
not
None
:
config
.
timers
(
'forward-compute'
,
log_level
=
2
).
start
()
config
.
timers
(
'forward-compute'
,
log_level
=
2
).
start
()
if
config
.
enable_autocast
:
if
config
.
enable_autocast
:
...
@@ -377,7 +354,6 @@ def forward_backward_step(
...
@@ -377,7 +354,6 @@ def forward_backward_step(
# forward preprocess
# forward preprocess
unwrap_output_tensor
=
False
unwrap_output_tensor
=
False
f_schedule_plan
=
None
if
f_model
is
not
None
:
if
f_model
is
not
None
:
with
f_context
:
with
f_context
:
if
is_first_microbatch
and
hasattr
(
f_model
,
'set_is_first_microbatch'
):
if
is_first_microbatch
and
hasattr
(
f_model
,
'set_is_first_microbatch'
):
...
@@ -391,10 +367,15 @@ def forward_backward_step(
...
@@ -391,10 +367,15 @@ def forward_backward_step(
set_input_tensor
=
get_attr_wrapped_model
(
f_model
,
"set_input_tensor"
)
set_input_tensor
=
get_attr_wrapped_model
(
f_model
,
"set_input_tensor"
)
set_input_tensor
(
input_tensor
)
set_input_tensor
(
input_tensor
)
with
context_manager
:
# autocast context
with
context_manager
:
f_schedule_plan
,
loss_func
=
forward_step_func
(
data_iterator
,
f_model
)
if
checkpoint_activations_microbatch
is
None
:
output_tensor
,
loss_func
=
forward_step_func
(
data_iterator
,
f_model
)
else
:
output_tensor
,
loss_func
=
forward_step_func
(
data_iterator
,
f_model
,
checkpoint_activations_microbatch
)
assert
isinstance
(
assert
isinstance
(
f_schedule_plan
,
AbstractSchedulePlan
output_tensor
,
AbstractSchedulePlan
),
"first output of forward_step_func must be one instance of AbstractSchedulePlan"
),
"first output of forward_step_func must be one instance of AbstractSchedulePlan"
# backward preprocess
# backward preprocess
...
@@ -425,8 +406,9 @@ def forward_backward_step(
...
@@ -425,8 +406,9 @@ def forward_backward_step(
torch
.
autograd
.
backward
(
b_output_tensor
[
0
],
grad_tensors
=
b_output_tensor_grad
[
0
])
torch
.
autograd
.
backward
(
b_output_tensor
[
0
],
grad_tensors
=
b_output_tensor_grad
[
0
])
b_output_tensor_grad
[
0
]
=
loss_node
.
get_grad
()
b_output_tensor_grad
[
0
]
=
loss_node
.
get_grad
()
f_schedule_plan
=
output_tensor
if
f_model
else
None
grad
=
b_output_tensor_grad
[
0
]
if
b_model
else
None
grad
=
b_output_tensor_grad
[
0
]
if
b_model
else
None
with
context_manager
:
# autocast context
with
context_manager
:
# schedule forward and backward
# schedule forward and backward
output_tensor
=
schedule_chunk_1f1b
(
output_tensor
=
schedule_chunk_1f1b
(
f_schedule_plan
,
f_schedule_plan
,
...
@@ -442,7 +424,7 @@ def forward_backward_step(
...
@@ -442,7 +424,7 @@ def forward_backward_step(
# forward post process
# forward post process
num_tokens
=
None
num_tokens
=
None
if
f_model
is
not
None
:
if
f_model
:
with
f_context
:
with
f_context
:
num_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
int
)
num_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
int
)
if
parallel_state
.
is_pipeline_last_stage
():
if
parallel_state
.
is_pipeline_last_stage
():
...
@@ -511,33 +493,18 @@ def forward_backward_step(
...
@@ -511,33 +493,18 @@ def forward_backward_step(
return
output_tensor
,
num_tokens
,
input_tensor_grad
return
output_tensor
,
num_tokens
,
input_tensor_grad
def
get_default_cls_for_unwrap
():
def
get_default_cls_for_unwrap
():
"""Returns the default classes to unwrap from a model.
This function provides a tuple of classes that should be unwrapped from a model
to access the underlying GPTModel instance. It includes DistributedDataParallel
and Float16Module by default, and also attempts to include LegacyFloat16Module
if available for backward compatibility.
Returns:
tuple: A tuple of classes to unwrap from a model.
"""
cls
=
(
DistributedDataParallel
,
Float16Module
)
cls
=
(
DistributedDataParallel
,
Float16Module
)
try
:
try
:
# legacy should not be used in core, but for backward compatibility, we support it here
# legacy should not be used in core, but for backward compatibility, we support it here
from
megatron.legacy.model
import
Float16Module
as
LegacyFloat16Module
from
megatron.legacy.model
import
Float16Module
as
LegacyFloat16Module
cls
=
cls
+
(
LegacyFloat16Module
,)
cls
=
cls
+
(
LegacyFloat16Module
,)
except
:
except
:
pass
pass
return
cls
return
cls
def
unwrap_model
(
model
,
module_instances
=
get_default_cls_for_unwrap
()):
def
unwrap_model
(
model
,
module_instances
=
get_default_cls_for_unwrap
()):
"""Unwrap_model DistributedDataParallel and Float16Module wrapped model
"""unwrap_model DistributedDataParallel and Float16Module wrapped model"""
to return GPTModel instance
"""
return_list
=
True
return_list
=
True
if
not
isinstance
(
model
,
list
):
if
not
isinstance
(
model
,
list
):
model
=
[
model
]
model
=
[
model
]
...
@@ -546,80 +513,19 @@ def unwrap_model(model, module_instances=get_default_cls_for_unwrap()):
...
@@ -546,80 +513,19 @@ def unwrap_model(model, module_instances=get_default_cls_for_unwrap()):
for
model_module
in
model
:
for
model_module
in
model
:
while
isinstance
(
model_module
,
module_instances
):
while
isinstance
(
model_module
,
module_instances
):
model_module
=
model_module
.
module
model_module
=
model_module
.
module
assert
isinstance
(
model_module
,
GPTModel
),
"The final unwrapped model must be a GPTModel instance"
unwrapped_model
.
append
(
model_module
)
unwrapped_model
.
append
(
model_module
)
if
not
return_list
:
if
not
return_list
:
return
unwrapped_model
[
0
]
return
unwrapped_model
[
0
]
return
unwrapped_model
return
unwrapped_model
def
wrap_forward_func
(
forward_step_func
):
def
wrap_forward_func
(
config
,
forward_step_func
):
"""Wrap the input to forward_step_func.
"""wrap the input to forward_step_func, to make forward_step_func return schedule plan"""
The wrapped function will return forward_schedule_plan and the loss_function.
"""
def
wrapped_func
(
data_iterator
,
model
):
def
wrapped_func
(
data_iterator
,
model
):
# Model is unwrapped to get GPTModel instance.
# GPTModel.build_schedule_plan(model_forward_inputs) is called in the forward_step.
# The return value becomes (forward_schedule_plan, loss_function),
# which is used to be (forward_output_tensor, loss_function).
return
forward_step_func
(
data_iterator
,
unwrap_model
(
model
).
build_schedule_plan
)
return
forward_step_func
(
data_iterator
,
unwrap_model
(
model
).
build_schedule_plan
)
if
config
.
combined_1f1b
and
config
.
combined_1f1b_recipe
==
"ep_a2a"
:
return
wrapped_func
return
wrapped_func
else
:
return
forward_step_func
class
MemoryManagementStrategy
:
"""Base class for memory management strategies.
Different memory management strategies can be implemented by subclassing this class.
These strategies control how tensors are handled in memory during the computation.
"""
def
handle_inputs
(
self
,
inputs
,
stream
):
"""Process input tensors after computation.
Args:
inputs (tuple): Input tensors that have been used
stream (torch.cuda.Stream): Current CUDA stream
"""
pass
def
handle_outputs
(
self
,
outputs
,
stream
):
"""Process output tensors after computation.
Args:
outputs (tuple): Output tensors produced by the computation
stream (torch.cuda.Stream): Current CUDA stream
"""
pass
class
NoOpMemoryStrategy
(
MemoryManagementStrategy
):
"""Strategy that performs no memory management operations.
This is the default strategy - it doesn't free any memory.
"""
pass
class
FreeInputsMemoryStrategy
(
MemoryManagementStrategy
):
"""Strategy that immediately frees input tensors after they are used.
This strategy is useful for nodes where inputs are no longer needed
after computation, helping to reduce memory usage.
"""
def
handle_inputs
(
self
,
inputs
,
stream
):
"""Free input tensors by resizing their storage to zero.
Args:
inputs (tuple): Input tensors to be freed
stream (torch.cuda.Stream): Current CUDA stream
"""
for
input
in
inputs
:
if
input
is
not
None
:
input
.
record_stream
(
stream
)
input
.
untyped_storage
().
resize_
(
0
)
dcu_megatron/core/transformer/moe/token_dispatcher.py
View file @
32ee381a
from
megatron.core.transformer.moe.token_dispatcher
import
_DeepepManager
as
MegatronCoreDeepepManager
from
megatron.core.transformer.moe.token_dispatcher
import
MoEAlltoAllTokenDispatcher
as
MegatronCoreMoEAlltoAllTokenDispatcher
class
MoEAlltoAllTokenDispatcher
(
MoETokenDispatcher
):
def
token_permutation
(
# decouple perbatch state from MoEAlltoAllTokenDispatcher
class
MoEAlltoAllPerBatchState
:
def
__init__
(
self
,
build_event
=
False
):
self
.
num_global_tokens_per_local_expert
=
None
self
.
output_splits_tp
=
None
self
.
output_splits
=
None
self
.
input_splits
=
None
self
.
num_out_tokens
=
None
self
.
capacity
=
None
self
.
preprocess_event
=
None
self
.
hidden_shape
=
None
self
.
probs
=
None
self
.
routing_map
=
None
self
.
reversed_local_input_permutation_mapping
=
None
self
.
cuda_sync_point
=
None
self
.
hidden_shape_before_permute
=
None
class
MoEAlltoAllTokenDispatcher
(
MegatronCoreMoEAlltoAllTokenDispatcher
):
def
collect_per_batch_state
(
self
,
state
:
MoEAlltoAllPerBatchState
):
state
.
num_global_tokens_per_local_expert
=
getattr
(
self
,
"num_global_tokens_per_local_expert"
,
None
)
state
.
output_splits_tp
=
getattr
(
self
,
"output_splits_tp"
,
None
)
state
.
output_splits
=
getattr
(
self
,
"output_splits"
,
None
)
state
.
input_splits
=
getattr
(
self
,
"input_splits"
,
None
)
state
.
num_out_tokens
=
getattr
(
self
,
"num_out_tokens"
,
None
)
state
.
capacity
=
getattr
(
self
,
"capacity"
,
None
)
state
.
preprocess_event
=
getattr
(
self
,
"preprocess_event"
,
None
)
state
.
hidden_shape
=
getattr
(
self
,
"hidden_shape"
,
None
)
state
.
probs
=
getattr
(
self
,
"probs"
,
None
)
state
.
routing_map
=
getattr
(
self
,
"routing_map"
,
None
)
state
.
reversed_local_input_permutation_mapping
=
getattr
(
self
,
"reversed_local_input_permutation_mapping"
,
None
)
state
.
hidden_shape_before_permute
=
getattr
(
self
,
"hidden_shape_before_permute"
,
None
)
state
.
cuda_sync_point
=
getattr
(
self
,
"cuda_sync_point"
,
None
)
def
apply_per_batch_state
(
self
,
state
:
MoEAlltoAllPerBatchState
):
self
.
num_global_tokens_per_local_expert
=
state
.
num_global_tokens_per_local_expert
self
.
output_splits_tp
=
state
.
output_splits_tp
self
.
output_splits
=
state
.
output_splits
self
.
input_splits
=
state
.
input_splits
self
.
num_out_tokens
=
state
.
num_out_tokens
self
.
capacity
=
state
.
capacity
self
.
preprocess_event
=
state
.
preprocess_event
self
.
hidden_shape
=
state
.
hidden_shape
self
.
probs
=
state
.
probs
self
.
routing_map
=
state
.
routing_map
self
.
reversed_local_input_permutation_mapping
=
(
state
.
reversed_local_input_permutation_mapping
)
self
.
hidden_shape_before_permute
=
state
.
hidden_shape_before_permute
self
.
cuda_sync_point
=
state
.
cuda_sync_point
@
contextmanager
def
per_batch_state_context
(
self
,
state
:
MoEAlltoAllPerBatchState
):
origin_state
=
MoEAlltoAllPerBatchState
()
self
.
collect_per_batch_state
(
origin_state
)
try
:
self
.
apply_per_batch_state
(
state
)
yield
finally
:
self
.
collect_per_batch_state
(
state
)
self
.
apply_per_batch_state
(
origin_state
)
def
meta_prepare
(
self
,
hidden_states
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
,
routing_map
:
torch
.
Tensor
self
,
hidden_states
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
,
routing_map
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
):
"""
Dispatch tokens to local experts using AlltoAll communication.
This method performs the following steps:
1. Preprocess the routing map to get metadata for communication and permutation.
2. Permute input tokens for AlltoAll communication.
3. Perform expert parallel AlltoAll communication.
4. Sort tokens by local expert (if multiple local experts exist).
Args:
hidden_states (torch.Tensor): Input token embeddings.
probs (torch.Tensor): The probabilities of token to experts assignment.
routing_map (torch.Tensor): The mapping of token to experts assignment.
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- Permuted token embeddings for local experts.
- Number of tokens per expert.
- Permuted probs of each token produced by the router.
"""
# Preprocess: Get the metadata for communication, permutation and computation operations.
self
.
hidden_shape
=
hidden_states
.
shape
self
.
hidden_shape
=
hidden_states
.
shape
self
.
probs
=
probs
self
.
probs
=
probs
self
.
routing_map
=
routing_map
self
.
routing_map
=
routing_map
assert
probs
.
dim
()
==
2
,
"Expected 2D tensor for probs"
assert
probs
.
dim
()
==
2
,
"Expected 2D tensor for probs"
assert
routing_map
.
dim
()
==
2
,
"Expected 2D tensor for token2expert mask"
assert
routing_map
.
dim
()
==
2
,
"Expected 2D tensor for token2expert mask"
assert
routing_map
.
dtype
==
torch
.
bool
,
"Expected bool tensor for mask"
assert
routing_map
.
dtype
==
torch
.
bool
,
"Expected bool tensor for mask"
hidden_states
=
hidden_states
.
view
(
-
1
,
self
.
hidden_shape
[
-
1
])
tokens_per_expert
=
self
.
preprocess
(
self
.
routing_map
)
tokens_per_expert
=
self
.
preprocess
(
self
.
routing_map
)
return
tokens_per_expert
def
dispatch_preprocess
(
self
,
hidden_states
:
torch
.
Tensor
,
routing_map
:
torch
.
Tensor
,
tokens_per_expert
:
torch
.
Tensor
):
hidden_states
=
hidden_states
.
view
(
-
1
,
self
.
hidden_shape
[
-
1
])
if
self
.
shared_experts
is
not
None
:
if
self
.
shared_experts
is
not
None
:
self
.
shared_experts
.
pre_forward_comm
(
hidden_states
.
view
(
self
.
hidden_shape
))
self
.
shared_experts
.
pre_forward_comm
(
hidden_states
.
view
(
self
.
hidden_shape
))
...
@@ -49,12 +98,15 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
...
@@ -49,12 +98,15 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
)
=
permute
(
)
=
permute
(
hidden_states
,
hidden_states
,
routing_map
,
routing_map
,
probs
=
probs
,
probs
=
self
.
probs
,
num_out_tokens
=
self
.
num_out_tokens
,
num_out_tokens
=
self
.
num_out_tokens
,
fused
=
self
.
config
.
moe_permute_fusion
,
fused
=
self
.
config
.
moe_permute_fusion
,
drop_and_pad
=
self
.
drop_and_pad
,
drop_and_pad
=
self
.
drop_and_pad
,
)
)
return
tokens_per_expert
,
permutated_local_input_tokens
,
permuted_probs
def
dispatch_all_to_all
(
self
,
tokens_per_expert
,
permutated_local_input_tokens
,
permuted_probs
):
# Perform expert parallel AlltoAll communication
# Perform expert parallel AlltoAll communication
tokens_per_expert
=
self
.
_maybe_dtoh_and_synchronize
(
tokens_per_expert
=
self
.
_maybe_dtoh_and_synchronize
(
"before_ep_alltoall"
,
tokens_per_expert
"before_ep_alltoall"
,
tokens_per_expert
...
@@ -65,6 +117,10 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
...
@@ -65,6 +117,10 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
global_probs
=
all_to_all
(
global_probs
=
all_to_all
(
self
.
ep_group
,
permuted_probs
,
self
.
output_splits
,
self
.
input_splits
self
.
ep_group
,
permuted_probs
,
self
.
output_splits
,
self
.
input_splits
)
)
return
tokens_per_expert
,
global_input_tokens
,
global_probs
def
dispatch_postprocess
(
self
,
tokens_per_expert
,
global_input_tokens
,
global_probs
):
if
self
.
shared_experts
is
not
None
:
if
self
.
shared_experts
is
not
None
:
self
.
shared_experts
.
linear_fc1_forward_and_act
(
global_input_tokens
)
self
.
shared_experts
.
linear_fc1_forward_and_act
(
global_input_tokens
)
...
@@ -118,184 +174,137 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
...
@@ -118,184 +174,137 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher):
)
)
tokens_per_expert
=
self
.
_maybe_dtoh_and_synchronize
(
"before_finish"
,
tokens_per_expert
)
tokens_per_expert
=
self
.
_maybe_dtoh_and_synchronize
(
"before_finish"
,
tokens_per_expert
)
return
global_input_tokens
,
tokens_per_expert
,
global_probs
return
global_input_tokens
,
tokens_per_expert
,
global_probs
class
_DeepepManager
(
Meg
at
r
on
CoreDeepepManager
):
def
token_permut
at
i
on
(
"""
self
,
hidden_states
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
,
routing_map
:
torch
.
Tensor
patch megatron _DeepepManager. async
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
"""
Dispatch tokens to local experts using AlltoAll communication.
def
dispatch
(
This method performs the following steps:
self
,
1. Preprocess the routing map to get metadata for communication and permutation.
hidden_states
:
torch
.
Tensor
,
2. Permute input tokens for AlltoAll communication.
async_finish
:
bool
=
False
,
3. Perform expert parallel AlltoAll communication.
allocate_on_comm_stream
:
bool
=
False
,
4. Sort tokens by local expert (if multiple local experts exist).
)
->
torch
.
Tensor
:
# DeepEP only supports float32 probs
if
self
.
token_probs
.
dtype
!=
torch
.
float32
:
if
self
.
token_probs
.
dtype
in
[
torch
.
bfloat16
,
torch
.
float16
]:
print
(
"DeepEP only supports float32 probs, please set --moe-router-dtype=fp32"
)
self
.
token_probs
=
self
.
token_probs
.
float
()
# downcast or upcast
hidden_states
,
dispatched_indices
,
dispatched_probs
,
num_tokens_per_expert
,
handle
=
(
fused_dispatch
(
hidden_states
,
self
.
token_indices
,
self
.
token_probs
,
self
.
num_experts
,
self
.
group
,
async_finish
=
async_finish
,
allocate_on_comm_stream
=
allocate_on_comm_stream
,
)
)
self
.
handle
=
handle
self
.
tokens_per_expert
=
num_tokens_per_expert
self
.
dispatched_indices
=
dispatched_indices
self
.
dispatched_probs
=
dispatched_probs
return
hidden_states
def
combine
(
self
,
hidden_states
:
torch
.
Tensor
,
async_finish
:
bool
=
False
,
allocate_on_comm_stream
:
bool
=
False
,
)
->
torch
.
Tensor
:
hidden_states
,
_
=
fused_combine
(
hidden_states
,
self
.
group
,
self
.
handle
,
async_finish
=
async_finish
,
allocate_on_comm_stream
=
allocate_on_comm_stream
,
)
# Release the handle after combine operation
self
.
handle
=
None
return
hidden_states
class
MoEFlexTokenDispatcher
(
MoETokenDispatcher
):
"""
Flex token dispatcher using DeepEP.
"""
def
dispatch_preprocess
(
self
,
hidden_states
:
torch
.
Tensor
,
routing_map
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
):
"""
Preprocesses the hidden states and routing information before dispatching tokens to experts.
Args:
Args:
hidden_states (torch.Tensor): Input
hidden states to be processed
hidden_states (torch.Tensor): Input
token embeddings.
ro
uting_map
(torch.Tensor):
Map indicating which expert each token should be routed to
p
ro
bs
(torch.Tensor):
The probabilities of token to experts assignment.
p
ro
bs
(torch.Tensor):
Routing probabilities for each
token
-
expert
pair
ro
uting_map
(torch.Tensor):
The mapping of
token
to
expert
s assignment.
Returns:
Returns:
Tuple
containing
:
Tuple
[torch.Tensor, torch.Tensor, torch.Tensor]
:
- torch.Tensor: Reshaped hidden states
- Permuted token embeddings for local experts.
- torch.Tensor: T
oken p
robabilities from the communication manager
- Number of t
oken
s
p
er expert.
- None: Placeholder for compatibility
- Permuted probs of each token produced by the router.
"""
"""
self
.
hidden_shape
=
hidden_states
.
shape
# Preprocess: Get the metadata for communication, permutation and computation operations.
hidden_states
=
hidden_states
.
view
(
-
1
,
self
.
hidden_shape
[
-
1
])
# Permutation 1: input to AlltoAll input
tokens_per_expert
=
self
.
meta_prepare
(
hidden_states
,
probs
,
routing_map
)
# Initialize metadata
tokens_per_expert
,
permutated_local_input_tokens
,
permuted_probs
=
self
.
dispatch_preprocess
(
hidden_states
,
routing_map
,
tokens_per_expert
)
routing_map
,
probs
=
self
.
_initialize_metadata
(
routing_map
,
probs
)
self
.
_comm_manager
.
setup_metadata
(
routing_map
,
probs
)
# Perform expert parallel AlltoAll communication
return
hidden_states
,
self
.
_comm_manager
.
token_probs
,
None
tokens_per_expert
,
global_input_tokens
,
global_probs
=
self
.
dispatch_all_to_all
(
tokens_per_expert
,
permutated_local_input_tokens
,
permuted_probs
)
def
dispatch_all_to_all
(
# Permutation 2: Sort tokens by local expert.
self
,
global_input_tokens
,
tokens_per_expert
,
global_probs
=
self
.
dispatch_postprocess
(
tokens_per_expert
,
global_input_tokens
,
global_probs
)
hidden_states
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
=
None
,
async_finish
:
bool
=
True
,
allocate_on_comm_stream
:
bool
=
True
,
):
"""
Performs all-to-all communication to dispatch tokens across expert parallel ranks.
"""
return
(
self
.
_comm_manager
.
dispatch
(
hidden_states
,
async_finish
,
allocate_on_comm_stream
),
self
.
_comm_manager
.
dispatched_probs
,
)
def
dispatch_postprocess
(
self
,
hidden_states
:
torch
.
Tensor
):
return
global_input_tokens
,
tokens_per_expert
,
global_probs
"""
Post-processes the dispatched hidden states after all-to-all communication.
This method retrieves the permuted hidden states by experts, calculates the number of tokens
def
combine_preprocess
(
self
,
hidden_states
):
per expert, and returns the processed data ready for expert processing.
# Unpermutation 2: Unsort tokens by local expert.
"""
if
self
.
num_local_experts
>
1
:
global_input_tokens
,
permuted_probs
=
(
if
self
.
drop_and_pad
:
self
.
_comm_manager
.
get_permuted_hidden_states_by_experts
(
hidden_states
)
hidden_states
=
(
hidden_states
.
view
(
self
.
num_local_experts
,
self
.
tp_size
*
self
.
ep_size
,
self
.
capacity
,
*
hidden_states
.
size
()[
1
:],
)
)
tokens_per_expert
=
self
.
_comm_manager
.
get_number_of_tokens_per_expert
()
.
transpose
(
0
,
1
)
return
global_input_tokens
,
tokens_per_expert
,
permuted_probs
.
contiguous
()
.
flatten
(
start_dim
=
0
,
end_dim
=
2
)
def
token_permutation
(
self
,
hidden_states
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
,
routing_map
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Permutes tokens according to the routing map and dispatches them to experts.
This method implements the token permutation process in three steps:
1. Preprocess the hidden states and routing information
2. Perform all-to-all communication to dispatch tokens
3. Post-process the dispatched tokens for expert processing
"""
hidden_states
,
_
,
_
=
self
.
dispatch_preprocess
(
hidden_states
,
routing_map
,
probs
)
hidden_states
,
_
=
self
.
dispatch_all_to_all
(
hidden_states
,
async_finish
=
False
,
allocate_on_comm_stream
=
False
)
)
global_input_tokens
,
tokens_per_expert
,
permuted_probs
=
self
.
dispatch_postprocess
(
else
:
hidden_states
hidden_states
,
_
=
sort_chunks_by_idxs
(
hidden_states
,
self
.
num_global_tokens_per_local_expert
.
T
.
ravel
(),
self
.
restore_output_by_local_experts
,
fused
=
self
.
config
.
moe_permute_fusion
,
)
)
return
global_input_tokens
,
tokens_per_expert
,
permuted_probs
if
self
.
tp_size
>
1
:
if
self
.
output_splits_tp
is
None
:
input_split_sizes
=
None
else
:
input_split_sizes
=
self
.
output_splits_tp
.
tolist
()
# The precision of TP reduce_scatter should be the same as the router_dtype
hidden_states
=
reduce_scatter_to_sequence_parallel_region
(
hidden_states
.
to
(
self
.
probs
.
dtype
),
group
=
self
.
tp_group
,
input_split_sizes
=
input_split_sizes
,
).
to
(
hidden_states
.
dtype
)
return
hidden_states
def
combine_preprocess
(
self
,
hidden_states
:
torch
.
Tensor
):
def
combine_all_to_all
(
self
,
hidden_states
):
"""
# Perform expert parallel AlltoAll communication
Pre-processes the hidden states before combining them after expert processing.
# hidden_states: [SEQL, H] -> [SEQL, H/TP]
permutated_local_input_tokens
=
all_to_all
(
self
.
ep_group
,
hidden_states
,
self
.
input_splits
,
self
.
output_splits
)
return
permutated_local_input_tokens
This method restores the hidden states to their original ordering before expert processing
def
combine_postprocess
(
self
,
permutated_local_input_tokens
):
by using the communication manager's restoration function.
if
self
.
shared_experts
is
not
None
:
"""
self
.
shared_experts
.
linear_fc2_forward
(
permutated_local_input_tokens
)
hidden_states
=
self
.
_comm_manager
.
get_restored_hidden_states_by_experts
(
hidden_states
)
self
.
shared_experts
.
post_forward_comm
()
return
hidden_states
def
combine_all_to_all
(
# Unpermutation 1: AlltoAll output to output
self
,
output
=
unpermute
(
hidden_states
:
torch
.
Tensor
,
permutated_local_input_tokens
,
async_finish
:
bool
=
True
,
self
.
reversed_local_input_permutation_mapping
,
allocate_on_comm_stream
:
bool
=
True
,
restore_shape
=
self
.
hidden_shape_before_permute
,
):
routing_map
=
self
.
routing_map
,
"""
fused
=
self
.
config
.
moe_permute_fusion
,
Performs all-to-all communication to combine tokens after expert processing.
drop_and_pad
=
self
.
drop_and_pad
,
"""
)
return
self
.
_comm_manager
.
combine
(
hidden_states
,
async_finish
,
allocate_on_comm_stream
)
def
combine_postprocess
(
self
,
hidden_states
:
torch
.
Tensor
):
# Reshape the output tensor
"""
output
=
output
.
view
(
self
.
hidden_shape
)
Post-processes the combined hidden states after all-to-all communication.
This method reshapes the combined hidden states to match the original input shape.
# Add shared experts output
"""
if
self
.
shared_experts
is
not
None
:
return
hidden_states
.
view
(
self
.
hidden_shape
)
shared_expert_output
=
self
.
shared_experts
.
get_output
()
output
+=
shared_expert_output
return
output
def
token_unpermutation
(
def
token_unpermutation
(
self
,
hidden_states
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
self
,
hidden_states
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
"""
"""
Reverses the token permutation process to restore the original token order.
Reverse the token permutation to restore the original order.
This method performs the following steps:
1. Unsort tokens by local expert (if multiple local experts exist).
2. Perform expert parallel AlltoAll communication to restore the original order.
3. Unpermute tokens to restore the original order.
This method implements the token unpermutation process in three steps:
Args:
1. Pre-process the hidden states to restore their original ordering
hidden_states (torch.Tensor): Output from local experts.
2. Perform all-to-all communication to combine tokens
bias (torch.Tensor, optional): Bias tensor (not supported).
3. Post-process the combined tokens to match the original input shape
Returns:
Tuple[torch.Tensor, Optional[torch.Tensor]]:
- Unpermuted token embeddings in the original order.
- None (bias is not supported).
"""
"""
assert
bias
is
None
,
"Bias is not supported in MoEFlexTokenDispatcher"
assert
bias
is
None
,
"Bias is not supported in MoEAlltoAllTokenDispatcher"
hidden_states
=
self
.
combine_preprocess
(
hidden_states
)
hidden_states
=
self
.
combine_preprocess
(
hidden_states
)
hidden_state
s
=
self
.
combine_all_to_all
(
hidden_states
,
False
,
False
)
permutated_local_input_token
s
=
self
.
combine_all_to_all
(
hidden_states
)
hidden_states
=
self
.
combine_postprocess
(
hidden_state
s
)
output
=
self
.
combine_postprocess
(
permutated_local_input_token
s
)
return
hidden_states
,
None
return
output
,
None
dcu_megatron/core/transformer/transformer_layer.py
View file @
32ee381a
from
megatron.core
import
parallel_state
,
tensor_parallel
from
megatron.core.utils
import
(
deprecate_inference_params
,
make_viewless_tensor
,
)
from
megatron.core.transformer.transformer_layer
import
TransformerLayer
as
MegatronCoreTransformerLayer
from
megatron.core.transformer.transformer_layer
import
TransformerLayer
as
MegatronCoreTransformerLayer
class
TransformerLayer
(
MegatronCoreTransformerLayer
):
class
TransformerLayer
(
MegatronCoreTransformerLayer
):
def
_submodule_attn_router_forward
(
def
_callable_wrapper
(
self
,
self
,
is_forward
,
func
,
stream
,
event
,
*
args
,
skip_detach
=
False
,
**
kwargs
hidden_states
,
attention_mask
=
None
,
inference_params
=
None
,
rotary_pos_emb
=
None
,
rotary_pos_cos
=
None
,
rotary_pos_sin
=
None
,
attention_bias
=
None
,
packed_seq_params
=
None
,
sequence_len_offset
=
None
,
state
=
None
,
):
):
"""
"""
Performs a combined forward pass that includes self-attention and MLP routing logic.
Wraps a function call so that it waits for a given CUDA event before
proceeding and then runs the function on a specified CUDA stream.
"""
"""
hidden_states
,
_
=
self
.
_forward_attention
(
torch
.
cuda
.
nvtx
.
range_push
(
func
.
__name__
)
hidden_states
=
hidden_states
,
event
.
wait
(
stream
)
with
torch
.
cuda
.
stream
(
stream
):
outputs
=
func
(
*
args
,
**
kwargs
)
event
.
record
(
stream
)
if
skip_detach
:
torch
.
cuda
.
nvtx
.
range_pop
()
return
outputs
detached_output_tensors
=
[]
if
not
is_forward
:
torch
.
cuda
.
nvtx
.
range_pop
()
return
outputs
,
detached_output_tensors
for
tensor
in
outputs
:
if
tensor
is
None
:
detached_output_tensors
.
append
(
None
)
elif
tensor
.
dtype
.
is_floating_point
:
detached_output_tensors
.
append
(
tensor
.
detach
().
requires_grad_
(
True
))
else
:
detached_output_tensors
.
append
(
tensor
.
detach
())
torch
.
cuda
.
nvtx
.
range_pop
()
return
outputs
,
detached_output_tensors
def
_submodule_attention_forward
(
self
,
hidden_states
:
Tensor
,
attention_mask
:
Optional
[
Tensor
]
=
None
,
rotary_pos_emb
:
Optional
[
Tensor
]
=
None
,
rotary_pos_cos
:
Optional
[
Tensor
]
=
None
,
rotary_pos_sin
:
Optional
[
Tensor
]
=
None
,
attention_bias
:
Optional
[
Tensor
]
=
None
,
inference_context
:
Optional
[
Any
]
=
None
,
packed_seq_params
:
Optional
[
PackedSeqParams
]
=
None
,
sequence_len_offset
:
Optional
[
Tensor
]
=
None
,
*
,
inference_params
:
Optional
[
Any
]
=
None
,
):
# todo
inference_context
=
deprecate_inference_params
(
inference_context
,
inference_params
)
# Residual connection.
residual
=
hidden_states
# Optional Input Layer norm
if
self
.
recompute_input_layernorm
:
self
.
input_layernorm_checkpoint
=
tensor_parallel
.
CheckpointWithoutOutput
()
input_layernorm_output
=
self
.
input_layernorm_checkpoint
.
checkpoint
(
self
.
input_layernorm
,
hidden_states
)
else
:
input_layernorm_output
=
self
.
input_layernorm
(
hidden_states
)
# Self attention.
attention_output_with_bias
=
self
.
self_attention
(
input_layernorm_output
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
inference_context
=
inference_context
,
rotary_pos_emb
=
rotary_pos_emb
,
rotary_pos_emb
=
rotary_pos_emb
,
rotary_pos_cos
=
rotary_pos_cos
,
rotary_pos_cos
=
rotary_pos_cos
,
rotary_pos_sin
=
rotary_pos_sin
,
rotary_pos_sin
=
rotary_pos_sin
,
attention_bias
=
attention_bias
,
attention_bias
=
attention_bias
,
packed_seq_params
=
packed_seq_params
,
packed_seq_params
=
packed_seq_params
,
sequence_len_offset
=
sequence_len_offset
,
sequence_len_offset
=
sequence_len_offset
,
)
if
self
.
recompute_input_layernorm
:
# discard the output of the input layernorm and register the recompute
# as a gradient hook of attention_output_with_bias[0]
self
.
input_layernorm_checkpoint
.
discard_output_and_register_recompute
(
attention_output_with_bias
[
0
]
)
# TODO: could we move `bias_dropout_add_exec_handler` itself
# inside the module provided in the `bias_dropout_add_spec` module?
with
self
.
bias_dropout_add_exec_handler
():
hidden_states
=
self
.
self_attn_bda
(
self
.
training
,
self
.
config
.
bias_dropout_fusion
)(
attention_output_with_bias
,
residual
,
self
.
hidden_dropout
)
return
hidden_states
def
_submodule_attention_router_compound_forward
(
self
,
hidden_states
:
Tensor
,
attention_mask
:
Optional
[
Tensor
]
=
None
,
rotary_pos_emb
:
Optional
[
Tensor
]
=
None
,
rotary_pos_cos
:
Optional
[
Tensor
]
=
None
,
rotary_pos_sin
:
Optional
[
Tensor
]
=
None
,
attention_bias
:
Optional
[
Tensor
]
=
None
,
inference_context
:
Optional
[
Any
]
=
None
,
packed_seq_params
:
Optional
[
PackedSeqParams
]
=
None
,
sequence_len_offset
:
Optional
[
Tensor
]
=
None
,
*
,
inference_params
:
Optional
[
Any
]
=
None
,
):
"""
Performs a combined forward pass that includes self-attention and MLP routing logic.
"""
hidden_states
=
self
.
_submodule_attention_forward
(
hidden_states
,
attention_mask
,
rotary_pos_emb
,
rotary_pos_cos
,
rotary_pos_sin
,
attention_bias
,
inference_context
,
packed_seq_params
,
sequence_len_offset
,
inference_params
=
inference_params
,
inference_params
=
inference_params
,
)
)
# Optional Layer norm post the cross-attention.
if
self
.
recompute_pre_mlp_layernorm
:
self
.
pre_mlp_norm_checkpoint
=
tensor_parallel
.
CheckpointWithoutOutput
()
pre_mlp_layernorm_output
=
self
.
pre_mlp_norm_checkpoint
.
checkpoint
(
self
.
pre_mlp_layernorm
,
hidden_states
)
else
:
pre_mlp_layernorm_output
=
self
.
pre_mlp_layernorm
(
hidden_states
)
pre_mlp_layernorm_output
=
self
.
pre_mlp_layernorm
(
hidden_states
)
probs
,
routing_map
=
self
.
mlp
.
router
(
pre_mlp_layernorm_output
)
probs
,
routing_map
=
self
.
mlp
.
router
(
pre_mlp_layernorm_output
)
local_tokens
,
probs
,
tokens_per_expert
=
self
.
mlp
.
token_dispatcher
.
dispatch_preprocess
(
tokens_per_expert
=
self
.
mlp
.
token_dispatcher
.
meta_prepare
(
pre_mlp_layernorm_output
,
routing_map
,
probs
pre_mlp_layernorm_output
,
probs
,
routing_map
)
tokens_per_expert
,
permutated_local_input_tokens
,
permuted_probs
=
self
.
mlp
.
token_dispatcher
.
dispatch_preprocess
(
pre_mlp_layernorm_output
,
routing_map
,
tokens_per_expert
)
)
return
(
local_tokens
,
probs
,
hidden_states
,
pre_mlp_layernorm_output
,
tokens_per_expert
)
outputs
=
[
hidden_states
,
def
_submodule_dispatch_forward
(
self
,
local_tokens
,
probs
,
state
=
None
):
pre_mlp_layernorm_output
,
tokens_per_expert
,
permutated_local_input_tokens
,
permuted_probs
,
probs
,
]
return
tuple
(
outputs
)
def
_submodule_dispatch_forward
(
self
,
tokens_per_expert
,
permutated_local_input_tokens
,
permuted_probs
):
"""
"""
Dispatches tokens to the appropriate experts based on the router output.
Dispatches tokens to the appropriate experts based on the router output.
"""
"""
token_dispatcher
=
self
.
mlp
.
token_dispatcher
tokens_per_expert
,
global_input_tokens
,
global_probs
=
self
.
mlp
.
token_dispatcher
.
dispatch_all_to_all
(
if
self
.
is_deepep
:
tokens_per_expert
,
permutated_local_input_tokens
,
permuted_probs
token_dispatcher
.
_comm_manager
.
token_probs
=
probs
)
return
[
tokens_per_expert
,
global_input_tokens
,
global_probs
]
return
token_dispatcher
.
dispatch_all_to_all
(
local_tokens
,
probs
)
def
_submodule_dense_forward
(
self
,
hidden_states
):
residual
=
hidden_states
pre_mlp_layernorm_output
=
self
.
pre_mlp_layernorm
(
hidden_states
)
mlp_output_with_bias
=
self
.
mlp
(
pre_mlp_layernorm_output
)
with
self
.
bias_dropout_add_exec_handler
():
hidden_states
=
self
.
mlp_bda
(
self
.
training
,
self
.
config
.
bias_dropout_fusion
)(
mlp_output_with_bias
,
residual
,
self
.
hidden_dropout
)
output
=
make_viewless_tensor
(
inp
=
hidden_states
,
requires_grad
=
hidden_states
.
requires_grad
,
keep_graph
=
True
)
return
output
def
_submodule_moe_forward
(
self
,
dispatched_tokens
,
probs
=
None
,
state
=
None
):
def
_submodule_moe_forward
(
self
,
tokens_per_expert
,
global_input_tokens
,
global_prob
,
hidden_states
):
"""
"""
Performs a forward pass for the MLP submodule, including both expert-based
Performs a forward pass for the MLP submodule, including both expert-based
and optional shared-expert computations.
and optional shared-expert computations.
"""
"""
shared_expert_output
=
None
shared_expert_output
=
None
token_dispatcher
=
self
.
mlp
.
token_dispatcher
(
dispatched_input
,
tokens_per_expert
,
permuted_probs
)
=
(
if
self
.
is_deepep
:
self
.
mlp
.
token_dispatcher
.
dispatch_postprocess
(
tokens_per_expert
,
global_input_tokens
,
global_prob
)
token_dispatcher
.
_comm_manager
.
dispatched_probs
=
state
.
dispatched_probs
dispatched_tokens
,
tokens_per_expert
,
permuted_probs
=
(
token_dispatcher
.
dispatch_postprocess
(
dispatched_tokens
)
)
)
else
:
expert_output
,
mlp_bias
=
self
.
mlp
.
experts
(
dispatched_input
,
tokens_per_expert
,
permuted_probs
)
dispatched_tokens
,
permuted_probs
=
token_dispatcher
.
dispatch_postprocess
(
dispatched_tokens
,
probs
)
tokens_per_expert
=
state
.
tokens_per_expert
expert_output
,
mlp_bias
=
self
.
mlp
.
experts
(
dispatched_tokens
,
tokens_per_expert
,
permuted_probs
)
assert
mlp_bias
is
None
,
f
"Bias is not supported in
{
token_dispatcher
.
__class__
.
__name__
}
"
if
self
.
mlp
.
use_shared_expert
and
not
self
.
mlp
.
shared_expert_overlap
:
shared_expert_output
=
self
.
mlp
.
shared_experts
(
state
.
pre_mlp_layernorm_output
)
expert_output
=
self
.
mlp
.
token_dispatcher
.
combine_preprocess
(
expert_output
)
expert_output
=
self
.
mlp
.
token_dispatcher
.
combine_preprocess
(
expert_output
)
if
self
.
mlp
.
use_shared_expert
and
not
self
.
mlp
.
shared_expert_overlap
:
shared_expert_output
=
self
.
mlp
.
shared_experts
(
hidden_states
)
return
expert_output
,
shared_expert_output
,
mlp_bias
return
expert_output
,
shared_expert_output
,
mlp_bias
def
_submodule_combine_forward
(
self
,
output
,
shared_expert_output
=
None
,
state
=
None
):
def
_submodule_combine_forward
(
self
,
hidden_states
):
residual
=
state
.
residual
return
[
self
.
mlp
.
token_dispatcher
.
combine_all_to_all
(
hidden_states
)]
token_dispatcher
=
self
.
mlp
.
token_dispatcher
output
=
token_dispatcher
.
combine_all_to_all
(
output
)
def
_submodule_post_combine_forward
(
output
=
token_dispatcher
.
combine_postprocess
(
output
)
self
,
expert_output
,
shared_expert_output
,
mlp_bias
,
residual
):
"""
Re-combines the expert outputs (and optional shared_expert_output) into the same order
as the original input tokens, applying any required bias.
"""
output
=
self
.
mlp
.
token_dispatcher
.
combine_postprocess
(
expert_output
)
if
shared_expert_output
is
not
None
:
if
shared_expert_output
is
not
None
:
output
=
output
+
shared_expert_output
output
+
=
shared_expert_output
mlp_output_with_bias
=
(
output
,
None
)
mlp_output_with_bias
=
(
output
,
mlp_bias
)
with
self
.
bias_dropout_add_exec_handler
():
with
self
.
bias_dropout_add_exec_handler
():
hidden_states
=
self
.
mlp_bda
(
self
.
training
,
self
.
config
.
bias_dropout_fusion
)(
hidden_states
=
self
.
mlp_bda
(
self
.
training
,
self
.
config
.
bias_dropout_fusion
)(
mlp_output_with_bias
,
residual
,
self
.
hidden_dropout
mlp_output_with_bias
,
residual
,
self
.
hidden_dropout
...
@@ -92,133 +213,141 @@ class TransformerLayer(MegatronCoreTransformerLayer):
...
@@ -92,133 +213,141 @@ class TransformerLayer(MegatronCoreTransformerLayer):
return
output
return
output
def
_submodule_attn_router_dw
(
self
):
def
_submodule_attention_backward
(
self
,
hidden_states
,
pre_mlp_layernorm_output
,
detached_inputs
):
pre_mlp_layernorm_output
.
backward
(
detached_inputs
[
1
].
grad
)
hidden_states
.
backward
(
detached_inputs
[
0
].
grad
)
def
_submodule_attention_router_compound_backward
(
self
,
hidden_states
,
pre_mlp_layernorm_output
,
tokens_per_expert
,
permutated_local_input_tokens
,
probs
,
detached_inputs
,
):
permutated_local_input_tokens
.
backward
(
detached_inputs
[
3
].
grad
)
probs
.
backward
(
detached_inputs
[
4
].
grad
)
# tokens_per_expert.backward(detached_inputs[2].grad)
pre_mlp_layernorm_output
.
backward
(
detached_inputs
[
1
].
grad
)
hidden_states
.
backward
(
detached_inputs
[
0
].
grad
)
def
_submodule_dispatch_backward
(
self
,
global_input_tokens
,
detached_inputs
):
global_input_tokens
.
backward
(
detached_inputs
[
0
].
grad
)
def
_submodule_dense_backward
(
self
,
output
,
detached_inputs
):
output
.
backward
(
detached_inputs
[
0
].
grad
)
def
_submodule_moe_backward
(
self
,
expert_output
,
shared_expert_output
,
mlp_bias
,
detached_inputs
):
expert_output
.
backward
(
detached_inputs
[
0
].
grad
)
shared_expert_output
.
backward
(
detached_inputs
[
1
].
grad
)
if
mlp_bias
is
not
None
:
mlp_bias
.
backward
(
detached_inputs
[
2
].
grad
)
def
_submodule_combine_backward
(
self
,
hidden_states
,
detached_inputs
):
hidden_states
.
backward
(
detached_inputs
[
0
].
grad
)
def
_submodule_post_combine_backward
(
self
,
output
,
output_grad
):
output
.
backward
(
output_grad
)
def
_submodule_attention_router_compound_dgrad
(
self
):
raise
NotImplementedError
(
"Not implemented"
)
def
_submodule_attention_router_compound_dw
(
self
):
self
.
self_attention
.
backward_dw
()
self
.
self_attention
.
backward_dw
()
# raise NotImplementedError("Not implemented")
def
_submodule_dispatch_dgrad
(
self
):
raise
NotImplementedError
(
"Not implemented"
)
def
_submodule_mlp_dgrad
(
self
):
raise
NotImplementedError
(
"Not implemented"
)
def
_submodule_mlp_dw
(
self
):
def
_submodule_mlp_dw
(
self
):
self
.
mlp
.
backward_dw
()
self
.
mlp
.
backward_dw
()
# raise NotImplementedError("Not implemented")
def
_submodule_attn_router_postprocess
(
def
_submodule_combine_dgrad
(
self
):
self
,
node
,
local_tokens
,
probs
,
residual
,
pre_mlp_layernorm_output
,
tokens_per_expert
raise
NotImplementedError
(
"Not implemented"
)
):
node
.
common_state
.
residual
=
node
.
detach
(
residual
)
if
self
.
mlp
.
use_shared_expert
:
node
.
common_state
.
pre_mlp_layernorm_output
=
node
.
detach
(
pre_mlp_layernorm_output
)
if
not
self
.
is_deepep
:
node
.
common_state
.
tokens_per_expert
=
tokens_per_expert
return
local_tokens
,
probs
def
_submodule_dispatch_postprocess
(
self
,
node
,
dispatched_tokens
,
probs
):
if
self
.
is_deepep
:
node
.
common_state
.
dispatched_probs
=
node
.
detach
(
probs
)
return
dispatched_tokens
else
:
return
dispatched_tokens
,
probs
def
_submodule_mlp_postprocess
(
self
,
node
,
expert_output
,
shared_expert_output
,
mlp_bias
):
assert
mlp_bias
is
None
node
.
common_state
.
pre_mlp_layernorm_output
=
None
if
shared_expert_output
is
None
:
return
expert_output
return
expert_output
,
shared_expert_output
def
_submodule_combine_postprocess
(
self
,
node
,
output
):
cur_stream
=
torch
.
cuda
.
current_stream
()
node
.
common_state
.
residual
.
record_stream
(
cur_stream
)
node
.
common_state
.
residual
=
None
return
output
def
_submodule_
attn_postprocess
(
self
,
node
,
hidden_states
,
context
):
def
_submodule_
identity_forward
(
self
,
*
args
):
return
hidden_state
s
return
arg
s
def
_submodule_dense_postprocess
(
self
,
node
,
hidden_states
):
def
_submodule_identity_backward
(
self
,
*
args
):
return
hidden_states
pass
def
_submodule_not_implemented
(
self
,
*
args
):
raise
NotImplementedError
(
"This callable is not implemented."
)
def
get_submodule_callables
(
self
,
chunk_state
):
def
get_submodule_callables
(
self
):
"""
"""
The forward callables take 2 parts of inputs:
Returns a dictionary of submodule callables for the transformer layer.
1. The ScheduleNode object.
2. The input tensors.
"""
"""
from
megatron.core.transformer.moe.moe_layer
import
MoELayer
from
megatron.core.transformer.moe.moe_layer
import
MoELayer
from
megatron.core.transformer.moe.token_dispatcher
import
MoEFlexTokenDispatcher
self
.
is_moe
=
isinstance
(
self
.
mlp
,
MoELayer
)
self
.
is_deepep
=
False
if
self
.
is_moe
:
self
.
is_deepep
=
isinstance
(
self
.
mlp
.
token_dispatcher
,
MoEFlexTokenDispatcher
)
def
get_func_with_default
(
func
,
default_func
):
def
get_func_with_default
(
func
,
default_func
):
if
self
.
is_moe
:
if
isinstance
(
self
.
mlp
,
MoELayer
)
:
return
func
return
func
return
default_func
return
default_func
def
callable_wrapper
(
forward_func
,
postprocess_func
,
node
,
*
args
):
attention_func
=
get_func_with_default
(
state
=
getattr
(
node
,
'common_state'
,
None
)
self
.
_submodule_attention_router_compound_forward
,
self
.
_submodule_attention_forward
callable_outputs
=
forward_func
(
*
args
,
state
=
state
)
if
isinstance
(
callable_outputs
,
tuple
):
outputs
=
postprocess_func
(
node
,
*
callable_outputs
)
else
:
outputs
=
postprocess_func
(
node
,
callable_outputs
)
return
outputs
attn_func
=
get_func_with_default
(
self
.
_submodule_attn_router_forward
,
self
.
_forward_attention
)
def
attn_wrapper
(
hidden_states
,
state
=
None
):
return
attn_func
(
hidden_states
=
hidden_states
,
attention_mask
=
chunk_state
.
attention_mask
,
attention_bias
=
chunk_state
.
attention_bias
,
inference_params
=
chunk_state
.
inference_params
,
packed_seq_params
=
chunk_state
.
packed_seq_params
,
sequence_len_offset
=
chunk_state
.
sequence_len_offset
,
rotary_pos_emb
=
chunk_state
.
rotary_pos_emb
,
rotary_pos_cos
=
chunk_state
.
rotary_pos_cos
,
rotary_pos_sin
=
chunk_state
.
rotary_pos_sin
,
state
=
state
,
)
)
attention_backward_func
=
get_func_with_default
(
attn_postprocess_func
=
get_func_with_default
(
self
.
_submodule_attention_router_compound_backward
,
self
.
_submodule_attention_backward
self
.
_submodule_attn_router_postprocess
,
self
.
_submodule_attn_postprocess
)
)
dispatch_func
=
get_func_with_default
(
dispatch_func
=
get_func_with_default
(
self
.
_submodule_dispatch_forward
,
self
.
_submodule_
not_implemente
d
self
.
_submodule_dispatch_forward
,
self
.
_submodule_
identity_forwar
d
)
)
dispatch_
postprocess
_func
=
get_func_with_default
(
dispatch_
backward
_func
=
get_func_with_default
(
self
.
_submodule_dispatch_
postprocess
,
self
.
_submodule_
not_implemente
d
self
.
_submodule_dispatch_
backward
,
self
.
_submodule_
identity_backwar
d
)
)
mlp_func
=
get_func_with_default
(
self
.
_submodule_moe_forward
,
self
.
_submodule_dense_forward
)
mlp_func
=
get_func_with_default
(
self
.
_submodule_moe_forward
,
self
.
_forward_mlp
)
mlp_backward_func
=
get_func_with_default
(
mlp_postprocess_func
=
get_func_with_default
(
self
.
_submodule_moe_backward
,
self
.
_submodule_dense_backward
self
.
_submodule_mlp_postprocess
,
self
.
_submodule_dense_postprocess
)
)
combine_func
=
get_func_with_default
(
combine_func
=
get_func_with_default
(
self
.
_submodule_combine_forward
,
self
.
_submodule_
not_implemente
d
self
.
_submodule_combine_forward
,
self
.
_submodule_
identity_forwar
d
)
)
combine_postprocess_func
=
get_func_with_default
(
combine_backward_func
=
get_func_with_default
(
self
.
_submodule_combine_postprocess
,
self
.
_submodule_not_implemented
self
.
_submodule_combine_backward
,
self
.
_submodule_identity_backward
)
post_combine_func
=
get_func_with_default
(
self
.
_submodule_post_combine_forward
,
self
.
_submodule_identity_forward
)
post_combine_backward_func
=
get_func_with_default
(
self
.
_submodule_post_combine_backward
,
self
.
_submodule_identity_backward
)
)
attn_forward
=
partial
(
callable_wrapper
,
attn_wrapper
,
attn_postprocess_func
)
dispatch_forward
=
partial
(
callable_wrapper
,
dispatch_func
,
dispatch_postprocess_func
)
mlp_forward
=
partial
(
callable_wrapper
,
mlp_func
,
mlp_postprocess_func
)
combine_forward
=
partial
(
callable_wrapper
,
combine_func
,
combine_postprocess_func
)
callables
=
TransformerLayerSubmoduleCallables
(
callables
=
TransformerLayerSubmoduleCallables
(
attention
=
SubmoduleCallables
(
forward
=
attn_forward
,
dw
=
self
.
_submodule_attn_router_dw
),
attention
=
SubmoduleCallables
(
dispatch
=
SubmoduleCallables
(
forward
=
dispatch_forward
),
forward
=
partial
(
self
.
_callable_wrapper
,
True
,
attention_func
,
skip_detach
=
True
),
mlp
=
SubmoduleCallables
(
forward
=
mlp_forward
,
dw
=
self
.
_submodule_mlp_dw
),
backward
=
partial
(
self
.
_callable_wrapper
,
False
,
attention_backward_func
),
combine
=
SubmoduleCallables
(
forward
=
combine_forward
),
# dgrad=partial(self._callable_wrapper, False,self._submodule_attention_router_compound_dgrad),
is_moe
=
self
.
is_moe
,
dw
=
partial
(
is_deepep
=
self
.
is_deepep
,
self
.
_callable_wrapper
,
False
,
self
.
_submodule_attention_router_compound_dw
),
),
dispatch
=
SubmoduleCallables
(
forward
=
partial
(
self
.
_callable_wrapper
,
True
,
dispatch_func
),
backward
=
partial
(
self
.
_callable_wrapper
,
False
,
dispatch_backward_func
),
# dgrad=partial(self._callable_wrapper, False, self._submodule_dispatch_dgrad),
),
mlp
=
SubmoduleCallables
(
forward
=
partial
(
self
.
_callable_wrapper
,
True
,
mlp_func
),
backward
=
partial
(
self
.
_callable_wrapper
,
False
,
mlp_backward_func
),
# dgrad=partial(self._callable_wrapper, False, self._submodule_mlp_dgrad),
dw
=
partial
(
self
.
_callable_wrapper
,
False
,
self
.
_submodule_mlp_dw
),
),
combine
=
SubmoduleCallables
(
forward
=
partial
(
self
.
_callable_wrapper
,
True
,
combine_func
),
backward
=
partial
(
self
.
_callable_wrapper
,
False
,
combine_backward_func
),
# dgrad=partial(self._callable_wrapper, False, self._submodule_combine_dgrad),
),
post_combine
=
SubmoduleCallables
(
forward
=
partial
(
self
.
_callable_wrapper
,
True
,
post_combine_func
),
backward
=
partial
(
self
.
_callable_wrapper
,
False
,
post_combine_backward_func
),
),
)
)
return
callables
return
callables
\ No newline at end of file
dcu_megatron/core/transformer/utils.py
0 → 100644
View file @
32ee381a
from
dataclasses
import
dataclass
from
typing
import
Callable
,
Optional
@
dataclass
class
SubmoduleCallables
:
"""
Holds references to forward, dgrad, and dw (weight-grad) callables
for a particular submodule.
"""
forward
:
Optional
[
Callable
]
=
None
backward
:
Optional
[
Callable
]
=
None
dgrad
:
Optional
[
Callable
]
=
None
dw
:
Optional
[
Callable
]
=
None
@
dataclass
class
TransformerLayerSubmoduleCallables
:
"""
Collects the SubmoduleMethods for each of the submodules:
'attention', 'dispatch', 'mlp', 'combine'.
"""
attention
:
SubmoduleCallables
dispatch
:
SubmoduleCallables
mlp
:
SubmoduleCallables
combine
:
SubmoduleCallables
post_combine
:
SubmoduleCallables
def
as_array
(
self
):
return
[
self
.
attention
,
self
.
dispatch
,
self
.
mlp
,
self
.
combine
,
self
.
post_combine
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment