Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
6a579b17
Commit
6a579b17
authored
Jun 16, 2025
by
dongcl
Browse files
rewrite combined_1f1b
parent
e103a256
Changes
4
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
611 additions
and
525 deletions
+611
-525
dcu_megatron/core/models/gpt/fine_grained_schedule.py
dcu_megatron/core/models/gpt/fine_grained_schedule.py
+261
-223
dcu_megatron/core/pipeline_parallel/combined_1f1b.py
dcu_megatron/core/pipeline_parallel/combined_1f1b.py
+182
-69
dcu_megatron/core/transformer/moe/token_dispatcher.py
dcu_megatron/core/transformer/moe/token_dispatcher.py
+2
-6
dcu_megatron/core/transformer/transformer_layer.py
dcu_megatron/core/transformer/transformer_layer.py
+166
-227
No files found.
dcu_megatron/core/models/gpt/fine_grained_schedule.py
View file @
6a579b17
This diff is collapsed.
Click to expand it.
dcu_megatron/core/pipeline_parallel/combined_1f1b.py
View file @
6a579b17
...
...
@@ -10,7 +10,7 @@ from torch.autograd.variable import Variable
from
megatron.training
import
get_args
from
megatron.core
import
parallel_state
from
megatron.core.distributed
import
DistributedDataParallel
from
megatron.core.models.gpt.gpt_model
import
GPTModel
from
megatron.core.transformer.module
import
Float16Module
from
megatron.core.transformer.moe.router
import
MoEAuxLossAutoScaler
from
megatron.core.transformer.multi_token_prediction
import
MTPLossAutoScaler
...
...
@@ -20,13 +20,14 @@ from dcu_megatron.core.parallel_state import get_dualpipe_chunk
def
make_viewless
(
e
):
"""
m
ake_viewless util func"""
"""
M
ake_viewless util func"""
e
=
make_viewless_tensor
(
inp
=
e
,
requires_grad
=
e
.
requires_grad
,
keep_graph
=
True
)
return
e
@
contextmanager
def
stream_acquire_context
(
stream
,
event
):
"""Stream acquire context"""
event
.
wait
(
stream
)
try
:
yield
...
...
@@ -34,8 +35,29 @@ def stream_acquire_context(stream, event):
event
.
record
(
stream
)
class
FakeScheduleNode
:
"""A placeholder node in the computation graph that simply passes through inputs and outputs.
This class is used as a no-op node in the scheduling system when a real computation node
is not needed but the interface must be maintained. It simply returns its inputs unchanged
in both forward and backward passes.
"""
def
forward
(
self
,
inputs
):
"""Passes through inputs unchanged in the forward pass."""
return
inputs
def
backward
(
self
,
outgrads
):
"""Passes through gradients unchanged in the backward pass."""
return
outgrads
class
ScheduleNode
:
"""base node for fine-grained schedule"""
"""Base node for fine-grained scheduling.
This class represents a computational node in the pipeline schedule.
It handles the execution of forward and backward operations on a stream.
"""
def
__init__
(
self
,
...
...
@@ -43,24 +65,30 @@ class ScheduleNode:
stream
,
event
,
backward_func
=
None
,
free_inputs
=
Fals
e
,
memory_strategy
=
Non
e
,
name
=
"schedule_node"
,
):
"""Initialize a schedule node.
Args:
forward_func (callable): Function to execute during forward pass
stream (torch.cuda.Stream): CUDA stream for computation
event (torch.cuda.Event): Event for synchronization
backward_func (callable, optional): Function for backward pass
memory_strategy (MemoryManagementStrategy, optional): Strategy for memory management
name (str): Name of the node for debugging
"""
self
.
name
=
name
self
.
forward_func
=
forward_func
self
.
backward_func
=
backward_func
self
.
backward_func
=
backward_func
if
backward_func
else
self
.
default_backward_func
self
.
stream
=
stream
self
.
event
=
event
self
.
free_inputs
=
free_inputs
self
.
memory_strategy
=
memory_strategy
or
NoOpMemoryStrategy
()
self
.
inputs
=
None
self
.
outputs
=
None
def
default_backward_func
(
self
,
outputs
,
output_grad
):
# Handle scalar output
if
output_grad
is
None
:
assert
outputs
.
numel
()
==
1
,
"implicit grad requires scalar output."
output_grad
=
torch
.
ones_like
(
outputs
,
memory_format
=
torch
.
preserve_format
)
"""Default backward function"""
Variable
.
_execution_engine
.
run_backward
(
tensors
=
outputs
,
grad_tensors
=
output_grad
,
...
...
@@ -72,20 +100,16 @@ class ScheduleNode:
)
return
output_grad
def
forward
(
self
,
inputs
=
(),
stream_wait_event
=
None
,
stream_record_event
=
None
):
"""schedule node forward"""
def
forward
(
self
,
inputs
=
()):
"""Schedule node forward"""
if
not
isinstance
(
inputs
,
tuple
):
inputs
=
(
inputs
,)
return
self
.
_forward
(
*
inputs
,
stream_wait_event
=
stream_wait_event
,
stream_record_event
=
stream_record_event
)
return
self
.
_forward
(
*
inputs
)
def
_forward
(
self
,
*
inputs
,
stream_wait_event
=
None
,
stream_record_event
=
None
):
def
_forward
(
self
,
*
inputs
):
with
stream_acquire_context
(
self
.
stream
,
self
.
event
):
torch
.
cuda
.
nvtx
.
range_push
(
f
"
{
self
.
name
}
forward"
)
with
torch
.
cuda
.
stream
(
self
.
stream
):
if
stream_wait_event
is
not
None
:
stream_wait_event
.
wait
(
self
.
stream
)
self
.
inputs
=
[
make_viewless
(
e
).
detach
()
if
e
is
not
None
else
None
for
e
in
inputs
]
for
i
,
input
in
enumerate
(
self
.
inputs
):
if
input
is
not
None
:
...
...
@@ -100,50 +124,35 @@ class ScheduleNode:
data
=
tuple
([
make_viewless
(
e
)
if
isinstance
(
e
,
Tensor
)
else
e
for
e
in
data
])
self
.
output
=
data
if
stream_record_event
is
not
None
:
stream_record_event
.
record
(
self
.
stream
)
torch
.
cuda
.
nvtx
.
range_pop
()
if
self
.
free_inputs
:
for
input
in
inputs
:
input
.
record_stream
(
self
.
stream
)
input
.
untyped_storage
().
resize_
(
0
)
# Handle inputs using the memory strategy
self
.
memory_strategy
.
handle_inputs
(
inputs
,
self
.
stream
)
return
self
.
output
def
get_output
(
self
):
"""
g
et the forward output"""
"""
G
et the forward output"""
return
self
.
output
def
backward
(
self
,
output_grad
,
stream_wait_event
=
None
,
stream_record_event
=
None
):
"""
s
chedule node backward"""
def
backward
(
self
,
output_grad
):
"""
S
chedule node backward"""
if
not
isinstance
(
output_grad
,
tuple
):
output_grad
=
(
output_grad
,)
return
self
.
_backward
(
*
output_grad
,
stream_wait_event
=
stream_wait_event
,
stream_record_event
=
stream_record_event
)
return
self
.
_backward
(
*
output_grad
)
def
_backward
(
self
,
*
output_grad
,
stream_wait_event
=
None
,
stream_record_event
=
None
):
def
_backward
(
self
,
*
output_grad
):
with
stream_acquire_context
(
self
.
stream
,
self
.
event
):
torch
.
cuda
.
nvtx
.
range_push
(
f
"
{
self
.
name
}
backward"
)
with
torch
.
cuda
.
stream
(
self
.
stream
):
if
stream_wait_event
is
not
None
:
stream_wait_event
.
wait
(
self
.
stream
)
outputs
=
self
.
output
if
not
isinstance
(
outputs
,
tuple
):
outputs
=
(
outputs
,)
assert
len
(
outputs
)
==
len
(
output_grad
),
f
"
{
len
(
outputs
)
}
of
{
type
(
outputs
[
0
])
}
vs
{
len
(
output_grad
)
}
of
{
type
(
output_grad
[
0
])
}
"
if
self
.
backward_func
is
not
None
:
output_grad
=
self
.
backward_func
(
outputs
,
output_grad
)
else
:
output_grad
=
self
.
default_backward_func
(
outputs
,
output_grad
)
if
stream_record_event
is
not
None
:
stream_record_event
.
record
(
self
.
stream
)
assert
len
(
outputs
)
==
len
(
output_grad
),
(
f
"
{
len
(
outputs
)
}
of
{
type
(
outputs
[
0
])
}
is not equal to "
f
"
{
len
(
output_grad
)
}
of
{
type
(
output_grad
[
0
])
}
"
)
output_grad
=
self
.
backward_func
(
outputs
,
output_grad
)
torch
.
cuda
.
nvtx
.
range_pop
()
# output_grad maybe from another stream
...
...
@@ -153,7 +162,7 @@ class ScheduleNode:
return
self
.
get_grad
()
def
get_grad
(
self
):
"""
g
et the grad of inputs"""
"""
G
et the grad of inputs"""
grad
=
tuple
([
e
.
grad
if
e
is
not
None
else
None
for
e
in
self
.
inputs
])
# clear state
self
.
inputs
=
None
...
...
@@ -165,7 +174,7 @@ class ScheduleNode:
class
AbstractSchedulePlan
(
ABC
):
"""
t
o use combined 1f1b, model must implement build_schedule_plan while take the same
"""
T
o use combined 1f1b, model must implement build_schedule_plan while take the same
signature as model forward but return an instance of AbstractSchedulePlan"""
@
classmethod
...
...
@@ -197,7 +206,29 @@ def schedule_chunk_1f1b(
post_forward
=
None
,
post_backward
=
None
,
):
"""model level 1f1b fine-grained schedule"""
"""Model level 1f1b fine-grained schedule
This function schedules the forward and backward passes for a chunk of the model.
It takes in the forward schedule plan, backward schedule plan, gradient, and optional
context managers for the forward and backward passes.
Args:
f_schedule_plan (subclass of AbstractSchedulePlan): The forward schedule plan
b_schedule_plan (subclass of AbstractSchedulePlan): The backward schedule plan
grad (Tensor or None): The gradient of the loss function
f_context (VppContextManager or None): The VppContextManager for the forward pass
b_context (VppContextManager or None): The VppContextManager for the backward pass
pre_forward (callable or None): The function to call before the forward pass
pre_backward (callable or None): The function to call before the backward pass
post_forward (callable or None): The function to call after the forward pass
post_backward (callable or None): The function to call after the backward pass
Returns:
The output of the forward pass
"""
# Calls fine_grained_schedule.py::ModelChunkSchedulePlan.forward_backward(),
# which calls fine_grained_schedule.py::schedule_chunk_1f1b()
return
type
(
f_schedule_plan
or
b_schedule_plan
).
forward_backward
(
f_schedule_plan
,
b_schedule_plan
,
...
...
@@ -216,7 +247,7 @@ _COM_STREAM = None
def
set_streams
(
comp_stream
=
None
,
com_stream
=
None
):
"""
s
et the streams for communication and computation"""
"""
S
et the streams for communication and computation"""
global
_COMP_STREAM
global
_COM_STREAM
if
_COMP_STREAM
is
not
None
:
...
...
@@ -234,19 +265,19 @@ def set_streams(comp_stream=None, com_stream=None):
def
get_comp_stream
():
"""
g
et the stream for computation"""
"""
G
et the stream for computation"""
global
_COMP_STREAM
return
_COMP_STREAM
def
get_com_stream
():
"""
g
et the stream for communication"""
"""
G
et the stream for communication"""
global
_COM_STREAM
return
_COM_STREAM
class
VppContextManager
:
"""
a
reusable context manager for switch vpp stage"""
"""
A
reusable context manager for switch vpp stage"""
def
__init__
(
self
,
vpp_rank
):
self
.
vpp_rank
=
vpp_rank
...
...
@@ -353,9 +384,17 @@ def forward_backward_step(
Tensor or list[Tensor]: The output object(s) from the forward step.
Tensor: The number of tokens.
"""
assert
(
checkpoint_activations_microbatch
is
None
),
"checkpoint_activations_microbatch is not supported for combined_1f1b"
if
config
.
combined_1f1b_recipe
!=
"ep_a2a"
:
raise
NotImplementedError
(
f
"combined_1f1b_recipe
{
config
.
combined_1f1b_recipe
}
not supported yet"
)
from
megatron.core.pipeline_parallel.schedules
import
set_current_microbatch
if
config
.
timers
is
not
None
:
if
f_model
is
not
None
and
config
.
timers
is
not
None
:
config
.
timers
(
'forward-compute'
,
log_level
=
2
).
start
()
if
config
.
enable_autocast
:
...
...
@@ -364,6 +403,8 @@ def forward_backward_step(
context_manager
=
contextlib
.
nullcontext
()
# forward preprocess
unwrap_output_tensor
=
False
f_schedule_plan
=
None
if
f_model
is
not
None
:
with
f_context
:
if
is_first_microbatch
and
hasattr
(
f_model
,
'set_is_first_microbatch'
):
...
...
@@ -371,7 +412,6 @@ def forward_backward_step(
if
current_microbatch
is
not
None
:
set_current_microbatch
(
f_model
,
current_microbatch
)
unwrap_output_tensor
=
False
if
not
isinstance
(
input_tensor
,
list
):
input_tensor
=
[
input_tensor
]
unwrap_output_tensor
=
True
...
...
@@ -381,20 +421,20 @@ def forward_backward_step(
with
context_manager
:
if
checkpoint_activations_microbatch
is
None
:
output_tensor
,
loss_func
=
forward_step_func
(
data_iterator
,
f_model
)
f_schedule_plan
,
loss_func
=
forward_step_func
(
data_iterator
,
f_model
)
else
:
output_tensor
,
loss_func
=
forward_step_func
(
f_schedule_plan
,
loss_func
=
forward_step_func
(
data_iterator
,
f_model
,
checkpoint_activations_microbatch
)
assert
isinstance
(
output_tensor
,
AbstractSchedulePlan
f_schedule_plan
,
AbstractSchedulePlan
),
"first output of forward_step_func must be one instance of AbstractSchedulePlan"
# backward preprocess
unwrap_input_tensor_grad
=
False
b_schedule_plan
=
None
if
b_model
is
not
None
:
# Retain the grad on the input_tensor.
unwrap_input_tensor_grad
=
False
if
not
isinstance
(
b_input_tensor
,
list
):
b_input_tensor
=
[
b_input_tensor
]
unwrap_input_tensor_grad
=
True
...
...
@@ -418,9 +458,8 @@ def forward_backward_step(
torch
.
autograd
.
backward
(
b_output_tensor
[
0
],
grad_tensors
=
b_output_tensor_grad
[
0
])
b_output_tensor_grad
[
0
]
=
loss_node
.
get_grad
()
f_schedule_plan
=
output_tensor
if
f_model
else
None
grad
=
b_output_tensor_grad
[
0
]
if
b_model
else
None
with
context_manager
:
with
context_manager
:
# autocast context
# schedule forward and backward
output_tensor
=
schedule_chunk_1f1b
(
f_schedule_plan
,
...
...
@@ -436,7 +475,7 @@ def forward_backward_step(
# forward post process
num_tokens
=
None
if
f_model
:
if
f_model
is
not
None
:
with
f_context
:
model_vp_stage
=
getattr
(
f_model
,
"vp_stage"
,
None
)
if
vp_stage
is
not
None
and
model_vp_stage
is
not
None
:
...
...
@@ -535,7 +574,18 @@ def forward_backward_step(
return
output_tensor
,
num_tokens
,
input_tensor_grad
def
get_default_cls_for_unwrap
():
"""Returns the default classes to unwrap from a model.
This function provides a tuple of classes that should be unwrapped from a model
to access the underlying GPTModel instance. It includes DistributedDataParallel
and Float16Module by default, and also attempts to include LegacyFloat16Module
if available for backward compatibility.
Returns:
tuple: A tuple of classes to unwrap from a model.
"""
cls
=
(
DistributedDataParallel
,
Float16Module
)
try
:
# legacy should not be used in core, but for backward compatibility, we support it here
...
...
@@ -547,7 +597,9 @@ def get_default_cls_for_unwrap():
def
unwrap_model
(
model
,
module_instances
=
get_default_cls_for_unwrap
()):
"""unwrap_model DistributedDataParallel and Float16Module wrapped model"""
"""Unwrap_model DistributedDataParallel and Float16Module wrapped model
to return GPTModel instance
"""
return_list
=
True
if
not
isinstance
(
model
,
list
):
model
=
[
model
]
...
...
@@ -556,19 +608,80 @@ def unwrap_model(model, module_instances=get_default_cls_for_unwrap()):
for
model_module
in
model
:
while
isinstance
(
model_module
,
module_instances
):
model_module
=
model_module
.
module
assert
isinstance
(
model_module
,
GPTModel
),
"The final unwrapped model must be a GPTModel instance"
unwrapped_model
.
append
(
model_module
)
if
not
return_list
:
return
unwrapped_model
[
0
]
return
unwrapped_model
def
wrap_forward_func
(
config
,
forward_step_func
):
"""wrap the input to forward_step_func, to make forward_step_func return schedule plan"""
def
wrap_forward_func
(
forward_step_func
):
"""Wrap the input to forward_step_func.
The wrapped function will return forward_schedule_plan and the loss_function.
"""
def
wrapped_func
(
data_iterator
,
model
):
# Model is unwrapped to get GPTModel instance.
# GPTModel.build_schedule_plan(model_forward_inputs) is called in the forward_step.
# The return value becomes (forward_schedule_plan, loss_function),
# which is used to be (forward_output_tensor, loss_function).
return
forward_step_func
(
data_iterator
,
unwrap_model
(
model
).
build_schedule_plan
)
if
config
.
combined_1f1b
and
config
.
combined_1f1b_recipe
==
"ep_a2a"
:
return
wrapped_func
else
:
return
forward_step_func
return
wrapped_func
class
MemoryManagementStrategy
:
"""Base class for memory management strategies.
Different memory management strategies can be implemented by subclassing this class.
These strategies control how tensors are handled in memory during the computation.
"""
def
handle_inputs
(
self
,
inputs
,
stream
):
"""Process input tensors after computation.
Args:
inputs (tuple): Input tensors that have been used
stream (torch.cuda.Stream): Current CUDA stream
"""
pass
def
handle_outputs
(
self
,
outputs
,
stream
):
"""Process output tensors after computation.
Args:
outputs (tuple): Output tensors produced by the computation
stream (torch.cuda.Stream): Current CUDA stream
"""
pass
class
NoOpMemoryStrategy
(
MemoryManagementStrategy
):
"""Strategy that performs no memory management operations.
This is the default strategy - it doesn't free any memory.
"""
pass
class
FreeInputsMemoryStrategy
(
MemoryManagementStrategy
):
"""Strategy that immediately frees input tensors after they are used.
This strategy is useful for nodes where inputs are no longer needed
after computation, helping to reduce memory usage.
"""
def
handle_inputs
(
self
,
inputs
,
stream
):
"""Free input tensors by resizing their storage to zero.
Args:
inputs (tuple): Input tensors to be freed
stream (torch.cuda.Stream): Current CUDA stream
"""
for
input
in
inputs
:
if
input
is
not
None
:
input
.
record_stream
(
stream
)
input
.
untyped_storage
().
resize_
(
0
)
dcu_megatron/core/transformer/moe/token_dispatcher.py
View file @
6a579b17
...
...
@@ -94,7 +94,7 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
self
.
collect_per_batch_state
(
state
)
self
.
apply_per_batch_state
(
origin_state
)
def
meta_prepare
(
def
dispatch_preprocess
(
self
,
hidden_states
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
,
routing_map
:
torch
.
Tensor
):
self
.
hidden_shape
=
hidden_states
.
shape
...
...
@@ -112,9 +112,6 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
self
.
routing_map
=
pad_routing_map
(
self
.
routing_map
,
pad_multiple
)
tokens_per_expert
=
self
.
preprocess
(
self
.
routing_map
)
return
tokens_per_expert
def
dispatch_preprocess
(
self
,
hidden_states
:
torch
.
Tensor
,
routing_map
:
torch
.
Tensor
,
tokens_per_expert
:
torch
.
Tensor
):
hidden_states
=
hidden_states
.
view
(
-
1
,
self
.
hidden_shape
[
-
1
])
if
self
.
shared_experts
is
not
None
:
self
.
shared_experts
.
pre_forward_comm
(
hidden_states
.
view
(
self
.
hidden_shape
))
...
...
@@ -235,8 +232,7 @@ class MoEAlltoAllTokenDispatcher(MegatronCoreMoEAlltoAllTokenDispatcher):
"""
# Preprocess: Get the metadata for communication, permutation and computation operations.
# Permutation 1: input to AlltoAll input
tokens_per_expert
=
self
.
meta_prepare
(
hidden_states
,
probs
,
routing_map
)
tokens_per_expert
,
permutated_local_input_tokens
,
permuted_probs
=
self
.
dispatch_preprocess
(
hidden_states
,
routing_map
,
tokens_per_expert
)
tokens_per_expert
,
permutated_local_input_tokens
,
permuted_probs
=
self
.
dispatch_preprocess
(
hidden_states
,
probs
,
routing_map
)
# Perform expert parallel AlltoAll communication
tokens_per_expert
,
global_input_tokens
,
global_probs
=
self
.
dispatch_all_to_all
(
tokens_per_expert
,
permutated_local_input_tokens
,
permuted_probs
)
...
...
dcu_megatron/core/transformer/transformer_layer.py
View file @
6a579b17
...
...
@@ -193,7 +193,7 @@ def get_transformer_layer_offset(config: TransformerConfig, vp_stage: Optional[i
class
TransformerLayer
(
MegatronCoreTransformerLayer
):
def
forward
(
def
_submodule_attn_router_
forward
(
self
,
hidden_states
:
Tensor
,
attention_mask
:
Optional
[
Tensor
]
=
None
,
...
...
@@ -209,275 +209,214 @@ class TransformerLayer(MegatronCoreTransformerLayer):
*
,
inference_params
:
Optional
[
Any
]
=
None
,
):
if
(
not
isinstance
(
self
.
mlp
,
MoELayer
)
or
not
isinstance
(
self
.
mlp
.
token_dispatcher
,
MoEAlltoAllTokenDispatcher
)
):
return
super
().
forward
(
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
context
=
context
,
context_mask
=
context_mask
,
rotary_pos_emb
=
rotary_pos_emb
,
rotary_pos_cos
=
rotary_pos_cos
,
rotary_pos_sin
=
rotary_pos_sin
,
attention_bias
=
attention_bias
,
inference_context
=
inference_context
,
packed_seq_params
=
packed_seq_params
,
sequence_len_offset
=
sequence_len_offset
,
inference_params
=
inference_params
,
)
(
hidden_states
,
pre_mlp_layernorm_output
,
tokens_per_expert
,
permutated_local_input_tokens
,
permuted_probs
,
)
=
self
.
_submodule_attention_router_compound_forward
(
hidden_states
,
attention_mask
,
rotary_pos_emb
,
rotary_pos_cos
,
rotary_pos_sin
,
attention_bias
,
inference_context
,
packed_seq_params
,
sequence_len_offset
,
inference_params
=
inference_params
,
)
(
tokens_per_expert
,
global_input_tokens
,
global_probs
)
=
self
.
_submodule_dispatch_forward
(
tokens_per_expert
,
permutated_local_input_tokens
,
permuted_probs
,
)
(
expert_output
,
shared_expert_output
,
mlp_bias
)
=
self
.
_submodule_moe_forward
(
tokens_per_expert
,
global_input_tokens
,
global_probs
,
pre_mlp_layernorm_output
)
expert_output
=
self
.
_submodule_combine_forward
(
expert_output
)[
0
]
output
=
self
.
_submodule_post_combine_forward
(
expert_output
,
shared_expert_output
,
mlp_bias
,
hidden_states
)
return
output
,
None
def
_submodule_attention_forward
(
self
,
hidden_states
:
Tensor
,
attention_mask
:
Optional
[
Tensor
]
=
None
,
rotary_pos_emb
:
Optional
[
Tensor
]
=
None
,
rotary_pos_cos
:
Optional
[
Tensor
]
=
None
,
rotary_pos_sin
:
Optional
[
Tensor
]
=
None
,
attention_bias
:
Optional
[
Tensor
]
=
None
,
inference_context
:
Optional
[
Any
]
=
None
,
packed_seq_params
:
Optional
[
PackedSeqParams
]
=
None
,
sequence_len_offset
:
Optional
[
Tensor
]
=
None
,
*
,
inference_params
:
Optional
[
Any
]
=
None
,
):
# todo
inference_context
=
deprecate_inference_params
(
inference_context
,
inference_params
)
# Residual connection.
residual
=
hidden_states
# Optional Input Layer norm
if
self
.
recompute_input_layernorm
:
self
.
input_layernorm_checkpoint
=
tensor_parallel
.
CheckpointWithoutOutput
()
input_layernorm_output
=
self
.
input_layernorm_checkpoint
.
checkpoint
(
self
.
input_layernorm
,
hidden_states
)
else
:
input_layernorm_output
=
self
.
input_layernorm
(
hidden_states
)
# Self attention.
nvtx_range_push
(
suffix
=
"self_attention"
)
attention_output_with_bias
=
self
.
self_attention
(
input_layernorm_output
,
"""
Performs a combined forward pass that includes self-attention and MLP routing logic.
"""
pre_mlp_layernorm_output
,
residual
,
context
=
self
.
_forward_attention
(
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
inference_context
=
inference_context
,
context
=
context
,
context_mask
=
context_mask
,
rotary_pos_emb
=
rotary_pos_emb
,
rotary_pos_cos
=
rotary_pos_cos
,
rotary_pos_sin
=
rotary_pos_sin
,
attention_bias
=
attention_bias
,
inference_context
=
inference_context
,
packed_seq_params
=
packed_seq_params
,
sequence_len_offset
=
sequence_len_offset
,
)
nvtx_range_pop
(
suffix
=
"self_attention"
)
if
self
.
recompute_input_layernorm
:
# discard the output of the input layernorm and register the recompute
# as a gradient hook of attention_output_with_bias[0]
self
.
input_layernorm_checkpoint
.
discard_output_and_register_recompute
(
attention_output_with_bias
[
0
]
)
# TODO: could we move `bias_dropout_add_exec_handler` itself
# inside the module provided in the `bias_dropout_add_spec` module?
nvtx_range_push
(
suffix
=
"self_attn_bda"
)
with
self
.
bias_dropout_add_exec_handler
():
hidden_states
=
self
.
self_attn_bda
(
self
.
training
,
self
.
config
.
bias_dropout_fusion
)(
attention_output_with_bias
,
residual
,
self
.
hidden_dropout
)
nvtx_range_pop
(
suffix
=
"self_attn_bda"
)
return
hidden_states
def
_submodule_attention_router_compound_forward
(
self
,
hidden_states
:
Tensor
,
attention_mask
:
Optional
[
Tensor
]
=
None
,
rotary_pos_emb
:
Optional
[
Tensor
]
=
None
,
rotary_pos_cos
:
Optional
[
Tensor
]
=
None
,
rotary_pos_sin
:
Optional
[
Tensor
]
=
None
,
attention_bias
:
Optional
[
Tensor
]
=
None
,
inference_context
:
Optional
[
Any
]
=
None
,
packed_seq_params
:
Optional
[
PackedSeqParams
]
=
None
,
sequence_len_offset
:
Optional
[
Tensor
]
=
None
,
*
,
inference_params
:
Optional
[
Any
]
=
None
,
):
"""
Performs a combined forward pass that includes self-attention and MLP routing logic.
"""
hidden_states
=
self
.
_submodule_attention_forward
(
hidden_states
,
attention_mask
,
rotary_pos_emb
,
rotary_pos_cos
,
rotary_pos_sin
,
attention_bias
,
inference_context
,
packed_seq_params
,
sequence_len_offset
,
inference_params
=
inference_params
,
)
# Optional Layer norm post the cross-attention.
if
self
.
recompute_pre_mlp_layernorm
:
self
.
pre_mlp_norm_checkpoint
=
tensor_parallel
.
CheckpointWithoutOutput
()
pre_mlp_layernorm_output
=
self
.
pre_mlp_norm_checkpoint
.
checkpoint
(
self
.
pre_mlp_layernorm
,
hidden_states
)
else
:
pre_mlp_layernorm_output
=
self
.
pre_mlp_layernorm
(
hidden_states
)
probs
,
routing_map
=
self
.
mlp
.
router
(
pre_mlp_layernorm_output
)
tokens_per_expert
=
self
.
mlp
.
token_dispatcher
.
meta_prepare
(
pre_mlp_layernorm_output
,
probs
,
routing_map
)
tokens_per_expert
,
permutated_local_input_tokens
,
permuted_probs
=
self
.
mlp
.
token_dispatcher
.
dispatch_preprocess
(
pre_mlp_layernorm_output
,
routing_map
,
tokens_per_expert
pre_mlp_layernorm_output
,
probs
,
routing_map
)
outputs
=
[
hidden_states
,
pre_mlp_layernorm_output
,
tokens_per_expert
,
permutated_local_input_tokens
,
permuted_probs
,
]
return
tuple
(
outputs
)
return
(
tokens_per_expert
,
permutated_local_input_tokens
,
permuted_probs
,
pre_mlp_layernorm_output
,
residual
,
context
)
def
_submodule_attn_router_postprocess
(
self
,
node
,
tokens_per_expert
,
permutated_local_input_tokens
,
permuted_probs
,
pre_mlp_layernorm_output
,
residual
,
context
,
):
node
.
common_state
.
residual
=
node
.
detach
(
residual
)
if
self
.
mlp
.
use_shared_expert
:
node
.
common_state
.
pre_mlp_layernorm_output
=
node
.
detach
(
pre_mlp_layernorm_output
)
def
_submodule_dispatch_forward
(
self
,
tokens_per_expert
,
permutated_local_input_tokens
,
permuted_probs
):
return
tokens_per_expert
,
permutated_local_input_tokens
,
permuted_probs
def
_submodule_dispatch_forward
(
self
,
tokens_per_expert
,
permutated_local_input_tokens
,
permuted_probs
,
state
=
None
):
"""
Dispatches tokens to the appropriate experts based on the router output.
"""
tokens_per_expert
,
global_input_tokens
,
global_probs
=
self
.
mlp
.
token_dispatcher
.
dispatch_all_to_all
(
tokens_per_expert
,
permutated_local_input_tokens
,
permuted_probs
)
return
[
tokens_per_expert
,
global_input_tokens
,
global_probs
]
token_dispatcher
=
self
.
mlp
.
token_dispatcher
tokens_per_expert
,
global_input_tokens
,
global_probs
=
token_dispatcher
.
dispatch_all_to_all
(
tokens_per_expert
,
permutated_local_input_tokens
,
permuted_probs
)
def
_submodule_dense_forward
(
self
,
hidden_states
):
residual
=
hidden_states
pre_mlp_layernorm_output
=
self
.
pre_mlp_layernorm
(
hidden_states
)
mlp_output_with_bias
=
self
.
mlp
(
pre_mlp_layernorm_output
)
with
self
.
bias_dropout_add_exec_handler
():
hidden_states
=
self
.
mlp_bda
(
self
.
training
,
self
.
config
.
bias_dropout_fusion
)(
mlp_output_with_bias
,
residual
,
self
.
hidden_dropout
)
output
=
make_viewless_tensor
(
inp
=
hidden_states
,
requires_grad
=
hidden_states
.
requires_grad
,
keep_graph
=
True
)
return
tokens_per_expert
,
global_input_tokens
,
global_probs
return
output
def
_submodule_dispatch_postprocess
(
self
,
node
,
tokens_per_expert
,
global_input_tokens
,
global_probs
):
return
tokens_per_expert
,
global_input_tokens
,
global_probs
def
_submodule_moe_forward
(
self
,
tokens_per_expert
,
global_input_tokens
,
global_probs
,
pre_mlp_layernorm_output
):
def
_submodule_moe_forward
(
self
,
tokens_per_expert
,
global_input_tokens
,
global_probs
,
state
=
None
):
"""
Performs a forward pass for the MLP submodule, including both expert-based
and optional shared-expert computations.
"""
shared_expert_output
=
None
(
dispatched_input
,
tokens_per_expert
,
permuted_probs
)
=
(
self
.
mlp
.
token_dispatcher
.
dispatch_postprocess
(
tokens_per_expert
,
global_input_tokens
,
global_probs
)
token_dispatcher
=
self
.
mlp
.
token_dispatcher
dispatched_input
,
tokens_per_expert
,
permuted_probs
=
token_dispatcher
.
dispatch_postprocess
(
tokens_per_expert
,
global_input_tokens
,
global_probs
)
expert_output
,
mlp_bias
=
self
.
mlp
.
experts
(
dispatched_input
,
tokens_per_expert
,
permuted_probs
)
expert_output
=
self
.
mlp
.
token_dispatcher
.
combine_preprocess
(
expert_output
)
expert_output
,
mlp_bias
=
self
.
mlp
.
experts
(
dispatched_tokens
,
tokens_per_expert
,
permuted_probs
)
assert
mlp_bias
is
None
,
f
"Bias is not supported in
{
token_dispatcher
.
__class__
.
__name__
}
"
if
self
.
mlp
.
use_shared_expert
and
not
self
.
mlp
.
shared_expert_overlap
:
# if shared_expert_overlap is True, the expert calculation happens in
# the token_dispatcher to overlap communications and computations
shared_expert_output
=
self
.
mlp
.
shared_experts
(
pre_mlp_layernorm_output
)
return
expert_output
,
shared_expert_output
,
mlp_bias
assert
state
is
not
None
shared_expert_output
=
self
.
mlp
.
shared_experts
(
state
.
pre_mlp_layernorm_output
)
def
_submodule_combine_forward
(
self
,
hidden_states
):
return
[
self
.
mlp
.
token_dispatcher
.
combine_all_to_all
(
hidden_states
)]
expert_output
=
self
.
mlp
.
token_dispatcher
.
combine_preprocess
(
expert_output
)
return
expert_output
,
shared_expert_output
,
mlp_bias
def
_submodule_post_combine_forward
(
self
,
expert_output
,
shared_expert_output
,
mlp_bias
,
residual
):
"""
Re-combines the expert outputs (and optional shared_expert_output) into the same order
as the original input tokens, applying any required bias.
"""
output
=
self
.
mlp
.
token_dispatcher
.
combine_postprocess
(
expert_output
)
def
_submodule_mlp_postprocess
(
self
,
node
,
expert_output
,
shared_expert_output
,
mlp_bias
):
assert
mlp_bias
is
None
node
.
common_state
.
pre_mlp_layernorm_output
=
None
if
shared_expert_output
is
None
:
return
expert_output
return
expert_output
,
shared_expert_output
def
_submodule_combine_forward
(
self
,
expert_output
,
shared_expert_output
=
None
,
state
=
None
):
residual
=
state
.
residual
token_dispatcher
=
self
.
mlp
.
token_dispatcher
permutated_local_input_tokens
=
token_dispatcher
.
combine_all_to_all
(
expert_output
)
output
=
token_dispatcher
.
combine_postprocess
(
permutated_local_input_tokens
)
if
shared_expert_output
is
not
None
:
output
+=
shared_expert_output
mlp_output_with_bias
=
(
output
,
mlp_bias
)
if
self
.
recompute_pre_mlp_layernorm
:
# discard the output of the pre-mlp layernorm and register the recompute
# as a gradient hook of mlp_output_with_bias[0]
self
.
pre_mlp_norm_checkpoint
.
discard_output_and_register_recompute
(
mlp_output_with_bias
[
0
]
)
output
=
output
+
shared_expert_output
# TODO: could we move `bias_dropout_add_exec_handler` itself
# inside the module provided in the `bias_dropout_add_spec` module?
nvtx_range_push
(
suffix
=
"mlp_bda"
)
mlp_output_with_bias
=
(
output
,
None
)
with
self
.
bias_dropout_add_exec_handler
():
hidden_states
=
self
.
mlp_bda
(
self
.
training
,
self
.
config
.
bias_dropout_fusion
)(
mlp_output_with_bias
,
residual
,
self
.
hidden_dropout
)
nvtx_range_pop
(
suffix
=
"mlp_bda"
)
# Jit compiled function creates 'view' tensor. This tensor
# potentially gets saved in the MPU checkpoint function context,
# which rejects view tensors. While making a viewless tensor here
# won't result in memory savings (like the data loader, or
# p2p_communication), it serves to document the origin of this
# 'view' tensor.
output
=
make_viewless_tensor
(
inp
=
hidden_states
,
requires_grad
=
hidden_states
.
requires_grad
,
keep_graph
=
True
)
return
output
def
_submodule_attention_dw
(
self
):
self
.
self_attention
.
backward_dw
()
def
_submodule_combine_postprocess
(
self
,
node
,
output
):
cur_stream
=
torch
.
cuda
.
current_stream
()
node
.
common_state
.
residual
.
record_stream
(
cur_stream
)
node
.
common_state
.
residual
=
None
return
output
def
_submodule_att
entio
n_router_
compound_
dw
(
self
):
self
.
_submodule
_attention_dw
()
def
_submodule_attn_router_dw
(
self
):
self
.
self
_attention
.
backward
_dw
()
def
_submodule_mlp_dw
(
self
):
self
.
mlp
.
backward_dw
()
def
_submodule_attn_postprocess
(
self
,
node
,
pre_mlp_layernorm_output
,
residual
,
context
):
return
pre_mlp_layernorm_output
,
residual
def
_submodule_dense_postprocess
(
self
,
node
,
hidden_states
):
return
hidden_states
def
_submodule_not_implemented
(
self
,
*
args
):
raise
NotImplementedError
(
"This callable is not implemented."
)
def
get_submodule_callables
(
self
,
chunk_state
):
"""
The forward callables take 2 parts of inputs:
1. The ScheduleNode object.
2. The input tensors.
"""
from
megatron.core.transformer.moe.moe_layer
import
MoELayer
from
megatron.core.transformer.moe.token_dispatcher
import
MoEFlexTokenDispatcher
self
.
is_moe
=
isinstance
(
self
.
mlp
,
MoELayer
)
self
.
is_deepep
=
False
if
self
.
is_moe
:
self
.
is_deepep
=
isinstance
(
self
.
mlp
.
token_dispatcher
,
MoEFlexTokenDispatcher
)
def
get_func_with_default
(
func
,
default_func
):
if
self
.
is_moe
:
return
func
return
default_func
def
callable_wrapper
(
forward_func
,
postprocess_func
,
node
,
*
args
):
state
=
getattr
(
node
,
'common_state'
,
None
)
callable_outputs
=
forward_func
(
*
args
,
state
=
state
)
if
isinstance
(
callable_outputs
,
tuple
):
outputs
=
postprocess_func
(
node
,
*
callable_outputs
)
else
:
outputs
=
postprocess_func
(
node
,
callable_outputs
)
return
outputs
attn_func
=
get_func_with_default
(
self
.
_submodule_attn_router_forward
,
self
.
_forward_attention
)
def
attn_wrapper
(
hidden_states
,
state
=
None
):
"""
state (Any, optional): Placeholder for submodule callable wrapper.
"""
return
attn_func
(
hidden_states
=
hidden_states
,
attention_mask
=
chunk_state
.
attention_mask
,
content
=
chunk_state
.
context
,
context_mask
=
chunk_state
.
context_mask
,
rotary_pos_emb
=
chunk_state
.
rotary_pos_emb
,
rotary_pos_cos
=
chunk_state
.
rotary_pos_cos
,
rotary_pos_sin
=
chunk_state
.
rotary_pos_sin
,
attention_bias
=
chunk_state
.
attention_bias
,
inference_context
=
chunk_state
.
inference_context
,
packed_seq_params
=
chunk_state
.
packed_seq_params
,
sequence_len_offset
=
chunk_state
.
sequence_len_offset
,
inference_params
=
chunk_state
.
inference_params
,
)
attn_postprocess_func
=
get_func_with_default
(
self
.
_submodule_attn_router_postprocess
,
self
.
_submodule_attn_postprocess
)
dispatch_func
=
get_func_with_default
(
self
.
_submodule_dispatch_forward
,
self
.
_submodule_not_implemented
)
dispatch_postprocess_func
=
get_func_with_default
(
self
.
_submodule_dispatch_postprocess
,
self
.
_submodule_not_implemented
)
mlp_func
=
get_func_with_default
(
self
.
_submodule_moe_forward
,
self
.
_forward_mlp
)
mlp_postprocess_func
=
get_func_with_default
(
self
.
_submodule_mlp_postprocess
,
self
.
_submodule_dense_postprocess
)
combine_func
=
get_func_with_default
(
self
.
_submodule_combine_forward
,
self
.
_submodule_not_implemented
)
combine_postprocess_func
=
get_func_with_default
(
self
.
_submodule_combine_postprocess
,
self
.
_submodule_not_implemented
)
attn_forward
=
partial
(
callable_wrapper
,
attn_wrapper
,
attn_postprocess_func
)
dispatch_forward
=
partial
(
callable_wrapper
,
dispatch_func
,
dispatch_postprocess_func
)
mlp_forward
=
partial
(
callable_wrapper
,
mlp_func
,
mlp_postprocess_func
)
combine_forward
=
partial
(
callable_wrapper
,
combine_func
,
combine_postprocess_func
)
callables
=
TransformerLayerSubmoduleCallables
(
attention
=
SubmoduleCallables
(
forward
=
attn_forward
,
dw
=
self
.
_submodule_attn_router_dw
),
dispatch
=
SubmoduleCallables
(
forward
=
dispatch_forward
),
mlp
=
SubmoduleCallables
(
forward
=
mlp_forward
,
dw
=
self
.
_submodule_mlp_dw
),
combine
=
SubmoduleCallables
(
forward
=
combine_forward
),
is_moe
=
self
.
is_moe
,
is_deepep
=
self
.
is_deepep
,
)
return
callables
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment