Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
12b56c98
Commit
12b56c98
authored
Apr 30, 2025
by
dongcl
Browse files
support a2a overlap
parent
8551c38e
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
2139 additions
and
125 deletions
+2139
-125
dcu_megatron/core/models/gpt/fine_grained_schedule.py
dcu_megatron/core/models/gpt/fine_grained_schedule.py
+778
-0
dcu_megatron/core/models/gpt/gpt_model.py
dcu_megatron/core/models/gpt/gpt_model.py
+191
-125
dcu_megatron/core/pipeline_parallel/combined_1f1b.py
dcu_megatron/core/pipeline_parallel/combined_1f1b.py
+625
-0
dcu_megatron/core/transformer/moe/token_dispatcher.py
dcu_megatron/core/transformer/moe/token_dispatcher.py
+301
-0
dcu_megatron/core/transformer/transformer_block.py
dcu_megatron/core/transformer/transformer_block.py
+20
-0
dcu_megatron/core/transformer/transformer_layer.py
dcu_megatron/core/transformer/transformer_layer.py
+224
-0
No files found.
dcu_megatron/core/models/gpt/fine_grained_schedule.py
0 → 100644
View file @
12b56c98
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
import
contextlib
import
weakref
from
typing
import
Optional
import
torch
from
torch
import
Tensor
from
megatron.core.pipeline_parallel.combined_1f1b
import
(
AbstractSchedulePlan
,
FakeScheduleNode
,
FreeInputsMemoryStrategy
,
NoOpMemoryStrategy
,
ScheduleNode
,
get_com_stream
,
get_comp_stream
,
make_viewless
,
)
from
megatron.core.transformer
import
transformer_layer
from
megatron.core.transformer.module
import
float16_to_fp32
def
weak_method
(
method
):
"""Creates a weak reference to a method to prevent circular references.
This function creates a weak reference to a method and returns a wrapper function
that calls the method when invoked. This helps prevent memory leaks from circular
references.
"""
method_ref
=
weakref
.
WeakMethod
(
method
)
del
method
def
wrapped_func
(
*
args
,
**
kwarg
):
# nonlocal object_ref
return
method_ref
()(
*
args
,
**
kwarg
)
return
wrapped_func
class
MemoryStrategyRegistry
:
"""Registry for memory management strategies based on node names.
This class centralizes the definition of which memory strategy
should be used for each type of node in the computation graph.
"""
@
classmethod
def
get_strategy_by_name
(
cls
,
name
,
is_moe
,
is_deepep
):
"""Gets the appropriate memory strategy for a node based on its name and MoE status.
Args:
name: The name of the node, which determines which strategy to use.
is_moe: Whether the node is part of a Mixture of Experts model.
Returns:
The memory strategy to use for the node.
"""
strategies
=
{
"default"
:
NoOpMemoryStrategy
(),
"attn"
:
NoOpMemoryStrategy
(),
# Attention nodes keep their inputs
"dispatch"
:
(
FreeInputsMemoryStrategy
()
if
not
is_deepep
else
NoOpMemoryStrategy
()
),
# deepep dispatch inputs share same storage with moe inputs
"mlp"
:
FreeInputsMemoryStrategy
(),
# MLP nodes free inputs after use
"combine"
:
FreeInputsMemoryStrategy
(),
# Combine nodes free inputs after use
}
if
is_moe
:
return
strategies
.
get
(
name
,
strategies
[
"default"
])
# For dense layers [attn, fake, mlp, fake], the inputs of mlp are required for backward
return
NoOpMemoryStrategy
()
class
PreProcessNode
(
ScheduleNode
):
"""Node responsible for preprocessing operations in the model.
This node handles embedding and rotary positional embedding computations
before the main transformer layers.
"""
def
__init__
(
self
,
gpt_model
,
model_chunk_state
,
event
,
stream
):
"""Initializes a preprocessing node.
Args:
gpt_model: The GPT model instance.
model_chunk_state: State shared across the model chunk.
event: CUDA event for synchronization.
stream: CUDA stream for execution.
"""
super
().
__init__
(
weak_method
(
self
.
forward_impl
),
stream
,
event
,
name
=
"pre_process"
)
self
.
gpt_model
=
gpt_model
self
.
model_chunk_state
=
model_chunk_state
def
forward_impl
(
self
):
"""Implements the forward pass for preprocessing.
This method handles:
1. Decoder embedding computation
2. Rotary positional embedding computation
3. Sequence length offset computation for flash decoding
Returns:
The processed decoder input tensor.
"""
gpt_model
=
self
.
gpt_model
decoder_input
=
self
.
model_chunk_state
.
decoder_input
input_ids
=
self
.
model_chunk_state
.
input_ids
position_ids
=
self
.
model_chunk_state
.
position_ids
inference_params
=
self
.
model_chunk_state
.
inference_params
packed_seq_params
=
self
.
model_chunk_state
.
packed_seq_params
# Decoder embedding.
if
decoder_input
is
not
None
:
pass
elif
gpt_model
.
pre_process
:
decoder_input
=
gpt_model
.
embedding
(
input_ids
=
input_ids
,
position_ids
=
position_ids
)
else
:
# intermediate stage of pipeline
# decoder will get hidden_states from encoder.input_tensor
decoder_input
=
gpt_model
.
decoder
.
input_tensor
# Rotary positional embeddings (embedding is None for PP intermediate devices)
rotary_pos_emb
=
None
rotary_pos_cos
=
None
rotary_pos_sin
=
None
if
(
gpt_model
.
position_embedding_type
==
'rope'
and
not
gpt_model
.
config
.
multi_latent_attention
):
if
not
gpt_model
.
training
and
gpt_model
.
config
.
flash_decode
and
inference_params
:
# Flash decoding uses precomputed cos and sin for RoPE
rotary_pos_cos
,
rotary_pos_sin
=
gpt_model
.
rotary_pos_emb_cache
.
setdefault
(
inference_params
.
max_sequence_length
,
gpt_model
.
rotary_pos_emb
.
get_cos_sin
(
inference_params
.
max_sequence_length
),
)
else
:
rotary_seq_len
=
gpt_model
.
rotary_pos_emb
.
get_rotary_seq_len
(
inference_params
,
gpt_model
.
decoder
,
decoder_input
,
gpt_model
.
config
,
packed_seq_params
,
)
rotary_pos_emb
=
gpt_model
.
rotary_pos_emb
(
rotary_seq_len
,
packed_seq
=
packed_seq_params
is
not
None
and
packed_seq_params
.
qkv_format
==
'thd'
,
)
if
(
(
gpt_model
.
config
.
enable_cuda_graph
or
gpt_model
.
config
.
flash_decode
)
and
rotary_pos_cos
is
not
None
and
inference_params
):
sequence_len_offset
=
torch
.
tensor
(
[
inference_params
.
sequence_len_offset
]
*
inference_params
.
current_batch_size
,
dtype
=
torch
.
int32
,
device
=
rotary_pos_cos
.
device
,
# Co-locate this with the rotary tensors
)
else
:
sequence_len_offset
=
None
# saved for later use
self
.
model_chunk_state
.
rotary_pos_emb
=
rotary_pos_emb
self
.
model_chunk_state
.
rotary_pos_cos
=
rotary_pos_cos
self
.
model_chunk_state
.
rotary_pos_sin
=
rotary_pos_sin
self
.
model_chunk_state
.
sequence_len_offset
=
sequence_len_offset
return
decoder_input
class
PostProcessNode
(
ScheduleNode
):
"""Node responsible for postprocessing operations in the model.
This node handles final layer normalization and output layer computation
after the main transformer layers.
"""
def
__init__
(
self
,
gpt_model
,
model_chunk_state
,
event
,
stream
):
"""Initializes a postprocessing node.
Args:
gpt_model: The GPT model instance.
model_chunk_state: State shared across the model chunk.
event: CUDA event for synchronization.
stream: CUDA stream for execution.
"""
super
().
__init__
(
weak_method
(
self
.
forward_impl
),
stream
,
event
,
name
=
"post_process"
)
self
.
gpt_model
=
gpt_model
self
.
model_chunk_state
=
model_chunk_state
def
forward_impl
(
self
,
hidden_states
):
"""Implements the forward pass for postprocessing.
This method handles:
1. Final layer normalization
2. Output layer computation
3. Loss computation if labels are provided
Args:
hidden_states: The hidden states from the transformer layers.
Returns:
The logits or loss depending on whether labels are provided.
"""
# Final layer norm.
if
self
.
gpt_model
.
decoder
.
final_layernorm
is
not
None
:
hidden_states
=
self
.
gpt_model
.
decoder
.
final_layernorm
(
hidden_states
)
# TENorm produces a "viewed" tensor. This will result in schedule.py's
# deallocate_output_tensor() throwing an error, so a viewless tensor is
# created to prevent this.
hidden_states
=
transformer_layer
.
make_viewless_tensor
(
inp
=
hidden_states
,
requires_grad
=
True
,
keep_graph
=
True
)
gpt_model
=
self
.
gpt_model
runtime_gather_output
=
self
.
model_chunk_state
.
runtime_gather_output
labels
=
self
.
model_chunk_state
.
labels
output_weight
=
None
if
gpt_model
.
share_embeddings_and_output_weights
:
output_weight
=
gpt_model
.
shared_embedding_or_output_weight
()
logits
,
_
=
gpt_model
.
output_layer
(
hidden_states
,
weight
=
output_weight
,
runtime_gather_output
=
runtime_gather_output
)
if
labels
is
None
:
# [s b h] => [b s h]
return
float16_to_fp32
(
logits
.
transpose
(
0
,
1
).
contiguous
())
loss
=
float16_to_fp32
(
gpt_model
.
compute_language_model_loss
(
labels
,
logits
))
return
loss
class
TransformerLayerNode
(
ScheduleNode
):
"""Base class for transformer layer computation nodes.
This class provides common functionality for different types of
transformer layer nodes (attention, MLP, etc.)
"""
def
__init__
(
self
,
stream
,
event
,
state
,
callables
,
name
=
"default"
):
"""Initialize a transformer layer node.
Args:
stream (torch.cuda.Stream): CUDA stream for execution
event (torch.cuda.Event): Synchronization event
common_state (TransformerLayerState): State shared within a transformer layer
callables (Callable): The callables contain forward and dw function
it's the per_batch_state_context, o.w. nullcontext
name (str): Node name, also used to determine memory strategy
"""
# Get memory strategy based on node name
memory_strategy
=
MemoryStrategyRegistry
.
get_strategy_by_name
(
name
,
callables
.
is_moe
,
callables
.
is_deepep
)
super
().
__init__
(
weak_method
(
self
.
forward_impl
),
stream
,
event
,
weak_method
(
self
.
backward_impl
),
memory_strategy
=
memory_strategy
,
name
=
name
,
)
self
.
common_state
=
state
self
.
callables
=
callables
self
.
detached
=
tuple
()
self
.
before_detached
=
tuple
()
def
detach
(
self
,
t
):
"""Detaches a tensor and stores it for backward computation."""
detached
=
make_viewless
(
t
).
detach
()
detached
.
requires_grad
=
t
.
requires_grad
self
.
before_detached
=
self
.
before_detached
+
(
t
,)
self
.
detached
=
self
.
detached
+
(
detached
,)
return
detached
def
forward_impl
(
self
,
*
args
):
"""Implements the forward pass for the transformer layer node."""
return
self
.
callables
.
forward
(
self
,
*
args
)
def
backward_impl
(
self
,
outputs
,
output_grad
):
"""Implements the backward pass for the transformer layer node."""
detached_grad
=
tuple
([
e
.
grad
for
e
in
self
.
detached
])
grads
=
output_grad
+
detached_grad
self
.
default_backward_func
(
outputs
+
self
.
before_detached
,
grads
)
self
.
before_detached
=
None
self
.
detached
=
None
# return grads for record stream
return
grads
def
dw
(
self
):
"""Computes the weight gradients for the transformer layer node."""
with
torch
.
cuda
.
nvtx
.
range
(
f
"
{
self
.
name
}
wgrad"
):
self
.
callables
.
dw
()
class
TransformerLayerState
:
"""State shared within a transformer layer.
This class holds state that is shared between different nodes
within a transformer layer.
"""
pass
class
ModelChunkSate
:
"""State shared across a model chunk.
This class holds state that is shared between different components
of a model chunk, such as input tensors, parameters, and configuration.
"""
pass
class
TransformerLayerSchedulePlan
:
"""Schedule plan for a transformer layer.
This class organizes the computation nodes for a transformer layer,
including attention, MLP, dispatch, and combine nodes.
"""
def
__init__
(
self
,
layer
,
event
,
chunk_state
,
comp_stream
,
com_stream
):
"""Initializes a transformer layer schedule plan.
Args:
layer (TransformerLayer): The transformer layer to schedule.
event (torch.cuda.Event): CUDA event for synchronization.
chunk_state (ModelChunkState): State shared across the model chunk.
comp_stream (torch.cuda.Stream): CUDA stream for computation.
com_stream (torch.cuda.Stream): CUDA stream for communication.
"""
self
.
common_state
=
TransformerLayerState
()
# get callables for transformer layer
attn_callable
,
dispatch_callable
,
mlp_callable
,
combine_callable
=
(
layer
.
get_submodule_callables
(
chunk_state
).
as_array
()
)
# Create nodes for different operations in the layer
# Each node type has a predefined name that determines its memory strategy
self
.
attn
=
TransformerLayerNode
(
comp_stream
,
event
,
self
.
common_state
,
attn_callable
,
name
=
"attn"
)
self
.
mlp
=
TransformerLayerNode
(
comp_stream
,
event
,
self
.
common_state
,
mlp_callable
,
name
=
"mlp"
)
if
attn_callable
.
is_moe
:
self
.
dispatch
=
TransformerLayerNode
(
com_stream
,
event
,
self
.
common_state
,
dispatch_callable
,
name
=
"dispatch"
)
self
.
combine
=
TransformerLayerNode
(
com_stream
,
event
,
self
.
common_state
,
combine_callable
,
name
=
"combine"
)
else
:
self
.
dispatch
=
FakeScheduleNode
()
self
.
combine
=
FakeScheduleNode
()
class
ModelChunkSchedulePlan
(
AbstractSchedulePlan
):
"""Schedule plan for a model chunk.
This class organizes the computation nodes for a model chunk,
including preprocessing, transformer layers, and postprocessing.
"""
def
__init__
(
self
):
"""Initializes a model chunk schedule plan."""
super
().
__init__
()
self
.
_pre_process
=
None
self
.
_post_process
=
None
self
.
_model_chunk_state
=
ModelChunkSate
()
self
.
_transformer_layers
=
[]
self
.
_event
=
torch
.
cuda
.
Event
()
@
classmethod
def
forward_backward
(
cls
,
f_schedule_plan
,
b_schedule_plan
,
grad
=
None
,
f_context
=
None
,
b_context
=
None
,
pre_forward
=
None
,
pre_backward
=
None
,
post_forward
=
None
,
post_backward
=
None
,
):
"""Schedules forward and backward passes for model chunks.
Args:
f_schedule_plan (ModelChunkSchedulePlan): Forward schedule plan.
b_schedule_plan (ModelChunkSchedulePlan): Backward schedule plan.
grad (Tensor): Gradient for backward computation.
f_context (VppContextManager or None): The VppContextManager for the forward pass.
b_context (VppContextManager or None): The VppContextManager for the backward pass
pre_forward (Callable): Callback for preprocessing in forward pass.
pre_backward (Callable): Callback for preprocessing in backward pass.
post_forward (Callable): Callback for postprocessing in forward pass.
post_backward (Callable): Callback for postprocessing in backward pass.
Returns:
The output of the forward pass.
"""
return
schedule_chunk_1f1b
(
f_schedule_plan
,
b_schedule_plan
,
grad
=
grad
,
f_context
=
f_context
,
b_context
=
b_context
,
pre_forward
=
pre_forward
,
pre_backward
=
pre_backward
,
post_forward
=
post_forward
,
post_backward
=
post_backward
,
)
@
property
def
event
(
self
):
"""Gets the CUDA event for synchronization."""
return
self
.
_event
def
record_current_stream
(
self
):
"""Records the current CUDA stream in the event."""
stream
=
torch
.
cuda
.
current_stream
()
self
.
event
.
record
(
stream
)
def
wait_current_stream
(
self
):
"""Waits for the event to complete on the current CUDA stream."""
stream
=
torch
.
cuda
.
current_stream
()
self
.
event
.
wait
(
stream
)
@
property
def
pre_process
(
self
):
"""Gets the preprocessing node."""
return
self
.
_pre_process
@
pre_process
.
setter
def
pre_process
(
self
,
value
):
"""Sets the preprocessing node."""
self
.
_pre_process
=
value
@
property
def
post_process
(
self
):
"""Gets the postprocessing node."""
return
self
.
_post_process
@
post_process
.
setter
def
post_process
(
self
,
value
):
"""Sets the postprocessing node."""
self
.
_post_process
=
value
def
get_layer
(
self
,
i
):
"""Gets the transformer layer at the specified index."""
assert
i
<
self
.
num_layers
()
return
self
.
_transformer_layers
[
i
]
def
num_layers
(
self
):
"""Gets the number of transformer layers."""
return
len
(
self
.
_transformer_layers
)
def
add_layer
(
self
,
layer
):
"""Adds a transformer layer to the schedule plan."""
self
.
_transformer_layers
.
append
(
layer
)
@
property
def
state
(
self
):
"""Gets the model chunk state."""
return
self
.
_model_chunk_state
def
schedule_layer_1f1b
(
f_layer
,
b_layer
,
f_input
=
None
,
b_grad
=
None
,
pre_forward
=
None
,
pre_backward
=
None
,
pre_backward_dw
=
None
,
f_context
=
None
,
b_context
=
None
,
):
"""Schedule one-forward-one-backward operations for a single layer.
This function interleaves forward and backward operations to maximize
parallelism and efficiency.
Args:
f_layer (TransformerLayerSchedulePlan): Forward layer (for current microbatch)
b_layer (TransformerLayerSchedulePlan): Backward layer (for previous microbatch)
f_input (Tensor): Input for forward computation
b_grad (Tensor): Gradient for backward computation
pre_forward (Callable): Callback to get forward input if not provided
pre_backward (Callable): Callback to get backward gradient if not provided
pre_backward_dw (Callable): Callback for weight gradient computation
f_context (VppContextManager or None): The VppContextManager for the forward pass.
b_context (VppContextManager or None): The VppContextManager for the backward pass
Returns:
Functions or values for next iteration's computation
"""
f_context
=
f_context
if
f_context
is
not
None
else
contextlib
.
nullcontext
()
b_context
=
b_context
if
b_context
is
not
None
else
contextlib
.
nullcontext
()
if
pre_forward
is
not
None
:
assert
f_input
is
None
# combine from last iter
f_input
=
pre_forward
()
del
pre_forward
if
pre_backward
is
not
None
:
# attn backward from last iter
assert
b_grad
is
None
b_grad
=
pre_backward
()
del
pre_backward
if
b_layer
is
not
None
:
with
b_context
:
b_grad
=
b_layer
.
combine
.
backward
(
b_grad
)
if
pre_backward_dw
is
not
None
:
pre_backward_dw
()
del
pre_backward_dw
if
f_layer
is
not
None
:
with
f_context
:
f_input
=
f_layer
.
attn
.
forward
(
f_input
)
if
f_layer
is
not
None
:
with
f_context
:
f_input
=
f_layer
.
dispatch
.
forward
(
f_input
)
if
b_layer
is
not
None
:
with
b_context
:
b_grad
=
b_layer
.
mlp
.
backward
(
b_grad
)
b_grad
=
b_layer
.
dispatch
.
backward
(
b_grad
)
b_layer
.
mlp
.
dw
()
if
f_layer
is
not
None
:
with
f_context
:
f_input
=
f_layer
.
mlp
.
forward
(
f_input
)
def
next_iter_pre_forward
():
if
f_layer
is
not
None
:
with
f_context
:
output
=
f_layer
.
combine
.
forward
(
f_input
)
return
output
def
next_iter_pre_backward
():
if
b_layer
is
not
None
:
with
b_context
:
grad
=
b_layer
.
attn
.
backward
(
b_grad
)
return
grad
def
next_iter_pre_backward_dw
():
if
b_layer
is
not
None
:
with
b_context
:
b_layer
.
attn
.
dw
()
if
f_layer
and
b_layer
:
return
next_iter_pre_forward
,
next_iter_pre_backward
,
next_iter_pre_backward_dw
else
:
return
next_iter_pre_forward
(),
next_iter_pre_backward
(),
next_iter_pre_backward_dw
()
def
schedule_chunk_1f1b
(
f_schedule_plan
,
b_schedule_plan
,
grad
=
None
,
f_context
=
None
,
b_context
=
None
,
pre_forward
=
None
,
pre_backward
=
None
,
post_forward
=
None
,
post_backward
=
None
,
):
"""Schedules one-forward-one-backward operations for a model chunk.
This function interleaves forward and backward operations across multiple layers
to maximize parallelism and efficiency.
Args:
f_schedule_plan: Forward schedule plan.
b_schedule_plan: Backward schedule plan.
grad: Gradient for backward computation.
f_context: Context for forward computation.
b_context: Context for backward computation.
pre_forward: Callback for preprocessing in forward pass.
pre_backward: Callback for preprocessing in backward pass.
post_forward: Callback for postprocessing in forward pass.
post_backward: Callback for postprocessing in backward pass.
Returns:
The output of the forward pass.
"""
f_context
=
f_context
if
f_context
is
not
None
else
contextlib
.
nullcontext
()
b_context
=
b_context
if
b_context
is
not
None
else
contextlib
.
nullcontext
()
if
f_schedule_plan
:
# pp output send/receive sync
if
pre_forward
is
not
None
:
with
f_context
:
# virtual pipeline parallel context
pre_forward
()
f_schedule_plan
.
record_current_stream
()
if
b_schedule_plan
:
b_schedule_plan
.
record_current_stream
()
f_input
=
None
def
layer_pre_forward
():
tmp
=
f_input
if
f_schedule_plan
is
not
None
:
tmp
=
f_schedule_plan
.
pre_process
.
forward
()
return
tmp
def
layer_pre_backward
():
tmp
=
grad
if
b_schedule_plan
is
not
None
:
assert
grad
is
not
None
if
b_schedule_plan
.
post_process
is
not
None
:
with
b_context
:
# virtual pipeline parallel context
tmp
=
b_schedule_plan
.
post_process
.
backward
(
grad
)
if
pre_backward
is
not
None
:
# pp grad send receive sync here, safe for now, maybe not safe in the future
with
torch
.
cuda
.
stream
(
get_com_stream
()):
b_schedule_plan
.
wait_current_stream
()
with
b_context
:
# virtual pipeline parallel context
pre_backward
()
b_schedule_plan
.
record_current_stream
()
return
tmp
def
layer_pre_backward_dw
():
pass
f_num_layers
=
f_schedule_plan
.
num_layers
()
if
f_schedule_plan
is
not
None
else
0
b_num_layers
=
b_schedule_plan
.
num_layers
()
if
b_schedule_plan
is
not
None
else
0
overlaped_layers
=
min
(
f_num_layers
,
b_num_layers
)
for
i
in
range
(
overlaped_layers
):
f_layer
=
f_schedule_plan
.
get_layer
(
i
)
b_layer
=
b_schedule_plan
.
get_layer
(
b_num_layers
-
1
-
i
)
torch
.
cuda
.
nvtx
.
range_push
(
f
"layer_
{
i
}
f-layer_
{
b_num_layers
-
1
-
i
}
b"
)
layer_pre_forward
,
layer_pre_backward
,
layer_pre_backward_dw
=
schedule_layer_1f1b
(
f_layer
,
b_layer
,
pre_forward
=
layer_pre_forward
,
pre_backward
=
layer_pre_backward
,
pre_backward_dw
=
layer_pre_backward_dw
,
f_context
=
f_context
,
b_context
=
b_context
,
)
torch
.
cuda
.
nvtx
.
range_pop
()
# tail forward
f_input
=
layer_pre_forward
()
del
layer_pre_forward
# tail backward
grad
=
layer_pre_backward
()
del
layer_pre_backward
with
b_context
:
for
i
in
range
(
overlaped_layers
,
b_num_layers
):
b_layer
=
b_schedule_plan
.
get_layer
(
b_num_layers
-
1
-
i
)
torch
.
cuda
.
nvtx
.
range_push
(
f
"layer_
{
b_num_layers
-
1
-
i
}
b"
)
tmp
,
grad
,
_
=
schedule_layer_1f1b
(
None
,
b_layer
,
b_grad
=
grad
)
torch
.
cuda
.
nvtx
.
range_pop
()
# if b_schedule_plan is not None:
# b_schedule_plan.pre_process.backward(grad)
# # tail forward
# f_input = layer_pre_forward()
# del layer_pre_forward
with
f_context
:
for
i
in
range
(
overlaped_layers
,
f_num_layers
):
f_layer
=
f_schedule_plan
.
get_layer
(
i
)
torch
.
cuda
.
nvtx
.
range_push
(
f
"layer_
{
i
}
f"
)
f_input
,
tmp
,
_
=
schedule_layer_1f1b
(
f_layer
,
None
,
f_input
=
f_input
)
torch
.
cuda
.
nvtx
.
range_pop
()
# if f_schedule_plan is not None and f_schedule_plan.post_process is not None:
# f_input = f_schedule_plan.post_process.forward(f_input)
# output pp send receive, overlapped with attn backward
if
f_schedule_plan
is
not
None
and
post_forward
is
not
None
:
with
f_context
:
f_schedule_plan
.
wait_current_stream
()
post_forward
(
f_input
)
# pp grad send / receive, overlapped with attn dw of cur micro-batch
# and forward attn of next micro-batch
if
b_schedule_plan
is
not
None
and
post_backward
is
not
None
:
with
b_context
:
b_schedule_plan
.
wait_current_stream
()
post_backward
(
grad
)
# The last wgrad of attention
layer_pre_backward_dw
()
del
layer_pre_backward_dw
with
f_context
:
if
f_schedule_plan
is
not
None
and
f_schedule_plan
.
post_process
is
not
None
:
f_input
=
f_schedule_plan
.
post_process
.
forward
(
f_input
)
with
b_context
:
if
b_schedule_plan
is
not
None
:
b_schedule_plan
.
pre_process
.
backward
(
grad
)
if
f_schedule_plan
:
f_schedule_plan
.
wait_current_stream
()
if
b_schedule_plan
:
b_schedule_plan
.
wait_current_stream
()
return
f_input
def
build_model_chunk_schedule_plan
(
model
,
input_ids
:
Tensor
,
position_ids
:
Tensor
,
attention_mask
:
Tensor
,
decoder_input
:
Tensor
=
None
,
labels
:
Tensor
=
None
,
inference_params
=
None
,
packed_seq_params
=
None
,
extra_block_kwargs
=
None
,
runtime_gather_output
:
Optional
[
bool
]
=
None
,
):
"""Builds a schedule plan for a model chunk.
This function creates a schedule plan for a model chunk, including
preprocessing, transformer layers, and postprocessing.
Args:
model: The model to build a schedule plan for.
input_ids: Input token IDs.
position_ids: Position IDs.
attention_mask: Attention mask.
decoder_input: Decoder input tensor.
labels: Labels for loss computation.
inference_params: Parameters for inference.
packed_seq_params: Parameters for packed sequences.
extra_block_kwargs: Additional keyword arguments for blocks.
runtime_gather_output: Whether to gather output at runtime.
Returns:
The model chunk schedule plan.
"""
comp_stream
=
get_comp_stream
()
com_stream
=
get_com_stream
()
model_chunk_schedule_plan
=
ModelChunkSchedulePlan
()
event
=
model_chunk_schedule_plan
.
event
state
=
model_chunk_schedule_plan
.
state
# save for later use
state
.
input_ids
=
input_ids
state
.
position_ids
=
position_ids
state
.
attention_mask
=
attention_mask
state
.
decoder_input
=
decoder_input
state
.
labels
=
labels
state
.
inference_params
=
inference_params
state
.
packed_seq_params
=
packed_seq_params
state
.
extra_block_kwargs
=
extra_block_kwargs
state
.
runtime_gather_output
=
runtime_gather_output
state
.
context
=
None
state
.
context_mask
=
None
state
.
attention_bias
=
None
# build preprocess
model_chunk_schedule_plan
.
pre_process
=
PreProcessNode
(
model
,
state
,
event
,
comp_stream
)
# build for layers
for
layer_idx
in
range
(
model
.
decoder
.
num_layers_per_pipeline_rank
):
layer
=
model
.
decoder
.
_get_layer
(
layer_idx
)
layer_plan
=
TransformerLayerSchedulePlan
(
layer
,
event
,
state
,
comp_stream
,
com_stream
)
model_chunk_schedule_plan
.
add_layer
(
layer_plan
)
# build post process
if
model
.
post_process
:
model_chunk_schedule_plan
.
post_process
=
PostProcessNode
(
model
,
state
,
event
,
comp_stream
)
return
model_chunk_schedule_plan
dcu_megatron/core/models/gpt/gpt_model.py
View file @
12b56c98
...
...
@@ -10,6 +10,7 @@ from torch import Tensor
from
megatron.core
import
InferenceParams
,
tensor_parallel
from
megatron.core.config_logger
import
has_config_logger_enabled
,
log_config_to_disk
from
megatron.core.packed_seq_params
import
PackedSeqParams
from
megatron.core.models.gpt
import
GPTModel
as
MegatronCoreGPTModel
from
dcu_megatron.core.tensor_parallel
import
FluxColumnParallelLinear
...
...
@@ -45,100 +46,143 @@ def gpt_model_init_wrapper(fn):
return
wrapper
def
gpt_model_forward
(
self
,
input_ids
:
Tensor
,
position_ids
:
Tensor
,
attention_mask
:
Tensor
,
decoder_input
:
Tensor
=
None
,
labels
:
Tensor
=
None
,
inference_params
:
InferenceParams
=
None
,
packed_seq_params
:
PackedSeqParams
=
None
,
extra_block_kwargs
:
dict
=
None
,
runtime_gather_output
:
Optional
[
bool
]
=
None
,
loss_mask
:
Optional
[
Tensor
]
=
None
,
)
->
Tensor
:
"""Forward function of the GPT Model This function passes the input tensors
through the embedding layer, and then the decoder and finally into the post
processing layer (optional).
It either returns the Loss values if labels are given or the final hidden units
Args:
runtime_gather_output (bool): Gather output at runtime. Default None means
`parallel_output` arg in the constructor will be used.
class
GPTModel
(
MegatronCoreGPTModel
):
"""
# If decoder_input is provided (not None), then input_ids and position_ids are ignored.
# Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input.
# Decoder embedding.
if
decoder_input
is
not
None
:
pass
elif
self
.
pre_process
:
decoder_input
=
self
.
embedding
(
input_ids
=
input_ids
,
position_ids
=
position_ids
)
else
:
# intermediate stage of pipeline
# decoder will get hidden_states from encoder.input_tensor
decoder_input
=
None
# Rotary positional embeddings (embedding is None for PP intermediate devices)
rotary_pos_emb
=
None
rotary_pos_cos
=
None
rotary_pos_sin
=
None
if
self
.
position_embedding_type
==
'rope'
and
not
self
.
config
.
multi_latent_attention
:
if
not
self
.
training
and
self
.
config
.
flash_decode
and
inference_params
:
# Flash decoding uses precomputed cos and sin for RoPE
rotary_pos_cos
,
rotary_pos_sin
=
self
.
rotary_pos_emb_cache
.
setdefault
(
inference_params
.
max_sequence_length
,
self
.
rotary_pos_emb
.
get_cos_sin
(
inference_params
.
max_sequence_length
),
)
else
:
rotary_seq_len
=
self
.
rotary_pos_emb
.
get_rotary_seq_len
(
inference_params
,
self
.
decoder
,
decoder_input
,
self
.
config
,
packed_seq_params
)
rotary_pos_emb
=
self
.
rotary_pos_emb
(
rotary_seq_len
,
packed_seq
=
packed_seq_params
is
not
None
and
packed_seq_params
.
qkv_format
==
'thd'
,
)
if
(
(
self
.
config
.
enable_cuda_graph
or
self
.
config
.
flash_decode
)
and
rotary_pos_cos
is
not
None
and
inference_params
patch megatron GPTModel
"""
def
get_transformer_callables_by_layer
(
self
,
layer_number
:
int
):
"""
Get the callables for the layer at the given transformer layer number.
"""
return
self
.
decoder
.
get_layer_callables
(
layer_number
)
def
build_schedule_plan
(
self
,
input_ids
:
Tensor
,
position_ids
:
Tensor
,
attention_mask
:
Tensor
,
decoder_input
:
Tensor
=
None
,
labels
:
Tensor
=
None
,
inference_params
:
InferenceParams
=
None
,
packed_seq_params
:
PackedSeqParams
=
None
,
extra_block_kwargs
:
dict
=
None
,
runtime_gather_output
:
Optional
[
bool
]
=
None
,
loss_mask
:
Optional
[
Tensor
]
=
None
,
):
sequence_len_offset
=
torch
.
tensor
(
[
inference_params
.
sequence_len_offset
]
*
inference_params
.
current_batch_size
,
dtype
=
torch
.
int32
,
device
=
rotary_pos_cos
.
device
,
# Co-locate this with the rotary tensors
)
else
:
sequence_len_offset
=
None
# Run decoder.
hidden_states
=
self
.
decoder
(
hidden_states
=
decoder_input
,
attention_mask
=
attention_mask
,
inference_params
=
inference_params
,
rotary_pos_emb
=
rotary_pos_emb
,
rotary_pos_cos
=
rotary_pos_cos
,
rotary_pos_sin
=
rotary_pos_sin
,
packed_seq_params
=
packed_seq_params
,
sequence_len_offset
=
sequence_len_offset
,
**
(
extra_block_kwargs
or
{}),
)
# logits and loss
output_weight
=
None
if
self
.
share_embeddings_and_output_weights
:
output_weight
=
self
.
shared_embedding_or_output_weight
()
if
self
.
mtp_process
:
hidden_states
=
self
.
mtp
(
input_ids
=
input_ids
,
position_ids
=
position_ids
,
"""Builds a computation schedule plan for the model.
This function creates a schedule plan for a model chunk, including
preprocessing, transformer layers, and postprocessing.
The schedule plan is used to optimize computation and memory usage
in distributed environments.
Args:
input_ids (Tensor): Input token IDs.
position_ids (Tensor): Position IDs.
attention_mask (Tensor): Attention mask.
decoder_input (Tensor, optional): Decoder input tensor. Defaults to None.
labels (Tensor, optional): Labels for loss computation. Defaults to None.
inference_params (InferenceParams, optional):
Parameters for inference. Defaults to None.
packed_seq_params (PackedSeqParams, optional):
Parameters for packed sequences. Defaults to None.
extra_block_kwargs (dict, optional):
Additional keyword arguments for blocks. Defaults to None.
runtime_gather_output (Optional[bool], optional):
Whether to gather output at runtime. Defaults to None.
loss_mask (Optional[Tensor], optional): Loss mask. Defaults to None.
Returns:
ModelChunkSchedulePlan: The model chunk schedule plan.
"""
from
.fine_grained_schedule
import
build_model_chunk_schedule_plan
return
build_model_chunk_schedule_plan
(
self
,
input_ids
,
position_ids
,
attention_mask
,
decoder_input
=
decoder_input
,
labels
=
labels
,
loss_mask
=
loss_mask
,
hidden_states
=
hidden_states
,
inference_params
=
inference_params
,
packed_seq_params
=
packed_seq_params
,
extra_block_kwargs
=
extra_block_kwargs
,
runtime_gather_output
=
runtime_gather_output
,
)
def
forward
(
self
,
input_ids
:
Tensor
,
position_ids
:
Tensor
,
attention_mask
:
Tensor
,
decoder_input
:
Tensor
=
None
,
labels
:
Tensor
=
None
,
inference_params
:
InferenceParams
=
None
,
packed_seq_params
:
PackedSeqParams
=
None
,
extra_block_kwargs
:
dict
=
None
,
runtime_gather_output
:
Optional
[
bool
]
=
None
,
loss_mask
:
Optional
[
Tensor
]
=
None
,
)
->
Tensor
:
"""Forward function of the GPT Model This function passes the input tensors
through the embedding layer, and then the decoder and finally into the post
processing layer (optional).
It either returns the Loss values if labels are given or the final hidden units
Args:
runtime_gather_output (bool): Gather output at runtime. Default None means
`parallel_output` arg in the constructor will be used.
"""
# If decoder_input is provided (not None), then input_ids and position_ids are ignored.
# Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input.
# Decoder embedding.
if
decoder_input
is
not
None
:
pass
elif
self
.
pre_process
:
decoder_input
=
self
.
embedding
(
input_ids
=
input_ids
,
position_ids
=
position_ids
)
else
:
# intermediate stage of pipeline
# decoder will get hidden_states from encoder.input_tensor
decoder_input
=
None
# Rotary positional embeddings (embedding is None for PP intermediate devices)
rotary_pos_emb
=
None
rotary_pos_cos
=
None
rotary_pos_sin
=
None
if
self
.
position_embedding_type
==
'rope'
and
not
self
.
config
.
multi_latent_attention
:
if
not
self
.
training
and
self
.
config
.
flash_decode
and
inference_params
:
# Flash decoding uses precomputed cos and sin for RoPE
rotary_pos_cos
,
rotary_pos_sin
=
self
.
rotary_pos_emb_cache
.
setdefault
(
inference_params
.
max_sequence_length
,
self
.
rotary_pos_emb
.
get_cos_sin
(
inference_params
.
max_sequence_length
),
)
else
:
rotary_seq_len
=
self
.
rotary_pos_emb
.
get_rotary_seq_len
(
inference_params
,
self
.
decoder
,
decoder_input
,
self
.
config
,
packed_seq_params
)
rotary_pos_emb
=
self
.
rotary_pos_emb
(
rotary_seq_len
,
packed_seq
=
packed_seq_params
is
not
None
and
packed_seq_params
.
qkv_format
==
'thd'
,
)
if
(
(
self
.
config
.
enable_cuda_graph
or
self
.
config
.
flash_decode
)
and
rotary_pos_cos
is
not
None
and
inference_params
):
sequence_len_offset
=
torch
.
tensor
(
[
inference_params
.
sequence_len_offset
]
*
inference_params
.
current_batch_size
,
dtype
=
torch
.
int32
,
device
=
rotary_pos_cos
.
device
,
# Co-locate this with the rotary tensors
)
else
:
sequence_len_offset
=
None
# Run decoder.
hidden_states
=
self
.
decoder
(
hidden_states
=
decoder_input
,
attention_mask
=
attention_mask
,
inference_params
=
inference_params
,
rotary_pos_emb
=
rotary_pos_emb
,
...
...
@@ -146,44 +190,66 @@ def gpt_model_forward(
rotary_pos_sin
=
rotary_pos_sin
,
packed_seq_params
=
packed_seq_params
,
sequence_len_offset
=
sequence_len_offset
,
embedding
=
self
.
embedding
,
output_layer
=
self
.
output_layer
,
output_weight
=
output_weight
,
runtime_gather_output
=
runtime_gather_output
,
compute_language_model_loss
=
self
.
compute_language_model_loss
,
**
(
extra_block_kwargs
or
{}),
)
if
(
self
.
mtp_process
is
not
None
and
getattr
(
self
.
decoder
,
"main_final_layernorm"
,
None
)
is
not
None
):
# move block main model final norms here
hidden_states
=
self
.
decoder
.
main_final_layernorm
(
hidden_states
)
if
not
self
.
post_process
:
return
hidden_states
logits
,
_
=
self
.
output_layer
(
hidden_states
,
weight
=
output_weight
,
runtime_gather_output
=
runtime_gather_output
)
if
has_config_logger_enabled
(
self
.
config
):
payload
=
OrderedDict
(
{
'input_ids'
:
input_ids
,
'position_ids'
:
position_ids
,
'attention_mask'
:
attention_mask
,
'decoder_input'
:
decoder_input
,
'logits'
:
logits
,
}
# logits and loss
output_weight
=
None
if
self
.
share_embeddings_and_output_weights
:
output_weight
=
self
.
shared_embedding_or_output_weight
()
if
self
.
mtp_process
:
hidden_states
=
self
.
mtp
(
input_ids
=
input_ids
,
position_ids
=
position_ids
,
labels
=
labels
,
loss_mask
=
loss_mask
,
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
inference_params
=
inference_params
,
rotary_pos_emb
=
rotary_pos_emb
,
rotary_pos_cos
=
rotary_pos_cos
,
rotary_pos_sin
=
rotary_pos_sin
,
packed_seq_params
=
packed_seq_params
,
sequence_len_offset
=
sequence_len_offset
,
embedding
=
self
.
embedding
,
output_layer
=
self
.
output_layer
,
output_weight
=
output_weight
,
runtime_gather_output
=
runtime_gather_output
,
compute_language_model_loss
=
self
.
compute_language_model_loss
,
**
(
extra_block_kwargs
or
{}),
)
if
(
self
.
mtp_process
is
not
None
and
getattr
(
self
.
decoder
,
"main_final_layernorm"
,
None
)
is
not
None
):
# move block main model final norms here
hidden_states
=
self
.
decoder
.
main_final_layernorm
(
hidden_states
)
if
not
self
.
post_process
:
return
hidden_states
logits
,
_
=
self
.
output_layer
(
hidden_states
,
weight
=
output_weight
,
runtime_gather_output
=
runtime_gather_output
)
log_config_to_disk
(
self
.
config
,
payload
,
prefix
=
'input_and_logits'
)
if
labels
is
None
:
# [s b h] => [b s h]
return
logits
.
transpose
(
0
,
1
).
contiguous
()
if
has_config_logger_enabled
(
self
.
config
):
payload
=
OrderedDict
(
{
'input_ids'
:
input_ids
,
'position_ids'
:
position_ids
,
'attention_mask'
:
attention_mask
,
'decoder_input'
:
decoder_input
,
'logits'
:
logits
,
}
)
log_config_to_disk
(
self
.
config
,
payload
,
prefix
=
'input_and_logits'
)
if
labels
is
None
:
# [s b h] => [b s h]
return
logits
.
transpose
(
0
,
1
).
contiguous
()
loss
=
self
.
compute_language_model_loss
(
labels
,
logits
)
loss
=
self
.
compute_language_model_loss
(
labels
,
logits
)
return
loss
return
loss
dcu_megatron/core/pipeline_parallel/combined_1f1b.py
0 → 100644
View file @
12b56c98
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
import
contextlib
from
abc
import
ABC
,
abstractmethod
from
contextlib
import
contextmanager
from
typing
import
List
,
Union
import
torch
from
torch
import
Tensor
from
torch.autograd.variable
import
Variable
from
megatron.core
import
parallel_state
from
megatron.core.distributed
import
DistributedDataParallel
from
megatron.core.models.gpt.gpt_model
import
GPTModel
from
megatron.core.transformer.module
import
Float16Module
from
megatron.core.transformer.moe.router
import
MoEAuxLossAutoScaler
from
megatron.core.utils
import
get_attr_wrapped_model
,
make_viewless_tensor
# Types
Shape
=
Union
[
List
[
int
],
torch
.
Size
]
def
make_viewless
(
e
):
"""Make_viewless util func"""
e
=
make_viewless_tensor
(
inp
=
e
,
requires_grad
=
e
.
requires_grad
,
keep_graph
=
True
)
return
e
@
contextmanager
def
stream_acquire_context
(
stream
,
event
):
"""Stream acquire context"""
event
.
wait
(
stream
)
try
:
yield
finally
:
event
.
record
(
stream
)
class
FakeScheduleNode
:
"""A placeholder node in the computation graph that simply passes through inputs and outputs.
This class is used as a no-op node in the scheduling system when a real computation node
is not needed but the interface must be maintained. It simply returns its inputs unchanged
in both forward and backward passes.
"""
def
forward
(
self
,
inputs
):
"""Passes through inputs unchanged in the forward pass."""
return
inputs
def
backward
(
self
,
outgrads
):
"""Passes through gradients unchanged in the backward pass."""
return
outgrads
class
ScheduleNode
:
"""Base node for fine-grained scheduling.
This class represents a computational node in the pipeline schedule.
It handles the execution of forward and backward operations on a stream.
"""
def
__init__
(
self
,
forward_func
,
stream
,
event
,
backward_func
=
None
,
memory_strategy
=
None
,
name
=
"schedule_node"
,
):
"""Initialize a schedule node.
Args:
forward_func (callable): Function to execute during forward pass
stream (torch.cuda.Stream): CUDA stream for computation
event (torch.cuda.Event): Event for synchronization
backward_func (callable, optional): Function for backward pass
memory_strategy (MemoryManagementStrategy, optional): Strategy for memory management
name (str): Name of the node for debugging
"""
self
.
name
=
name
self
.
forward_func
=
forward_func
self
.
backward_func
=
backward_func
if
backward_func
else
self
.
default_backward_func
self
.
stream
=
stream
self
.
event
=
event
self
.
memory_strategy
=
memory_strategy
or
NoOpMemoryStrategy
()
self
.
inputs
=
None
self
.
outputs
=
None
def
default_backward_func
(
self
,
outputs
,
output_grad
):
"""Default backward function"""
Variable
.
_execution_engine
.
run_backward
(
tensors
=
outputs
,
grad_tensors
=
output_grad
,
keep_graph
=
False
,
create_graph
=
False
,
inputs
=
tuple
(),
allow_unreachable
=
True
,
accumulate_grad
=
True
,
)
return
output_grad
def
forward
(
self
,
inputs
=
()):
"""Schedule node forward"""
if
not
isinstance
(
inputs
,
tuple
):
inputs
=
(
inputs
,)
return
self
.
_forward
(
*
inputs
)
def
_forward
(
self
,
*
inputs
):
with
stream_acquire_context
(
self
.
stream
,
self
.
event
):
torch
.
cuda
.
nvtx
.
range_push
(
f
"
{
self
.
name
}
forward"
)
with
torch
.
cuda
.
stream
(
self
.
stream
):
self
.
inputs
=
[
make_viewless
(
e
).
detach
()
if
e
is
not
None
else
None
for
e
in
inputs
]
for
i
,
input
in
enumerate
(
self
.
inputs
):
if
input
is
not
None
:
input
.
requires_grad
=
inputs
[
i
].
requires_grad
data
=
tuple
(
self
.
inputs
)
data
=
self
.
forward_func
(
*
data
)
if
not
isinstance
(
data
,
tuple
):
data
=
make_viewless
(
data
)
else
:
data
=
tuple
([
make_viewless
(
e
)
if
isinstance
(
e
,
Tensor
)
else
e
for
e
in
data
])
self
.
output
=
data
torch
.
cuda
.
nvtx
.
range_pop
()
# Handle inputs using the memory strategy
self
.
memory_strategy
.
handle_inputs
(
inputs
,
self
.
stream
)
return
self
.
output
def
get_output
(
self
):
"""Get the forward output"""
return
self
.
output
def
backward
(
self
,
output_grad
):
"""Schedule node backward"""
if
not
isinstance
(
output_grad
,
tuple
):
output_grad
=
(
output_grad
,)
return
self
.
_backward
(
*
output_grad
)
def
_backward
(
self
,
*
output_grad
):
with
stream_acquire_context
(
self
.
stream
,
self
.
event
):
torch
.
cuda
.
nvtx
.
range_push
(
f
"
{
self
.
name
}
backward"
)
with
torch
.
cuda
.
stream
(
self
.
stream
):
outputs
=
self
.
output
if
not
isinstance
(
outputs
,
tuple
):
outputs
=
(
outputs
,)
assert
len
(
outputs
)
==
len
(
output_grad
),
(
f
"
{
len
(
outputs
)
}
of
{
type
(
outputs
[
0
])
}
is not equal to "
f
"
{
len
(
output_grad
)
}
of
{
type
(
output_grad
[
0
])
}
"
)
output_grad
=
self
.
backward_func
(
outputs
,
output_grad
)
torch
.
cuda
.
nvtx
.
range_pop
()
# output_grad maybe from another stream
for
g
in
output_grad
:
g
.
record_stream
(
self
.
stream
)
return
self
.
get_grad
()
def
get_grad
(
self
):
"""Get the grad of inputs"""
grad
=
tuple
([
e
.
grad
if
e
is
not
None
else
None
for
e
in
self
.
inputs
])
# clear state
self
.
inputs
=
None
self
.
output
=
None
# multiple in, multiple out
if
len
(
grad
)
==
1
:
grad
=
grad
[
0
]
return
grad
class
AbstractSchedulePlan
(
ABC
):
"""To use combined 1f1b, model must implement build_schedule_plan while take the same
signature as model forward but return an instance of AbstractSchedulePlan"""
@
classmethod
@
abstractmethod
def
forward_backward
(
cls
,
f_schedule_plan
,
b_schedule_plan
,
grad
=
None
,
f_context
=
None
,
b_context
=
None
,
pre_forward
=
None
,
pre_backward
=
None
,
post_forward
=
None
,
post_backward
=
None
,
):
"""forward_backward is the protocol between our schedule logic and model"""
...
def
schedule_chunk_1f1b
(
f_schedule_plan
,
b_schedule_plan
,
grad
=
None
,
f_context
=
None
,
b_context
=
None
,
pre_forward
=
None
,
pre_backward
=
None
,
post_forward
=
None
,
post_backward
=
None
,
):
"""Model level 1f1b fine-grained schedule
This function schedules the forward and backward passes for a chunk of the model.
It takes in the forward schedule plan, backward schedule plan, gradient, and optional
context managers for the forward and backward passes.
Args:
f_schedule_plan (subclass of AbstractSchedulePlan): The forward schedule plan
b_schedule_plan (subclass of AbstractSchedulePlan): The backward schedule plan
grad (Tensor or None): The gradient of the loss function
f_context (VppContextManager or None): The VppContextManager for the forward pass
b_context (VppContextManager or None): The VppContextManager for the backward pass
pre_forward (callable or None): The function to call before the forward pass
pre_backward (callable or None): The function to call before the backward pass
post_forward (callable or None): The function to call after the forward pass
post_backward (callable or None): The function to call after the backward pass
Returns:
The output of the forward pass
"""
# Calls fine_grained_schedule.py::ModelChunkSchedulePlan.forward_backward(),
# which calls fine_grained_schedule.py::schedule_chunk_1f1b()
return
type
(
f_schedule_plan
or
b_schedule_plan
).
forward_backward
(
f_schedule_plan
,
b_schedule_plan
,
grad
=
grad
,
f_context
=
f_context
,
b_context
=
b_context
,
pre_forward
=
pre_forward
,
pre_backward
=
pre_backward
,
post_forward
=
post_forward
,
post_backward
=
post_backward
,
)
_COMP_STREAM
=
None
_COM_STREAM
=
None
def
set_streams
(
comp_stream
=
None
,
com_stream
=
None
):
"""Set the streams for communication and computation"""
global
_COMP_STREAM
global
_COM_STREAM
if
_COMP_STREAM
is
not
None
:
return
if
comp_stream
is
None
:
comp_stream
=
torch
.
cuda
.
current_stream
()
if
com_stream
is
None
:
com_stream
=
torch
.
cuda
.
Stream
(
device
=
"cuda"
)
assert
_COMP_STREAM
is
None
assert
_COM_STREAM
is
None
_COMP_STREAM
=
comp_stream
_COM_STREAM
=
com_stream
def
get_comp_stream
():
"""Get the stream for computation"""
global
_COMP_STREAM
return
_COMP_STREAM
def
get_com_stream
():
"""Get the stream for communication"""
global
_COM_STREAM
return
_COM_STREAM
class
VppContextManager
:
"""A reusable context manager for switch vpp stage"""
def
__init__
(
self
,
vpp_rank
):
self
.
vpp_rank
=
vpp_rank
def
__enter__
(
self
):
self
.
origin_vpp_rank
=
parallel_state
.
get_virtual_pipeline_model_parallel_rank
()
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
self
.
vpp_rank
)
return
self
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
self
.
origin_vpp_rank
)
def
forward_backward_step
(
forward_step_func
,
data_iterator
,
f_model
,
num_microbatches
,
input_tensor
,
forward_data_store
,
b_model
,
b_input_tensor
,
b_output_tensor
,
b_output_tensor_grad
,
config
,
f_context
=
None
,
b_context
=
None
,
pre_forward
=
None
,
pre_backward
=
None
,
post_forward
=
None
,
post_backward
=
None
,
collect_non_loss_data
=
False
,
checkpoint_activations_microbatch
=
None
,
is_first_microbatch
=
False
,
current_microbatch
=
None
,
encoder_decoder_xattn
=
False
,
):
"""Merged forward and backward step for combined_1f1b.
Args:
Need to accept the argument of both forward_step() and backward_step().
forward_step_func (callable): is wrapped by wrap_forward_func() which is now returning
a forward schedule plan which is an input of schedule_chunk_1f1b function.
f_context (VppContextManager or nullcontext): The context manager for setting vpp ranks.
b_context (VppContextManager or nullcontext): The context manager for setting vpp ranks.
Only exists in 1f1b steady state with p2p overlap.
pre_forward (callable): The function to call before the forward_step.
pre_backward (callable): The function to call before the backward_step.
post_forward (callable): The function to call after the forward_step.
post_backward (callable): The function to call after the backward_step.
Returns:
forward_output_tensor (Tensor or list[Tensor]): The output object(s) from the forward step.
forward_num_tokens (Tensor): The number of tokens.
backward_input_tensor_grad (Tensor): The grad of the input tensor.
Descriptions:
This method merges the forward_step() and backward_step() methods in the schedules.py file.
Assuming that:
def forward_step():
# forward_preprocess()
# forward_compute()
# forward_postprocess()
def backward_step():
# backward_preprocess()
# backward_compute()
# backward_postprocess()
Then the forward_backward_step() method will be:
def forward_backward_step():
# forward_preprocess() // the same as the forward_step()
# GENERATE f_schedule_plan // schedule happens in schedule_chunk_1f1b()
# backward_preprocess() // the same as the backward_step()
# COMBINED_FORWARD_BACKWARD_COMPUTE() // by calling schedule_chunk_1f1b()
# forward_postprocess() // the same as the forward_step()
# backward_postprocess() // the same as the backward_step()
"""
assert
(
checkpoint_activations_microbatch
is
None
),
"checkpoint_activations_microbatch is not supported for combined_1f1b"
if
config
.
combined_1f1b_recipe
!=
"ep_a2a"
:
raise
NotImplementedError
(
f
"combined_1f1b_recipe
{
config
.
combined_1f1b_recipe
}
not supported yet"
)
from
.schedules
import
set_current_microbatch
if
f_model
is
not
None
and
config
.
timers
is
not
None
:
config
.
timers
(
'forward-compute'
,
log_level
=
2
).
start
()
if
config
.
enable_autocast
:
context_manager
=
torch
.
autocast
(
"cuda"
,
dtype
=
config
.
autocast_dtype
)
else
:
context_manager
=
contextlib
.
nullcontext
()
# forward preprocess
unwrap_output_tensor
=
False
f_schedule_plan
=
None
if
f_model
is
not
None
:
with
f_context
:
if
is_first_microbatch
and
hasattr
(
f_model
,
'set_is_first_microbatch'
):
f_model
.
set_is_first_microbatch
()
if
current_microbatch
is
not
None
:
set_current_microbatch
(
f_model
,
current_microbatch
)
if
not
isinstance
(
input_tensor
,
list
):
input_tensor
=
[
input_tensor
]
unwrap_output_tensor
=
True
set_input_tensor
=
get_attr_wrapped_model
(
f_model
,
"set_input_tensor"
)
set_input_tensor
(
input_tensor
)
with
context_manager
:
# autocast context
f_schedule_plan
,
loss_func
=
forward_step_func
(
data_iterator
,
f_model
)
assert
isinstance
(
f_schedule_plan
,
AbstractSchedulePlan
),
"first output of forward_step_func must be one instance of AbstractSchedulePlan"
# backward preprocess
unwrap_input_tensor_grad
=
False
b_schedule_plan
=
None
if
b_model
is
not
None
:
# Retain the grad on the input_tensor.
if
not
isinstance
(
b_input_tensor
,
list
):
b_input_tensor
=
[
b_input_tensor
]
unwrap_input_tensor_grad
=
True
for
x
in
b_input_tensor
:
if
x
is
not
None
:
x
.
retain_grad
()
if
not
isinstance
(
b_output_tensor
,
list
):
b_output_tensor
=
[
b_output_tensor
]
if
not
isinstance
(
b_output_tensor_grad
,
list
):
b_output_tensor_grad
=
[
b_output_tensor_grad
]
# Backward pass for loss function
b_schedule_plan
=
b_output_tensor
[
0
].
schedule_plan
b_output_tensor
[
0
].
schedule_plan
=
None
if
b_output_tensor_grad
[
0
]
is
None
and
config
.
grad_scale_func
is
not
None
:
# backward schedule plan
loss_node
=
b_output_tensor
[
0
].
loss_func
b_output_tensor
[
0
].
loss_func
=
None
b_output_tensor
[
0
]
=
config
.
grad_scale_func
(
b_output_tensor
[
0
])
torch
.
autograd
.
backward
(
b_output_tensor
[
0
],
grad_tensors
=
b_output_tensor_grad
[
0
])
b_output_tensor_grad
[
0
]
=
loss_node
.
get_grad
()
grad
=
b_output_tensor_grad
[
0
]
if
b_model
else
None
with
context_manager
:
# autocast context
# schedule forward and backward
output_tensor
=
schedule_chunk_1f1b
(
f_schedule_plan
,
b_schedule_plan
,
grad
,
f_context
=
f_context
,
b_context
=
b_context
,
pre_forward
=
pre_forward
,
pre_backward
=
pre_backward
,
post_forward
=
post_forward
,
post_backward
=
post_backward
,
)
# forward post process
num_tokens
=
None
if
f_model
is
not
None
:
with
f_context
:
num_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
int
)
if
parallel_state
.
is_pipeline_last_stage
():
if
not
collect_non_loss_data
:
loss_node
=
ScheduleNode
(
loss_func
,
torch
.
cuda
.
current_stream
(),
f_schedule_plan
.
event
,
name
=
"loss_func"
,
)
loss_func
=
loss_node
.
forward
outputs
=
loss_func
(
output_tensor
)
if
len
(
outputs
)
==
3
:
output_tensor
,
num_tokens
,
loss_reduced
=
outputs
if
not
config
.
calculate_per_token_loss
:
output_tensor
/=
num_tokens
output_tensor
/=
num_microbatches
else
:
# preserve legacy loss averaging behavior
# (ie, over the number of microbatches)
assert
len
(
outputs
)
==
2
output_tensor
,
loss_reduced
=
outputs
output_tensor
=
output_tensor
/
num_microbatches
forward_data_store
.
append
(
loss_reduced
)
# attach loss_func on output_tensor
output_tensor
.
loss_func
=
loss_node
else
:
data
=
loss_func
(
output_tensor
,
non_loss_data
=
True
)
forward_data_store
.
append
(
data
)
# attach schedule plan on output tensor
output_tensor
.
schedule_plan
=
f_schedule_plan
if
config
.
timers
is
not
None
:
config
.
timers
(
'forward-compute'
).
stop
()
# Set the loss scale for the auxiliary loss of the MoE layer.
# Since we use a trick to do backward on the auxiliary loss, we need to set the scale
# explicitly.
if
hasattr
(
config
,
'num_moe_experts'
)
and
config
.
num_moe_experts
is
not
None
:
# Calculate the loss scale based on the grad_scale_func if available,
# else default to 1.
loss_scale
=
(
config
.
grad_scale_func
(
torch
.
ones
(
1
,
device
=
output_tensor
.
device
))
if
config
.
grad_scale_func
is
not
None
else
torch
.
tensor
(
1.0
)
)
# Set the loss scale
MoEAuxLossAutoScaler
.
set_loss_scale
(
loss_scale
/
num_microbatches
)
if
not
unwrap_output_tensor
:
output_tensor
,
num_tokens
=
[
output_tensor
],
num_tokens
# backward post process
input_tensor_grad
=
None
if
b_model
is
not
None
:
input_tensor_grad
=
[
None
]
if
b_input_tensor
is
not
None
:
input_tensor_grad
=
[]
for
x
in
b_input_tensor
:
if
x
is
None
:
input_tensor_grad
.
append
(
None
)
else
:
input_tensor_grad
.
append
(
x
.
grad
)
if
unwrap_input_tensor_grad
:
input_tensor_grad
=
input_tensor_grad
[
0
]
return
output_tensor
,
num_tokens
,
input_tensor_grad
def
get_default_cls_for_unwrap
():
"""Returns the default classes to unwrap from a model.
This function provides a tuple of classes that should be unwrapped from a model
to access the underlying GPTModel instance. It includes DistributedDataParallel
and Float16Module by default, and also attempts to include LegacyFloat16Module
if available for backward compatibility.
Returns:
tuple: A tuple of classes to unwrap from a model.
"""
cls
=
(
DistributedDataParallel
,
Float16Module
)
try
:
# legacy should not be used in core, but for backward compatibility, we support it here
from
megatron.legacy.model
import
Float16Module
as
LegacyFloat16Module
cls
=
cls
+
(
LegacyFloat16Module
,)
except
:
pass
return
cls
def
unwrap_model
(
model
,
module_instances
=
get_default_cls_for_unwrap
()):
"""Unwrap_model DistributedDataParallel and Float16Module wrapped model
to return GPTModel instance
"""
return_list
=
True
if
not
isinstance
(
model
,
list
):
model
=
[
model
]
return_list
=
False
unwrapped_model
=
[]
for
model_module
in
model
:
while
isinstance
(
model_module
,
module_instances
):
model_module
=
model_module
.
module
assert
isinstance
(
model_module
,
GPTModel
),
"The final unwrapped model must be a GPTModel instance"
unwrapped_model
.
append
(
model_module
)
if
not
return_list
:
return
unwrapped_model
[
0
]
return
unwrapped_model
def
wrap_forward_func
(
forward_step_func
):
"""Wrap the input to forward_step_func.
The wrapped function will return forward_schedule_plan and the loss_function.
"""
def
wrapped_func
(
data_iterator
,
model
):
# Model is unwrapped to get GPTModel instance.
# GPTModel.build_schedule_plan(model_forward_inputs) is called in the forward_step.
# The return value becomes (forward_schedule_plan, loss_function),
# which is used to be (forward_output_tensor, loss_function).
return
forward_step_func
(
data_iterator
,
unwrap_model
(
model
).
build_schedule_plan
)
return
wrapped_func
class
MemoryManagementStrategy
:
"""Base class for memory management strategies.
Different memory management strategies can be implemented by subclassing this class.
These strategies control how tensors are handled in memory during the computation.
"""
def
handle_inputs
(
self
,
inputs
,
stream
):
"""Process input tensors after computation.
Args:
inputs (tuple): Input tensors that have been used
stream (torch.cuda.Stream): Current CUDA stream
"""
pass
def
handle_outputs
(
self
,
outputs
,
stream
):
"""Process output tensors after computation.
Args:
outputs (tuple): Output tensors produced by the computation
stream (torch.cuda.Stream): Current CUDA stream
"""
pass
class
NoOpMemoryStrategy
(
MemoryManagementStrategy
):
"""Strategy that performs no memory management operations.
This is the default strategy - it doesn't free any memory.
"""
pass
class
FreeInputsMemoryStrategy
(
MemoryManagementStrategy
):
"""Strategy that immediately frees input tensors after they are used.
This strategy is useful for nodes where inputs are no longer needed
after computation, helping to reduce memory usage.
"""
def
handle_inputs
(
self
,
inputs
,
stream
):
"""Free input tensors by resizing their storage to zero.
Args:
inputs (tuple): Input tensors to be freed
stream (torch.cuda.Stream): Current CUDA stream
"""
for
input
in
inputs
:
if
input
is
not
None
:
input
.
record_stream
(
stream
)
input
.
untyped_storage
().
resize_
(
0
)
dcu_megatron/core/transformer/moe/token_dispatcher.py
0 → 100644
View file @
12b56c98
from
megatron.core.transformer.moe.token_dispatcher
import
_DeepepManager
as
MegatronCoreDeepepManager
class
MoEAlltoAllTokenDispatcher
(
MoETokenDispatcher
):
def
token_permutation
(
self
,
hidden_states
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
,
routing_map
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Dispatch tokens to local experts using AlltoAll communication.
This method performs the following steps:
1. Preprocess the routing map to get metadata for communication and permutation.
2. Permute input tokens for AlltoAll communication.
3. Perform expert parallel AlltoAll communication.
4. Sort tokens by local expert (if multiple local experts exist).
Args:
hidden_states (torch.Tensor): Input token embeddings.
probs (torch.Tensor): The probabilities of token to experts assignment.
routing_map (torch.Tensor): The mapping of token to experts assignment.
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- Permuted token embeddings for local experts.
- Number of tokens per expert.
- Permuted probs of each token produced by the router.
"""
# Preprocess: Get the metadata for communication, permutation and computation operations.
self
.
hidden_shape
=
hidden_states
.
shape
self
.
probs
=
probs
self
.
routing_map
=
routing_map
assert
probs
.
dim
()
==
2
,
"Expected 2D tensor for probs"
assert
routing_map
.
dim
()
==
2
,
"Expected 2D tensor for token2expert mask"
assert
routing_map
.
dtype
==
torch
.
bool
,
"Expected bool tensor for mask"
hidden_states
=
hidden_states
.
view
(
-
1
,
self
.
hidden_shape
[
-
1
])
tokens_per_expert
=
self
.
preprocess
(
self
.
routing_map
)
if
self
.
shared_experts
is
not
None
:
self
.
shared_experts
.
pre_forward_comm
(
hidden_states
.
view
(
self
.
hidden_shape
))
# Permutation 1: input to AlltoAll input
tokens_per_expert
=
self
.
_maybe_dtoh_and_synchronize
(
"before_permutation_1"
,
tokens_per_expert
)
self
.
hidden_shape_before_permute
=
hidden_states
.
shape
(
permutated_local_input_tokens
,
permuted_probs
,
self
.
reversed_local_input_permutation_mapping
,
)
=
permute
(
hidden_states
,
routing_map
,
probs
=
probs
,
num_out_tokens
=
self
.
num_out_tokens
,
fused
=
self
.
config
.
moe_permute_fusion
,
drop_and_pad
=
self
.
drop_and_pad
,
)
# Perform expert parallel AlltoAll communication
tokens_per_expert
=
self
.
_maybe_dtoh_and_synchronize
(
"before_ep_alltoall"
,
tokens_per_expert
)
global_input_tokens
=
all_to_all
(
self
.
ep_group
,
permutated_local_input_tokens
,
self
.
output_splits
,
self
.
input_splits
)
global_probs
=
all_to_all
(
self
.
ep_group
,
permuted_probs
,
self
.
output_splits
,
self
.
input_splits
)
if
self
.
shared_experts
is
not
None
:
self
.
shared_experts
.
linear_fc1_forward_and_act
(
global_input_tokens
)
if
self
.
tp_size
>
1
:
if
self
.
output_splits_tp
is
None
:
output_split_sizes
=
None
else
:
output_split_sizes
=
self
.
output_splits_tp
.
tolist
()
global_input_tokens
=
gather_from_sequence_parallel_region
(
global_input_tokens
,
group
=
self
.
tp_group
,
output_split_sizes
=
output_split_sizes
)
global_probs
=
gather_from_sequence_parallel_region
(
global_probs
,
group
=
self
.
tp_group
,
output_split_sizes
=
output_split_sizes
)
# Permutation 2: Sort tokens by local expert.
tokens_per_expert
=
self
.
_maybe_dtoh_and_synchronize
(
"before_permutation_2"
,
tokens_per_expert
)
if
self
.
num_local_experts
>
1
:
if
self
.
drop_and_pad
:
global_input_tokens
=
(
global_input_tokens
.
view
(
self
.
tp_size
*
self
.
ep_size
,
self
.
num_local_experts
,
self
.
capacity
,
*
global_input_tokens
.
size
()[
1
:],
)
.
transpose
(
0
,
1
)
.
contiguous
()
.
flatten
(
start_dim
=
0
,
end_dim
=
2
)
)
global_probs
=
(
global_probs
.
view
(
self
.
tp_size
*
self
.
ep_size
,
self
.
num_local_experts
,
self
.
capacity
,
*
global_probs
.
size
()[
1
:],
)
.
transpose
(
0
,
1
)
.
contiguous
()
.
flatten
(
start_dim
=
0
,
end_dim
=
2
)
)
else
:
global_input_tokens
,
global_probs
=
sort_chunks_by_idxs
(
global_input_tokens
,
self
.
num_global_tokens_per_local_expert
.
ravel
(),
self
.
sort_input_by_local_experts
,
probs
=
global_probs
,
fused
=
self
.
config
.
moe_permute_fusion
,
)
tokens_per_expert
=
self
.
_maybe_dtoh_and_synchronize
(
"before_finish"
,
tokens_per_expert
)
return
global_input_tokens
,
tokens_per_expert
,
global_probs
class
_DeepepManager
(
MegatronCoreDeepepManager
):
"""
patch megatron _DeepepManager. async
"""
def
dispatch
(
self
,
hidden_states
:
torch
.
Tensor
,
async_finish
:
bool
=
False
,
allocate_on_comm_stream
:
bool
=
False
,
)
->
torch
.
Tensor
:
# DeepEP only supports float32 probs
if
self
.
token_probs
.
dtype
!=
torch
.
float32
:
if
self
.
token_probs
.
dtype
in
[
torch
.
bfloat16
,
torch
.
float16
]:
print
(
"DeepEP only supports float32 probs, please set --moe-router-dtype=fp32"
)
self
.
token_probs
=
self
.
token_probs
.
float
()
# downcast or upcast
hidden_states
,
dispatched_indices
,
dispatched_probs
,
num_tokens_per_expert
,
handle
=
(
fused_dispatch
(
hidden_states
,
self
.
token_indices
,
self
.
token_probs
,
self
.
num_experts
,
self
.
group
,
async_finish
=
async_finish
,
allocate_on_comm_stream
=
allocate_on_comm_stream
,
)
)
self
.
handle
=
handle
self
.
tokens_per_expert
=
num_tokens_per_expert
self
.
dispatched_indices
=
dispatched_indices
self
.
dispatched_probs
=
dispatched_probs
return
hidden_states
def
combine
(
self
,
hidden_states
:
torch
.
Tensor
,
async_finish
:
bool
=
False
,
allocate_on_comm_stream
:
bool
=
False
,
)
->
torch
.
Tensor
:
hidden_states
,
_
=
fused_combine
(
hidden_states
,
self
.
group
,
self
.
handle
,
async_finish
=
async_finish
,
allocate_on_comm_stream
=
allocate_on_comm_stream
,
)
# Release the handle after combine operation
self
.
handle
=
None
return
hidden_states
class
MoEFlexTokenDispatcher
(
MoETokenDispatcher
):
"""
Flex token dispatcher using DeepEP.
"""
def
dispatch_preprocess
(
self
,
hidden_states
:
torch
.
Tensor
,
routing_map
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
):
"""
Preprocesses the hidden states and routing information before dispatching tokens to experts.
Args:
hidden_states (torch.Tensor): Input hidden states to be processed
routing_map (torch.Tensor): Map indicating which expert each token should be routed to
probs (torch.Tensor): Routing probabilities for each token-expert pair
Returns:
Tuple containing:
- torch.Tensor: Reshaped hidden states
- torch.Tensor: Token probabilities from the communication manager
- None: Placeholder for compatibility
"""
self
.
hidden_shape
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
self
.
hidden_shape
[
-
1
])
# Initialize metadata
routing_map
,
probs
=
self
.
_initialize_metadata
(
routing_map
,
probs
)
self
.
_comm_manager
.
setup_metadata
(
routing_map
,
probs
)
return
hidden_states
,
self
.
_comm_manager
.
token_probs
,
None
def
dispatch_all_to_all
(
self
,
hidden_states
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
=
None
,
async_finish
:
bool
=
True
,
allocate_on_comm_stream
:
bool
=
True
,
):
"""
Performs all-to-all communication to dispatch tokens across expert parallel ranks.
"""
return
(
self
.
_comm_manager
.
dispatch
(
hidden_states
,
async_finish
,
allocate_on_comm_stream
),
self
.
_comm_manager
.
dispatched_probs
,
)
def
dispatch_postprocess
(
self
,
hidden_states
:
torch
.
Tensor
):
"""
Post-processes the dispatched hidden states after all-to-all communication.
This method retrieves the permuted hidden states by experts, calculates the number of tokens
per expert, and returns the processed data ready for expert processing.
"""
global_input_tokens
,
permuted_probs
=
(
self
.
_comm_manager
.
get_permuted_hidden_states_by_experts
(
hidden_states
)
)
tokens_per_expert
=
self
.
_comm_manager
.
get_number_of_tokens_per_expert
()
return
global_input_tokens
,
tokens_per_expert
,
permuted_probs
def
token_permutation
(
self
,
hidden_states
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
,
routing_map
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Permutes tokens according to the routing map and dispatches them to experts.
This method implements the token permutation process in three steps:
1. Preprocess the hidden states and routing information
2. Perform all-to-all communication to dispatch tokens
3. Post-process the dispatched tokens for expert processing
"""
hidden_states
,
_
,
_
=
self
.
dispatch_preprocess
(
hidden_states
,
routing_map
,
probs
)
hidden_states
,
_
=
self
.
dispatch_all_to_all
(
hidden_states
,
async_finish
=
False
,
allocate_on_comm_stream
=
False
)
global_input_tokens
,
tokens_per_expert
,
permuted_probs
=
self
.
dispatch_postprocess
(
hidden_states
)
return
global_input_tokens
,
tokens_per_expert
,
permuted_probs
def
combine_preprocess
(
self
,
hidden_states
:
torch
.
Tensor
):
"""
Pre-processes the hidden states before combining them after expert processing.
This method restores the hidden states to their original ordering before expert processing
by using the communication manager's restoration function.
"""
hidden_states
=
self
.
_comm_manager
.
get_restored_hidden_states_by_experts
(
hidden_states
)
return
hidden_states
def
combine_all_to_all
(
self
,
hidden_states
:
torch
.
Tensor
,
async_finish
:
bool
=
True
,
allocate_on_comm_stream
:
bool
=
True
,
):
"""
Performs all-to-all communication to combine tokens after expert processing.
"""
return
self
.
_comm_manager
.
combine
(
hidden_states
,
async_finish
,
allocate_on_comm_stream
)
def
combine_postprocess
(
self
,
hidden_states
:
torch
.
Tensor
):
"""
Post-processes the combined hidden states after all-to-all communication.
This method reshapes the combined hidden states to match the original input shape.
"""
return
hidden_states
.
view
(
self
.
hidden_shape
)
def
token_unpermutation
(
self
,
hidden_states
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
"""
Reverses the token permutation process to restore the original token order.
This method implements the token unpermutation process in three steps:
1. Pre-process the hidden states to restore their original ordering
2. Perform all-to-all communication to combine tokens
3. Post-process the combined tokens to match the original input shape
"""
assert
bias
is
None
,
"Bias is not supported in MoEFlexTokenDispatcher"
hidden_states
=
self
.
combine_preprocess
(
hidden_states
)
hidden_states
=
self
.
combine_all_to_all
(
hidden_states
,
False
,
False
)
hidden_states
=
self
.
combine_postprocess
(
hidden_states
)
return
hidden_states
,
None
dcu_megatron/core/transformer/transformer_block.py
View file @
12b56c98
from
functools
import
wraps
from
megatron.core.transformer.transformer_block
import
TransformerBlock
as
MegatronCoreTransformerBlock
def
transformer_block_init_wrapper
(
fn
):
@
wraps
(
fn
)
...
...
@@ -13,3 +14,22 @@ def transformer_block_init_wrapper(fn):
self
.
final_layernorm
=
None
return
wrapper
class
TransformerBlock
(
MegatronCoreTransformerBlock
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
# mtp require seperate layernorms for main model and mtp modules, thus move finalnorm out of block
config
=
args
[
0
]
if
len
(
args
)
>
1
else
kwargs
[
'config'
]
if
getattr
(
config
,
"mtp_num_layers"
,
0
)
>
0
:
self
.
main_final_layernorm
=
self
.
final_layernorm
self
.
final_layernorm
=
None
def
get_layer_callables
(
self
,
layer_number
:
int
):
"""
Get the callables for the layer at the given layer number.
"""
return
self
.
layers
[
layer_number
].
get_submodule_callables
()
dcu_megatron/core/transformer/transformer_layer.py
0 → 100644
View file @
12b56c98
from
megatron.core.transformer.transformer_layer
import
TransformerLayer
as
MegatronCoreTransformerLayer
class
TransformerLayer
(
MegatronCoreTransformerLayer
):
def
_submodule_attn_router_forward
(
self
,
hidden_states
,
attention_mask
=
None
,
inference_params
=
None
,
rotary_pos_emb
=
None
,
rotary_pos_cos
=
None
,
rotary_pos_sin
=
None
,
attention_bias
=
None
,
packed_seq_params
=
None
,
sequence_len_offset
=
None
,
state
=
None
,
):
"""
Performs a combined forward pass that includes self-attention and MLP routing logic.
"""
hidden_states
,
_
=
self
.
_forward_attention
(
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
rotary_pos_emb
=
rotary_pos_emb
,
rotary_pos_cos
=
rotary_pos_cos
,
rotary_pos_sin
=
rotary_pos_sin
,
attention_bias
=
attention_bias
,
packed_seq_params
=
packed_seq_params
,
sequence_len_offset
=
sequence_len_offset
,
inference_params
=
inference_params
,
)
pre_mlp_layernorm_output
=
self
.
pre_mlp_layernorm
(
hidden_states
)
probs
,
routing_map
=
self
.
mlp
.
router
(
pre_mlp_layernorm_output
)
local_tokens
,
probs
,
tokens_per_expert
=
self
.
mlp
.
token_dispatcher
.
dispatch_preprocess
(
pre_mlp_layernorm_output
,
routing_map
,
probs
)
return
(
local_tokens
,
probs
,
hidden_states
,
pre_mlp_layernorm_output
,
tokens_per_expert
)
def
_submodule_dispatch_forward
(
self
,
local_tokens
,
probs
,
state
=
None
):
"""
Dispatches tokens to the appropriate experts based on the router output.
"""
token_dispatcher
=
self
.
mlp
.
token_dispatcher
if
self
.
is_deepep
:
token_dispatcher
.
_comm_manager
.
token_probs
=
probs
return
token_dispatcher
.
dispatch_all_to_all
(
local_tokens
,
probs
)
def
_submodule_moe_forward
(
self
,
dispatched_tokens
,
probs
=
None
,
state
=
None
):
"""
Performs a forward pass for the MLP submodule, including both expert-based
and optional shared-expert computations.
"""
shared_expert_output
=
None
token_dispatcher
=
self
.
mlp
.
token_dispatcher
if
self
.
is_deepep
:
token_dispatcher
.
_comm_manager
.
dispatched_probs
=
state
.
dispatched_probs
dispatched_tokens
,
tokens_per_expert
,
permuted_probs
=
(
token_dispatcher
.
dispatch_postprocess
(
dispatched_tokens
)
)
else
:
dispatched_tokens
,
permuted_probs
=
token_dispatcher
.
dispatch_postprocess
(
dispatched_tokens
,
probs
)
tokens_per_expert
=
state
.
tokens_per_expert
expert_output
,
mlp_bias
=
self
.
mlp
.
experts
(
dispatched_tokens
,
tokens_per_expert
,
permuted_probs
)
assert
mlp_bias
is
None
,
f
"Bias is not supported in
{
token_dispatcher
.
__class__
.
__name__
}
"
if
self
.
mlp
.
use_shared_expert
and
not
self
.
mlp
.
shared_expert_overlap
:
shared_expert_output
=
self
.
mlp
.
shared_experts
(
state
.
pre_mlp_layernorm_output
)
expert_output
=
self
.
mlp
.
token_dispatcher
.
combine_preprocess
(
expert_output
)
return
expert_output
,
shared_expert_output
,
mlp_bias
def
_submodule_combine_forward
(
self
,
output
,
shared_expert_output
=
None
,
state
=
None
):
residual
=
state
.
residual
token_dispatcher
=
self
.
mlp
.
token_dispatcher
output
=
token_dispatcher
.
combine_all_to_all
(
output
)
output
=
token_dispatcher
.
combine_postprocess
(
output
)
if
shared_expert_output
is
not
None
:
output
=
output
+
shared_expert_output
mlp_output_with_bias
=
(
output
,
None
)
with
self
.
bias_dropout_add_exec_handler
():
hidden_states
=
self
.
mlp_bda
(
self
.
training
,
self
.
config
.
bias_dropout_fusion
)(
mlp_output_with_bias
,
residual
,
self
.
hidden_dropout
)
output
=
make_viewless_tensor
(
inp
=
hidden_states
,
requires_grad
=
hidden_states
.
requires_grad
,
keep_graph
=
True
)
return
output
def
_submodule_attn_router_dw
(
self
):
self
.
self_attention
.
backward_dw
()
def
_submodule_mlp_dw
(
self
):
self
.
mlp
.
backward_dw
()
def
_submodule_attn_router_postprocess
(
self
,
node
,
local_tokens
,
probs
,
residual
,
pre_mlp_layernorm_output
,
tokens_per_expert
):
node
.
common_state
.
residual
=
node
.
detach
(
residual
)
if
self
.
mlp
.
use_shared_expert
:
node
.
common_state
.
pre_mlp_layernorm_output
=
node
.
detach
(
pre_mlp_layernorm_output
)
if
not
self
.
is_deepep
:
node
.
common_state
.
tokens_per_expert
=
tokens_per_expert
return
local_tokens
,
probs
def
_submodule_dispatch_postprocess
(
self
,
node
,
dispatched_tokens
,
probs
):
if
self
.
is_deepep
:
node
.
common_state
.
dispatched_probs
=
node
.
detach
(
probs
)
return
dispatched_tokens
else
:
return
dispatched_tokens
,
probs
def
_submodule_mlp_postprocess
(
self
,
node
,
expert_output
,
shared_expert_output
,
mlp_bias
):
assert
mlp_bias
is
None
node
.
common_state
.
pre_mlp_layernorm_output
=
None
if
shared_expert_output
is
None
:
return
expert_output
return
expert_output
,
shared_expert_output
def
_submodule_combine_postprocess
(
self
,
node
,
output
):
cur_stream
=
torch
.
cuda
.
current_stream
()
node
.
common_state
.
residual
.
record_stream
(
cur_stream
)
node
.
common_state
.
residual
=
None
return
output
def
_submodule_attn_postprocess
(
self
,
node
,
hidden_states
,
context
):
return
hidden_states
def
_submodule_dense_postprocess
(
self
,
node
,
hidden_states
):
return
hidden_states
def
_submodule_not_implemented
(
self
,
*
args
):
raise
NotImplementedError
(
"This callable is not implemented."
)
def
get_submodule_callables
(
self
,
chunk_state
):
"""
The forward callables take 2 parts of inputs:
1. The ScheduleNode object.
2. The input tensors.
"""
from
megatron.core.transformer.moe.moe_layer
import
MoELayer
from
megatron.core.transformer.moe.token_dispatcher
import
MoEFlexTokenDispatcher
self
.
is_moe
=
isinstance
(
self
.
mlp
,
MoELayer
)
self
.
is_deepep
=
False
if
self
.
is_moe
:
self
.
is_deepep
=
isinstance
(
self
.
mlp
.
token_dispatcher
,
MoEFlexTokenDispatcher
)
def
get_func_with_default
(
func
,
default_func
):
if
self
.
is_moe
:
return
func
return
default_func
def
callable_wrapper
(
forward_func
,
postprocess_func
,
node
,
*
args
):
state
=
getattr
(
node
,
'common_state'
,
None
)
callable_outputs
=
forward_func
(
*
args
,
state
=
state
)
if
isinstance
(
callable_outputs
,
tuple
):
outputs
=
postprocess_func
(
node
,
*
callable_outputs
)
else
:
outputs
=
postprocess_func
(
node
,
callable_outputs
)
return
outputs
attn_func
=
get_func_with_default
(
self
.
_submodule_attn_router_forward
,
self
.
_forward_attention
)
def
attn_wrapper
(
hidden_states
,
state
=
None
):
return
attn_func
(
hidden_states
=
hidden_states
,
attention_mask
=
chunk_state
.
attention_mask
,
attention_bias
=
chunk_state
.
attention_bias
,
inference_params
=
chunk_state
.
inference_params
,
packed_seq_params
=
chunk_state
.
packed_seq_params
,
sequence_len_offset
=
chunk_state
.
sequence_len_offset
,
rotary_pos_emb
=
chunk_state
.
rotary_pos_emb
,
rotary_pos_cos
=
chunk_state
.
rotary_pos_cos
,
rotary_pos_sin
=
chunk_state
.
rotary_pos_sin
,
state
=
state
,
)
attn_postprocess_func
=
get_func_with_default
(
self
.
_submodule_attn_router_postprocess
,
self
.
_submodule_attn_postprocess
)
dispatch_func
=
get_func_with_default
(
self
.
_submodule_dispatch_forward
,
self
.
_submodule_not_implemented
)
dispatch_postprocess_func
=
get_func_with_default
(
self
.
_submodule_dispatch_postprocess
,
self
.
_submodule_not_implemented
)
mlp_func
=
get_func_with_default
(
self
.
_submodule_moe_forward
,
self
.
_forward_mlp
)
mlp_postprocess_func
=
get_func_with_default
(
self
.
_submodule_mlp_postprocess
,
self
.
_submodule_dense_postprocess
)
combine_func
=
get_func_with_default
(
self
.
_submodule_combine_forward
,
self
.
_submodule_not_implemented
)
combine_postprocess_func
=
get_func_with_default
(
self
.
_submodule_combine_postprocess
,
self
.
_submodule_not_implemented
)
attn_forward
=
partial
(
callable_wrapper
,
attn_wrapper
,
attn_postprocess_func
)
dispatch_forward
=
partial
(
callable_wrapper
,
dispatch_func
,
dispatch_postprocess_func
)
mlp_forward
=
partial
(
callable_wrapper
,
mlp_func
,
mlp_postprocess_func
)
combine_forward
=
partial
(
callable_wrapper
,
combine_func
,
combine_postprocess_func
)
callables
=
TransformerLayerSubmoduleCallables
(
attention
=
SubmoduleCallables
(
forward
=
attn_forward
,
dw
=
self
.
_submodule_attn_router_dw
),
dispatch
=
SubmoduleCallables
(
forward
=
dispatch_forward
),
mlp
=
SubmoduleCallables
(
forward
=
mlp_forward
,
dw
=
self
.
_submodule_mlp_dw
),
combine
=
SubmoduleCallables
(
forward
=
combine_forward
),
is_moe
=
self
.
is_moe
,
is_deepep
=
self
.
is_deepep
,
)
return
callables
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment