Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
evt_fugx1
dcu_megatron
Commits
7c9dc3ec
"git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "40e7698a3aac4079033937f4b385eba32fc97065"
Commit
7c9dc3ec
authored
May 07, 2025
by
dongcl
Browse files
forward_backward_pipelining_without_interleaving supports a2a_overlap
parent
649bfbdb
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
433 additions
and
1167 deletions
+433
-1167
dcu_megatron/adaptor/megatron_adaptor.py
dcu_megatron/adaptor/megatron_adaptor.py
+26
-11
dcu_megatron/core/models/gpt/gpt_layer_specs.py
dcu_megatron/core/models/gpt/gpt_layer_specs.py
+10
-8
dcu_megatron/core/models/gpt/gpt_model.py
dcu_megatron/core/models/gpt/gpt_model.py
+9
-10
dcu_megatron/core/pipeline_parallel/schedules.py
dcu_megatron/core/pipeline_parallel/schedules.py
+362
-1133
dcu_megatron/core/tensor_parallel/__init__.py
dcu_megatron/core/tensor_parallel/__init__.py
+0
-4
dcu_megatron/core/tensor_parallel/layers.py
dcu_megatron/core/tensor_parallel/layers.py
+5
-1
dcu_megatron/core/transformer/transformer_config.py
dcu_megatron/core/transformer/transformer_config.py
+7
-0
dcu_megatron/training/arguments.py
dcu_megatron/training/arguments.py
+14
-0
No files found.
dcu_megatron/adaptor/megatron_adaptor.py
View file @
7c9dc3ec
...
@@ -5,6 +5,8 @@ import types
...
@@ -5,6 +5,8 @@ import types
import
argparse
import
argparse
import
torch
import
torch
from
megatron.core.utils
import
is_te_min_version
class
MegatronAdaptation
:
class
MegatronAdaptation
:
"""
"""
...
@@ -89,14 +91,14 @@ class CoreAdaptation(MegatronAdaptationABC):
...
@@ -89,14 +91,14 @@ class CoreAdaptation(MegatronAdaptationABC):
pass
pass
def
patch_core_models
(
self
):
def
patch_core_models
(
self
):
from
..core.models.gpt.gpt_model
import
gpt_model_init_wrapper
,
gpt_model_forward
from
..core.models.gpt.gpt_model
import
gpt_model_init_wrapper
,
GPTModel
# GPT Model
# GPT Model
MegatronAdaptation
.
register
(
'megatron.core.models.gpt.gpt_model.GPTModel.__init__'
,
MegatronAdaptation
.
register
(
'megatron.core.models.gpt.gpt_model.GPTModel.__init__'
,
gpt_model_init_wrapper
,
gpt_model_init_wrapper
,
apply_wrapper
=
True
)
apply_wrapper
=
True
)
MegatronAdaptation
.
register
(
'megatron.core.models.gpt.gpt_model.GPTModel
.forward
'
,
MegatronAdaptation
.
register
(
'megatron.core.models.gpt.gpt_model.GPTModel'
,
gpt_model_forward
)
GPTModel
)
def
patch_core_transformers
(
self
):
def
patch_core_transformers
(
self
):
from
..core
import
transformer_block_init_wrapper
from
..core
import
transformer_block_init_wrapper
...
@@ -116,9 +118,9 @@ class CoreAdaptation(MegatronAdaptationABC):
...
@@ -116,9 +118,9 @@ class CoreAdaptation(MegatronAdaptationABC):
MegatronAdaptation
.
register
(
'megatron.core.transformer.moe.moe_utils.topk_softmax_with_capacity'
,
MegatronAdaptation
.
register
(
'megatron.core.transformer.moe.moe_utils.topk_softmax_with_capacity'
,
torch
.
compile
(
options
=
{
"triton.cudagraphs"
:
True
,
"triton.cudagraph_trees"
:
False
}),
torch
.
compile
(
options
=
{
"triton.cudagraphs"
:
True
,
"triton.cudagraph_trees"
:
False
}),
apply_wrapper
=
True
)
apply_wrapper
=
True
)
#
MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.switch_load_balancing_loss_func',
MegatronAdaptation
.
register
(
'megatron.core.transformer.moe.moe_utils.switch_load_balancing_loss_func'
,
#
torch.compile(options={"triton.cudagraphs": True, "triton.cudagraph_trees": False, "triton.cudagraph_support_input_mutation":True}),
torch
.
compile
(
options
=
{
"triton.cudagraphs"
:
True
,
"triton.cudagraph_trees"
:
False
,
"triton.cudagraph_support_input_mutation"
:
True
}),
#
apply_wrapper=True)
apply_wrapper
=
True
)
MegatronAdaptation
.
register
(
'megatron.core.transformer.moe.moe_utils.permute'
,
MegatronAdaptation
.
register
(
'megatron.core.transformer.moe.moe_utils.permute'
,
torch
.
compile
(
mode
=
'max-autotune-no-cudagraphs'
),
torch
.
compile
(
mode
=
'max-autotune-no-cudagraphs'
),
apply_wrapper
=
True
)
apply_wrapper
=
True
)
...
@@ -132,12 +134,25 @@ class CoreAdaptation(MegatronAdaptationABC):
...
@@ -132,12 +134,25 @@ class CoreAdaptation(MegatronAdaptationABC):
from
..core.extensions.transformer_engine
import
TEDotProductAttentionPatch
from
..core.extensions.transformer_engine
import
TEDotProductAttentionPatch
from
megatron.core.extensions.transformer_engine
import
TEGroupedLinear
from
megatron.core.extensions.transformer_engine
import
TEGroupedLinear
# kv channels, te_min_version 1.10.0 -> 1.9.0
if
not
is_te_min_version
(
"1.10.0"
):
MegatronAdaptation
.
register
(
'megatron.core.extensions.transformer_engine.TEDotProductAttention.__init__'
,
# kv channels, te_min_version 1.10.0 -> 1.9.0
TEDotProductAttentionPatch
.
__init__
)
MegatronAdaptation
.
register
(
'megatron.core.extensions.transformer_engine.TEDotProductAttention.__init__'
,
TEDotProductAttentionPatch
.
__init__
)
if
int
(
os
.
getenv
(
"GROUPED_GEMM_BatchLinear"
,
'0'
)):
if
int
(
os
.
getenv
(
"GROUPED_GEMM_BatchLinear"
,
'0'
)):
TEGroupedLinear
.
__bases__
=
(
te
.
pytorch
.
BatchLinear
,)
TEGroupedLinear
.
__bases__
=
(
te
.
pytorch
.
BatchedLinear
if
is_te_min_version
(
"2.3.0.dev0"
)
else
te
.
pytorch
.
BatchLinear
,)
def
patch_pipeline_parallel
(
self
):
from
..core.pipeline_parallel.schedules
import
get_pp_rank_microbatches
,
forward_backward_pipelining_with_interleaving
# num_warmup_microbatches + 1
MegatronAdaptation
.
register
(
'megatron.core.pipeline_parallel.schedules.get_pp_rank_microbatches'
,
get_pp_rank_microbatches
)
# a2a_overlap
MegatronAdaptation
.
register
(
'megatron.core.pipeline_parallel.schedules.forward_backward_pipelining_with_interleaving'
,
forward_backward_pipelining_with_interleaving
)
def
patch_tensor_parallel
(
self
):
def
patch_tensor_parallel
(
self
):
from
..core.tensor_parallel.cross_entropy
import
VocabParallelCrossEntropy
from
..core.tensor_parallel.cross_entropy
import
VocabParallelCrossEntropy
...
@@ -162,7 +177,7 @@ class CoreAdaptation(MegatronAdaptationABC):
...
@@ -162,7 +177,7 @@ class CoreAdaptation(MegatronAdaptationABC):
# flux
# flux
if
int
(
os
.
getenv
(
"USE_FLUX_OVERLAP"
,
"0"
)):
if
int
(
os
.
getenv
(
"USE_FLUX_OVERLAP"
,
"0"
)):
from
..core.tensor_parallel
import
(
from
..core.tensor_parallel
.layers
import
(
FluxColumnParallelLinear
,
FluxColumnParallelLinear
,
FluxRowParallelLinear
FluxRowParallelLinear
)
)
...
...
dcu_megatron/core/models/gpt/gpt_layer_specs.py
View file @
7c9dc3ec
...
@@ -12,6 +12,7 @@ from megatron.core.transformer.multi_latent_attention import (
...
@@ -12,6 +12,7 @@ from megatron.core.transformer.multi_latent_attention import (
MLASelfAttentionSubmodules
,
MLASelfAttentionSubmodules
,
)
)
from
megatron.core.transformer.spec_utils
import
ModuleSpec
from
megatron.core.transformer.spec_utils
import
ModuleSpec
from
megatron.core.transformer.torch_norm
import
L2Norm
from
megatron.core.transformer.transformer_block
import
TransformerBlockSubmodules
from
megatron.core.transformer.transformer_block
import
TransformerBlockSubmodules
from
megatron.core.transformer.transformer_config
import
TransformerConfig
from
megatron.core.transformer.transformer_config
import
TransformerConfig
from
megatron.core.transformer.transformer_layer
import
(
from
megatron.core.transformer.transformer_layer
import
(
...
@@ -40,12 +41,6 @@ from dcu_megatron.core.tensor_parallel.layers import (
...
@@ -40,12 +41,6 @@ from dcu_megatron.core.tensor_parallel.layers import (
FluxColumnParallelLinear
,
FluxColumnParallelLinear
,
FluxRowParallelLinear
FluxRowParallelLinear
)
)
from
dcu_megatron.core.transformer.multi_token_prediction
import
(
MultiTokenPredictionBlockSubmodules
,
get_mtp_layer_offset
,
get_mtp_layer_spec
,
get_mtp_num_layers_to_build
,
)
def
get_gpt_layer_with_flux_spec
(
def
get_gpt_layer_with_flux_spec
(
...
@@ -55,6 +50,7 @@ def get_gpt_layer_with_flux_spec(
...
@@ -55,6 +50,7 @@ def get_gpt_layer_with_flux_spec(
multi_latent_attention
:
Optional
[
bool
]
=
False
,
multi_latent_attention
:
Optional
[
bool
]
=
False
,
fp8
:
Optional
[
str
]
=
None
,
# pylint: disable=unused-arguments
fp8
:
Optional
[
str
]
=
None
,
# pylint: disable=unused-arguments
moe_use_legacy_grouped_gemm
:
Optional
[
bool
]
=
False
,
moe_use_legacy_grouped_gemm
:
Optional
[
bool
]
=
False
,
qk_l2_norm
:
Optional
[
bool
]
=
False
,
)
->
ModuleSpec
:
)
->
ModuleSpec
:
"""Use this spec to use flux modules (required for fp8 training).
"""Use this spec to use flux modules (required for fp8 training).
...
@@ -66,6 +62,7 @@ def get_gpt_layer_with_flux_spec(
...
@@ -66,6 +62,7 @@ def get_gpt_layer_with_flux_spec(
fp8 (str, optional): Deprecated. For temporary Nemo compatibility.
fp8 (str, optional): Deprecated. For temporary Nemo compatibility.
moe_use_legacy_grouped_gemm (bool, optional): Force use the legacy GroupedMLP.
moe_use_legacy_grouped_gemm (bool, optional): Force use the legacy GroupedMLP.
Defaults to False.
Defaults to False.
qk_l2_norm (bool, optional): To use l2 norm for queries/keys. Defaults to False.
Returns:
Returns:
ModuleSpec: Module specification with flux modules
ModuleSpec: Module specification with flux modules
...
@@ -84,6 +81,7 @@ def get_gpt_layer_with_flux_spec(
...
@@ -84,6 +81,7 @@ def get_gpt_layer_with_flux_spec(
)
)
if
multi_latent_attention
:
if
multi_latent_attention
:
assert
qk_l2_norm
is
False
,
"qk_l2_norm is not supported with MLA."
return
ModuleSpec
(
return
ModuleSpec
(
module
=
TransformerLayer
,
module
=
TransformerLayer
,
submodules
=
TransformerLayerSubmodules
(
submodules
=
TransformerLayerSubmodules
(
...
@@ -127,8 +125,12 @@ def get_gpt_layer_with_flux_spec(
...
@@ -127,8 +125,12 @@ def get_gpt_layer_with_flux_spec(
linear_qkv
=
FluxColumnParallelLinear
,
linear_qkv
=
FluxColumnParallelLinear
,
core_attention
=
TEDotProductAttention
,
core_attention
=
TEDotProductAttention
,
linear_proj
=
FluxRowParallelLinear
,
linear_proj
=
FluxRowParallelLinear
,
q_layernorm
=
qk_norm
if
qk_layernorm
else
IdentityOp
,
q_layernorm
=
(
k_layernorm
=
qk_norm
if
qk_layernorm
else
IdentityOp
,
L2Norm
if
qk_l2_norm
else
(
qk_norm
if
qk_layernorm
else
IdentityOp
)
),
k_layernorm
=
(
L2Norm
if
qk_l2_norm
else
(
qk_norm
if
qk_layernorm
else
IdentityOp
)
),
),
),
),
),
self_attn_bda
=
get_bias_dropout_add
,
self_attn_bda
=
get_bias_dropout_add
,
...
...
dcu_megatron/core/models/gpt/gpt_model.py
View file @
7c9dc3ec
...
@@ -13,8 +13,6 @@ from megatron.core.inference.contexts import BaseInferenceContext
...
@@ -13,8 +13,6 @@ from megatron.core.inference.contexts import BaseInferenceContext
from
megatron.core.packed_seq_params
import
PackedSeqParams
from
megatron.core.packed_seq_params
import
PackedSeqParams
from
megatron.core.models.gpt
import
GPTModel
as
MegatronCoreGPTModel
from
megatron.core.models.gpt
import
GPTModel
as
MegatronCoreGPTModel
from
dcu_megatron.core.tensor_parallel
import
FluxColumnParallelLinear
def
gpt_model_init_wrapper
(
fn
):
def
gpt_model_init_wrapper
(
fn
):
@
wraps
(
fn
)
@
wraps
(
fn
)
...
@@ -22,12 +20,13 @@ def gpt_model_init_wrapper(fn):
...
@@ -22,12 +20,13 @@ def gpt_model_init_wrapper(fn):
fn
(
self
,
*
args
,
**
kwargs
)
fn
(
self
,
*
args
,
**
kwargs
)
# Output
# Output
if
self
.
post_process
or
self
.
mtp_process
:
if
(
if
int
(
os
.
getenv
(
"USE_FLUX_OVERLAP"
,
"0"
)):
(
self
.
post_process
or
self
.
mtp_process
)
parallel_linear_impl
=
FluxColumnParallelLinear
and
int
(
os
.
getenv
(
"USE_FLUX_OVERLAP"
,
"0"
))
else
:
):
parallel_linear_impl
=
tensor_parallel
.
ColumnParallelLinear
from
dcu_megatron.core.tensor_parallel.layers
import
FluxColumnParallelLinear
self
.
output_layer
=
parallel_linear_impl
(
self
.
output_layer
=
FluxColumnParallelLinear
(
self
.
config
.
hidden_size
,
self
.
config
.
hidden_size
,
self
.
vocab_size
,
self
.
vocab_size
,
config
=
self
.
config
,
config
=
self
.
config
,
...
@@ -41,8 +40,8 @@ def gpt_model_init_wrapper(fn):
...
@@ -41,8 +40,8 @@ def gpt_model_init_wrapper(fn):
grad_output_buffer
=
self
.
grad_output_buffer
,
grad_output_buffer
=
self
.
grad_output_buffer
,
)
)
if
self
.
pre_process
or
self
.
post_process
:
if
self
.
pre_process
or
self
.
post_process
:
self
.
setup_embeddings_and_output_layer
()
self
.
setup_embeddings_and_output_layer
()
return
wrapper
return
wrapper
...
...
dcu_megatron/core/pipeline_parallel/schedules.py
View file @
7c9dc3ec
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import
contextlib
import
contextlib
from
typing
import
Callable
,
Iterator
,
List
,
Optional
,
Union
from
typing
import
Callable
,
Iterator
,
List
,
Optional
,
Union
import
torch
import
torch
from
torch.autograd.variable
import
Variable
from
torch.autograd.variable
import
Variable
from
megatron.training
import
get_args
from
megatron.core
import
parallel_state
from
megatron.core
import
parallel_state
from
megatron.core.enums
import
ModelType
from
megatron.core.enums
import
ModelType
from
megatron.core.pipeline_parallel
import
p2p_communication
from
megatron.core.pipeline_parallel
import
p2p_communication
...
@@ -19,574 +18,26 @@ from megatron.core.utils import (
...
@@ -19,574 +18,26 @@ from megatron.core.utils import (
get_model_type
,
get_model_type
,
get_model_xattn
,
get_model_xattn
,
)
)
from
megatron.core.pipeline_parallel.schedules
import
(
forward_step
,
backward_step
,
get_tensor_shapes
,
get_schedule_table
,
check_first_val_step
,
deallocate_output_tensor
,
finish_embedding_wgrad_compute
,
clear_embedding_activation_buffer
,
)
from
.combined_1f1b
import
VppContextManager
,
forward_backward_step
,
set_streams
,
wrap_forward_func
from
.combined_1f1b
import
VppContextManager
,
forward_backward_step
,
set_streams
,
wrap_forward_func
# Types
Shape
=
Union
[
List
[
int
],
torch
.
Size
]
def
get_forward_backward_func
():
"""Retrieves the appropriate forward_backward function given the
configuration of parallel_state.
Returns a function that will perform all of the forward and
backward passes of the model given the pipeline model parallel
world size and virtual pipeline model parallel world size in the
global parallel_state.
Note that if using sequence parallelism, the sequence length component of
the tensor shape is updated to original_sequence_length /
tensor_model_parallel_world_size.
The function returned takes the following arguments:
forward_step_func (required): A function that takes a data
iterator and a model as its arguments and return the model's
forward output and the loss function. The loss function should
take one torch.Tensor and return a torch.Tensor of loss and a
dictionary of string -> torch.Tensor.
A third argument, checkpoint_activations_microbatch, indicates
that the activations for this microbatch should be
checkpointed. A None value for this argument indicates that
the default from the configuration should be used. This is
used when the
num_microbatches_with_partial_activation_checkpoints is used.
For example:
def loss_func(loss_mask, output_tensor):
losses = output_tensor.float()
loss_mask = loss_mask.view(-1).float()
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
# Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {'lm loss': averaged_loss[0]}
def forward_step(data_iterator, model):
data, loss_mask = next(data_iterator)
output = model(data)
return output, partial(loss_func, loss_mask)
forward_backward_func(forward_step_func=forward_step, ...)
data_iterator (required): an iterator over the data, will be
passed as is to forward_step_func. Expected to be a list of
iterators in the case of interleaved pipeline parallelism.
model (required): the actual model. Expected to be a list of modules in the case of interleaved
pipeline parallelism. Must be a (potentially wrapped) megatron.core.models.MegatronModule.
num_microbatches (int, required):
The number of microbatches to go through
seq_length (int, required): Sequence length of the current global batch. If this is a dual-stack
transformer, this is the encoder's sequence length. This is ignored if variable_seq_lengths
in the config is True. Otherwise, each microbatch in the current global batch size must use
this sequence length.
micro_batch_size (int, required): The number of sequences in a microbatch.
decoder_seq_length (int, optional): The sequence length for the decoder in a dual-stack
transformer. This is ignored for a single-stack transformer.
forward_only (optional, default = False): Perform only the forward step
collect_non_loss_data (optional, bool, default=False): TODO
first_val_step (bool, optional): Is the first step of the validation phase. Used by
Transformer Engine modules to only update their fp8 weights only on the first validation
step.
adjust_tensor_shapes_fn (Callable, optional): A function that adjusts the receive and send
tensor shapes. Only applicable in forward_backward_pipelining_without_interleaving for now.
Takes in a list of receive shapes and a list of send shapes and returns the adjusted
respective list of shapes. Thus it is not used in the other forward-backward functions
which have different shape handling.
"""
pipeline_model_parallel_size
=
parallel_state
.
get_pipeline_model_parallel_world_size
()
if
pipeline_model_parallel_size
>
1
:
if
parallel_state
.
get_virtual_pipeline_model_parallel_world_size
()
is
not
None
:
forward_backward_func
=
forward_backward_pipelining_with_interleaving
else
:
forward_backward_func
=
forward_backward_pipelining_without_interleaving
else
:
forward_backward_func
=
forward_backward_no_pipelining
return
forward_backward_func
def
deallocate_output_tensor
(
out
,
deallocate_pipeline_outputs
=
False
):
'''Pseudo-deallocate (i.e., set to scalar) the output tensor's '.data' field.
This method should be called right after the output tensor has been
sent to the next pipeline stage. At this point, the output tensor is
only useful for its '.grad_fn' field, and not its '.data'.
'''
if
(
out
is
None
)
or
(
not
deallocate_pipeline_outputs
):
return
assert
isinstance
(
out
,
torch
.
Tensor
),
"expected Tensor, found %s."
%
type
(
out
).
__name__
assert
out
.
_base
is
None
,
"counter-productive to free a view of another tensor."
out
.
data
=
torch
.
empty
((
1
,),
device
=
out
.
device
,
dtype
=
out
.
dtype
)
def
custom_backward
(
output
,
grad_output
):
'''Directly call C++ autograd engine.
To make the 'deallocate_output_tensor' (above) optimization work, the C++
autograd engine must be called directly, bypassing Pytorch's
torch.autograd.backward. Pytorch's 'backward' checks that the output and
grad have the same shape, while C++'s 'backward' does not.
'''
assert
output
.
numel
()
==
1
,
"output should be pseudo-'freed' in schedule, to optimize memory"
assert
isinstance
(
output
,
torch
.
Tensor
),
"output == '%s'."
%
type
(
output
).
__name__
assert
isinstance
(
grad_output
,
(
torch
.
Tensor
,
type
(
None
))),
(
"grad_output == '%s'."
%
type
(
grad_output
).
__name__
)
# Handle scalar output
if
grad_output
is
None
:
assert
output
.
numel
()
==
1
,
"implicit grad requires scalar output."
grad_output
=
torch
.
ones_like
(
output
,
memory_format
=
torch
.
preserve_format
)
# Call c++ engine [ see torch/csrc/autograd/python_engine.cpp ]
Variable
.
_execution_engine
.
run_backward
(
tensors
=
(
output
,),
grad_tensors
=
(
grad_output
,),
keep_graph
=
False
,
create_graph
=
False
,
inputs
=
tuple
(),
allow_unreachable
=
True
,
accumulate_grad
=
True
,
)
def
set_current_microbatch
(
model
,
microbatch_id
):
"""Set the current microbatch."""
decoder_exists
=
True
decoder
=
None
try
:
decoder
=
get_attr_wrapped_model
(
model
,
"decoder"
)
except
RuntimeError
:
decoder_exists
=
False
if
decoder_exists
and
decoder
is
not
None
:
for
layer
in
decoder
.
layers
:
layer
.
current_microbatch
=
microbatch_id
def
forward_step
(
forward_step_func
,
data_iterator
,
model
,
num_microbatches
,
input_tensor
,
forward_data_store
,
config
,
collect_non_loss_data
=
False
,
checkpoint_activations_microbatch
=
None
,
is_first_microbatch
=
False
,
current_microbatch
=
None
,
encoder_decoder_xattn
=
False
,
):
"""Forward step for passed-in model.
If it is the first stage, the input tensor is obtained from the data_iterator.
Otherwise, the passed-in input_tensor is used.
Args:
forward_step_func (callable):
The forward step function for the model that takes the
data iterator as the first argument, and model as the second.
This user's forward step is expected to output a tuple of two elements:
1. The output object from the forward step. This output object needs to be a
tensor or some kind of collection of tensors. The only hard requirement
for this object is that it needs to be acceptible as input into the second
function.
2. A function to reduce (optionally) the output from the forward step. This
could be a reduction over the loss from the model, it could be a function that
grabs the output from the model and reformats, it could be a function that just
passes through the model output. This function must have one of the following
patterns, and depending on the pattern different things happen internally:
a. A tuple of reduced loss and some other data. Note that in this case
the first argument is divided by the number of global microbatches,
assuming it is a loss, so that the loss is stable as a function of
the number of devices the step is split across.
b. A triple of reduced loss, number of tokens, and some other data. This
is similar to case (a), but the loss is further averaged across the
number of tokens in the batch. If the user is not already averaging
across the number of tokens, this pattern is useful to use.
c. Any arbitrary data the user wants (eg a dictionary of tensors, a list
of tensors, etc in the case of inference). To trigger case 3 you need
to specify `collect_non_loss_data=True` and you may also want to
specify `forward_only=True` in the call to the parent forward_backward
function.
data_iterator (iterator):
The data iterator.
model (nn.Module):
The model to perform the forward step on.
num_microbatches (int):
The number of microbatches.
input_tensor (Tensor or list[Tensor]):
The input tensor(s) for the forward step.
forward_data_store (list):
The list to store the forward data. If you go down path 2.a or
2.b for the return of your forward reduction function then this will store only the
final dimension of the output, for example the metadata output by the loss function.
If you go down the path of 2.c then this will store the entire output of the forward
reduction function applied to the model output.
config (object):
The configuration object.
collect_non_loss_data (bool, optional):
Whether to collect non-loss data. Defaults to False.
This is the path to use if you want to collect arbitrary output from the model forward,
such as with inference use cases. Defaults to False.
checkpoint_activations_microbatch (int, optional):
The microbatch to checkpoint activations.
Defaults to None.
is_first_microbatch (bool, optional):
Whether it is the first microbatch. Defaults to False.
current_microbatch (int, optional):
The current microbatch. Defaults to None.
Returns:
Tensor or list[Tensor]: The output object(s) from the forward step.
Tensor: The number of tokens.
"""
if
config
.
timers
is
not
None
:
config
.
timers
(
'forward-compute'
,
log_level
=
2
).
start
()
if
is_first_microbatch
and
hasattr
(
model
,
'set_is_first_microbatch'
):
model
.
set_is_first_microbatch
()
if
current_microbatch
is
not
None
:
set_current_microbatch
(
model
,
current_microbatch
)
unwrap_output_tensor
=
False
if
not
isinstance
(
input_tensor
,
list
):
input_tensor
=
[
input_tensor
]
unwrap_output_tensor
=
True
set_input_tensor
=
get_attr_wrapped_model
(
model
,
"set_input_tensor"
)
set_input_tensor
(
input_tensor
)
if
config
.
enable_autocast
:
context_manager
=
torch
.
autocast
(
"cuda"
,
dtype
=
config
.
autocast_dtype
)
else
:
context_manager
=
contextlib
.
nullcontext
()
with
context_manager
:
if
checkpoint_activations_microbatch
is
None
:
output_tensor
,
loss_func
=
forward_step_func
(
data_iterator
,
model
)
else
:
output_tensor
,
loss_func
=
forward_step_func
(
data_iterator
,
model
,
checkpoint_activations_microbatch
)
num_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
int
)
if
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
False
):
if
not
collect_non_loss_data
:
outputs
=
loss_func
(
output_tensor
)
if
len
(
outputs
)
==
3
:
output_tensor
,
num_tokens
,
loss_reduced
=
outputs
if
not
config
.
calculate_per_token_loss
:
output_tensor
/=
num_tokens
output_tensor
*=
parallel_state
.
get_context_parallel_world_size
()
output_tensor
/=
num_microbatches
else
:
# preserve legacy loss averaging behavior (ie, over the number of microbatches)
assert
len
(
outputs
)
==
2
output_tensor
,
loss_reduced
=
outputs
output_tensor
*=
parallel_state
.
get_context_parallel_world_size
()
output_tensor
/=
num_microbatches
forward_data_store
.
append
(
loss_reduced
)
else
:
data
=
loss_func
(
output_tensor
,
non_loss_data
=
True
)
forward_data_store
.
append
(
data
)
if
config
.
timers
is
not
None
:
config
.
timers
(
'forward-compute'
).
stop
()
# Set the loss scale for the auxiliary loss of the MoE layer.
# Since we use a trick to do backward on the auxiliary loss, we need to set the scale
# explicitly.
if
hasattr
(
config
,
'num_moe_experts'
)
and
config
.
num_moe_experts
is
not
None
:
# Calculate the loss scale based on the grad_scale_func if available, else default to 1.
loss_scale
=
(
config
.
grad_scale_func
(
torch
.
ones
(
1
,
device
=
output_tensor
.
device
))
if
config
.
grad_scale_func
is
not
None
else
torch
.
ones
(
1
,
device
=
output_tensor
.
device
)
)
# Set the loss scale
if
config
.
calculate_per_token_loss
:
MoEAuxLossAutoScaler
.
set_loss_scale
(
loss_scale
)
else
:
MoEAuxLossAutoScaler
.
set_loss_scale
(
loss_scale
/
num_microbatches
)
# Set the loss scale for Multi-Token Prediction (MTP) loss.
if
hasattr
(
config
,
'mtp_num_layers'
)
and
config
.
mtp_num_layers
is
not
None
:
# Calculate the loss scale based on the grad_scale_func if available, else default to 1.
loss_scale
=
(
config
.
grad_scale_func
(
torch
.
ones
(
1
,
device
=
output_tensor
.
device
))
if
config
.
grad_scale_func
is
not
None
else
torch
.
ones
(
1
,
device
=
output_tensor
.
device
)
)
# Set the loss scale
if
config
.
calculate_per_token_loss
:
MTPLossAutoScaler
.
set_loss_scale
(
loss_scale
)
else
:
MTPLossAutoScaler
.
set_loss_scale
(
loss_scale
/
num_microbatches
)
# If T5 model and in decoder stack, then send encoder_hidden_state
# downstream as well.
model_type
=
get_model_type
(
model
)
if
(
model_type
==
ModelType
.
encoder_and_decoder
and
encoder_decoder_xattn
and
parallel_state
.
is_inside_decoder
()
):
return
[
output_tensor
,
input_tensor
[
-
1
]],
num_tokens
if
unwrap_output_tensor
:
return
output_tensor
,
num_tokens
return
[
output_tensor
],
num_tokens
def
backward_step
(
input_tensor
,
output_tensor
,
output_tensor_grad
,
model_type
,
config
):
"""Backward step through passed-in output tensor.
If last stage, output_tensor_grad is None, otherwise gradient of loss
with respect to stage's output tensor.
Returns gradient of loss with respect to input tensor (None if first
stage)."""
# NOTE: This code currently can handle at most one skip connection. It
# needs to be modified slightly to support arbitrary numbers of skip
# connections.
if
config
.
timers
is
not
None
:
config
.
timers
(
'backward-compute'
,
log_level
=
2
).
start
()
# Retain the grad on the input_tensor.
unwrap_input_tensor_grad
=
False
if
not
isinstance
(
input_tensor
,
list
):
input_tensor
=
[
input_tensor
]
unwrap_input_tensor_grad
=
True
for
x
in
input_tensor
:
if
x
is
not
None
:
x
.
retain_grad
()
if
not
isinstance
(
output_tensor
,
list
):
output_tensor
=
[
output_tensor
]
if
not
isinstance
(
output_tensor_grad
,
list
):
output_tensor_grad
=
[
output_tensor_grad
]
# Backward pass.
if
output_tensor_grad
[
0
]
is
None
and
config
.
grad_scale_func
is
not
None
:
output_tensor
[
0
]
=
config
.
grad_scale_func
(
output_tensor
[
0
])
# In multi-modal models like VLM, some batches may not have images.
# When no image is present, the vision encoder (as a separate pipeline stage)
# will not participate in the computation.
# This results in a tensor that does not require gradients.
# In such cases, we intentionally skip the backward pass while preserving zero gradients.
if
output_tensor
[
0
].
requires_grad
:
if
config
.
deallocate_pipeline_outputs
:
custom_backward
(
output_tensor
[
0
],
output_tensor_grad
[
0
])
else
:
torch
.
autograd
.
backward
(
output_tensor
[
0
],
grad_tensors
=
output_tensor_grad
[
0
])
# Collect the grad of the input_tensor.
input_tensor_grad
=
[
None
]
if
input_tensor
is
not
None
:
input_tensor_grad
=
[]
for
x
in
input_tensor
:
if
x
is
None
:
input_tensor_grad
.
append
(
None
)
else
:
input_tensor_grad
.
append
(
x
.
grad
)
# Handle single skip connection if it exists (encoder_hidden_state in
# model with encoder and decoder).
if
(
parallel_state
.
get_pipeline_model_parallel_world_size
()
>
1
and
model_type
==
ModelType
.
encoder_and_decoder
and
len
(
output_tensor_grad
)
>
1
# excludes models that lack a skip connection.
):
if
output_tensor_grad
[
1
]
is
not
None
:
assert
input_tensor_grad
[
-
1
]
is
not
None
input_tensor_grad
[
-
1
].
add_
(
output_tensor_grad
[
1
])
if
unwrap_input_tensor_grad
:
input_tensor_grad
=
input_tensor_grad
[
0
]
if
config
.
timers
is
not
None
:
config
.
timers
(
'backward-compute'
).
stop
()
return
input_tensor_grad
def
check_first_val_step
(
first_val_step
,
forward_only
,
cond
):
"""Check if it is the first validation step."""
if
(
first_val_step
is
not
None
)
and
forward_only
:
return
first_val_step
and
cond
else
:
return
cond
def
forward_backward_no_pipelining
(
*
,
forward_step_func
,
data_iterator
:
Union
[
Iterator
,
List
[
Iterator
]],
model
:
Union
[
torch
.
nn
.
Module
,
List
[
torch
.
nn
.
Module
]],
num_microbatches
:
int
,
seq_length
:
int
,
# unused
micro_batch_size
:
int
,
# unused
decoder_seq_length
:
Optional
[
int
]
=
None
,
# unused
forward_only
:
bool
=
False
,
collect_non_loss_data
:
bool
=
False
,
first_val_step
:
Optional
[
bool
]
=
None
,
adjust_tensor_shapes_fn
:
Optional
[
Callable
]
=
None
,
# unused
):
"""Run forward and backward passes with no pipeline parallelism
(no inter-stage communication).
Returns dictionary with losses.
See get_forward_backward_func() for argument details
"""
if
isinstance
(
model
,
list
):
assert
len
(
model
)
==
1
,
"non-pipeline-parallel schedule does not support model chunking"
model
=
model
[
0
]
if
isinstance
(
data_iterator
,
list
):
assert
(
len
(
data_iterator
)
==
1
),
"non-pipeline-parallel schedule does not support model chunking"
data_iterator
=
data_iterator
[
0
]
assert
(
adjust_tensor_shapes_fn
is
None
),
"adjust_tensor_shapes_fn is not supported for non-pipeline-parallel schedule"
config
=
get_model_config
(
model
)
if
config
.
timers
is
not
None
:
config
.
timers
(
'forward-backward'
,
log_level
=
1
).
start
(
barrier
=
config
.
barrier_with_L1_time
)
no_sync_func
=
config
.
no_sync_func
if
no_sync_func
is
None
:
no_sync_func
=
contextlib
.
nullcontext
model_type
=
get_model_type
(
model
)
forward_data_store
=
[]
input_tensor
,
output_tensor_grad
=
None
,
None
total_num_tokens
=
torch
.
zeros
([],
dtype
=
torch
.
int
,
device
=
"cuda"
)
with
no_sync_func
():
for
i
in
range
(
num_microbatches
-
1
):
output_tensor
,
num_tokens
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
num_microbatches
,
input_tensor
,
forward_data_store
,
config
,
collect_non_loss_data
,
is_first_microbatch
=
check_first_val_step
(
first_val_step
,
forward_only
,
i
==
0
),
current_microbatch
=
i
,
)
total_num_tokens
+=
num_tokens
if
not
forward_only
:
backward_step
(
input_tensor
,
output_tensor
,
output_tensor_grad
,
model_type
,
config
)
# Run computation for last microbatch out of context handler (want to
# synchronize gradients).
output_tensor
,
num_tokens
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
num_microbatches
,
input_tensor
,
forward_data_store
,
config
,
collect_non_loss_data
,
is_first_microbatch
=
check_first_val_step
(
first_val_step
,
forward_only
,
num_microbatches
==
1
),
current_microbatch
=
num_microbatches
-
1
,
)
total_num_tokens
+=
num_tokens
if
not
forward_only
:
backward_step
(
input_tensor
,
output_tensor
,
output_tensor_grad
,
model_type
,
config
)
if
config
.
finalize_model_grads_func
is
not
None
and
not
forward_only
:
# Finalize model grads (perform full grad all-reduce / reduce-scatter for
# data parallelism and layernorm all-reduce for sequence parallelism).
config
.
finalize_model_grads_func
(
[
model
],
total_num_tokens
if
config
.
calculate_per_token_loss
else
None
)
if
config
.
timers
is
not
None
:
config
.
timers
(
'forward-backward'
).
stop
()
if
hasattr
(
config
,
'enable_cuda_graph'
)
and
config
.
enable_cuda_graph
:
create_cudagraphs
()
return
forward_data_store
def
clear_embedding_activation_buffer
(
config
,
model
):
"""Clear embedding activation buffer."""
if
(
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
)
and
config
.
defer_embedding_wgrad_compute
):
if
isinstance
(
model
,
list
):
embedding_module
=
get_attr_wrapped_model
(
model
[
-
1
],
'post_process'
,
return_model_obj
=
True
)
else
:
embedding_module
=
get_attr_wrapped_model
(
model
,
'post_process'
,
return_model_obj
=
True
)
# Need to ensure no stray activations exists in this buffer
embedding_module
.
embedding_activation_buffer
.
clear
()
return
embedding_module
else
:
return
None
def
finish_embedding_wgrad_compute
(
config
,
embedding_module
):
"""Finish embedding wgrad compute."""
if
(
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
)
and
config
.
defer_embedding_wgrad_compute
):
embedding_activation_buffer
=
embedding_module
.
embedding_activation_buffer
grad_output_buffer
=
embedding_module
.
grad_output_buffer
weight
=
(
embedding_module
.
output_layer
.
weight
if
embedding_module
.
share_embeddings_and_output_weights
else
embedding_module
.
shared_embedding_or_output_weight
()
)
drain_embedding_wgrad_compute
(
config
,
embedding_activation_buffer
,
grad_output_buffer
,
weight
)
def
get_pp_rank_microbatches
(
def
get_pp_rank_microbatches
(
num_microbatches
,
num_model_chunks
,
microbatch_group_size_per_vp_stage
,
forward_only
=
False
num_microbatches
,
num_model_chunks
,
microbatch_group_size_per_vp_stage
,
forward_only
=
False
):
):
"""Get the number of total, warmup, and remaining microbatches in PP scheduling."""
"""Get the number of total, warmup, and remaining microbatches in PP scheduling."""
args
=
get_args
()
pipeline_parallel_size
=
parallel_state
.
get_pipeline_model_parallel_world_size
()
pipeline_parallel_size
=
parallel_state
.
get_pipeline_model_parallel_world_size
()
pipeline_parallel_rank
=
parallel_state
.
get_pipeline_model_parallel_rank
()
pipeline_parallel_rank
=
parallel_state
.
get_pipeline_model_parallel_rank
()
virtual_pipeline_parallel_size
=
parallel_state
.
get_virtual_pipeline_model_parallel_world_size
()
virtual_pipeline_parallel_size
=
parallel_state
.
get_virtual_pipeline_model_parallel_world_size
()
...
@@ -608,6 +59,9 @@ def get_pp_rank_microbatches(
...
@@ -608,6 +59,9 @@ def get_pp_rank_microbatches(
# immediately start with 1F1B).
# immediately start with 1F1B).
num_warmup_microbatches
=
(
pipeline_parallel_size
-
pipeline_parallel_rank
-
1
)
*
2
num_warmup_microbatches
=
(
pipeline_parallel_size
-
pipeline_parallel_rank
-
1
)
*
2
num_warmup_microbatches
+=
(
num_model_chunks
-
1
)
*
microbatch_group_size_per_vp_stage
num_warmup_microbatches
+=
(
num_model_chunks
-
1
)
*
microbatch_group_size_per_vp_stage
if
args
.
combined_1f1b
:
num_warmup_microbatches
=
num_warmup_microbatches
+
1
else
:
else
:
# forward_backward_no_pipelining
# forward_backward_no_pipelining
num_warmup_microbatches
=
1
num_warmup_microbatches
=
1
...
@@ -625,62 +79,6 @@ def get_pp_rank_microbatches(
...
@@ -625,62 +79,6 @@ def get_pp_rank_microbatches(
)
)
def
get_schedule_table
(
num_microbatches
,
num_model_chunks
,
microbatch_group_size_per_vp_stage
):
"""Get the schedule table for PP scheduling."""
schedule_table
=
[]
for
min_microbatch_id_in_group
in
range
(
0
,
num_microbatches
,
microbatch_group_size_per_vp_stage
):
if
min_microbatch_id_in_group
+
microbatch_group_size_per_vp_stage
>=
num_microbatches
:
# Construct schedule for the last microbatch group
schedule_table
.
extend
(
[
(
microbatch_id
,
model_chunk_id
)
for
model_chunk_id
in
range
(
num_model_chunks
)
for
microbatch_id
in
range
(
min_microbatch_id_in_group
,
num_microbatches
)
]
)
else
:
# Construct schedule for other microbatch groups
schedule_table
.
extend
(
[
(
microbatch_id
,
model_chunk_id
)
for
model_chunk_id
in
range
(
num_model_chunks
)
for
microbatch_id
in
range
(
min_microbatch_id_in_group
,
min_microbatch_id_in_group
+
microbatch_group_size_per_vp_stage
,
)
]
)
return
schedule_table
def
convert_schedule_table_to_order
(
num_warmup_microbatches
,
num_model_chunks
,
schedule_table
):
"""Convert a tunable schedule lookup table to the te.make_graphed_callables() accepted
order format. For example, the tunable schedule table for PP2 N3M5 with VP2 is as below:
virtual_microbatch_id | 0 1 2 3 4 5 6 7 8 9
microbatch_id | 0 1 2 0 1 2 3 4 3 4
model_chunk_id | 0 0 0 1 1 1 0 0 1 1
Then the forward backward separated order is:
forward | 1 1 1 2 2 2 1 1 2 2
backward | -2 -2 -2 -1 -1 -1 -2 -2 -1 -1
If num_warmup_microbatches is 5, the output order is:
1 1 1 2 2 2 -2 1 -2 1 -2 2 -1 2 -1 -1 -2 -2 -1 -1
"""
_
,
model_chunk_id_table
=
zip
(
*
schedule_table
)
forward_order
=
[
chunk_id
+
1
for
chunk_id
in
model_chunk_id_table
]
backward_order
=
[
chunk_id
-
num_model_chunks
for
chunk_id
in
model_chunk_id_table
]
order
=
forward_order
[:
num_warmup_microbatches
]
for
i
in
range
(
num_warmup_microbatches
,
len
(
forward_order
)):
order
.
append
(
forward_order
[
i
])
order
.
append
(
backward_order
[
i
-
num_warmup_microbatches
])
if
num_warmup_microbatches
>
0
:
order
.
extend
(
backward_order
[
-
num_warmup_microbatches
:])
return
order
def
forward_backward_pipelining_with_interleaving
(
def
forward_backward_pipelining_with_interleaving
(
*
,
*
,
forward_step_func
,
forward_step_func
,
...
@@ -1057,6 +455,7 @@ def forward_backward_pipelining_with_interleaving(
...
@@ -1057,6 +455,7 @@ def forward_backward_pipelining_with_interleaving(
"""Helper method to run backward step with model split into chunks
"""Helper method to run backward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
(run set_virtual_pipeline_model_parallel_rank() before calling
backward_step())."""
backward_step())."""
nonlocal
output_tensor_grads
# TODO(dongcl)
model_chunk_id
=
get_model_chunk_id
(
virtual_microbatch_id
,
forward
=
False
)
model_chunk_id
=
get_model_chunk_id
(
virtual_microbatch_id
,
forward
=
False
)
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
model_chunk_id
)
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
model_chunk_id
)
...
@@ -1099,6 +498,214 @@ def forward_backward_pipelining_with_interleaving(
...
@@ -1099,6 +498,214 @@ def forward_backward_pipelining_with_interleaving(
return
input_tensor_grad
return
input_tensor_grad
def
combined_forward_backward_helper
(
f_virtual_microbatch_id
=
None
,
b_virtual_microbatch_id
=
None
,
pre_forward
=
None
,
pre_backward
=
None
,
post_forward
=
None
,
post_backward
=
None
,
):
"""Helper method to run combined forward and backward step"""
# forward prepare
f_model_chunk_id
=
None
f_microbatch_id
=
None
if
f_virtual_microbatch_id
is
not
None
:
f_microbatch_id
=
get_microbatch_id_in_model_chunk
(
f_virtual_microbatch_id
,
True
)
f_context
=
contextlib
.
nullcontext
()
input_tensor
=
None
if
f_virtual_microbatch_id
is
not
None
:
model_chunk_id
=
get_model_chunk_id
(
f_virtual_microbatch_id
,
forward
=
True
)
f_model_chunk_id
=
model_chunk_id
f_context
=
VppContextManager
(
f_model_chunk_id
)
with
f_context
:
# launch param synchronization for next model chunk
# Note: Asynchronous communication tends to slow down compute.
# To reduce idling from mismatched microbatch times, we launch
# asynchronous communication at the same time across the
# pipeline-parallel group.
if
config
.
param_sync_func
is
not
None
:
param_sync_virtual_microbatch_id
=
(
f_virtual_microbatch_id
+
pipeline_parallel_rank
)
if
(
param_sync_virtual_microbatch_id
<
total_num_microbatches
and
is_first_microbatch_for_model_chunk
(
param_sync_virtual_microbatch_id
)
):
param_sync_chunk_id
=
(
get_model_chunk_id
(
param_sync_virtual_microbatch_id
,
forward
=
True
)
+
1
)
if
1
<
param_sync_chunk_id
<
num_model_chunks
:
config
.
param_sync_func
[
param_sync_chunk_id
](
model
[
param_sync_chunk_id
].
parameters
()
)
# forward step
if
parallel_state
.
is_pipeline_first_stage
():
if
len
(
input_tensors
[
model_chunk_id
])
==
len
(
output_tensors
[
model_chunk_id
]):
input_tensors
[
model_chunk_id
].
append
(
None
)
# For non-depth-first pipeline schedules, the first rank would
# buffer multiple received activation tensors for a model chunk
# until accessed during warmup. This input buffering is needed to overlap
# the computation with the receipt of the next inputs. To index
# the proper buffered inputs for forword_step, we use
# microbatch_id offset with number of released microbatches
# that have completed backprop.
offset
=
num_released_microbatches
(
f_virtual_microbatch_id
,
model_chunk_id
)
input_tensor
=
input_tensors
[
model_chunk_id
][
f_microbatch_id
-
offset
]
# backward prepare
b_model_chunk_id
=
None
b_context
=
contextlib
.
nullcontext
()
b_input_tensor
=
None
b_output_tensor
=
None
b_output_tensor_grad
=
None
if
b_virtual_microbatch_id
is
not
None
:
model_chunk_id
=
get_model_chunk_id
(
b_virtual_microbatch_id
,
forward
=
False
)
b_model_chunk_id
=
model_chunk_id
b_context
=
VppContextManager
(
b_model_chunk_id
)
with
b_context
:
# launch grad synchronization (default)
if
config
.
grad_sync_func
is
None
and
is_last_microbatch_for_model_chunk
(
b_virtual_microbatch_id
):
enable_grad_sync
()
synchronized_model_chunks
.
add
(
model_chunk_id
)
if
parallel_state
.
is_pipeline_last_stage
():
if
len
(
output_tensor_grads
[
model_chunk_id
])
==
0
:
output_tensor_grads
[
model_chunk_id
].
append
(
None
)
b_input_tensor
=
input_tensors
[
model_chunk_id
].
pop
(
0
)
b_output_tensor
=
output_tensors
[
model_chunk_id
].
pop
(
0
)
b_output_tensor_grad
=
output_tensor_grads
[
model_chunk_id
].
pop
(
0
)
output_tensor
,
num_tokens
,
input_tensor_grad
=
forward_backward_step
(
forward_step_func
,
data_iterator
[
f_model_chunk_id
]
if
f_model_chunk_id
is
not
None
else
None
,
model
[
f_model_chunk_id
]
if
f_model_chunk_id
is
not
None
else
None
,
num_microbatches
,
input_tensor
,
forward_data_store
,
model
[
b_model_chunk_id
]
if
b_model_chunk_id
is
not
None
else
None
,
b_input_tensor
,
b_output_tensor
,
b_output_tensor_grad
,
config
,
f_context
=
f_context
,
b_context
=
b_context
,
pre_forward
=
pre_forward
,
pre_backward
=
pre_backward
,
post_forward
=
post_forward
,
post_backward
=
post_backward
,
collect_non_loss_data
=
collect_non_loss_data
,
checkpoint_activations_microbatch
=
None
,
is_first_microbatch
=
check_first_val_step
(
first_val_step
,
forward_only
,
(
is_first_microbatch_for_model_chunk
(
f_virtual_microbatch_id
)
if
f_virtual_microbatch_id
is
not
None
else
None
),
),
current_microbatch
=
f_microbatch_id
,
)
# forward post process
if
f_model_chunk_id
is
not
None
:
with
f_context
:
output_tensors
[
f_model_chunk_id
].
append
(
output_tensor
)
nonlocal
total_num_tokens
total_num_tokens
+=
num_tokens
.
item
()
# If forward-only, no need to save tensors for a backward pass.
if
forward_only
:
# Release the tensor that have completed forward step.
input_tensors
[
f_model_chunk_id
].
pop
(
0
)
output_tensors
[
f_model_chunk_id
].
pop
()
# backward post process
if
b_model_chunk_id
:
with
b_context
:
# launch grad synchronization (custom grad sync)
# Note: Asynchronous communication tends to slow down compute.
# To reduce idling from mismatched microbatch times, we launch
# asynchronous communication at the same time across the
# pipeline-parallel group.
if
config
.
grad_sync_func
is
not
None
:
grad_sync_virtual_microbatch_id
=
(
b_virtual_microbatch_id
-
pipeline_parallel_rank
)
if
grad_sync_virtual_microbatch_id
>=
0
and
is_last_microbatch_for_model_chunk
(
grad_sync_virtual_microbatch_id
):
grad_sync_chunk_id
=
get_model_chunk_id
(
grad_sync_virtual_microbatch_id
,
forward
=
False
)
enable_grad_sync
()
config
.
grad_sync_func
[
grad_sync_chunk_id
](
model
[
grad_sync_chunk_id
].
parameters
()
)
synchronized_model_chunks
.
add
(
grad_sync_chunk_id
)
disable_grad_sync
()
if
input_tensor
is
not
None
:
assert
input_tensor_grad
is
not
None
return
output_tensor
,
input_tensor_grad
def
forward_backward_helper_wrapper
(
f_virtual_microbatch_id
=
None
,
b_virtual_microbatch_id
=
None
,
pre_forward
=
None
,
pre_backward
=
None
,
post_forward
=
None
,
post_backward
=
None
,
checkpoint_activations_microbatch
=
None
,
):
"""
wrap forward_helper、backward_helper、combined_forward_backward_helper in a unified way
"""
if
config
.
combined_1f1b
and
config
.
combined_1f1b_recipe
==
"ep_a2a"
and
not
forward_only
:
assert
(
checkpoint_activations_microbatch
is
None
),
"checkpoint_activations_microbatch not supported when combined_1f1b is true"
return
combined_forward_backward_helper
(
f_virtual_microbatch_id
=
f_virtual_microbatch_id
,
b_virtual_microbatch_id
=
b_virtual_microbatch_id
,
pre_forward
=
pre_forward
,
pre_backward
=
pre_backward
,
post_forward
=
post_forward
,
post_backward
=
post_backward
,
)
else
:
output_tensor
=
None
input_tensor_grad
=
None
if
f_virtual_microbatch_id
is
not
None
:
# forward pass
forward_model_chunk_id
=
get_model_chunk_id
(
f_virtual_microbatch_id
,
forward
=
True
)
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
forward_model_chunk_id
)
if
pre_forward
is
not
None
:
pre_forward
()
microbatch_id
=
get_microbatch_id_in_model_chunk
(
f_virtual_microbatch_id
,
forward
=
True
)
output_tensor
=
forward_step_helper
(
f_virtual_microbatch_id
,
microbatch_id
,
checkpoint_activations_microbatch
)
if
post_forward
is
not
None
:
output_tensor
=
post_forward
(
output_tensor
)
if
b_virtual_microbatch_id
is
not
None
:
# Backward pass.
backward_model_chunk_id
=
get_model_chunk_id
(
b_virtual_microbatch_id
,
forward
=
False
)
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
backward_model_chunk_id
)
if
pre_backward
is
not
None
:
pre_backward
()
input_tensor_grad
=
backward_step_helper
(
b_virtual_microbatch_id
)
if
post_backward
is
not
None
:
input_tensor_grad
=
post_backward
(
input_tensor_grad
)
return
output_tensor
,
input_tensor_grad
# Run warmup forward passes.
# Run warmup forward passes.
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
0
)
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
0
)
input_tensors
[
0
].
append
(
p2p_communication
.
recv_forward
(
tensor_shape
,
config
))
input_tensors
[
0
].
append
(
p2p_communication
.
recv_forward
(
tensor_shape
,
config
))
...
@@ -1172,8 +779,10 @@ def forward_backward_pipelining_with_interleaving(
...
@@ -1172,8 +779,10 @@ def forward_backward_pipelining_with_interleaving(
else
:
else
:
checkpoint_activations_microbatch
=
None
checkpoint_activations_microbatch
=
None
microbatch_id
=
get_microbatch_id_in_model_chunk
(
k
,
forward
=
True
)
output_tensor
,
_
=
forward_backward_helper_wrapper
(
output_tensor
=
forward_step_helper
(
k
,
microbatch_id
,
checkpoint_activations_microbatch
)
f_virtual_microbatch_id
=
k
,
checkpoint_activations_microbatch
=
checkpoint_activations_microbatch
,
)
# Don't send tensor downstream if on last stage.
# Don't send tensor downstream if on last stage.
if
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
False
):
if
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
False
):
...
@@ -1296,131 +905,153 @@ def forward_backward_pipelining_with_interleaving(
...
@@ -1296,131 +905,153 @@ def forward_backward_pipelining_with_interleaving(
cur_model_chunk_id
=
get_model_chunk_id
(
forward_k
,
forward
=
True
)
cur_model_chunk_id
=
get_model_chunk_id
(
forward_k
,
forward
=
True
)
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
cur_model_chunk_id
)
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
cur_model_chunk_id
)
microbatch_id
=
get_microbatch_id_in_model_chunk
(
forward_k
,
forward
=
True
)
if
config
.
overlap_p2p_comm
:
if
config
.
overlap_p2p_comm
:
if
not
parallel_state
.
is_pipeline_first_stage
(
ignore_virtual
=
False
):
backward_k
=
k
if
config
.
overlap_p2p_comm_warmup_flush
:
assert
recv_prev_wait_handles
,
(
# output send / receive sync
f
'pp rank
{
pipeline_parallel_rank
}
, fwd iteration
{
forward
_k
}
, '
def
pp_pre_
forward
():
'should have registered recv handle'
if
not
parallel_state
.
is_pipeline_first_stage
(
ignore_virtual
=
False
):
)
if
config
.
overlap_p2p_comm_warmup_flush
:
recv_prev_wait_handle
=
recv_prev_wait_handles
.
pop
(
0
)
assert
recv_prev_wait_handles
,
(
recv_prev_wait_handle
.
wait
()
f
'pp rank
{
pipeline_parallel_rank
}
, fwd iteration
{
forward_k
}
, '
else
:
'should have registered recv handle'
if
recv_prev_wait_handles
is
not
None
and
recv_prev_wait_handles
:
)
recv_prev_wait_handle
=
recv_prev_wait_handles
.
pop
(
0
)
recv_prev_wait_handle
=
recv_prev_wait_handles
.
pop
(
0
)
recv_prev_wait_handle
.
wait
()
recv_prev_wait_handle
.
wait
()
else
:
if
recv_prev_wait_handles
is
not
None
and
recv_prev_wait_handles
:
recv_prev_wait_handle
=
recv_prev_wait_handles
.
pop
(
0
)
recv_prev_wait_handle
.
wait
()
deallocate_output_tensor
(
output_tensor
,
config
.
deallocate_pipeline_outputs
)
# output async send / receive
def
pp_post_forward
(
output_tensor
):
nonlocal
send_next_wait_handle
nonlocal
fwd_recv_buffer
nonlocal
fwd_wait_handles
nonlocal
recv_prev_wait_handles
# Last virtual stage no activation tensor to send.
if
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
False
):
output_tensor
=
None
recv_prev
,
next_forward_model_chunk_id
=
recv_tensor_from_previous_stage
(
forward_k
,
forward
=
True
)
deallocate_output_tensor
(
output_tensor
,
config
.
deallocate_pipeline_outputs
)
# If last iteration, don't receive; we already received one extra
# before the start of the for loop.
output_tensor
=
forward_step_helper
(
if
k
==
(
num_microbatches_remaining
-
1
):
forward_k
,
microbatch_id
,
checkpoint_activations_microbatch
recv_prev
=
False
)
# Determine if current stage has anything to send in either direction,
# otherwise set tensor to None.
forward_model_chunk_id
=
get_model_chunk_id
(
forward_k
,
forward
=
True
)
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
forward_model_chunk_id
)
# Last virtual stage no activation tensor to send.
if
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
False
):
output_tensor
=
None
recv_prev
,
next_forward_model_chunk_id
=
recv_tensor_from_previous_stage
(
forward_k
,
forward
=
True
)
# If last iteration, don't receive; we already received one extra
# before the start of the for loop.
if
k
==
(
num_microbatches_remaining
-
1
):
recv_prev
=
False
# Send activation tensor to the next stage and receive activation tensor from the
# Send activation tensor to the next stage and receive activation tensor from the
# previous stage
# previous stage
fwd_recv_buffer
[
forward_k
%
fwd_recv_buffer_size
],
fwd_wait_handles
=
(
fwd_recv_buffer
[
forward_k
%
fwd_recv_buffer_size
],
fwd_wait_handles
=
(
p2p_communication
.
send_forward_recv_forward
(
p2p_communication
.
send_forward_recv_forward
(
output_tensor
,
output_tensor
,
recv_prev
=
recv_prev
,
recv_prev
=
recv_prev
,
tensor_shape
=
tensor_shape
,
tensor_shape
=
tensor_shape
,
config
=
config
,
config
=
config
,
overlap_p2p_comm
=
True
,
overlap_p2p_comm
=
True
,
)
)
)
if
send_next_wait_handle
is
not
None
:
send_next_wait_handle
.
wait
()
if
fwd_wait_handles
is
not
None
:
send_next_wait_handle
=
(
fwd_wait_handles
.
pop
(
"send_next"
)
if
"send_next"
in
fwd_wait_handles
else
None
)
)
if
"recv_prev"
in
fwd_wait_handles
:
if
send_next_wait_handle
is
not
None
:
recv_prev_wait_handles
.
append
(
fwd_wait_handles
.
pop
(
"recv_prev"
))
send_next_wait_handle
.
wait
()
# assert fwd_wait_handles is not None
if
fwd_wait_handles
is
not
None
:
send_next_wait_handle
=
(
# Backward pass.
fwd_wait_handles
.
pop
(
"send_next"
)
backward_k
=
k
if
"send_next"
in
fwd_wait_handles
backward_model_chunk_id
=
get_model_chunk_id
(
backward_k
,
forward
=
False
)
else
None
parallel_state
.
set_virtual_pipeline_model_parallel_rank
(
backward_model_chunk_id
)
if
not
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
False
):
if
config
.
overlap_p2p_comm_warmup_flush
:
assert
recv_next_wait_handles
,
(
f
'pp rank
{
pipeline_parallel_rank
}
, bwd iteration
{
backward_k
}
, '
'should have registered recv next handle'
)
)
recv_next_wait_handle
=
recv_next_wait_handles
.
pop
(
0
)
if
"recv_prev"
in
fwd_wait_handles
:
recv_next_wait_handle
.
wait
()
recv_prev_wait_handles
.
append
(
fwd_wait_handles
.
pop
(
"recv_prev"
))
else
:
# assert fwd_wait_handles is not None
if
recv_next_wait_handles
is
not
None
and
recv_next_wait_handles
:
# Put input_tensor and output_tensor_grad in data structures in the
# right location.
if
recv_prev
:
input_tensors
[
next_forward_model_chunk_id
].
append
(
fwd_recv_buffer
[
forward_k
%
fwd_recv_buffer_size
]
)
fwd_recv_buffer
[(
forward_k
+
1
)
%
fwd_recv_buffer_size
]
=
None
return
output_tensor
# grad send receive sync
def
pp_pre_backward
():
nonlocal
recv_next_wait_handles
if
not
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
False
):
if
config
.
overlap_p2p_comm_warmup_flush
:
assert
recv_next_wait_handles
,
(
f
'pp rank
{
pipeline_parallel_rank
}
, bwd iteration
{
backward_k
}
, '
'should have registered recv next handle'
)
recv_next_wait_handle
=
recv_next_wait_handles
.
pop
(
0
)
recv_next_wait_handle
=
recv_next_wait_handles
.
pop
(
0
)
recv_next_wait_handle
.
wait
()
recv_next_wait_handle
.
wait
()
else
:
if
recv_next_wait_handles
is
not
None
and
recv_next_wait_handles
:
recv_next_wait_handle
=
recv_next_wait_handles
.
pop
(
0
)
recv_next_wait_handle
.
wait
()
# async grad send receive
def
pp_post_backward
(
input_tensor_grad
):
nonlocal
send_prev_wait_handle
nonlocal
bwd_wait_handles
nonlocal
recv_next_wait_handles
# First virtual stage no activation gradient tensor to send.
if
parallel_state
.
is_pipeline_first_stage
(
ignore_virtual
=
False
):
input_tensor_grad
=
None
recv_next
,
next_backward_model_chunk_id
=
recv_tensor_from_previous_stage
(
backward_k
,
forward
=
False
)
input_tensor_grad
=
backward_step_helper
(
backward_k
)
(
bwd_recv_buffer
[
backward_k
%
bwd_recv_buffer_size
],
bwd_wait_handles
)
=
(
p2p_communication
.
send_backward_recv_backward
(
# First virtual stage no activation gradient tensor to send.
input_tensor_grad
,
if
parallel_state
.
is_pipeline_first_stage
(
ignore_virtual
=
False
):
recv_next
=
recv_next
,
input_tensor_grad
=
None
tensor_shape
=
tensor_shape
,
config
=
config
,
overlap_p2p_comm
=
True
,
)
)
if
send_prev_wait_handle
is
not
None
:
send_prev_wait_handle
.
wait
()
if
bwd_wait_handles
is
not
None
:
send_prev_wait_handle
=
(
bwd_wait_handles
.
pop
(
"send_prev"
)
if
"send_prev"
in
bwd_wait_handles
else
None
)
if
"recv_next"
in
bwd_wait_handles
:
recv_next_wait_handles
.
append
(
bwd_wait_handles
.
pop
(
"recv_next"
))
recv_next
,
next_backward_model_chunk_id
=
recv_tensor_from_previous_stage
(
# Put input_tensor and output_tensor_grad in data structures in the
backward_k
,
forward
=
False
# right location.
)
(
bwd_recv_buffer
[
backward_k
%
bwd_recv_buffer_size
],
bwd_wait_handles
)
=
(
if
recv_next
:
p2p_communication
.
send_backward_recv_backward
(
output_tensor_grads
[
next_backward_model_chunk_id
].
append
(
input_tensor_grad
,
bwd_recv_buffer
[
backward_k
%
bwd_recv_buffer_size
]
recv_next
=
recv_next
,
)
tensor_shape
=
tensor_shape
,
bwd_recv_buffer
[(
backward_k
+
1
)
%
bwd_recv_buffer_size
]
=
None
config
=
config
,
return
input_tensor_grad
overlap_p2p_comm
=
True
,
)
output_tensor
,
input_tensor_grad
=
forward_backward_helper_wrapper
(
f_virtual_microbatch_id
=
forward_k
,
b_virtual_microbatch_id
=
backward_k
,
pre_forward
=
pp_pre_forward
,
pre_backward
=
pp_pre_backward
,
post_forward
=
pp_post_forward
,
post_backward
=
pp_post_backward
,
checkpoint_activations_microbatch
=
checkpoint_activations_microbatch
,
)
)
if
send_prev_wait_handle
is
not
None
:
send_prev_wait_handle
.
wait
()
if
bwd_wait_handles
is
not
None
:
send_prev_wait_handle
=
(
bwd_wait_handles
.
pop
(
"send_prev"
)
if
"send_prev"
in
bwd_wait_handles
else
None
)
if
"recv_next"
in
bwd_wait_handles
:
recv_next_wait_handles
.
append
(
bwd_wait_handles
.
pop
(
"recv_next"
))
# Put input_tensor and output_tensor_grad in data structures in the
# right location.
if
recv_prev
:
input_tensors
[
next_forward_model_chunk_id
].
append
(
fwd_recv_buffer
[
forward_k
%
fwd_recv_buffer_size
]
)
fwd_recv_buffer
[(
forward_k
+
1
)
%
fwd_recv_buffer_size
]
=
None
if
recv_next
:
output_tensor_grads
[
next_backward_model_chunk_id
].
append
(
bwd_recv_buffer
[
backward_k
%
bwd_recv_buffer_size
]
)
bwd_recv_buffer
[(
backward_k
+
1
)
%
bwd_recv_buffer_size
]
=
None
else
:
# No p2p overlap.
else
:
# No p2p overlap.
output_tensor
=
forward_step_helper
(
forward_k
,
microbatch_id
,
checkpoint_activations_microbatch
)
# Backward pass.
backward_k
=
k
backward_k
=
k
input_tensor_grad
=
backward_step_helper
(
backward_k
)
output_tensor
,
input_tensor_grad
=
forward_backward_helper_wrapper
(
f_virtual_microbatch_id
=
forward_k
,
b_virtual_microbatch_id
=
backward_k
,
checkpoint_activations_microbatch
=
checkpoint_activations_microbatch
,
)
# Send output_tensor and input_tensor_grad, receive input_tensor
# Send output_tensor and input_tensor_grad, receive input_tensor
# and output_tensor_grad.
# and output_tensor_grad.
...
@@ -1522,7 +1153,7 @@ def forward_backward_pipelining_with_interleaving(
...
@@ -1522,7 +1153,7 @@ def forward_backward_pipelining_with_interleaving(
if
bwd_wait_recv_handles
:
if
bwd_wait_recv_handles
:
recv_next_wait_handles
.
append
(
bwd_wait_recv_handles
.
pop
(
"recv_next"
))
recv_next_wait_handles
.
append
(
bwd_wait_recv_handles
.
pop
(
"recv_next"
))
input_tensor_grad
=
backward_
step_helper
(
k
)
_
,
input_tensor_grad
=
forward_
backward_
helper_wrapper
(
b_virtual_microbatch_id
=
k
)
# First virtual stage no activation gradient tensor to send.
# First virtual stage no activation gradient tensor to send.
if
parallel_state
.
is_pipeline_first_stage
(
ignore_virtual
=
False
):
if
parallel_state
.
is_pipeline_first_stage
(
ignore_virtual
=
False
):
...
@@ -1615,405 +1246,3 @@ def forward_backward_pipelining_with_interleaving(
...
@@ -1615,405 +1246,3 @@ def forward_backward_pipelining_with_interleaving(
return
forward_data_store
return
forward_data_store
def
get_tensor_shapes
(
*
,
rank
:
int
,
model_type
:
ModelType
,
seq_length
:
int
,
micro_batch_size
:
int
,
decoder_seq_length
:
int
,
config
,
encoder_decoder_xattn
:
bool
,
):
"""
Determine right tensor sizes (based on position of rank with respect to split rank) and
model size.
Send two tensors if model decoder requires the encoder's output (via cross-attention) and
rank is in decoder stage.
First tensor is decoder. Second tensor is encoder.
If model has an encoder & decoder and rank is at the boundary, send one tensor.
Otherwise, send one tensor.
"""
tensor_shapes
=
[]
seq_length
=
seq_length
//
parallel_state
.
get_context_parallel_world_size
()
if
model_type
==
ModelType
.
encoder_and_decoder
:
decoder_seq_length
=
decoder_seq_length
//
parallel_state
.
get_context_parallel_world_size
()
if
config
.
sequence_parallel
:
seq_length
=
seq_length
//
parallel_state
.
get_tensor_model_parallel_world_size
()
if
model_type
==
ModelType
.
encoder_and_decoder
:
decoder_seq_length
=
(
decoder_seq_length
//
parallel_state
.
get_tensor_model_parallel_world_size
()
)
if
model_type
==
ModelType
.
encoder_and_decoder
:
if
parallel_state
.
is_inside_encoder
(
rank
)
and
not
parallel_state
.
is_inside_decoder
(
rank
):
tensor_shapes
.
append
((
seq_length
,
micro_batch_size
,
config
.
hidden_size
))
elif
encoder_decoder_xattn
:
tensor_shapes
.
append
((
decoder_seq_length
,
micro_batch_size
,
config
.
hidden_size
))
tensor_shapes
.
append
((
seq_length
,
micro_batch_size
,
config
.
hidden_size
))
else
:
tensor_shapes
.
append
((
decoder_seq_length
,
micro_batch_size
,
config
.
hidden_size
))
else
:
# model_type == ModelType.encoder_or_decoder
tensor_shapes
.
append
((
seq_length
,
micro_batch_size
,
config
.
hidden_size
))
return
tensor_shapes
def
recv_forward
(
tensor_shapes
,
config
):
"""Wrapper for p2p_communication.recv_forward used with non-interleaving schedule."""
input_tensors
=
[]
for
tensor_shape
in
tensor_shapes
:
if
tensor_shape
is
None
:
input_tensors
.
append
(
None
)
else
:
input_tensors
.
append
(
p2p_communication
.
recv_forward
(
tensor_shape
,
config
))
return
input_tensors
def
recv_backward
(
tensor_shapes
,
config
):
"""Wrapper for p2p_communication.recv_backward used with non-interleaving schedule."""
output_tensor_grads
=
[]
for
tensor_shape
in
tensor_shapes
:
if
tensor_shape
is
None
:
output_tensor_grads
.
append
(
None
)
else
:
output_tensor_grads
.
append
(
p2p_communication
.
recv_backward
(
tensor_shape
,
config
))
return
output_tensor_grads
def
send_forward
(
output_tensors
,
tensor_shapes
,
config
):
"""Wrapper for p2p_communication.send_forward used with non-interleaving schedule."""
if
not
isinstance
(
output_tensors
,
list
):
output_tensors
=
[
output_tensors
]
for
output_tensor
,
tensor_shape
in
zip
(
output_tensors
,
tensor_shapes
):
if
tensor_shape
is
None
:
continue
p2p_communication
.
send_forward
(
output_tensor
,
config
)
def
send_backward
(
input_tensor_grads
,
tensor_shapes
,
config
):
"""Wrapper for p2p_communication.send_backward used with non-interleaving schedule."""
if
not
isinstance
(
input_tensor_grads
,
list
):
input_tensor_grads
=
[
input_tensor_grads
]
for
input_tensor_grad
,
tensor_shape
in
zip
(
input_tensor_grads
,
tensor_shapes
):
if
tensor_shape
is
None
:
continue
p2p_communication
.
send_backward
(
input_tensor_grad
,
config
)
def
send_forward_recv_backward
(
output_tensors
,
tensor_shapes
,
config
):
"""Wrapper for p2p_communication.send_forward_recv_backward used
with non-interleaving schedule."""
if
not
isinstance
(
output_tensors
,
list
):
output_tensors
=
[
output_tensors
]
output_tensor_grads
=
[]
for
output_tensor
,
tensor_shape
in
zip
(
output_tensors
,
tensor_shapes
):
if
tensor_shape
is
None
:
output_tensor_grads
.
append
(
None
)
continue
output_tensor_grad
=
p2p_communication
.
send_forward_recv_backward
(
output_tensor
,
tensor_shape
,
config
)
output_tensor_grads
.
append
(
output_tensor_grad
)
return
output_tensor_grads
def
send_backward_recv_forward
(
input_tensor_grads
,
tensor_shapes
,
config
):
"""Wrapper for p2p_communication.send_backward_recv_forward used
with non-interleaving schedule."""
if
not
isinstance
(
input_tensor_grads
,
list
):
input_tensor_grads
=
[
input_tensor_grads
]
input_tensors
=
[]
for
input_tensor_grad
,
tensor_shape
in
zip
(
input_tensor_grads
,
tensor_shapes
):
if
tensor_shape
is
None
:
input_tensors
.
append
(
None
)
continue
input_tensor
=
p2p_communication
.
send_backward_recv_forward
(
input_tensor_grad
,
tensor_shape
,
config
)
input_tensors
.
append
(
input_tensor
)
return
input_tensors
def
forward_backward_pipelining_without_interleaving
(
*
,
forward_step_func
,
data_iterator
:
Union
[
Iterator
,
List
[
Iterator
]],
model
:
Union
[
torch
.
nn
.
Module
,
List
[
torch
.
nn
.
Module
]],
num_microbatches
:
int
,
seq_length
:
int
,
micro_batch_size
:
int
,
decoder_seq_length
:
Optional
[
int
]
=
None
,
forward_only
:
bool
=
False
,
collect_non_loss_data
:
bool
=
False
,
first_val_step
:
Optional
[
bool
]
=
None
,
adjust_tensor_shapes_fn
:
Optional
[
Callable
]
=
None
,
):
"""Run non-interleaved 1F1B schedule, with communication between pipeline
stages. Returns dictionary with losses if the last stage, empty dict otherwise."""
if
isinstance
(
model
,
list
):
assert
(
len
(
model
)
==
1
),
"non-interleaved pipeline-parallel schedule does not support model chunking"
model
=
model
[
0
]
if
isinstance
(
data_iterator
,
list
):
assert
(
len
(
data_iterator
)
==
1
),
"non-interleaved pipeline-parallel schedule does not support model chunking"
data_iterator
=
data_iterator
[
0
]
config
=
get_model_config
(
model
)
if
config
.
overlap_p2p_comm
:
raise
ValueError
(
"Non-interleaved pipeline parallelism does not support overlapping p2p communication"
)
# Needed only when gradients are finalized in M-Core
if
config
.
finalize_model_grads_func
is
not
None
and
not
forward_only
:
embedding_module
=
clear_embedding_activation_buffer
(
config
,
model
)
if
config
.
timers
is
not
None
:
config
.
timers
(
'forward-backward'
,
log_level
=
1
).
start
(
barrier
=
config
.
barrier_with_L1_time
)
# Disable async grad reductions
no_sync_func
=
config
.
no_sync_func
if
no_sync_func
is
None
:
no_sync_func
=
contextlib
.
nullcontext
no_sync_context
=
None
def
disable_grad_sync
():
"""Disable asynchronous grad reductions"""
nonlocal
no_sync_context
if
no_sync_context
is
None
:
no_sync_context
=
no_sync_func
()
no_sync_context
.
__enter__
()
def
enable_grad_sync
():
"""Enable asynchronous grad reductions"""
nonlocal
no_sync_context
if
no_sync_context
is
not
None
:
no_sync_context
.
__exit__
(
None
,
None
,
None
)
no_sync_context
=
None
disable_grad_sync
()
# Compute number of warmup microbatches.
num_warmup_microbatches
=
(
parallel_state
.
get_pipeline_model_parallel_world_size
()
-
parallel_state
.
get_pipeline_model_parallel_rank
()
-
1
)
num_warmup_microbatches
=
min
(
num_warmup_microbatches
,
num_microbatches
)
num_microbatches_remaining
=
num_microbatches
-
num_warmup_microbatches
# Checkpoint the activations of partial Transformer layers in a number of micro-batches
# within the maximum outstanding micro-batch backpropagations.
# Micro-batches with the ids less than 'num_microbatches_with_partial_activation_checkpoints'
# checkpoint partial Transformer layers (or skip checkpointing) and
# the rest of micro-batches within a window of micro-batches checkpoint
# all Transformer layers. The window of micro-batches is set by the maximum
# outstanding backpropagations and becomes smaller at later pipeline stages.
# Please refer the appendix C in https://arxiv.org/pdf/2205.05198.pdf
max_outstanding_backprops
=
None
if
config
.
num_microbatches_with_partial_activation_checkpoints
is
not
None
:
max_outstanding_backprops
=
num_warmup_microbatches
+
1
model_type
=
get_model_type
(
model
)
encoder_decoder_xattn
=
get_model_xattn
(
model
)
rank
=
parallel_state
.
get_pipeline_model_parallel_rank
()
recv_tensor_shapes
=
get_tensor_shapes
(
rank
=
rank
-
1
,
model_type
=
model_type
,
seq_length
=
seq_length
,
micro_batch_size
=
micro_batch_size
,
decoder_seq_length
=
decoder_seq_length
,
config
=
config
,
encoder_decoder_xattn
=
encoder_decoder_xattn
,
)
send_tensor_shapes
=
get_tensor_shapes
(
rank
=
rank
,
model_type
=
model_type
,
seq_length
=
seq_length
,
micro_batch_size
=
micro_batch_size
,
decoder_seq_length
=
decoder_seq_length
,
config
=
config
,
encoder_decoder_xattn
=
encoder_decoder_xattn
,
)
if
adjust_tensor_shapes_fn
is
not
None
:
recv_tensor_shapes
,
send_tensor_shapes
=
adjust_tensor_shapes_fn
(
recv_tensor_shapes
,
send_tensor_shapes
)
# Input, output tensors only need to be saved when doing backward passes
input_tensors
=
None
output_tensors
=
None
total_num_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
int
).
cuda
()
if
not
forward_only
:
input_tensors
=
[]
output_tensors
=
[]
forward_data_store
=
[]
# Run warmup forward passes.
for
i
in
range
(
num_warmup_microbatches
):
# Decide to checkpoint all layers' activations of the current micro-batch
if
max_outstanding_backprops
is
not
None
:
checkpoint_activations_microbatch
=
(
i
%
max_outstanding_backprops
>=
config
.
num_microbatches_with_partial_activation_checkpoints
)
else
:
checkpoint_activations_microbatch
=
None
input_tensor
=
recv_forward
(
recv_tensor_shapes
,
config
)
output_tensor
,
num_tokens
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
num_microbatches
,
input_tensor
,
forward_data_store
,
config
,
collect_non_loss_data
,
checkpoint_activations_microbatch
,
check_first_val_step
(
first_val_step
,
forward_only
,
i
==
0
),
current_microbatch
=
i
,
encoder_decoder_xattn
=
encoder_decoder_xattn
,
)
send_forward
(
output_tensor
,
send_tensor_shapes
,
config
)
total_num_tokens
+=
num_tokens
if
not
forward_only
:
input_tensors
.
append
(
input_tensor
)
output_tensors
.
append
(
output_tensor
)
deallocate_output_tensor
(
output_tensor
[
0
],
config
.
deallocate_pipeline_outputs
)
# Before running 1F1B, need to receive first forward tensor.
# If all microbatches are run in warmup / cooldown phase, then no need to
# receive this tensor here.
if
num_microbatches_remaining
>
0
:
input_tensor
=
recv_forward
(
recv_tensor_shapes
,
config
)
# Run 1F1B in steady state.
for
i
in
range
(
num_microbatches_remaining
):
last_iteration
=
i
==
(
num_microbatches_remaining
-
1
)
# Decide to checkpoint all layers' activations of the current micro-batch
if
max_outstanding_backprops
is
not
None
:
checkpoint_activations_microbatch
=
(
(
i
+
num_warmup_microbatches
)
%
max_outstanding_backprops
)
>=
config
.
num_microbatches_with_partial_activation_checkpoints
else
:
checkpoint_activations_microbatch
=
None
output_tensor
,
num_tokens
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
num_microbatches
,
input_tensor
,
forward_data_store
,
config
,
collect_non_loss_data
,
checkpoint_activations_microbatch
,
check_first_val_step
(
first_val_step
,
forward_only
,
(
i
==
0
)
and
(
num_warmup_microbatches
==
0
)
),
current_microbatch
=
i
+
num_warmup_microbatches
,
encoder_decoder_xattn
=
encoder_decoder_xattn
,
)
total_num_tokens
+=
num_tokens
if
forward_only
:
send_forward
(
output_tensor
,
send_tensor_shapes
,
config
)
if
not
last_iteration
:
input_tensor
=
recv_forward
(
recv_tensor_shapes
,
config
)
else
:
output_tensor_grad
=
send_forward_recv_backward
(
output_tensor
,
send_tensor_shapes
,
config
)
# Add input_tensor and output_tensor to end of list.
input_tensors
.
append
(
input_tensor
)
output_tensors
.
append
(
output_tensor
)
deallocate_output_tensor
(
output_tensor
[
0
],
config
.
deallocate_pipeline_outputs
)
# Pop input_tensor and output_tensor from the start of the list for
# the backward pass.
input_tensor
=
input_tensors
.
pop
(
0
)
output_tensor
=
output_tensors
.
pop
(
0
)
# Enable grad sync for the last microbatch in the batch if the full
# backward pass completes in the 1F1B stage.
if
num_warmup_microbatches
==
0
and
last_iteration
:
if
config
.
grad_sync_func
is
None
or
rank
==
0
:
enable_grad_sync
()
input_tensor_grad
=
backward_step
(
input_tensor
,
output_tensor
,
output_tensor_grad
,
model_type
,
config
)
if
last_iteration
:
input_tensor
=
None
send_backward
(
input_tensor_grad
,
recv_tensor_shapes
,
config
)
else
:
input_tensor
=
send_backward_recv_forward
(
input_tensor_grad
,
recv_tensor_shapes
,
config
)
# Run cooldown backward passes.
if
not
forward_only
:
for
i
in
range
(
num_warmup_microbatches
):
# Enable async grad reduction in the last backward pass
# Note: If grad sync function is provided, only enable
# async grad reduction in first pipeline stage. Other
# pipeline stages do grad reduction during pipeline
# bubble.
if
i
==
num_warmup_microbatches
-
1
:
if
config
.
grad_sync_func
is
None
or
rank
==
0
:
enable_grad_sync
()
input_tensor
=
input_tensors
.
pop
(
0
)
output_tensor
=
output_tensors
.
pop
(
0
)
output_tensor_grad
=
recv_backward
(
send_tensor_shapes
,
config
)
input_tensor_grad
=
backward_step
(
input_tensor
,
output_tensor
,
output_tensor_grad
,
model_type
,
config
)
send_backward
(
input_tensor_grad
,
recv_tensor_shapes
,
config
)
# Launch any remaining grad reductions.
if
no_sync_context
is
not
None
:
enable_grad_sync
()
if
config
.
grad_sync_func
is
not
None
:
config
.
grad_sync_func
(
model
.
parameters
())
if
config
.
finalize_model_grads_func
is
not
None
and
not
forward_only
:
# If defer_embedding_wgrad_compute is enabled we need to do the
# weight gradient GEMM's here.
finish_embedding_wgrad_compute
(
config
,
embedding_module
)
# Finalize model grads (perform full grad all-reduce / reduce-scatter for
# data parallelism, layernorm all-reduce for sequence parallelism, and
# embedding all-reduce for pipeline parallelism).
config
.
finalize_model_grads_func
(
[
model
],
total_num_tokens
if
config
.
calculate_per_token_loss
else
None
)
if
config
.
timers
is
not
None
:
config
.
timers
(
'forward-backward'
).
stop
()
if
hasattr
(
config
,
'enable_cuda_graph'
)
and
config
.
enable_cuda_graph
:
create_cudagraphs
()
return
forward_data_store
dcu_megatron/core/tensor_parallel/__init__.py
deleted
100644 → 0
View file @
649bfbdb
from
.layers
import
(
FluxColumnParallelLinear
,
FluxRowParallelLinear
,
)
\ No newline at end of file
dcu_megatron/core/tensor_parallel/layers.py
View file @
7c9dc3ec
...
@@ -740,6 +740,7 @@ class FluxColumnParallelLinear(ColumnParallelLinear):
...
@@ -740,6 +740,7 @@ class FluxColumnParallelLinear(ColumnParallelLinear):
is_expert
:
bool
=
False
,
is_expert
:
bool
=
False
,
tp_comm_buffer_name
:
str
=
None
,
# Not used
tp_comm_buffer_name
:
str
=
None
,
# Not used
disable_grad_reduce
:
bool
=
False
,
disable_grad_reduce
:
bool
=
False
,
tp_group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
):
):
super
(
FluxColumnParallelLinear
,
self
).
__init__
(
super
(
FluxColumnParallelLinear
,
self
).
__init__
(
input_size
=
input_size
,
input_size
=
input_size
,
...
@@ -757,6 +758,7 @@ class FluxColumnParallelLinear(ColumnParallelLinear):
...
@@ -757,6 +758,7 @@ class FluxColumnParallelLinear(ColumnParallelLinear):
is_expert
=
is_expert
,
is_expert
=
is_expert
,
tp_comm_buffer_name
=
tp_comm_buffer_name
,
tp_comm_buffer_name
=
tp_comm_buffer_name
,
disable_grad_reduce
=
disable_grad_reduce
,
disable_grad_reduce
=
disable_grad_reduce
,
tp_group
=
tp_group
,
)
)
# flux params
# flux params
...
@@ -961,6 +963,7 @@ class FluxRowParallelLinear(RowParallelLinear):
...
@@ -961,6 +963,7 @@ class FluxRowParallelLinear(RowParallelLinear):
keep_master_weight_for_test
:
bool
=
False
,
keep_master_weight_for_test
:
bool
=
False
,
is_expert
:
bool
=
False
,
is_expert
:
bool
=
False
,
tp_comm_buffer_name
:
str
=
None
,
# Not used
tp_comm_buffer_name
:
str
=
None
,
# Not used
tp_group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
):
):
super
(
FluxRowParallelLinear
,
self
).
__init__
(
super
(
FluxRowParallelLinear
,
self
).
__init__
(
...
@@ -974,7 +977,8 @@ class FluxRowParallelLinear(RowParallelLinear):
...
@@ -974,7 +977,8 @@ class FluxRowParallelLinear(RowParallelLinear):
stride
=
stride
,
stride
=
stride
,
keep_master_weight_for_test
=
keep_master_weight_for_test
,
keep_master_weight_for_test
=
keep_master_weight_for_test
,
is_expert
=
is_expert
,
is_expert
=
is_expert
,
tp_comm_buffer_name
=
tp_comm_buffer_name
tp_comm_buffer_name
=
tp_comm_buffer_name
,
tp_group
=
tp_group
,
)
)
# flux params
# flux params
...
...
dcu_megatron/core/transformer/transformer_config.py
View file @
7c9dc3ec
...
@@ -23,6 +23,7 @@ def transformer_config_post_init_wrapper(fn):
...
@@ -23,6 +23,7 @@ def transformer_config_post_init_wrapper(fn):
##################
##################
self
.
flux_transpose_weight
=
args
.
flux_transpose_weight
self
.
flux_transpose_weight
=
args
.
flux_transpose_weight
return
wrapper
return
wrapper
...
@@ -33,6 +34,12 @@ class ExtraTransformerConfig:
...
@@ -33,6 +34,12 @@ class ExtraTransformerConfig:
##################
##################
flux_transpose_weight
:
bool
=
False
flux_transpose_weight
:
bool
=
False
combined_1f1b
:
bool
=
False
"""If true, use combined 1F1B for communication hiding."""
combined_1f1b_recipe
:
str
=
'ep_a2a'
"""Recipe to use for combined 1F1B. Currently only 'ep_a2a' and 'golden' are supported."""
@
dataclass
@
dataclass
class
TransformerConfigPatch
(
TransformerConfig
,
ExtraTransformerConfig
):
class
TransformerConfigPatch
(
TransformerConfig
,
ExtraTransformerConfig
):
...
...
dcu_megatron/training/arguments.py
View file @
7c9dc3ec
...
@@ -26,6 +26,8 @@ def add_megatron_arguments_patch(parser: argparse.ArgumentParser):
...
@@ -26,6 +26,8 @@ def add_megatron_arguments_patch(parser: argparse.ArgumentParser):
parser
=
_add_extra_training_args
(
parser
)
parser
=
_add_extra_training_args
(
parser
)
parser
=
_add_extra_distributed_args
(
parser
)
parser
=
_add_extra_distributed_args
(
parser
)
parser
=
_add_extra_tokenizer_args
(
parser
)
parser
=
_add_extra_tokenizer_args
(
parser
)
parser
=
_add_extra_moe_args
(
parser
)
parser
=
_add_flux_args
(
parser
)
return
parser
return
parser
...
@@ -128,6 +130,18 @@ def _add_extra_tokenizer_args(parser):
...
@@ -128,6 +130,18 @@ def _add_extra_tokenizer_args(parser):
return
parser
return
parser
def
_add_extra_moe_args
(
parser
):
group
=
parser
.
add_argument_group
(
title
=
"extra moe args"
)
group
.
add_argument
(
'--combined-1f1b'
,
action
=
'store_true'
,
help
=
'Batch-level overlapping in 1f1b stage.'
)
group
.
add_argument
(
'--combined-1f1b-recipe'
,
type
=
str
,
choices
=
[
'ep_a2a'
,
'golden'
],
default
=
'golden'
,
help
=
'Options are "ep_a2a" and "golden".'
)
return
parser
def
_add_flux_args
(
parser
):
def
_add_flux_args
(
parser
):
group
=
parser
.
add_argument_group
(
title
=
'flux args'
)
group
=
parser
.
add_argument_group
(
title
=
'flux args'
)
group
.
add_argument
(
'--flux-transpose-weight'
,
action
=
'store_true'
,
default
=
False
,
group
.
add_argument
(
'--flux-transpose-weight'
,
action
=
'store_true'
,
default
=
False
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment