Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wuxk1
Megatron-LM
Commits
4b097dee
Commit
4b097dee
authored
Oct 29, 2024
by
liangjing
Browse files
update to core_v0.9
parent
3aca1415
Changes
341
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
6878 additions
and
332 deletions
+6878
-332
megatron/core/models/vision/vit_layer_specs.py
megatron/core/models/vision/vit_layer_specs.py
+95
-0
megatron/core/num_microbatches_calculator.py
megatron/core/num_microbatches_calculator.py
+491
-0
megatron/core/optimizer/__init__.py
megatron/core/optimizer/__init__.py
+445
-0
megatron/core/optimizer/clip_grads.py
megatron/core/optimizer/clip_grads.py
+193
-0
megatron/core/optimizer/distrib_optimizer.py
megatron/core/optimizer/distrib_optimizer.py
+1839
-0
megatron/core/optimizer/grad_scaler.py
megatron/core/optimizer/grad_scaler.py
+49
-27
megatron/core/optimizer/optimizer.py
megatron/core/optimizer/optimizer.py
+1067
-0
megatron/core/optimizer/optimizer_config.py
megatron/core/optimizer/optimizer_config.py
+126
-0
megatron/core/optimizer_param_scheduler.py
megatron/core/optimizer_param_scheduler.py
+297
-0
megatron/core/package_info.py
megatron/core/package_info.py
+1
-1
megatron/core/packed_seq_params.py
megatron/core/packed_seq_params.py
+14
-0
megatron/core/parallel_state.py
megatron/core/parallel_state.py
+1141
-145
megatron/core/pipeline_parallel/__init__.py
megatron/core/pipeline_parallel/__init__.py
+1
-0
megatron/core/pipeline_parallel/p2p_communication.py
megatron/core/pipeline_parallel/p2p_communication.py
+113
-54
megatron/core/pipeline_parallel/schedules.py
megatron/core/pipeline_parallel/schedules.py
+395
-104
megatron/core/requirements.txt
megatron/core/requirements.txt
+2
-1
megatron/core/ssm/__init__.py
megatron/core/ssm/__init__.py
+0
-0
megatron/core/ssm/mamba_block.py
megatron/core/ssm/mamba_block.py
+337
-0
megatron/core/ssm/mamba_hybrid_layer_allocation.py
megatron/core/ssm/mamba_hybrid_layer_allocation.py
+191
-0
megatron/core/ssm/mamba_layer.py
megatron/core/ssm/mamba_layer.py
+81
-0
No files found.
Too many changes to show.
To preserve performance only
341 of 341+
files are displayed.
Plain diff
Email patch
megatron/core/models/vision/vit_layer_specs.py
0 → 100644
View file @
4b097dee
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from
megatron.core.extensions.transformer_engine
import
(
TEDotProductAttention
,
TELayerNormColumnParallelLinear
,
TERowParallelLinear
,
)
from
megatron.core.fusions.fused_bias_dropout
import
get_bias_dropout_add
from
megatron.core.tensor_parallel.layers
import
ColumnParallelLinear
,
RowParallelLinear
from
megatron.core.transformer.attention
import
SelfAttention
,
SelfAttentionSubmodules
from
megatron.core.transformer.dot_product_attention
import
DotProductAttention
from
megatron.core.transformer.enums
import
AttnMaskType
from
megatron.core.transformer.identity_op
import
IdentityOp
from
megatron.core.transformer.mlp
import
MLP
,
MLPSubmodules
from
megatron.core.transformer.spec_utils
import
ModuleSpec
from
megatron.core.transformer.transformer_layer
import
TransformerLayer
,
TransformerLayerSubmodules
try
:
import
apex
# pylint: disable=unused-import
from
megatron.core.fusions.fused_layer_norm
import
FusedLayerNorm
HAVE_APEX
=
True
LNImpl
=
FusedLayerNorm
except
ImportError
:
import
warnings
from
megatron.core.transformer.torch_layer_norm
import
WrappedTorchLayerNorm
warnings
.
warn
(
f
'Apex is not installed. Falling back to Torch LayerNorm'
)
LNImpl
=
WrappedTorchLayerNorm
# Use this spec to use lower level Transformer Engine modules (required for fp8 training)
def
get_vit_layer_with_transformer_engine_spec
()
->
ModuleSpec
:
'''
Returns ViT layer spec with Transformer Engine layers
'''
mlp
=
_get_mlp_module_spec
(
use_te
=
True
)
return
ModuleSpec
(
module
=
TransformerLayer
,
submodules
=
TransformerLayerSubmodules
(
self_attention
=
ModuleSpec
(
module
=
SelfAttention
,
params
=
{
"attn_mask_type"
:
AttnMaskType
.
no_mask
},
submodules
=
SelfAttentionSubmodules
(
linear_qkv
=
TELayerNormColumnParallelLinear
,
core_attention
=
TEDotProductAttention
,
linear_proj
=
TERowParallelLinear
,
),
),
self_attn_bda
=
get_bias_dropout_add
,
pre_mlp_layernorm
=
IdentityOp
,
mlp
=
mlp
,
mlp_bda
=
get_bias_dropout_add
,
),
)
def
get_vit_layer_with_local_spec
()
->
ModuleSpec
:
'''
Returns ViT layer spec with Mcore local layers
'''
mlp
=
_get_mlp_module_spec
(
use_te
=
False
)
return
ModuleSpec
(
module
=
TransformerLayer
,
submodules
=
TransformerLayerSubmodules
(
input_layernorm
=
LNImpl
,
self_attention
=
ModuleSpec
(
module
=
SelfAttention
,
params
=
{
"attn_mask_type"
:
AttnMaskType
.
causal
},
submodules
=
SelfAttentionSubmodules
(
linear_qkv
=
ColumnParallelLinear
,
core_attention
=
DotProductAttention
,
linear_proj
=
RowParallelLinear
,
),
),
self_attn_bda
=
get_bias_dropout_add
,
pre_mlp_layernorm
=
LNImpl
,
mlp
=
mlp
,
mlp_bda
=
get_bias_dropout_add
,
),
)
# Helper function to get module spec for MLP/MoE
def
_get_mlp_module_spec
(
use_te
:
bool
=
True
)
->
ModuleSpec
:
# Dense MLP w/ or w/o TE modules.
return
ModuleSpec
(
module
=
MLP
,
submodules
=
MLPSubmodules
(
linear_fc1
=
TELayerNormColumnParallelLinear
if
use_te
else
ColumnParallelLinear
,
linear_fc2
=
TERowParallelLinear
if
use_te
else
RowParallelLinear
,
),
)
megatron/core/num_microbatches_calculator.py
0 → 100644
View file @
4b097dee
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Megatron Core number of microbatches calculators."""
import
logging
from
abc
import
ABC
,
abstractmethod
from
typing
import
List
,
Optional
,
Union
logger
=
logging
.
getLogger
(
__name__
)
# TODO: global_var merge into mcore?
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
:
Union
[
'ConstantNumMicroBatchesCalculator'
,
'RampupBatchsizeNumMicroBatchesCalculator'
]
=
None
def
get_num_microbatches
()
->
int
:
"""Get number of microbatches."""
return
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
.
get
()
def
get_current_global_batch_size
()
->
int
:
"""Get current global batch size."""
return
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
.
get_current_global_batch_size
()
def
get_micro_batch_size
()
->
int
:
"""Get micro batch size."""
return
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
.
get_micro_batch_size
()
def
get_current_running_global_batch_size
()
->
int
:
"""Get current running global batch size, taking into account number of DP replicas might be
incompatible with true global batch size if `decrease_batch_size_if_needed` is True."""
return
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
.
get_current_running_global_batch_size
()
def
update_num_microbatches
(
consumed_samples
:
int
,
consistency_check
:
bool
=
True
,
verbose
:
bool
=
False
)
->
None
:
"""Update number of microbatches.
Args:
consumed_samples (int):
Number of samples consumed.
consistency_check (bool, optional):
Option to check current schedule's consistency. Defaults to True.
verbose (bool, optional):
Option to control logging. Defaults to False.
"""
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
.
update
(
consumed_samples
,
consistency_check
,
verbose
)
def
init_num_microbatches_calculator
(
rank
:
int
,
rampup_batch_size
:
Optional
[
List
[
int
]],
global_batch_size
:
int
,
micro_batch_size
:
int
,
data_parallel_size
:
int
,
decrease_batch_size_if_needed
:
bool
=
False
,
)
->
None
:
"""Initialize number of microbatches calculator. Supporting backward compatibility.
Args:
rank (int):
Rank of the GPU, only rank 0 will log the information.
rampup_batch_size (Optional[List[int]]):
Rampup batch size, should be in format of [start_global_batch_size,
batch_size_increment, ramup_samples].
global_batch_size (int):
Global batch size for the model.
micro_batch_size (int):
Micro batch size at initialization.
data_parallel_size (int):
Data parallel size.
decrease_batch_size_if_needed (bool, optional):
If true, scale down batch size to ensure divisibility by DP size * microbatch size.
Defaults to False.
"""
_configure_global_num_microbatches_calculator
(
rank
,
rampup_batch_size
,
global_batch_size
,
micro_batch_size
,
data_parallel_size
,
decrease_batch_size_if_needed
,
init
=
True
,
)
def
destroy_num_microbatches_calculator
():
"""Destroy number of microbatches calculator."""
global
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
=
None
def
reconfigure_num_microbatches_calculator
(
rank
:
int
,
rampup_batch_size
:
Optional
[
List
[
int
]],
global_batch_size
:
int
,
micro_batch_size
:
int
,
data_parallel_size
:
int
,
decrease_batch_size_if_needed
:
bool
=
False
,
)
->
None
:
"""Reconfigure number of microbatches calculator. Supporting backward compatibility.
Args:
rank (int):
Rank of the GPU, only rank 0 will log the information.
rampup_batch_size (Optional[List[int]]):
Rampup batch size, should be in format of
[start_global_batch_size, batch_size_increment, ramup_samples].
global_batch_size (int):
Global batch size for the model.
micro_batch_size (int):
Micro batch size at initialization.
data_parallel_size (int):
Data parallel size.
decrease_batch_size_if_needed (bool, optional):
If true, scale down batch size to ensure divisibility by DP size * microbatch size.
Defaults to False.
"""
_configure_global_num_microbatches_calculator
(
rank
,
rampup_batch_size
,
global_batch_size
,
micro_batch_size
,
data_parallel_size
,
decrease_batch_size_if_needed
,
init
=
False
,
)
def
_configure_global_num_microbatches_calculator
(
rank
:
int
,
rampup_batch_size
:
Optional
[
List
[
int
]],
global_batch_size
:
int
,
micro_batch_size
:
int
,
data_parallel_size
:
int
,
decrease_batch_size_if_needed
:
bool
=
False
,
init
:
bool
=
False
,
)
->
None
:
"""Configure number of microbatches calculator. Can be used for initialization and
reconfiguration.
Args:
rank (int):
Rank of the GPU, only rank 0 will log the information.
rampup_batch_size (Optional[List[int]]):
Rampup batch size, should be in format of
[start_global_batch_size, batch_size_increment, ramup_samples].
global_batch_size (int):
Global batch size for the model.
micro_batch_size (int):
Micro batch size at initialization.
data_parallel_size (int):
Data parallel size.
decrease_batch_size_if_needed (bool, optional):
If true, scale down batch size to ensure divisibility by DP size * microbatch size.
Defaults to False.
init (bool, optional):
If true, initialize the calculator. Defaults to False.
"""
global
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
if
init
:
assert
(
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
is
None
),
'num microbatches calculator is already initialized.'
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
=
_build_num_microbatches_calculator
(
rank
,
rampup_batch_size
,
global_batch_size
,
micro_batch_size
,
data_parallel_size
,
decrease_batch_size_if_needed
,
)
def
_build_num_microbatches_calculator
(
rank
:
int
,
rampup_batch_size
:
Optional
[
List
[
int
]],
global_batch_size
:
int
,
micro_batch_size
:
int
,
data_parallel_size
:
int
,
decrease_batch_size_if_needed
:
bool
,
)
->
Union
[
'ConstantNumMicroBatchesCalculator'
,
'RampupBatchsizeNumMicroBatchesCalculator'
]:
"""Build number of microbatches calculator. Internal helper method.
Args:
rank (int):
Rank of the GPU, only rank 0 will log the information.
rampup_batch_size (Optional[List[int]]):
Rampup batch size, should be in format of
[start_global_batch_size, batch_size_increment, ramup_samples].
global_batch_size (int):
Global batch size for the model.
micro_batch_size (int):
Micro batch size at initialization.
data_parallel_size (int):
Data parallel size.
decrease_batch_size_if_needed (bool):
If true, scale down batch size to ensure divisibility by DP size * microbatch size.
"""
# Constant batch size.
if
rampup_batch_size
is
None
:
num_microbatches_calculator
=
ConstantNumMicroBatchesCalculator
(
global_batch_size
,
micro_batch_size
,
data_parallel_size
,
decrease_batch_size_if_needed
,
rank
,
)
if
rank
==
0
:
logger
.
info
(
f
'setting number of microbatches to constant
{
num_microbatches_calculator
.
get
()
}
'
)
# Batch size ramp up.
else
:
assert
len
(
rampup_batch_size
)
==
3
,
(
'expected the following '
'format: --rampup-batch-size <start batch size> '
'<batch size incerement> <ramp-up samples>'
)
start_global_batch_size
=
int
(
rampup_batch_size
[
0
])
batch_size_increment
=
int
(
rampup_batch_size
[
1
])
ramup_samples
=
int
(
rampup_batch_size
[
2
])
if
rank
==
0
:
logger
.
info
(
f
'will use batch size rampup starting from global batch size '
f
'
{
start_global_batch_size
}
to global batch size
{
global_batch_size
}
with batch'
f
'size increments
{
batch_size_increment
}
over
{
ramup_samples
}
samples.'
)
num_microbatches_calculator
=
RampupBatchsizeNumMicroBatchesCalculator
(
global_batch_size
,
micro_batch_size
,
data_parallel_size
,
decrease_batch_size_if_needed
,
rank
,
start_global_batch_size
,
batch_size_increment
,
ramup_samples
,
)
return
num_microbatches_calculator
def
_round
(
batch_size
:
int
,
divisor
:
int
)
->
int
:
"""Round `batch_size` down to nearest batch size divisible by `divisor`."""
return
(
batch_size
//
divisor
)
*
divisor
class
NumMicroBatchesCalculator
(
ABC
):
"""Base class for number of microbatches calculator."""
def
__init__
(
self
)
->
None
:
self
.
num_micro_batches
=
None
self
.
current_global_batch_size
=
None
self
.
micro_batch_size
=
None
self
.
current_running_global_batch_size
=
None
def
get
(
self
)
->
int
:
"""Get number of microbatches."""
return
self
.
num_micro_batches
def
get_current_global_batch_size
(
self
)
->
int
:
"""Get current global batch size."""
return
self
.
current_global_batch_size
def
get_micro_batch_size
(
self
)
->
int
:
"""Get current global batch size."""
return
self
.
micro_batch_size
def
get_current_running_global_batch_size
(
self
)
->
int
:
"""Get current running global batch size. If decrease_batch_size_if_needed is False,
this just equals global batch size."""
return
self
.
current_running_global_batch_size
@
abstractmethod
def
update
(
self
,
consumed_samples
,
consistency_check
,
verbose
=
False
)
->
None
:
"""Update number of microbatches depending on batch size rampup."""
pass
class
ConstantNumMicroBatchesCalculator
(
NumMicroBatchesCalculator
):
"""Calculator of number of microbatches with constant global batch size.
Args:
global_batch_size (int):
Global batch size.
micro_batch_size (int):
Micro batch size.
data_parallel_size (int):
Data parallel size.
decrease_batch_size_if_needed (bool):
If true, decrease batch size to ensure divisibility by DP size * microbatch size
(if needed).
rank (int):
Rank (to determine whether logging should be performed).
"""
def
__init__
(
self
,
global_batch_size
:
int
,
micro_batch_size
:
int
,
data_parallel_size
:
int
,
decrease_batch_size_if_needed
:
bool
,
rank
:
int
,
)
->
None
:
micro_batch_times_data_parallel_size
=
micro_batch_size
*
data_parallel_size
if
decrease_batch_size_if_needed
:
running_global_batch_size
=
_round
(
global_batch_size
,
micro_batch_times_data_parallel_size
)
assert
running_global_batch_size
%
micro_batch_times_data_parallel_size
==
0
if
rank
==
0
:
logger
.
info
(
f
'decreasing batch size from
{
global_batch_size
}
to
{
running_global_batch_size
}
'
)
self
.
num_micro_batches
=
(
running_global_batch_size
//
micro_batch_times_data_parallel_size
)
else
:
assert
global_batch_size
%
micro_batch_times_data_parallel_size
==
0
,
(
'global batch size ({}) is not divisible by micro batch size ({})'
' times data parallel size ({})'
.
format
(
global_batch_size
,
micro_batch_size
,
data_parallel_size
)
)
running_global_batch_size
=
global_batch_size
self
.
num_micro_batches
=
global_batch_size
//
micro_batch_times_data_parallel_size
assert
(
self
.
num_micro_batches
>=
1
),
'number of microbatches should be at least 1, got {}.'
.
format
(
self
.
num_micro_batches
)
self
.
current_global_batch_size
=
global_batch_size
self
.
current_running_global_batch_size
=
running_global_batch_size
self
.
micro_batch_size
=
micro_batch_size
def
update
(
self
,
consumed_samples
,
consistency_check
,
verbose
=
False
)
->
None
:
pass
class
RampupBatchsizeNumMicroBatchesCalculator
(
NumMicroBatchesCalculator
):
"""Calculator of number of microbatches with batch size rampup.
Over `steps = (global-batch-size - start-batch-size) / batch_size_increment` increment batch
size from start-batch-size to global-batch-size using rampup-samples / steps
samples.
Args:
global_batch_size (int):
Global batch size post rampup.
micro_batch_size (int):
Micro batch size.
data_parallel_size (int):
Data parallel size.
decrease_batch_size_if_needed (bool):
If true, decrease batch size to ensure divisibility by DP size * microbatch size
(if needed).
rank (int):
Rank (to determine whether logging should be performed).
start_global_batch_size (int):
Global batch size to start with.
batch_size_increment (int):
Global batch size increments.
ramup_samples (int):
Number of samples to use ramp up global
batch size from `start_global_batch_size` to `global_batch_size`.
"""
def
__init__
(
self
,
global_batch_size
:
int
,
micro_batch_size
:
int
,
data_parallel_size
:
int
,
decrease_batch_size_if_needed
:
bool
,
rank
:
int
,
start_global_batch_size
:
int
,
batch_size_increment
:
int
,
ramup_samples
:
int
,
)
->
None
:
assert
global_batch_size
>
0
,
'global batch size should be positive, got {}.'
.
format
(
global_batch_size
)
assert
start_global_batch_size
>
0
,
'start batch size should be positive, got {}.'
.
format
(
start_global_batch_size
)
assert
batch_size_increment
>
0
,
'batch size increment should be positive, got {}.'
.
format
(
batch_size_increment
)
assert
ramup_samples
>=
0
,
'ramp-up samples should be non-negative, got {}.'
.
format
(
ramup_samples
)
self
.
global_batch_size
=
global_batch_size
self
.
micro_batch_size
=
micro_batch_size
self
.
data_parallel_size
=
data_parallel_size
self
.
decrease_batch_size_if_needed
=
decrease_batch_size_if_needed
self
.
rank
=
rank
self
.
start_global_batch_size
=
start_global_batch_size
self
.
batch_size_increment
=
batch_size_increment
self
.
ramup_samples
=
ramup_samples
self
.
micro_batch_times_data_parallel_size
=
self
.
micro_batch_size
*
self
.
data_parallel_size
assert
self
.
micro_batch_times_data_parallel_size
>
0
self
.
current_global_batch_size
=
None
diff_batch_size
=
self
.
global_batch_size
-
self
.
start_global_batch_size
assert
diff_batch_size
>=
0
,
(
'expected global batch size to be greater than or equal to start batch size, '
f
'got
{
self
.
global_batch_size
}
and
{
self
.
start_global_batch_size
}
'
)
assert
diff_batch_size
%
batch_size_increment
==
0
,
(
'expected '
f
'global batch size interval (
{
diff_batch_size
}
) to be divisible by global batch '
f
'size increment (
{
batch_size_increment
}
)'
)
num_increments
=
diff_batch_size
//
self
.
batch_size_increment
self
.
rampup_samples_per_increment
=
self
.
ramup_samples
/
num_increments
# Initialize number of microbatches.
self
.
update
(
0
,
False
)
def
update
(
self
,
consumed_samples
:
int
,
consistency_check
:
bool
,
verbose
:
bool
=
False
)
->
None
:
"""Update number of microbatches.
Args:
consumed_samples (int): Number of samples consumed.
consistency_check (bool): Option to check current schedule's consistency.
verbose (bool, optional): Option to control logging. Defaults to False.
"""
# Update current global batch size.
global_batch_size_changed
=
False
old_current_global_batch_size
=
self
.
current_global_batch_size
if
consumed_samples
>
self
.
ramup_samples
:
self
.
current_global_batch_size
=
self
.
global_batch_size
else
:
steps
=
int
(
consumed_samples
/
self
.
rampup_samples_per_increment
)
self
.
current_global_batch_size
=
(
self
.
start_global_batch_size
+
steps
*
self
.
batch_size_increment
)
assert
self
.
current_global_batch_size
<=
self
.
global_batch_size
if
old_current_global_batch_size
!=
self
.
current_global_batch_size
:
global_batch_size_changed
=
True
if
self
.
rank
==
0
and
global_batch_size_changed
and
verbose
:
logger
.
info
(
f
'ramping up batch size from
{
old_current_global_batch_size
}
to '
f
'
{
self
.
current_global_batch_size
}
'
)
# Check consistency of the current global batch size.
if
consistency_check
and
not
self
.
decrease_batch_size_if_needed
:
assert
(
self
.
current_global_batch_size
%
self
.
micro_batch_times_data_parallel_size
==
0
),
(
'current global '
'batch size ({}) is not divisible by micro-batch-size ({}) times'
'data parallel size ({})'
.
format
(
self
.
current_global_batch_size
,
self
.
micro_batch_size
,
self
.
data_parallel_size
)
)
if
(
self
.
decrease_batch_size_if_needed
and
self
.
current_global_batch_size
%
self
.
micro_batch_times_data_parallel_size
!=
0
):
self
.
current_running_global_batch_size
=
_round
(
self
.
current_global_batch_size
,
self
.
micro_batch_times_data_parallel_size
)
if
self
.
rank
==
0
and
global_batch_size_changed
and
verbose
:
logger
.
info
(
f
'decreasing batch size from
{
self
.
current_global_batch_size
}
to '
f
'
{
self
.
current_running_global_batch_size
}
'
)
assert
(
self
.
current_running_global_batch_size
%
self
.
micro_batch_times_data_parallel_size
==
0
)
else
:
self
.
current_running_global_batch_size
=
self
.
current_global_batch_size
self
.
num_micro_batches
=
(
self
.
current_running_global_batch_size
//
self
.
micro_batch_times_data_parallel_size
)
megatron/core/optimizer/__init__.py
0 → 100644
View file @
4b097dee
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import
logging
from
typing
import
Callable
,
Dict
,
List
,
Optional
,
Tuple
import
torch
try
:
from
transformer_engine.pytorch.optimizers
import
FusedAdam
as
Adam
from
transformer_engine.pytorch.optimizers
import
FusedSGD
as
SGD
except
ImportError
:
try
:
from
apex.optimizers
import
FusedAdam
as
Adam
from
apex.optimizers
import
FusedSGD
as
SGD
except
ImportError
:
import
warnings
warnings
.
warn
(
f
'Transformer Engine and Apex are not installed. Falling back to Torch optimizers.'
)
# Apex's FusedAdam is a drop-in replacement for torch's AdamW.
# pylint: disable-next=line-too-long.
# See https://github.com/NVIDIA/apex/blob/7b73b12361068a10b0f44844534613f252a5ea75/apex/optimizers/fused_adam.py#L16.
from
torch.optim
import
AdamW
as
Adam
,
SGD
from
megatron.core
import
mpu
from
..distributed.param_and_grad_buffer
import
_ParamAndGradBuffer
from
..transformer.module
import
MegatronModule
from
..utils
import
log_single_rank
from
.distrib_optimizer
import
DistributedOptimizer
from
.grad_scaler
import
ConstantGradScaler
,
DynamicGradScaler
from
.optimizer
import
(
ChainedOptimizer
,
Float16OptimizerWithFloat16Params
,
FP32Optimizer
,
MegatronOptimizer
,
)
from
.optimizer_config
import
OptimizerConfig
logger
=
logging
.
getLogger
(
__name__
)
def
_get_param_groups
(
model_chunks
:
List
[
MegatronModule
],
no_weight_decay_cond
:
Optional
[
Callable
],
scale_lr_cond
:
Optional
[
Callable
],
lr_mult
:
float
,
lr
:
float
,
min_lr
:
float
,
decoupled_lr
:
Optional
[
float
],
decoupled_min_lr
:
Optional
[
float
],
)
->
List
[
Dict
]:
"""Create parameter groups for optimizer.
Creates parameter groups based on weight decay condition (regularized vs
non regularized), learning rate scale condition (lr vs lr_mult * lr),
and whether it is expert parameters. scale_lr_cond is used during finetuning
where head of the network requires a scaled version of the base learning rate.
Args:
model_chunks (List[MegatronModule]): model chunks to create parameter
groups for.
no_weight_decay_cond (func, optional): function to determine whether a
parameter should not perform weight decay.
scale_lr_cond (func, optional): function to determine whether a parameter
should have a scaled learning rate.
lr_mult (float): learning rate multiplier for parameters that
satisfy scale_lr_cond.
lr (float): learning rate.
min_lr (float): minimum learning rate.
decoupled_lr (Optional[float]): optional decoupled learning rate.
decoupled_min_lr (Optional[float]): optional decoupled minimum learning rate.
Returns:
List of parameter groups.
"""
use_decoupled_learning_rate
=
decoupled_lr
is
not
None
# Map (wd_mult, lr_mult, is_expert_parallel, is_decoupled_lr) to params.
params_map
=
{}
for
model_chunk
in
model_chunks
:
for
name
,
param
in
model_chunk
.
named_parameters
():
if
not
param
.
requires_grad
:
continue
is_expert_parallel
=
not
getattr
(
param
,
'allreduce'
,
True
)
if
no_weight_decay_cond
is
not
None
:
no_wd
=
no_weight_decay_cond
(
name
,
param
)
else
:
# Do not regularize biases and norm parameters.
no_wd
=
name
.
endswith
(
".bias"
)
or
len
(
param
.
shape
)
==
1
if
scale_lr_cond
is
not
None
:
scale_lr
=
scale_lr_cond
(
name
,
param
)
else
:
scale_lr
=
False
if
not
no_wd
and
not
scale_lr
:
wd_mult
,
_lr_mult
=
1.0
,
1.0
elif
not
no_wd
and
scale_lr
:
wd_mult
,
_lr_mult
=
1.0
,
lr_mult
elif
no_wd
and
not
scale_lr
:
wd_mult
,
_lr_mult
=
0.0
,
1.0
else
:
wd_mult
,
_lr_mult
=
0.0
,
lr_mult
is_decoupled_lr
=
False
# For input/embedding and output layer: embedding.word_embeddings.weight /
# output_layer.weight.
if
use_decoupled_learning_rate
and
getattr
(
param
,
'is_embedding_or_output_parameter'
,
False
):
is_decoupled_lr
=
True
key
=
(
wd_mult
,
_lr_mult
,
is_expert_parallel
,
is_decoupled_lr
)
if
key
not
in
params_map
:
params_map
[
key
]
=
[]
params_map
[
key
].
append
(
param
)
param_groups
=
[]
for
(
wd_mult
,
_lr_mult
,
is_expert_parallel
,
is_decoupled_lr
),
params
in
params_map
.
items
():
assert
len
(
params
)
>
0
param_group
=
{
'params'
:
params
,
'wd_mult'
:
wd_mult
,
'lr_mult'
:
_lr_mult
,
'is_expert_parallel'
:
is_expert_parallel
,
'is_decoupled_lr'
:
is_decoupled_lr
,
}
param_groups
.
append
(
param_group
)
param_groups
=
_update_min_and_max_lr_in_param_groups
(
param_groups
,
lr
=
lr
,
min_lr
=
min_lr
,
decoupled_lr
=
decoupled_lr
,
decoupled_min_lr
=
decoupled_min_lr
,
)
return
param_groups
def
_update_min_and_max_lr_in_param_groups
(
param_groups
:
List
[
Dict
],
lr
:
float
,
min_lr
:
float
,
decoupled_lr
:
Optional
[
float
],
decoupled_min_lr
:
Optional
[
float
],
)
->
List
[
Dict
]:
"""
Updates `max_lr` and `min_lr` values in each parameter group, and returns new list.
By default, each group will use `lr` / `min_lr` as `max_lr` / `min_lr`.
If `decoupled_lr` is provided, then `decoupled_lr` / `decoupled_min_lr` will be used
as `max_lr` / `min_lr` for the input and output layer.
Args:
param_groups (List): parameter groups whose 'max_lr' and `min_lr` fields need to
be adjusted.
lr (float): learning rate.
min_lr (float): minimum learning rate.
decoupled_lr (Optional[float]): optional decoupled learning rate.
decoupled_min_lr (Optional[float]): optional decoupled minimum learning rate.
Returns:
List of adjusted parameter groups.
"""
if
decoupled_min_lr
is
None
:
decoupled_min_lr
=
min_lr
for
param_group
in
param_groups
:
if
param_group
[
'is_decoupled_lr'
]:
assert
decoupled_lr
is
not
None
param_group
[
'max_lr'
]
=
decoupled_lr
param_group
[
'min_lr'
]
=
decoupled_min_lr
else
:
param_group
[
'max_lr'
]
=
lr
param_group
[
'min_lr'
]
=
min_lr
return
param_groups
def
_get_param_groups_and_buffers
(
model_chunks
:
List
[
MegatronModule
],
model_chunk_offset
:
int
,
config
:
OptimizerConfig
,
no_weight_decay_cond
:
Optional
[
Callable
],
scale_lr_cond
:
Optional
[
Callable
],
lr_mult
:
float
,
filter_fn
:
Callable
,
buffer_name
:
str
,
)
->
Tuple
[
List
[
Dict
],
Dict
[
int
,
List
[
_ParamAndGradBuffer
]]]:
"""Returns parameter groups and buffer for optimizer.
Args:
model_chunks (List[MegatronModule]): model chunks to create parameter
groups for.
model_chunk_offset (int): offset of model_chunks in global model_chunks list.
config (OptimizerConfig): optimizer configuration object.
no_weight_decay_cond (func, optional): function to determine whether a
parameter should not perform weight decay.
scale_lr_cond (func, optional): function to determine whether a parameter
should have a scaled learning rate.
lr_mult (float): learning rate multiplier for parameters that
satisfy scale_lr_cond.
lr (float): learning rate.
min_lr (float): minimum learning rate.
filter_fn (callable): filtering function for param_groups.
buffer_name (str): name of buffer.
Returns:
List of parameter groups and dictionary of model chunk IDs to buffers.
"""
param_groups
=
_get_param_groups
(
model_chunks
,
no_weight_decay_cond
,
scale_lr_cond
,
lr_mult
,
lr
=
config
.
lr
,
min_lr
=
config
.
min_lr
,
decoupled_lr
=
config
.
decoupled_lr
,
decoupled_min_lr
=
config
.
decoupled_min_lr
,
)
param_groups
=
list
(
filter
(
filter_fn
,
param_groups
))
buffers
=
{}
for
model_chunk_idx
,
model_chunk
in
enumerate
(
model_chunks
):
if
hasattr
(
model_chunk
,
buffer_name
):
buffers
[
model_chunk_idx
+
model_chunk_offset
]
=
getattr
(
model_chunk
,
buffer_name
)
return
param_groups
,
buffers
def
_get_megatron_optimizer_based_on_param_groups
(
config
:
OptimizerConfig
,
model_chunks
:
List
[
MegatronModule
],
param_groups
:
List
,
per_model_buffers
:
Optional
[
Dict
[
int
,
List
[
_ParamAndGradBuffer
]]]
=
None
,
model_parallel_group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
data_parallel_group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
data_parallel_group_gloo
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
data_parallel_group_idx
:
Optional
[
int
]
=
None
,
)
->
MegatronOptimizer
:
"""Get Megatron optimizer based on parameter groups.
Args:
config (OptimizerConfig): optimizer configuration object.
model_chunks (list): list of model chunks.
param_groups (list): list of parameter groups.
per_model_buffers (dict, optional): buffers for distributed optimizer. Defaults to None.
data_parallel_group (torch.distributed.ProcessGroup, optional): data-parallel group for
distributed optimizer. Defaults to None.
data_parallel_group_gloo (torch.distributed.ProcessGroup, optional): gloo data-parallel
group for distributed optimizer. Defaults to None.
data_parallel_group_idx (int, optional): data-parallel group index for distributed
optimizer. Defaults to None.
Returns:
Instance of MegatronOptimizer.
"""
if
config
.
optimizer
==
'adam'
:
optimizer
=
Adam
(
param_groups
,
lr
=
config
.
lr
,
weight_decay
=
config
.
weight_decay
,
betas
=
(
config
.
adam_beta1
,
config
.
adam_beta2
),
eps
=
config
.
adam_eps
,
)
def
init_state_fn
(
opt
):
for
group
in
opt
.
param_groups
:
for
p
in
group
[
'params'
]:
if
len
(
opt
.
state
[
p
])
==
0
:
opt
.
state
[
p
][
'exp_avg'
]
=
torch
.
zeros_like
(
p
.
data
)
opt
.
state
[
p
][
'exp_avg_sq'
]
=
torch
.
zeros_like
(
p
.
data
)
elif
config
.
optimizer
==
'sgd'
:
optimizer
=
SGD
(
param_groups
,
lr
=
config
.
lr
,
weight_decay
=
config
.
weight_decay
,
momentum
=
config
.
sgd_momentum
,
)
init_state_fn
=
None
else
:
raise
Exception
(
'{} optimizer is not supported.'
.
format
(
config
.
optimizer
))
# Mixed precision optimizer.
# - Note: both the Float16Optimizer and the DistributedOptimizer inherit
# from the MixedPrecisionOptimizer, which manages any optimizer where
# the model params and main params are distinct.
if
config
.
fp16
or
config
.
bf16
or
config
.
use_distributed_optimizer
:
# Grad scaler:
# if loss-scale is provided, instantiate the constant scaler.
# if we are using fp16 and loss-scale is not present, use a
# dynamic scaler.
# otherwise we are running in bf16 with no loss-scale so
# leave it as None.
grad_scaler
=
None
# Constant loss scale.
if
config
.
loss_scale
:
grad_scaler
=
ConstantGradScaler
(
config
.
loss_scale
)
# Dynamic loss scale.
else
:
if
config
.
fp16
:
grad_scaler
=
DynamicGradScaler
(
initial_scale
=
config
.
initial_loss_scale
,
min_scale
=
config
.
min_loss_scale
,
growth_factor
=
2.0
,
backoff_factor
=
0.5
,
growth_interval
=
config
.
loss_scale_window
,
hysteresis
=
config
.
hysteresis
,
)
optimizer_args
=
[
optimizer
,
config
,
grad_scaler
,
init_state_fn
]
if
config
.
use_distributed_optimizer
:
optimizer
=
DistributedOptimizer
(
*
optimizer_args
,
model_chunks
=
model_chunks
,
per_model_buffers
=
per_model_buffers
,
data_parallel_group
=
data_parallel_group
,
data_parallel_group_gloo
=
data_parallel_group_gloo
,
data_parallel_group_idx
=
data_parallel_group_idx
,
)
else
:
optimizer
=
Float16OptimizerWithFloat16Params
(
*
optimizer_args
)
setattr
(
optimizer
,
'model_parallel_group'
,
model_parallel_group
)
else
:
# FP32 optimizer.
optimizer
=
FP32Optimizer
(
optimizer
,
config
,
init_state_fn
)
setattr
(
optimizer
,
'model_parallel_group'
,
model_parallel_group
)
return
optimizer
def
get_megatron_optimizer
(
config
:
OptimizerConfig
,
model_chunks
:
List
[
MegatronModule
],
no_weight_decay_cond
:
Optional
[
Callable
]
=
None
,
scale_lr_cond
:
Optional
[
Callable
]
=
None
,
lr_mult
:
float
=
1.0
,
)
->
MegatronOptimizer
:
"""Retrieve the Megatron optimizer for model chunks.
We use separate optimizers for expert parameters and non-expert parameters.
Args:
config (OptimizerConfig): optimizer configuration object.
model_chunks (List[MegatronModule]): model chunks to get optimizer for.
no_weight_decay_cond (func, optional): function to determine whether a parameter
should not perform weight decay. Defaults to None.
scale_lr_cond (func, optional): function to determine whether a parameter
should have a scaled learning rate. Defaults to None.
lr_mult (float, optional): learning rate multiplier for parameters that
satisfy scale_lr_cond. Defaults to 1.0.
Returns:
Instance of MegatronOptimizer.
"""
log_single_rank
(
logger
,
logging
.
INFO
,
f
'Setting up optimizer with config
{
config
}
'
)
# Separate out first model chunk if overlapping param AG with optimizer step.
if
config
.
overlap_param_gather_with_optimizer_step
:
all_dense_model_chunks
=
[[
model_chunks
[
0
]],
model_chunks
[
1
:]]
overlap_param_gather_with_optimizer_step_flags
=
[
True
,
False
]
else
:
all_dense_model_chunks
=
[
model_chunks
]
overlap_param_gather_with_optimizer_step_flags
=
[
False
]
model_parallel_rank
=
torch
.
distributed
.
get_rank
(
mpu
.
get_model_parallel_group
())
optimizers
=
[]
model_chunk_offset
=
0
for
dense_model_chunks
,
overlap_param_gather_with_optimizer_step
in
zip
(
all_dense_model_chunks
,
overlap_param_gather_with_optimizer_step_flags
):
param_groups
,
buffers
=
_get_param_groups_and_buffers
(
dense_model_chunks
,
model_chunk_offset
=
model_chunk_offset
,
config
=
config
,
no_weight_decay_cond
=
no_weight_decay_cond
,
scale_lr_cond
=
scale_lr_cond
,
lr_mult
=
lr_mult
,
filter_fn
=
lambda
g
:
not
g
[
'is_expert_parallel'
],
buffer_name
=
'buffers'
,
)
for
model_chunk
in
dense_model_chunks
:
model_chunk
.
overlap_param_gather_with_optimizer_step
=
(
overlap_param_gather_with_optimizer_step
)
optimizers
.
append
(
_get_megatron_optimizer_based_on_param_groups
(
config
,
model_chunks
=
dense_model_chunks
,
param_groups
=
param_groups
,
per_model_buffers
=
buffers
,
model_parallel_group
=
mpu
.
get_model_parallel_group
(),
data_parallel_group
=
mpu
.
get_data_parallel_group
(
with_context_parallel
=
True
),
data_parallel_group_gloo
=
mpu
.
get_data_parallel_group_gloo
(
with_context_parallel
=
True
),
data_parallel_group_idx
=
model_parallel_rank
,
)
)
model_chunk_offset
+=
1
moe_param_groups
,
moe_buffers
=
_get_param_groups_and_buffers
(
model_chunks
,
model_chunk_offset
=
0
,
config
=
config
,
no_weight_decay_cond
=
no_weight_decay_cond
,
scale_lr_cond
=
scale_lr_cond
,
lr_mult
=
lr_mult
,
filter_fn
=
lambda
g
:
g
[
'is_expert_parallel'
],
buffer_name
=
'expert_parallel_buffers'
,
)
if
len
(
moe_param_groups
)
>
0
:
model_parallel_world_size
=
torch
.
distributed
.
get_world_size
(
mpu
.
get_model_parallel_group
())
expert_parallel_rank
=
mpu
.
get_expert_model_parallel_rank
()
optimizers
.
append
(
_get_megatron_optimizer_based_on_param_groups
(
config
,
model_chunks
=
model_chunks
,
param_groups
=
moe_param_groups
,
per_model_buffers
=
moe_buffers
,
model_parallel_group
=
mpu
.
get_model_parallel_group
(
with_expert_parallel
=
True
),
data_parallel_group
=
mpu
.
get_data_modulo_expert_parallel_group
(
with_context_parallel
=
True
),
data_parallel_group_gloo
=
mpu
.
get_data_modulo_expert_parallel_group_gloo
(
with_context_parallel
=
True
),
data_parallel_group_idx
=
expert_parallel_rank
*
model_parallel_world_size
+
model_parallel_rank
,
)
)
if
len
(
optimizers
)
==
1
:
return
optimizers
[
0
]
return
ChainedOptimizer
(
optimizers
)
megatron/optimizer/clip_grads.py
→
megatron/
core/
optimizer/clip_grads.py
View file @
4b097dee
# Copyright (c) 202
2
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 202
4
, NVIDIA CORPORATION. All rights reserved.
"""Gradient clipping."""
import
os
from
typing
import
List
,
Optional
,
Union
import
torch
from
torch
import
inf
from
apex.multi_tensor_apply
import
multi_tensor_applier
import
amp_C
from
megatron.model.module
import
param_is_not_shared
from
megatron.core.tensor_parallel
import
param_is_not_tensor_parallel_duplicate
def
clip_grad_norm_fp32
(
parameters
,
grads_for_norm
,
max_norm
,
norm_type
=
2
,
model_parallel_group
=
None
):
"""Clips gradient norm of an iterable of parameters whose gradients
are in fp32.
try
:
from
transformer_engine.pytorch.optimizers
import
(
multi_tensor_applier
,
multi_tensor_l2norm
,
multi_tensor_scale
,
)
l2_norm_impl
=
multi_tensor_l2norm
multi_tensor_scale_impl
=
multi_tensor_scale
except
ImportError
:
try
:
import
amp_C
from
apex.multi_tensor_apply
import
multi_tensor_applier
l2_norm_impl
=
amp_C
.
multi_tensor_l2norm
multi_tensor_scale_impl
=
amp_C
.
multi_tensor_scale
except
ImportError
:
import
warnings
warnings
.
warn
(
f
'Transformer Engine and Apex are not installed. '
'Falling back to local implementations of multi_tensor_applier, '
'multi_tensor_l2norm, and multi_tensor_scale'
)
from
megatron.core.utils
import
(
local_multi_tensor_applier
,
local_multi_tensor_l2_norm
,
local_multi_tensor_scale
,
)
multi_tensor_applier
=
local_multi_tensor_applier
l2_norm_impl
=
local_multi_tensor_l2_norm
multi_tensor_scale_impl
=
local_multi_tensor_scale
from
..tensor_parallel
import
param_is_not_tensor_parallel_duplicate
from
..transformer.module
import
param_is_not_shared
def
get_grad_norm_fp32
(
grads_for_norm
:
Union
[
List
[
torch
.
Tensor
],
torch
.
Tensor
],
norm_type
:
Union
[
int
,
float
]
=
2
,
model_parallel_group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
)
->
float
:
"""Calculate the norm of gradients in fp32.
This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
added functionality to handle model parallel parameters. Note that
the gradients are modified in place.
added functionality to handle model parallel parameters.
Arguments:
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
single Tensor that will have gradients normalized
grads_for_norm (Iterable[Tensor]): an iterable of Tensors or a single
grads_for_norm (Iterable[Tensor] or Tensor): an iterable of Tensors or a single
Tensor that will be used for calculating the grad norm.
max_norm (float or int): max norm of the gradients
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
infinity norm.
model_parallel_group (group): given the nature of the distributed
...
...
@@ -37,76 +70,101 @@ def clip_grad_norm_fp32(parameters, grads_for_norm,
Total norm of the parameters (viewed as a single vector).
"""
if
isinstance
(
parameters
,
torch
.
Tensor
):
parameters
=
[
parameters
]
if
isinstance
(
grads_for_norm
,
torch
.
Tensor
):
grads_for_norm
=
[
grads_for_norm
]
# Grads.
grads
=
[]
for
param
in
parameters
:
if
param
.
grad
is
not
None
:
assert
param
.
grad
.
type
()
==
'torch.cuda.FloatTensor'
grads
.
append
(
param
.
grad
.
detach
())
# Norm parameters.
max_norm
=
float
(
max_norm
)
norm_type
=
float
(
norm_type
)
total_norm
=
0.0
# Calculate norm.
if
norm_type
==
inf
:
total_norm
=
max
(
grad
.
abs
().
max
()
for
grad
in
grads_for_norm
)
total_norm_cuda
=
torch
.
cuda
.
FloatT
ensor
([
float
(
total_norm
)])
total_norm_cuda
=
torch
.
t
ensor
([
float
(
total_norm
)]
,
dtype
=
torch
.
float
,
device
=
'cuda'
)
# Take max across all model-parallel GPUs.
torch
.
distributed
.
all_reduce
(
total_norm_cuda
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
group
=
model_parallel_group
)
torch
.
distributed
.
all_reduce
(
total_norm_cuda
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
group
=
model_parallel_group
)
total_norm
=
total_norm_cuda
[
0
].
item
()
else
:
if
norm_type
==
2.0
:
dummy_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
]
)
dummy_overflow_buf
=
torch
.
tensor
([
0
],
dtype
=
torch
.
int
,
device
=
'cuda'
)
# Use apex's multi-tensor applier for efficiency reasons.
# Multi-tensor applier takes a function and a list of list
# and performs the operation on that list all in one kernel.
if
grads_for_norm
:
grad_norm
,
_
=
multi_tensor_applier
(
amp_C
.
multi_tensor_l2norm
,
l2_norm_impl
,
dummy_overflow_buf
,
[
grads_for_norm
],
False
# no per-parameter norm
False
,
# no per-parameter norm
)
else
:
grad_norm
=
torch
.
cuda
.
FloatTensor
([
0
]
)
grad_norm
=
torch
.
tensor
([
0
],
dtype
=
torch
.
float
,
device
=
'cuda'
)
# Since we will be summing across data parallel groups,
# we need the pow(norm-type).
total_norm
=
grad_norm
**
norm_type
total_norm
=
grad_norm
**
norm_type
else
:
for
grad
in
grads_for_norm
:
grad_norm
=
torch
.
norm
(
grad
,
norm_type
)
total_norm
+=
grad_norm
**
norm_type
total_norm
+=
grad_norm
**
norm_type
# Sum across all model-parallel GPUs.
torch
.
distributed
.
all_reduce
(
total_norm
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
model_parallel_group
)
torch
.
distributed
.
all_reduce
(
total_norm
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
model_parallel_group
)
total_norm
=
total_norm
.
item
()
**
(
1.0
/
norm_type
)
return
total_norm
def
clip_grad_by_total_norm_fp32
(
parameters
:
Union
[
List
[
torch
.
Tensor
],
torch
.
Tensor
],
max_norm
:
Union
[
int
,
float
],
total_norm
:
float
,
):
"""Clips gradient of an iterable of parameters in fp32 by total norm.
Note that the gradients are modified in place.
Args:
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
single Tensor that will have gradients normalized.
max_norm (float or int): max norm of the gradients.
total_norm (float): total norm of the gradients.
"""
# Grads.
grads
=
[]
for
param
in
parameters
:
if
param
.
grad
is
not
None
:
assert
param
.
grad
.
type
()
==
'torch.cuda.FloatTensor'
grads
.
append
(
param
.
grad
.
detach
())
# Scale.
clip_coeff
=
max_norm
/
(
total_norm
+
1.0e-6
)
if
clip_coeff
<
1.0
:
dummy_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
multi_tensor_applier
(
amp_C
.
multi_tensor_scale
,
dummy_overflow_buf
,
[
grads
,
grads
],
clip_coeff
)
dummy_overflow_buf
=
torch
.
tensor
([
0
],
dtype
=
torch
.
int
,
device
=
'cuda'
)
multi_tensor_applier
(
multi_tensor_scale_impl
,
dummy_overflow_buf
,
[
grads
,
grads
],
clip_coeff
)
return
total_norm
def
count_zeros_fp32
(
parameters
:
Union
[
List
[
torch
.
Tensor
],
torch
.
Tensor
],
model_parallel_group
:
torch
.
distributed
.
ProcessGroup
,
)
->
float
:
"""Counts the number of zeros in gradients associated with the passed-in list of
parameters.
def
count_zeros_fp32
(
parameters
,
model_parallel_group
):
Args:
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
single Tensor that will have the number of zeros in its corresponding
gradient counted.
model_parallel_group (torch.distributed.ProcessGroup, optional): model-parallel
group over which grad norm needs to be aggregated.
"""
if
isinstance
(
parameters
,
torch
.
Tensor
):
parameters
=
[
parameters
]
...
...
@@ -115,7 +173,7 @@ def count_zeros_fp32(parameters, model_parallel_group):
# - grad should not be none
# - parameter should not be shared
# - should not be a replica due to tensor model parallelism
total_num_zeros
=
torch
.
cuda
.
FloatTensor
([
0.0
]
)
total_num_zeros
=
torch
.
tensor
([
0.0
],
dtype
=
torch
.
float
,
device
=
'cuda'
)
for
param
in
parameters
:
grad_not_none
=
param
.
grad
is
not
None
is_not_shared
=
param_is_not_shared
(
param
)
...
...
@@ -126,9 +184,9 @@ def count_zeros_fp32(parameters, model_parallel_group):
total_num_zeros
=
num_zeros
+
total_num_zeros
# Sum across all model-parallel GPUs.
torch
.
distributed
.
all_reduce
(
total_num_zeros
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
model_parallel_group
)
torch
.
distributed
.
all_reduce
(
total_num_zeros
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
model_parallel_group
)
total_num_zeros
=
total_num_zeros
.
item
()
...
...
megatron/core/optimizer/distrib_optimizer.py
0 → 100644
View file @
4b097dee
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Megatron distributed optimizer."""
import
itertools
import
warnings
from
dataclasses
import
replace
from
logging
import
getLogger
from
typing
import
Callable
,
Dict
,
List
,
Optional
,
Tuple
import
torch
HAVE_APEX_OR_TE
=
True
try
:
from
transformer_engine.pytorch.optimizers
import
FusedAdam
as
Adam
except
ImportError
:
try
:
from
apex.optimizers
import
FusedAdam
as
Adam
except
ImportError
:
from
torch.optim
import
Adam
HAVE_APEX_OR_TE
=
False
from
..
import
tensor_parallel
from
..config_logger
import
has_config_logger_enabled
,
log_config_to_disk
from
..dist_checkpointing
import
ShardedTensor
from
..dist_checkpointing.dict_utils
import
nested_values
from
..dist_checkpointing.mapping
import
(
LocalNonpersistentObject
,
ShardedObject
,
ShardedStateDict
,
ShardedTensorFactory
,
)
from
..dist_checkpointing.utils
import
extract_sharded_tensors_and_factories
from
..distributed.param_and_grad_buffer
import
_ParamAndGradBuffer
,
partition_buckets
from
..transformer.module
import
MegatronModule
from
..utils
import
is_float8tensor
from
.grad_scaler
import
MegatronGradScaler
from
.optimizer
import
(
MixedPrecisionOptimizer
,
_multi_tensor_copy_this_to_that
,
_zero_grad_group_helper
,
)
from
.optimizer_config
import
OptimizerConfig
try
:
# This will be used when "--fp8-param-gather" is enabled.
# When BF16/FP16 parameters don't exist, we need to cast the FP32 main parameters to
# FP8 directly in the optimizer.
from
transformer_engine.pytorch.cpp_extensions
import
cast_to_fp8
except
:
pass
logger
=
getLogger
(
__name__
)
class
Range
:
"""
A range represents a start and end points for indexing a shard
from a full tensor.
Args:
start (int): Start index.
end (int): End index.
"""
def
__init__
(
self
,
start
:
int
,
end
:
int
):
self
.
start
=
start
self
.
end
=
end
self
.
size
=
end
-
start
def
normalize
(
self
,
start
:
int
=
0
):
"""Shift start/end indexes to start at new start index.
Both start and end indexes will be shifted by [new start] - [old start].
Args:
start (int): New start index.
"""
return
Range
(
start
,
start
+
self
.
size
)
def
__str__
(
self
):
return
"%d,%d [%d]"
%
(
self
.
start
,
self
.
end
,
self
.
size
)
def
__len__
(
self
):
return
self
.
end
-
self
.
start
class
DistributedOptimizer
(
MixedPrecisionOptimizer
):
"""Distributed optimizer, for all data types (fp16, bf16, and fp32).
See __init__() below for argument details.
"""
@
classmethod
def
_build_model_gbuf_param_range_map
(
cls
,
param_world_index_map
:
Dict
[
torch
.
nn
.
Parameter
,
Tuple
],
gbuf_world_range
:
Range
,
bucket_offset
:
int
,
):
"""
Build mapping from param reference to grad buffer shard ranges.
This method builds a mapping from parameter references to grad
buffer shard ranges, specific to each data-parallel (DP) rank's
set of 'owned' parameters. Each grad buffer (padded to be an even
multiple of DP-world-size) is conceptually divided into DP-world-size
contiguous regions, where each DP rank 'owns' a contiguous region.
Ownership in this sense means DP rank is responsible for reducing
the relevant subset of grads, and updating the relevant subset of
params.
This conceptual partitioning of the grad buffer does NOT respect
parameter boundaries, and as such it is assumed that each created
range references a shard (or subset) of the full parameter. It is
easiest to think of each DP rank as operating (i.e., reducing,
gathering) purely on views into the grad buffer, for all model-to-
main & main-to-model operations.
This method creates four ranges:
- The param's range within the entire grad buffer (i.e., world index).
- The param's range within the relevant grad bucket's buffer.
- The param's range within the DP rank's local view of the grad buffer.
- The param's range within itself (i.e., its shard).
"""
# Param range map.
param_range_map
=
{}
for
param
,
param_world_indexes
in
param_world_index_map
.
items
():
# Param range.
param_world_start
,
param_world_end
,
_
=
param_world_indexes
param_local_start
=
max
(
0
,
param_world_start
-
gbuf_world_range
.
start
)
param_local_end
=
min
(
gbuf_world_range
.
size
,
param_world_end
-
gbuf_world_range
.
start
)
# Add param, if within local gbuf range.
if
param_local_end
>
param_local_start
:
param_local_range
=
Range
(
param_local_start
,
param_local_end
)
param_world_range
=
param_local_range
.
normalize
(
param_local_start
+
gbuf_world_range
.
start
)
param_world_range_in_bucket
=
Range
(
param_world_range
.
start
-
bucket_offset
,
param_world_range
.
end
-
bucket_offset
)
sub_param_start
=
max
(
0
,
gbuf_world_range
.
start
-
param_world_start
)
sub_param_range
=
param_local_range
.
normalize
(
sub_param_start
)
param_range_map
[
param
]
=
{
"gbuf_world"
:
param_world_range
,
"gbuf_world_in_bucket"
:
param_world_range_in_bucket
,
"gbuf_local"
:
param_local_range
,
"param"
:
sub_param_range
,
}
return
param_range_map
@
classmethod
def
_build_model_gbuf_range
(
cls
,
param_and_grad_buffer
:
_ParamAndGradBuffer
,
bucket_index
:
int
):
"""
Build mapping between params and their grad buffers.
This method does the initial setup for the method above. This setup
includes determining the shard ranges into the param_and_grad_buffer
for each data-parallel (DP) rank. Each DP rank keeps range info for
all other DP ranks, for the purpose of creating args for
reduce-scatter and all-gather.
"""
data_parallel_rank
=
torch
.
distributed
.
get_rank
(
param_and_grad_buffer
.
data_parallel_group
)
data_parallel_world_size
=
param_and_grad_buffer
.
data_parallel_group
.
size
()
bucket
=
param_and_grad_buffer
.
buckets
[
bucket_index
]
gbuf_size
=
bucket
.
grad_data
.
numel
()
assert
(
gbuf_size
%
data_parallel_world_size
==
0
),
f
"Each bucket's buffer size should be divisible by
{
data_parallel_world_size
}
"
max_gbuf_range_size
=
gbuf_size
//
data_parallel_world_size
# All world ranges (i.e., across all data parallel ranks).
gbuf_world_all_ranges
=
[]
for
r
in
range
(
data_parallel_world_size
):
# Compute start of chunk in this bucket.
gbuf_world_start
=
r
*
max_gbuf_range_size
gbuf_world_end
=
min
(
gbuf_size
,
gbuf_world_start
+
max_gbuf_range_size
)
# Add bucket's offset in grad buffer.
gbuf_world_range
=
Range
(
gbuf_world_start
+
bucket
.
offset
,
gbuf_world_end
+
bucket
.
offset
)
gbuf_world_all_ranges
.
append
(
gbuf_world_range
)
# Local DP's ranges.
gbuf_world_range
=
gbuf_world_all_ranges
[
data_parallel_rank
]
# Get each param's ranges.
param_range_map
=
cls
.
_build_model_gbuf_param_range_map
(
param_and_grad_buffer
.
param_index_map
,
gbuf_world_range
,
bucket
.
offset
)
# Group into dict.
data
=
{
"param_map"
:
param_range_map
}
return
data
@
classmethod
def
_build_gbuf_range_map
(
cls
,
param_and_grad_buffer
:
_ParamAndGradBuffer
):
"""
Build mapping between params and their grad buffers. These mappings are
partitioned according to data type.
Iterate through all buckets of grad buffer to construct param ranges
that this rank "owns" (the dp_rank'th shard of each bucket, where each
shard is 1/dp_world_size of the bucket).
Args:
param_and_grad_buffer (_ParamAndGradBuffer): buffer to build mapping for.
"""
return
{
(
param_and_grad_buffer
.
param_dtype
,
param_and_grad_buffer
.
grad_dtype
):
[
cls
.
_build_model_gbuf_range
(
param_and_grad_buffer
,
bucket_index
)
for
bucket_index
in
range
(
len
(
param_and_grad_buffer
.
buckets
))
]
}
@
classmethod
def
_build_model_param_gbuf_map
(
cls
,
gbuf_ranges
:
List
[
Dict
]
)
->
Dict
[
torch
.
nn
.
Parameter
,
Tuple
]:
"""
Create a reverse of the gbuf_ranges, for referencing in opposite direction.
"""
param_gbuf_map
=
{}
for
gbuf_index
,
gbuf_range_map
in
enumerate
(
gbuf_ranges
):
for
dtype
,
gbuf_range_map_for_all_buckets
in
gbuf_range_map
.
items
():
for
bucket_index
,
gbuf_range_map
in
enumerate
(
gbuf_range_map_for_all_buckets
):
for
param
,
_
in
gbuf_range_map
[
"param_map"
].
items
():
assert
param
not
in
param_gbuf_map
,
(
"Param should not be in param_gbuf_map; each param only belongs "
"to a single bucket."
)
param_gbuf_map
[
param
]
=
(
gbuf_index
,
dtype
,
bucket_index
)
return
param_gbuf_map
@
classmethod
def
_build_optimizer_group_ranges
(
cls
,
param_groups
:
List
[
Dict
],
gbuf_ranges
:
List
[
Dict
]):
"""
Create optimizer groups.
Given the set of parameter shard ranges that are owned by the current
data-parallel (DP) rank, gather the set of parameters that will be
used (in the method below) to create the current DP's optimizer
groups.
"""
# Param group map.
# World param group map.
# - Store a mapping of <model_parameter:group_index> for all parameters
# across all DP ranks. This is necessary because it is our first
# cross reference between the DDP mappings and the optimizer group
# parameters. This mapping only for use in the next step of building
# the local mapping over this DP rank's parameters.
world_param_group_map
=
{}
for
group_index
,
group
in
enumerate
(
param_groups
):
for
param
in
group
[
"params"
]:
assert
param
.
requires_grad
world_param_group_map
[
param
]
=
group_index
# Optimizer group ranges & param-group mapping.
# - Build a mapping from groups to their contained parameters, and also
# from parameters to their containing group index and order within
# the group. The group index and order are particularly important for
# saving and loading checkpoints.
local_param_group_map
=
{}
group_ranges
=
[{
"params"
:
[]}
for
_
in
param_groups
]
for
gbuf_range_map
in
gbuf_ranges
:
for
dtype
,
gbuf_range_map_for_all_buckets
in
gbuf_range_map
.
items
():
for
gbuf_range_map
in
gbuf_range_map_for_all_buckets
:
for
param
in
gbuf_range_map
[
"param_map"
]:
group_index
=
world_param_group_map
[
param
]
group_range
=
group_ranges
[
group_index
]
group_range
[
"params"
].
append
(
param
)
local_param_group_map
[
param
]
=
(
group_index
,
len
(
group_range
[
"params"
])
-
1
)
# Squeeze zero-size group ranges.
for
group_index
,
group_range
in
enumerate
(
group_ranges
):
group_range
[
"orig_group"
]
=
param_groups
[
group_index
]
group_range
[
"orig_group_idx"
]
=
param_groups
[
group_index
]
return
local_param_group_map
,
group_ranges
@
classmethod
def
_build_model_and_main_param_groups
(
cls
,
gbuf_ranges
:
List
[
Dict
],
param_gbuf_map
:
Dict
[
torch
.
nn
.
Parameter
,
Tuple
],
opt_group_ranges
:
List
,
):
"""
Create main parameter groups needed for the optimizer step.
These groups encompass both: 1) groups used by this class, for
reducing/gather, and 2) groups used by the inner optimizer for the
parameter update. Given that the conceptual grad buffer partitioning
(created in earlier method) doesn't respect parameter boundaries,
the optimizer operates on shards of the model parameters, rather than
the full parameters.
"""
# Parameter groups:
# model_float16_groups: original float16 parameters
# model_fp32_groups: original fp32 parameters
# shard_float16_groups: shards of original float16 parameters
# shard_fp32_groups: shards of original fp32 parameters
# shard_fp32_from_float16_groups: fp32 copy of float16 parameters
model_float16_groups
=
[]
model_fp32_groups
=
[]
shard_float16_groups
=
[]
shard_fp32_groups
=
[]
shard_fp32_from_float16_groups
=
[]
# Allocate (or slice) each group's param shard.
for
group_range
in
opt_group_ranges
:
# Params of this group.
model_float16_params_this_group
=
[]
model_fp32_params_this_group
=
[]
shard_float16_params_this_group
=
[]
shard_fp32_params_this_group
=
[]
shard_fp32_from_float16_params_this_group
=
[]
model_float16_groups
.
append
(
model_float16_params_this_group
)
model_fp32_groups
.
append
(
model_fp32_params_this_group
)
shard_float16_groups
.
append
(
shard_float16_params_this_group
)
shard_fp32_groups
.
append
(
shard_fp32_params_this_group
)
shard_fp32_from_float16_groups
.
append
(
shard_fp32_from_float16_params_this_group
)
for
model_param
in
group_range
[
"params"
]:
assert
model_param
.
requires_grad
gbuf_index
,
dtype
,
bucket_index
=
param_gbuf_map
[
model_param
]
gbuf_range
=
gbuf_ranges
[
gbuf_index
][
dtype
][
bucket_index
]
param_range
=
gbuf_range
[
"param_map"
][
model_param
][
"param"
]
# fp16, bf16 params.
if
model_param
.
type
()
in
[
'torch.cuda.HalfTensor'
,
'torch.cuda.BFloat16Tensor'
]:
# Clone model -> main.
shard_model_param
=
model_param
.
detach
().
view
(
-
1
)[
param_range
.
start
:
param_range
.
end
]
# If we use FP8 params to initialize FP32 main params (compared to using the
# bf16/fp16 params to initialize the main params), there will be a loss of
# precision at the beginning of training (this problem will not occur if the
# training is long enough or if the main params are loaded from a checkpoint).
if
is_float8tensor
(
model_param
)
and
hasattr
(
model_param
,
'get_high_precision_init_val'
):
shard_main_param
=
(
model_param
.
get_high_precision_init_val
()
.
view
(
-
1
)[
param_range
.
start
:
param_range
.
end
]
.
clone
()
.
to
(
shard_model_param
.
device
)
.
float
()
)
model_param
.
clear_high_precision_init_val
()
else
:
shard_main_param
=
shard_model_param
.
clone
().
float
()
tensor_parallel
.
copy_tensor_model_parallel_attributes
(
shard_model_param
,
model_param
)
tensor_parallel
.
copy_tensor_model_parallel_attributes
(
shard_main_param
,
model_param
)
if
hasattr
(
model_param
,
'shared'
):
shard_model_param
.
shared
=
model_param
.
shared
shard_main_param
.
shared
=
model_param
.
shared
# Add to group.
model_float16_params_this_group
.
append
(
model_param
)
shard_float16_params_this_group
.
append
(
shard_model_param
)
shard_fp32_from_float16_params_this_group
.
append
(
shard_main_param
)
# fp32 params.
elif
model_param
.
type
()
==
'torch.cuda.FloatTensor'
:
shard_model_param
=
model_param
.
view
(
-
1
)[
param_range
.
start
:
param_range
.
end
]
model_fp32_params_this_group
.
append
(
model_param
)
shard_fp32_params_this_group
.
append
(
shard_model_param
)
tensor_parallel
.
copy_tensor_model_parallel_attributes
(
shard_model_param
,
model_param
)
if
hasattr
(
model_param
,
'shared'
):
shard_model_param
.
shared
=
model_param
.
shared
else
:
raise
TypeError
(
'Wrapped parameters must be one of '
'torch.cuda.FloatTensor, '
'torch.cuda.HalfTensor, or '
'torch.cuda.BFloat16Tensor. '
'Received {}'
.
format
(
model_param
.
type
())
)
# Update optimizer's params.
group_range
[
"orig_group"
][
"params"
]
=
[
*
shard_fp32_params_this_group
,
*
shard_fp32_from_float16_params_this_group
,
]
return
(
model_float16_groups
,
model_fp32_groups
,
shard_float16_groups
,
shard_fp32_groups
,
shard_fp32_from_float16_groups
,
)
def
__init__
(
self
,
optimizer
:
torch
.
optim
.
Optimizer
,
config
:
OptimizerConfig
,
grad_scaler
:
MegatronGradScaler
,
init_state_fn
:
Optional
[
Callable
],
model_chunks
:
List
[
MegatronModule
],
per_model_buffers
:
Dict
[
int
,
List
[
_ParamAndGradBuffer
]],
data_parallel_group
:
torch
.
distributed
.
ProcessGroup
,
data_parallel_group_gloo
:
torch
.
distributed
.
ProcessGroup
,
data_parallel_group_idx
:
int
,
):
"""
Distributed optimizer, for all data types (fp16, bf16, and fp32).
The steps in this method create the core mapping between param and grad buffers,
parameters, and parameter shard ranges, that is needed for converting between model
param indexes and main parameter shard indexes. This method also updates the optimizer
parameter groups with the newly created shards.
Args:
optimizer (torch.optim.Optimizer): base optimizer such as Adam or SGD.
config (OptimizerConfig): configuration object for optimizer.
grad_scaler (MegatronGradScaler): used for scaling gradients. Note that
this can be None. This case happens when `bf16 = True` and we don't
use any loss scale. Note that for `bf16 = True`, we can have
a constant gradient scaler. Also for `bf16 = False`, we
always require a grad scaler.
init_state_fn (Callable, optional): function to initialize state in the optimizer.
model_chunks (List[MegatronModule]): list of model chunks.
per_model_buffers (Dict[int, List[ParamAndGradBuffer]]): the implementation of the
distributed optimizer is centered on using a contiguous buffer for
communicating grads & params between the model state and the optimizer state.
You can find a more detailed description in
https://github.com/NVIDIA/Megatron-LM/blob/main/docs/source/distrib_optimizer.md.
data_parallel_group (torch.distributed.ProcessGroup): data-parallel group to use to
all-gather params after optimizer.step().
data_parallel_group_gloo (torch.distributed.ProcessGroup): gloo data-parallel group
(used in checkpoint loading and saving).
data_parallel_group_idx (int): index in data-parallel group (used by
distributed checkpointing logic).
"""
if
has_config_logger_enabled
(
config
):
log_config_to_disk
(
config
,
locals
(),
prefix
=
type
(
self
).
__name__
)
assert
(
HAVE_APEX_OR_TE
),
f
'Please install Apex or Transformer Engine to use DistributedOptimizer.'
super
().
__init__
(
optimizer
,
config
,
grad_scaler
,
init_state_fn
)
self
.
model_chunks
=
model_chunks
self
.
ddp_config
=
self
.
model_chunks
[
0
].
ddp_config
for
model_chunk
in
self
.
model_chunks
:
assert
self
.
ddp_config
==
model_chunk
.
ddp_config
assert
isinstance
(
optimizer
,
Adam
),
"Only Adam currently supported, due to checkpointing requirements."
# Model grad buffer ranges.
assert
per_model_buffers
is
not
None
,
"per_model_buffers must be provided"
self
.
buffers
=
list
(
itertools
.
chain
(
*
per_model_buffers
.
values
()))
self
.
per_model_buffers
=
per_model_buffers
self
.
data_parallel_group
=
data_parallel_group
self
.
data_parallel_group_gloo
=
data_parallel_group_gloo
self
.
data_parallel_group_idx
=
data_parallel_group_idx
self
.
gbuf_idx_to_model_idx_map
=
{}
gbuf_idx
=
0
for
model_idx
,
buffers
in
self
.
per_model_buffers
.
items
():
for
_
in
buffers
:
self
.
gbuf_idx_to_model_idx_map
[
gbuf_idx
]
=
model_idx
gbuf_idx
+=
1
self
.
per_model_bucket_groups
=
{}
for
model_idx
,
buffers
in
self
.
per_model_buffers
.
items
():
self
.
per_model_bucket_groups
[
model_idx
]
=
partition_buckets
(
buffers
)
self
.
gbuf_ranges
=
[]
self
.
per_bucket_numel
=
[]
self
.
per_bucket_numel_unpadded
=
[]
for
buffer
in
self
.
buffers
:
self
.
per_bucket_numel
.
append
(
{
(
buffer
.
param_dtype
,
buffer
.
grad_dtype
):
[
bucket
.
grad_data
.
numel
()
for
bucket
in
buffer
.
buckets
]
}
)
self
.
per_bucket_numel_unpadded
.
append
(
{
(
buffer
.
param_dtype
,
buffer
.
grad_dtype
):
[
bucket
.
numel_unpadded
for
bucket
in
buffer
.
buckets
]
}
)
self
.
gbuf_ranges
.
append
(
self
.
_build_gbuf_range_map
(
buffer
))
self
.
model_param_gbuf_map
=
self
.
_build_model_param_gbuf_map
(
self
.
gbuf_ranges
)
# Optimizer ranges.
(
self
.
model_param_group_index_map
,
self
.
opt_group_ranges
)
=
(
self
.
_build_optimizer_group_ranges
(
self
.
optimizer
.
param_groups
,
self
.
gbuf_ranges
)
)
# Allocate main param shards.
(
self
.
model_float16_groups
,
self
.
model_fp32_groups
,
self
.
shard_float16_groups
,
self
.
shard_fp32_groups
,
self
.
shard_fp32_from_float16_groups
,
)
=
self
.
_build_model_and_main_param_groups
(
self
.
gbuf_ranges
,
self
.
model_param_gbuf_map
,
self
.
opt_group_ranges
)
# Update optimizer groups.
# - Also, leverage state_dict() and load_state_dict() to
# recast preexisting per-param state tensors.
self
.
optimizer
.
param_groups
=
[
g
[
"orig_group"
]
for
g
in
self
.
opt_group_ranges
]
self
.
optimizer
.
load_state_dict
(
self
.
optimizer
.
state_dict
())
def
enable_pre_hook
(
self
):
"""
Enable forward pre-hook needed for param all-gather overlap with forward compute.
"""
warnings
.
warn
(
"`DistributedOptimizer.enable_pre_hook` will be deprecated in a future release. "
"Use `DistributedDataParallel.enable_forward_pre_hook` directly."
)
for
model_chunk
in
self
.
model_chunks
:
model_chunk
.
enable_forward_pre_hook
()
def
disable_pre_hook
(
self
):
"""
Disable forward pre-hook needed for param all-gather overlap with forward compute.
"""
warnings
.
warn
(
"`DistributedOptimizer.disable_pre_hook` will be deprecated in a future release. "
"Use `DistributedDataParallel.disable_forward_pre_hook` directly."
)
for
model_chunk
in
self
.
model_chunks
:
model_chunk
.
disable_forward_pre_hook
()
def
_get_model_param_range_map
(
self
,
param
:
torch
.
nn
.
Parameter
):
"""
Given a model param, get the index sub-range of the param that this
data-parallel rank owns.
"""
gbuf_index
,
dtype
,
bucket_index
=
self
.
model_param_gbuf_map
[
param
]
gbuf_range_map
=
self
.
gbuf_ranges
[
gbuf_index
][
dtype
][
bucket_index
]
param_range_map
=
gbuf_range_map
[
"param_map"
][
param
]
return
param_range_map
def
get_model_parallel_group
(
self
)
->
torch
.
distributed
.
ProcessGroup
:
"""
With the distributed optimizer, the model parallel group is the
entire world.
"""
return
None
def
state_dict
(
self
):
"""
The state dict contains all non-DP-rank-dependent (i.e., non-parameter-
related) optimizer variables. The returned state dict can be stored in
the standard model/RNG checkpoint file. The parameter and dependent
optimizer state (e.g., exp_avg, exp_avg_sq) are stored in a separate
checkpoint file by calling 'save_parameter_state()'.
"""
inner_state_dict
=
self
.
optimizer
.
state_dict
()
state_dict
=
{}
# Extract 'step', for non-Apex/TE support.
if
not
HAVE_APEX_OR_TE
:
steps
=
list
(
set
([
s
[
"step"
].
item
()
for
s
in
inner_state_dict
[
"state"
].
values
()]))
assert
len
(
steps
)
==
1
step
=
steps
[
0
]
# Optimizer state (do not store parameter state here).
state_dict
[
'optimizer'
]
=
{
k
:
v
for
k
,
v
in
inner_state_dict
.
items
()
if
k
!=
"state"
}
for
param_group
in
state_dict
[
"optimizer"
][
"param_groups"
]:
del
param_group
[
"params"
]
if
not
HAVE_APEX_OR_TE
:
# Native PyTorch param group requires step (i.e., iteration).
param_group
[
"step"
]
=
step
# Grad scaler state.
if
self
.
grad_scaler
:
state_dict
[
'grad_scaler'
]
=
self
.
grad_scaler
.
state_dict
()
return
state_dict
def
load_state_dict
(
self
,
state_dict
):
"""Load the state dict.
As detailed in state_dict(), the state dict contains all non-
parameter-related variables. This method is notably longer than
state_dict(), because the Torch optimizers state has yet to be
allocated at this point, and so we must do a cross referencing between
the optimizers state (and the ordering it expects for parameter state)
and this DP rank's shards. The optimizer at this point does not contain
any tensor dimension information, so we must get these dimensions from
the DP shards mapped during DistributedOptimizer.__init__().
The tensor parameter state is loaded via load_parameter_state(), and
so this method also must populate the loaded state dict with dummy
tensor data (i.e., via torch.empty() below). This will be overwritten
during load_parameter_state().
** Note: Torch optimizer's state structure. **
The Torch optimizer stores its state in two levels. The top level is a
list of groups, where each group contains a list of integer indexes
(corresponding to parameters) that index into a master parameter list
that is shared by all groups. As such, three values are necessary for
maintaining this ordering:
- group_index : The group to which a parameter belongs.
- group_order : The index of a parameter within its group.
- state_order : The index of a parameter within the shared parameter
list.
"""
# Get the Torch optimizer's state dict.
# - This 'inner' optimizer at this point is unallocated, and only
# contains an integer ordering of parameters within each group, and
# the ordering of parameters within its flattened parameter state
# list.
inner_state_dict
=
self
.
optimizer
.
state_dict
()
state_dict_param_groups
=
[
{
**
group
,
"params"
:
list
(
inner_state_dict
[
"param_groups"
][
idx
][
"params"
])}
for
idx
,
group
in
enumerate
(
state_dict
[
"optimizer"
][
"param_groups"
])
]
# Allocate or retrieve optimizer state (i.e., tensors).
if
len
(
self
.
optimizer
.
state
)
==
0
:
# Allocate empty optimizer state if not previously initialized.
# - If len(self.optimizer.state) == 0, this means that the optimizer
# state has not been previously initialized. Once it has been
# initialized, we skip this code block to avoid reallocating
# empty tensors (i.e., torch.empty), which in turn reduces memory
# fragmentation.
# - Real data is overwritten during load_parameter_state().
state_dict_state
=
[]
for
gbuf_range_maps
in
self
.
gbuf_ranges
:
for
gbuf_range_map_for_all_buckets
in
gbuf_range_maps
.
values
():
for
gbuf_range_map
in
gbuf_range_map_for_all_buckets
:
for
model_param
,
param_range_map
in
gbuf_range_map
[
"param_map"
].
items
():
# Get parameter ordering information (see method docstring
# for details).
group_index
,
group_order
=
self
.
model_param_group_index_map
[
model_param
]
state_order
=
inner_state_dict
[
"param_groups"
][
group_index
][
"params"
][
group_order
]
# Allocate dummy tensors.
numel
=
len
(
param_range_map
[
"gbuf_world"
])
init_shard
=
lambda
:
torch
.
empty
(
(
numel
,),
dtype
=
torch
.
float32
,
device
=
torch
.
cuda
.
current_device
()
)
state_dict_state
.
append
(
(
state_order
,
{
"exp_avg"
:
init_shard
(),
"exp_avg_sq"
:
init_shard
()})
)
# Sort by state order (see method docstring for details).
state_dict_state
.
sort
(
key
=
lambda
s
:
s
[
0
])
state_dict_state
=
{
s
[
0
]:
s
[
1
]
for
s
in
state_dict_state
}
else
:
# Retrieve existing optimizer state.
state_dict_state
=
inner_state_dict
[
"state"
]
# Extract 'step', for non-Apex/TE support.
if
not
HAVE_APEX_OR_TE
:
steps
=
list
(
set
([
g
[
"step"
]
for
g
in
state_dict
[
"optimizer"
][
"param_groups"
]]))
assert
len
(
steps
)
==
1
step
=
torch
.
tensor
(
steps
[
0
],
dtype
=
torch
.
float
)
for
s
in
state_dict_state
.
values
():
# Native PyTorch state dict requires step (i.e., iteration).
s
[
"step"
]
=
step
# Optimizer.
self
.
optimizer
.
load_state_dict
(
{
"state"
:
state_dict_state
,
"param_groups"
:
state_dict_param_groups
}
)
# Grad scaler.
if
'grad_scaler'
not
in
state_dict
:
if
self
.
config
.
fp16
:
logger
.
info
(
'***WARNING*** found an old checkpoint, will not '
'load grad scaler ...'
)
else
:
if
self
.
grad_scaler
:
self
.
grad_scaler
.
load_state_dict
(
state_dict
[
'grad_scaler'
])
else
:
logger
.
info
(
'***WARNING*** fould the grad scaler in the '
'checkpoint but it is None in the class. '
'Skipping loading grad scaler ...'
)
if
'param_state'
in
state_dict
:
assert
'param_state_sharding_type'
in
state_dict
,
state_dict
.
keys
()
param_state
=
state_dict
[
'param_state'
]
sharding_type
=
state_dict
[
'param_state_sharding_type'
]
logger
.
info
(
f
'Loading distributed optimizer sharded state of type
{
sharding_type
}
'
)
if
sharding_type
==
'dp_zero_gather_scatter'
:
self
.
load_parameter_state_from_dp_zero
(
param_state
)
elif
sharding_type
==
'fully_sharded_bucket_space'
:
self
.
load_parameter_state_from_fs_bucket_space
(
param_state
)
elif
sharding_type
==
'fully_sharded_model_space'
:
self
.
load_parameter_state_from_fs_model_space
(
param_state
)
else
:
raise
NotImplementedError
(
f
'Unknown sharding_type:
{
sharding_type
}
'
)
def
get_parameter_state_fs_bucket_space
(
self
):
"""Get internal representation of parameter state without any copies and modifications.
This is referred to as "fully sharded bucket space" because the optimizer state is
fully sharded (e.g. no gather involved) and bucket-centric (the state
follows the internal structure of the Distributed Optimizer buckets)
as opposed to model-centric (typical structure of PyT optimizers)
"""
state
=
{
"per_bucket_numel"
:
self
.
per_bucket_numel
,
"per_bucket_numel_unpadded"
:
self
.
per_bucket_numel_unpadded
,
}
for
gbuf_idx
,
gbuf_range_maps
in
enumerate
(
self
.
gbuf_ranges
):
# Iterate grad buffers (by data type).
dtype_state
=
{}
assert
len
(
gbuf_range_maps
)
==
1
,
"single dtype supported, for now."
for
dtype
,
gbuf_range_map_for_all_buckets
in
gbuf_range_maps
.
items
():
buckets_state
=
[]
for
bucket_idx
,
gbuf_range_map
in
enumerate
(
gbuf_range_map_for_all_buckets
):
bucket_state
=
[]
for
model_param
,
param_range_map
in
gbuf_range_map
[
"param_map"
].
items
():
# Main param & optimizer states.
group_index
,
group_order
=
self
.
model_param_group_index_map
[
model_param
]
main_param
=
self
.
optimizer
.
param_groups
[
group_index
][
"params"
][
group_order
]
optim_state
=
self
.
optimizer
.
state
[
main_param
]
tensors
=
{
"param"
:
main_param
,
**
optim_state
,
"gbuf_local_start"
:
param_range_map
[
"gbuf_local"
].
start
,
"gbuf_local_end"
:
param_range_map
[
"gbuf_local"
].
end
,
}
bucket_state
.
append
(
tensors
)
buckets_state
.
append
(
bucket_state
)
dtype_state
[
dtype
]
=
buckets_state
state
[
gbuf_idx
]
=
dtype_state
return
state
def
get_parameter_state_dp_zero
(
self
):
"""Get parameter state (i.e., parameter & optimizer tensors).
This method performs two steps:
- For each DP rank, copy param & optimizer shards to contiguous CPU
buffers (e.g., one buffer each for main_param, exp_avg, and
exp_avg_sq).
- Gather contiguous buffers on DP rank 0 and concatenate to world
buffers.
"""
# Data parallelism variables.
data_parallel_world_size
=
self
.
data_parallel_group_gloo
.
size
()
data_parallel_rank
=
torch
.
distributed
.
get_rank
(
self
.
data_parallel_group_gloo
)
data_parallel_group_gloo
=
self
.
data_parallel_group_gloo
data_parallel_global_ranks
=
torch
.
distributed
.
get_process_group_ranks
(
self
.
data_parallel_group_gloo
)
# Collect param states.
state
=
{
"buckets_coalesced"
:
True
}
for
gbuf_idx
,
gbuf_range_maps
in
enumerate
(
self
.
gbuf_ranges
):
# Iterate grad buffers (by data type).
dtype_state
=
{}
assert
len
(
gbuf_range_maps
)
==
1
,
"single dtype supported, for now."
for
dtype
,
gbuf_range_map_for_all_buckets
in
gbuf_range_maps
.
items
():
buffer_numel_unpadded
=
self
.
buffers
[
gbuf_idx
].
numel_unpadded
# Create coalesced tensors for all state related to parameters in this buffer.
world_tensors
=
{}
if
data_parallel_rank
==
0
:
world_tensors
=
{
key
:
torch
.
zeros
(
(
buffer_numel_unpadded
,),
dtype
=
torch
.
float32
,
device
=
"cpu"
)
for
key
in
(
"param"
,
"exp_avg"
,
"exp_avg_sq"
)
}
world_tensors
[
"numel_unpadded"
]
=
buffer_numel_unpadded
offset_in_world_tensors
=
0
for
bucket_idx
,
gbuf_range_map
in
enumerate
(
gbuf_range_map_for_all_buckets
):
# Compute local DP contiguous shard's size.
gbuf_world_numel
=
self
.
buffers
[
gbuf_idx
].
buckets
[
bucket_idx
].
grad_data
.
numel
()
assert
gbuf_world_numel
%
data_parallel_world_size
==
0
gbuf_local_numel
=
gbuf_world_numel
//
data_parallel_world_size
gbuf_world_numel_unpadded
=
(
self
.
buffers
[
gbuf_idx
].
buckets
[
bucket_idx
].
numel_unpadded
)
assert
gbuf_world_numel_unpadded
<=
gbuf_world_numel
local_shards
=
{
key
:
torch
.
zeros
((
gbuf_local_numel
,),
dtype
=
torch
.
float32
,
device
=
"cpu"
)
for
key
in
(
"param"
,
"exp_avg"
,
"exp_avg_sq"
)
}
# Build contiguous DP rank shards (for param + optim states).
for
model_param
,
param_range_map
in
gbuf_range_map
[
"param_map"
].
items
():
# Main param & optimizer states.
group_index
,
group_order
=
self
.
model_param_group_index_map
[
model_param
]
main_param
=
self
.
optimizer
.
param_groups
[
group_index
][
"params"
][
group_order
]
optim_state
=
self
.
optimizer
.
state
[
main_param
]
tensors
=
{
"param"
:
main_param
,
**
optim_state
}
# Copy states into contiguous shard.
gbuf_local_start
=
param_range_map
[
"gbuf_local"
].
start
gbuf_local_end
=
param_range_map
[
"gbuf_local"
].
end
for
key
in
local_shards
:
local_shards
[
key
][
gbuf_local_start
:
gbuf_local_end
].
data
.
copy_
(
tensors
[
key
].
detach
().
cpu
()
)
# Gather contiguous shards on DP rank 0.
for
key
,
send_tensor
in
local_shards
.
items
():
# Gather tensor list.
if
data_parallel_rank
==
0
:
recv_tensors
=
[
torch
.
zeros
((
gbuf_local_numel
,),
dtype
=
torch
.
float32
,
device
=
"cpu"
)
for
_
in
range
(
data_parallel_world_size
)
]
else
:
recv_tensors
=
None
# Gather.
torch
.
distributed
.
gather
(
send_tensor
,
recv_tensors
,
data_parallel_global_ranks
[
0
],
data_parallel_group_gloo
,
)
# Concatenate.
if
data_parallel_rank
==
0
:
recv_tensors_concatenated
=
torch
.
cat
(
recv_tensors
)
# Copy this bucket's collected all-gather tensors into the right place
# in the tensor for the buffer. The tensor for the buffer gets rid of
# the padding between buckets.
start
=
offset_in_world_tensors
end
=
offset_in_world_tensors
+
gbuf_world_numel_unpadded
world_tensors
[
key
][
start
:
end
].
copy_
(
recv_tensors_concatenated
[:
gbuf_world_numel_unpadded
]
)
offset_in_world_tensors
+=
gbuf_world_numel_unpadded
# Collect world state.
dtype_state
[
dtype
]
=
world_tensors
state
[
gbuf_idx
]
=
dtype_state
return
state
def
save_parameter_state
(
self
,
filename
:
str
):
"""Save the distributed parameter state on DP rank 0.
Args:
filename (str): path to save parameter state to.
"""
state_dict
=
self
.
get_parameter_state_dp_zero
()
if
torch
.
distributed
.
get_rank
(
self
.
data_parallel_group
)
==
0
:
torch
.
save
(
state_dict
,
filename
)
def
sharded_state_dict
(
self
,
model_sharded_state_dict
:
ShardedStateDict
,
is_loading
:
bool
=
False
,
sharding_type
:
str
=
'fully_sharded_model_space'
,
):
"""
Chooses between 3 param state sharding implementations as requested by `sharding_type`.
Regular state dict parameters are saved on DP rank 0 and loaded on all ranks.
"""
if
not
is_loading
and
sharding_type
==
'fully_sharded_bucket_space'
:
logger
.
warning
(
'`fully_sharded_bucket_space` sharding for DistributedOptimizer'
' checkpoint is deprecated and will be removed in the future.'
' Please switch to `full_sharded_model_space`.'
)
state_dict
=
self
.
state_dict
()
if
sharding_type
!=
'fully_sharded_model_space'
:
# State dict differs between different model parallel groups
state_dict
=
{
k
:
ShardedObject
(
f
'optimizer.distributed.dp_group_idx_
{
self
.
data_parallel_group_idx
}
.
{
k
}
'
,
v
,
(
1
,),
(
0
,),
replica_id
=
torch
.
distributed
.
get_rank
(
self
.
data_parallel_group
),
)
for
k
,
v
in
state_dict
.
items
()
}
if
is_loading
:
# Call the distributed optimizer's specialized load_state_dict(),
# which conditionally skips re-allocating the optimizer's state if
# already initialized, which in turn reduces memory fragmentation.
self
.
load_state_dict
(
self
.
state_dict
())
if
sharding_type
==
'fully_sharded_bucket_space'
:
param_state
=
self
.
sharded_param_state_fs_bucket_space
(
model_sharded_state_dict
,
is_loading
)
elif
sharding_type
==
'dp_zero_gather_scatter'
:
param_state
=
self
.
sharded_param_state_dp_zero
(
model_sharded_state_dict
,
is_loading
)
elif
sharding_type
==
'fully_sharded_model_space'
:
param_state
=
self
.
sharded_param_state_fs_model_space
(
model_sharded_state_dict
,
is_loading
)
else
:
raise
NotImplementedError
(
f
'Unknown sharding_type:
{
sharding_type
}
'
)
state_dict
[
'param_state'
]
=
param_state
state_dict
[
'param_state_sharding_type'
]
=
sharding_type
return
state_dict
def
sharded_param_state_dp_zero
(
self
,
model_sharded_state_dict
:
ShardedStateDict
,
is_loading
:
bool
=
False
):
"""Naive implementation which reuses gather/scatter from the legacy ckpt format.
During saving, gathers the parameters state on DP rank 0 and saves a ShardedObject
with fixed TPxPP structure. During loading, loads the saved data on DP rank 0
(None on other ranks). Relies on the parameters scatter done in load_state_dict.
"""
if
is_loading
:
param_state_data
=
None
else
:
# Gather on rank 0
param_state_data
=
self
.
get_parameter_state_dp_zero
()
if
torch
.
distributed
.
get_rank
(
self
.
data_parallel_group
)
==
0
:
# Fixed TPxPP. Save on DP rank 0 only
param_state
=
ShardedObject
(
f
'optimizer.distributed.dp_group_idx_
{
self
.
data_parallel_group_idx
}
.param_state'
,
param_state_data
,
(
1
,),
(
0
,),
)
else
:
# DP ranks > 0 don't save. During loading, the param_state needs to be None.
param_state
=
LocalNonpersistentObject
(
None
)
return
param_state
def
sharded_param_state_fs_bucket_space
(
self
,
model_sharded_state_dict
:
ShardedStateDict
,
is_loading
:
bool
=
False
):
"""Sharded state dict where each noncontiguous buffer is a separate ShardedTensor.
Results in fully parallel save and load without any inter-process
communication or intermediate buffers/copies.
"""
data_parallel_rank
=
torch
.
distributed
.
get_rank
(
self
.
data_parallel_group
)
data_parallel_world_size
=
torch
.
distributed
.
get_world_size
(
self
.
data_parallel_group
)
state
=
self
.
get_parameter_state_fs_bucket_space
()
# per_bucket_numel metadata is saved separately for each TPxPP domain.
for
per_bucket_key
in
(
'per_bucket_numel'
,
'per_bucket_numel_unpadded'
):
key
=
(
f
'optimizer.distributed.dp_group_idx_
{
self
.
data_parallel_group_idx
}
'
f
'.
{
per_bucket_key
}
'
)
state
[
per_bucket_key
]
=
ShardedObject
(
key
,
state
[
per_bucket_key
],
(
1
,),
(
0
,),
replica_id
=
data_parallel_rank
)
for
gbuf_idx
,
gbuf_range_maps
in
enumerate
(
self
.
gbuf_ranges
):
for
dtype
,
gbuf_range_map_for_all_buckets
in
state
[
gbuf_idx
].
items
():
for
bucket_idx
,
bucket_state
in
enumerate
(
gbuf_range_map_for_all_buckets
):
# Compute local DP contiguous shard's size.
gbuf_world_numel
=
self
.
buffers
[
gbuf_idx
].
buckets
[
bucket_idx
].
grad_data
.
numel
()
assert
gbuf_world_numel
%
data_parallel_world_size
==
0
gbuf_local_numel
=
gbuf_world_numel
//
data_parallel_world_size
sharded_bucket_key
=
(
f
'optimizer.distributed.dp_group_idx_
{
self
.
data_parallel_group_idx
}
'
f
'.gbuf_idx_
{
gbuf_idx
}
.dtype_
{
dtype
}
.bucket_idx_
{
bucket_idx
}
'
)
# The global ckpt tensors must be fully covered.
# We add extra empty padding if necessary
assert
bucket_state
,
'empty bucket encountered'
# Insert padding between parameter tensors to ensure full coverage as needed.
all_pad_tensors
=
{}
for
i
in
range
(
len
(
bucket_state
)
-
1
):
next_param_start
=
bucket_state
[
i
+
1
][
'gbuf_local_start'
]
cur_param_end
=
bucket_state
[
i
][
'gbuf_local_end'
]
if
next_param_start
!=
cur_param_end
:
pad_tensors
=
{
k
:
torch
.
empty
(
next_param_start
-
cur_param_end
,
dtype
=
v
.
dtype
,
device
=
v
.
device
)
for
k
,
v
in
bucket_state
[
i
].
items
()
if
isinstance
(
v
,
torch
.
Tensor
)
}
all_pad_tensors
[
i
+
1
]
=
{
**
pad_tensors
,
'gbuf_local_start'
:
cur_param_end
,
'gbuf_local_end'
:
next_param_start
,
'padding'
:
True
,
}
# Insert from end so that insertion positions are still correct.
indices_to_insert
=
sorted
(
list
(
all_pad_tensors
.
keys
()))
for
index_to_insert
in
reversed
(
indices_to_insert
):
bucket_state
.
insert
(
index_to_insert
,
all_pad_tensors
[
index_to_insert
])
if
bucket_state
[
-
1
][
'gbuf_local_end'
]
!=
gbuf_local_numel
:
pad_tensors
=
{
k
:
torch
.
empty
(
gbuf_local_numel
-
bucket_state
[
-
1
][
'gbuf_local_end'
],
dtype
=
v
.
dtype
,
device
=
v
.
device
,
)
for
k
,
v
in
bucket_state
[
-
1
].
items
()
if
isinstance
(
v
,
torch
.
Tensor
)
}
bucket_state
.
append
(
{
**
pad_tensors
,
'gbuf_local_start'
:
bucket_state
[
-
1
][
'gbuf_local_end'
],
'gbuf_local_end'
:
gbuf_local_numel
,
'padding'
:
True
,
}
)
# Each tensor is mapped to a slice (`flattened_range`)
# of a DP-local shard of size `gbuf_local_numel`.
for
bucket_params_idx
in
range
(
len
(
bucket_state
)):
tensors
=
bucket_state
[
bucket_params_idx
]
gbuf_local_start
=
tensors
.
pop
(
'gbuf_local_start'
)
gbuf_local_end
=
tensors
.
pop
(
'gbuf_local_end'
)
if
'padding'
not
in
tensors
:
tensors
[
'padding'
]
=
False
for
key
in
tensors
:
if
key
==
'padding'
:
tensors
[
key
]
=
LocalNonpersistentObject
(
tensors
[
key
])
continue
assert
tensors
[
key
].
shape
==
(
gbuf_local_end
-
gbuf_local_start
,),
(
tensors
[
key
].
shape
,
gbuf_local_start
,
gbuf_local_end
,
)
tensors
[
key
]
=
ShardedTensor
(
f
'
{
sharded_bucket_key
}
.
{
key
}
'
,
tensors
[
key
],
tensors
[
key
].
dtype
,
(
gbuf_local_numel
,),
(
data_parallel_world_size
*
gbuf_local_numel
,),
(
data_parallel_rank
*
gbuf_local_numel
,),
axis_fragmentations
=
(
data_parallel_world_size
,),
flattened_range
=
slice
(
gbuf_local_start
,
gbuf_local_end
),
allow_shape_mismatch
=
True
,
)
return
state
def
sharded_param_state_fs_model_space
(
self
,
model_sharded_state_dict
:
ShardedStateDict
,
is_loading
:
bool
=
False
):
"""Sharded state dict where each buffer is mapped to corresponding model param.
In this approach the optimizer state tensors are directly related to model parameters
by linking them with metadata from `model_sharded_state_dict`.
This will allow changing TP and PP while using DistOpt (as with other optimizers).
"""
param_to_sharded_metadata
=
{}
model_sharded_state_dict
,
_
=
extract_sharded_tensors_and_factories
(
model_sharded_state_dict
)
for
sh_base
in
nested_values
(
model_sharded_state_dict
):
param_to_sharded_metadata
[
sh_base
.
data
]
=
sh_base
prefix
=
'optimizer.state'
state
=
{}
# Not stored in the checkpoint, used only to identify params in
# `sharded_param_state_fs_model_space`.
param_idx
=
0
for
gbuf_range_maps
in
self
.
gbuf_ranges
:
for
gbuf_range_map_for_all_buckets
in
gbuf_range_maps
.
values
():
for
gbuf_range_map
in
gbuf_range_map_for_all_buckets
:
for
model_param
,
param_range_map
in
gbuf_range_map
[
"param_map"
].
items
():
group_index
,
group_order
=
self
.
model_param_group_index_map
[
model_param
]
param_range
=
param_range_map
[
'param'
]
main_param
=
self
.
optimizer
.
param_groups
[
group_index
][
"params"
][
group_order
]
optim_state
=
self
.
optimizer
.
state
[
main_param
]
tensors
=
{
"fp32_param"
:
main_param
,
**
optim_state
}
# Match optimizer parameter with model ShardedTensor (or
# ShardedTensorFactory).
try
:
sharded_metadata
=
param_to_sharded_metadata
[
model_param
]
except
KeyError
as
e
:
raise
ValueError
(
f
'Model param
{
model_param
}
not in model_sharded_state_dict'
)
from
e
# Set DP corresponding replica_id coordinate to 0.
assert
(
len
(
sharded_metadata
.
replica_id
)
==
3
),
f
'Expected replica_id format (PP, TP, DP), got:
{
sharded_metadata
}
'
replica_id
=
(
*
sharded_metadata
.
replica_id
[:
2
],
0
)
# Instantiate ShardedTensor (or ShardedTensorFactory) for optimizer
# params.
for
state_key
,
state_ten
in
tensors
.
items
():
replace_kwargs
=
dict
(
key
=
f
'
{
prefix
}
.
{
state_key
}
.
{
sharded_metadata
.
key
}
'
,
data
=
state_ten
,
dtype
=
state_ten
.
dtype
,
flattened_range
=
slice
(
param_range
.
start
,
param_range
.
end
),
replica_id
=
replica_id
,
)
if
isinstance
(
sharded_metadata
,
ShardedTensorFactory
):
replace_kwargs
.
pop
(
'dtype'
)
tensors
[
state_key
]
=
replace
(
sharded_metadata
,
**
replace_kwargs
)
tensors
[
state_key
].
validate_metadata_integrity
()
state
[
param_idx
]
=
tensors
param_idx
+=
1
return
state
def
load_parameter_state_from_fs_bucket_space
(
self
,
state_dict
):
"""Loads the parameter state from an internal representation.
Inverse of the `get_parameter_state_fs_bucket_space` method.
"""
logger
.
warning
(
'`fully_sharded_bucket_space` sharding for DistributedOptimizer'
'checkpoint is deprecated. Please switch to `full_sharded_model_space`'
)
if
state_dict
is
not
None
and
"per_bucket_numel_unpadded"
in
state_dict
:
per_bucket_numel_unpadded_in_checkpoint
=
state_dict
[
"per_bucket_numel_unpadded"
]
assert
self
.
per_bucket_numel_unpadded
==
per_bucket_numel_unpadded_in_checkpoint
,
(
f
"Number of unpadded elements in each bucket need to be the same in current run "
f
"(
{
self
.
per_bucket_numel_unpadded
}
) and checkpoint "
f
"(
{
per_bucket_numel_unpadded_in_checkpoint
}
)"
)
for
gbuf_idx
,
gbuf_range_maps
in
enumerate
(
self
.
gbuf_ranges
):
assert
len
(
gbuf_range_maps
)
==
1
,
"single dtype supported, for now."
for
dtype
,
gbuf_range_map_for_all_buckets
in
gbuf_range_maps
.
items
():
for
bucket_idx
,
gbuf_range_map
in
enumerate
(
gbuf_range_map_for_all_buckets
):
bucket_state
=
state_dict
[
gbuf_idx
][
dtype
][
bucket_idx
]
bucket_state
=
[
bucket_state_elem
for
bucket_state_elem
in
bucket_state
if
not
bucket_state_elem
[
'padding'
]
]
assert
len
(
bucket_state
)
==
len
(
gbuf_range_map
[
"param_map"
]),
(
len
(
bucket_state
),
len
(
gbuf_range_map
[
"param_map"
]),
)
for
src_tensors
,
(
model_param
,
param_range_map
)
in
zip
(
bucket_state
,
gbuf_range_map
[
"param_map"
].
items
()
):
# Main param & optimizer states.
group_index
,
group_order
=
self
.
model_param_group_index_map
[
model_param
]
main_param
=
self
.
optimizer
.
param_groups
[
group_index
][
"params"
][
group_order
]
optim_state
=
self
.
optimizer
.
state
[
main_param
]
dst_tensors
=
{
"param"
:
main_param
,
**
optim_state
}
for
key
in
dst_tensors
:
dst_tensors
[
key
].
copy_
(
src_tensors
[
key
])
@
torch
.
no_grad
()
def
load_parameter_state_from_fs_model_space
(
self
,
state_dict
):
"""Loads the parameter state from a "model space" representation.
Inverse of the `sharded_param_state_fs_model_space` method.
"""
param_idx
=
0
# matching order with `sharded_param_state_fs_model_space`
for
gbuf_range_maps
in
self
.
gbuf_ranges
:
for
gbuf_range_map_for_all_buckets
in
gbuf_range_maps
.
values
():
for
gbuf_range_map
in
gbuf_range_map_for_all_buckets
:
for
model_param
,
param_range_map
in
gbuf_range_map
[
"param_map"
].
items
():
group_index
,
group_order
=
self
.
model_param_group_index_map
[
model_param
]
main_param
=
self
.
optimizer
.
param_groups
[
group_index
][
"params"
][
group_order
]
optim_state
=
self
.
optimizer
.
state
[
main_param
]
src_tensors
=
state_dict
[
param_idx
]
dst_tensors
=
{
"fp32_param"
:
main_param
,
**
optim_state
}
for
key
in
dst_tensors
:
dst_tensors
[
key
].
copy_
(
src_tensors
[
key
])
param_idx
+=
1
@
classmethod
def
_update_legacy_world_tensors
(
cls
,
old_tensors
,
new_numels
):
'''Reshard buckets (where each bucket is a tensor) to new target
numels, where the total numel remains the same.'''
old_total
=
sum
([
t
.
numel
()
for
t
in
old_tensors
])
new_total
=
sum
(
new_numels
)
assert
old_total
==
new_total
unified_tensor
=
torch
.
cat
(
old_tensors
,
dim
=
0
)
new_tensors
=
[]
start_idx
=
0
for
new_numel
in
new_numels
:
new_tensors
.
append
(
unified_tensor
[
start_idx
:
(
start_idx
+
new_numel
)])
start_idx
+=
new_numel
return
new_tensors
def
load_parameter_state_from_dp_zero_legacy
(
self
,
state_dict
):
"""Load parameter state (i.e., parameter & optimizer tensors) from DP 0 rank,
using the legacy checkpoint format as described below.
The difference between this method and `load_parameter_state_from_dp_zero_modern()`
is that this method is used for updating the format of checkpoints that
were saved using code from before Feb 13, 2024. Starting on this date, a
new format was used (i.e., different format for the parameter mapping and
bucket sharding).
Use arg `--ckpt-convert-update-legacy-dist-opt-format` to call this
method, along with `--ckpt-convert-format` and `--ckpt-convert-save` to
update a legacy-format checkpoint to the modern format.
"""
# Data parallelism variables.
data_parallel_world_size
=
self
.
data_parallel_group_gloo
.
size
()
data_parallel_rank
=
torch
.
distributed
.
get_rank
(
self
.
data_parallel_group_gloo
)
data_parallel_group_gloo
=
self
.
data_parallel_group_gloo
data_parallel_global_ranks
=
torch
.
distributed
.
get_process_group_ranks
(
self
.
data_parallel_group_gloo
)
# Scatter tensors to all DP ranks.
for
gbuf_idx
,
gbuf_range_maps
in
enumerate
(
self
.
gbuf_ranges
):
for
dtype
,
gbuf_range_map_for_all_buckets
in
gbuf_range_maps
.
items
():
if
data_parallel_rank
==
0
:
buffer_numel_unpadded
=
self
.
buffers
[
gbuf_idx
].
numel_unpadded
model_numels
=
[
b
.
numel_unpadded
for
b
in
self
.
buffers
[
gbuf_idx
].
buckets
]
checkpoint_numels
=
[
t
.
numel
()
for
t
in
state_dict
[
gbuf_idx
][
torch
.
float32
][
"param"
]
]
assert
sum
(
model_numels
)
==
sum
(
checkpoint_numels
)
for
key
in
(
"param"
,
"exp_avg"
,
"exp_avg_sq"
):
legacy_world_tensors
=
self
.
_update_legacy_world_tensors
(
state_dict
[
gbuf_idx
][
torch
.
float32
][
key
],
[
self
.
buffers
[
gbuf_idx
].
buckets
[
bi
].
numel_unpadded
for
bi
in
range
(
len
(
gbuf_range_map_for_all_buckets
))
],
)
offset_in_world_tensors
=
0
for
bucket_idx
,
gbuf_range_map
in
enumerate
(
gbuf_range_map_for_all_buckets
):
# Compute local DP contiguous shard's size.
gbuf_world_numel
=
(
self
.
buffers
[
gbuf_idx
].
buckets
[
bucket_idx
].
grad_data
.
numel
()
)
assert
gbuf_world_numel
%
data_parallel_world_size
==
0
gbuf_local_numel
=
gbuf_world_numel
//
data_parallel_world_size
gbuf_world_numel_unpadded
=
(
self
.
buffers
[
gbuf_idx
].
buckets
[
bucket_idx
].
numel_unpadded
)
assert
gbuf_world_numel_unpadded
<=
gbuf_world_numel
# Contiguous local shards (received from DP rank 0).
recv_tensor
=
torch
.
empty
(
(
gbuf_local_numel
,),
dtype
=
torch
.
float32
,
device
=
"cpu"
)
# Scatter tensor list.
if
data_parallel_rank
==
0
:
start
=
offset_in_world_tensors
end
=
offset_in_world_tensors
+
gbuf_world_numel_unpadded
world_tensor
=
legacy_world_tensors
[
bucket_idx
]
assert
(
world_tensor
.
numel
()
==
gbuf_world_numel_unpadded
),
"%d vs. %d."
%
(
world_tensor
.
numel
(),
gbuf_world_numel_unpadded
)
offset_in_world_tensors
+=
gbuf_world_numel_unpadded
# Pad world_tensor to gbuf_world_numel. Don't pad at the front,
# pad at the back.
world_tensor
=
torch
.
nn
.
functional
.
pad
(
world_tensor
,
(
0
,
gbuf_world_numel
-
gbuf_world_numel_unpadded
)
)
assert
world_tensor
.
numel
()
==
gbuf_world_numel
gbuf_start_idxs
=
list
(
range
(
0
,
gbuf_world_numel
,
gbuf_local_numel
))
send_tensors
=
[
world_tensor
[
i
:
(
i
+
gbuf_local_numel
)]
for
i
in
gbuf_start_idxs
]
else
:
send_tensors
=
None
# Scatter.
torch
.
distributed
.
scatter
(
recv_tensor
,
send_tensors
,
data_parallel_global_ranks
[
0
],
data_parallel_group_gloo
,
)
# Copy local contiguous shards to param/optim shards.
for
model_param
,
param_range_map
in
gbuf_range_map
[
"param_map"
].
items
():
# Main param & optimizer states.
group_index
,
group_order
=
self
.
model_param_group_index_map
[
model_param
]
main_param
=
self
.
optimizer
.
param_groups
[
group_index
][
"params"
][
group_order
]
if
key
==
"param"
:
tensor_to_copy_into
=
main_param
else
:
optim_state
=
self
.
optimizer
.
state
[
main_param
]
tensor_to_copy_into
=
optim_state
[
key
]
# Copy states into contiguous shard.
gbuf_local_start
=
param_range_map
[
"gbuf_local"
].
start
gbuf_local_end
=
param_range_map
[
"gbuf_local"
].
end
tensor_to_copy_into
.
data
.
copy_
(
recv_tensor
[
gbuf_local_start
:
gbuf_local_end
]
)
def
load_parameter_state_from_dp_zero
(
self
,
state_dict
,
*
,
update_legacy_format
=
False
):
"""Load parameter state (i.e., parameter & optimizer tensors) from DP 0 rank,
using the new checkpoint format with coalesced state across buckets.
This method performs the reverse of get_parameter_state_dp_zero():
- Scatter contiguous buffers from DP rank 0 to each DP rank (each DP
rank receives its relevant subset of the world buffers).
- For each DP rank, copy param & optimizer shards from contiguous CPU
buffers. (e.g., one buffer each for main_param, exp_avg, and
exp_avg_sq).
"""
# Selectively load from a legacy checkpoint. The legacy format was used
# prior to Feb 13, 2024.
if
update_legacy_format
:
return
self
.
load_parameter_state_from_dp_zero_legacy
(
state_dict
)
# Data parallelism variables.
data_parallel_world_size
=
self
.
data_parallel_group_gloo
.
size
()
data_parallel_rank
=
torch
.
distributed
.
get_rank
(
self
.
data_parallel_group_gloo
)
data_parallel_group_gloo
=
self
.
data_parallel_group_gloo
data_parallel_global_ranks
=
torch
.
distributed
.
get_process_group_ranks
(
self
.
data_parallel_group_gloo
)
if
data_parallel_rank
==
0
:
# Do nothing if "--fp8-param-gather" is not used.
self
.
split_state_dict_if_needed
(
state_dict
)
# Scatter tensors to all DP ranks.
for
gbuf_idx
,
gbuf_range_maps
in
enumerate
(
self
.
gbuf_ranges
):
for
dtype
,
gbuf_range_map_for_all_buckets
in
gbuf_range_maps
.
items
():
if
data_parallel_rank
==
0
:
buffer_numel_unpadded
=
self
.
buffers
[
gbuf_idx
].
numel_unpadded
checkpoint_numel_unpadded
=
state_dict
[
gbuf_idx
][
dtype
][
"numel_unpadded"
]
assert
buffer_numel_unpadded
==
checkpoint_numel_unpadded
,
(
f
"Number of unpadded elements must be same in current run "
f
"(
{
buffer_numel_unpadded
}
) and checkpoint (
{
checkpoint_numel_unpadded
}
)"
)
for
key
in
(
"param"
,
"exp_avg"
,
"exp_avg_sq"
):
offset_in_world_tensors
=
0
for
bucket_idx
,
gbuf_range_map
in
enumerate
(
gbuf_range_map_for_all_buckets
):
# Compute local DP contiguous shard's size.
gbuf_world_numel
=
(
self
.
buffers
[
gbuf_idx
].
buckets
[
bucket_idx
].
grad_data
.
numel
()
)
assert
gbuf_world_numel
%
data_parallel_world_size
==
0
gbuf_local_numel
=
gbuf_world_numel
//
data_parallel_world_size
gbuf_world_numel_unpadded
=
(
self
.
buffers
[
gbuf_idx
].
buckets
[
bucket_idx
].
numel_unpadded
)
assert
gbuf_world_numel_unpadded
<=
gbuf_world_numel
# Contiguous local shards (received from DP rank 0).
recv_tensor
=
torch
.
zeros
(
(
gbuf_local_numel
,),
dtype
=
torch
.
float32
,
device
=
"cpu"
)
# Scatter tensor list.
if
data_parallel_rank
==
0
:
world_tensors
=
state_dict
[
gbuf_idx
][
dtype
][
key
]
start
=
offset_in_world_tensors
end
=
offset_in_world_tensors
+
gbuf_world_numel_unpadded
assert
0
<=
start
<
end
<=
world_tensors
.
numel
()
world_tensor
=
world_tensors
[
start
:
end
]
offset_in_world_tensors
+=
gbuf_world_numel_unpadded
# Pad world_tensor to gbuf_world_numel. Don't pad at the front,
# pad at the back.
world_tensor
=
torch
.
nn
.
functional
.
pad
(
world_tensor
,
(
0
,
gbuf_world_numel
-
gbuf_world_numel_unpadded
)
)
assert
world_tensor
.
numel
()
==
gbuf_world_numel
gbuf_start_idxs
=
list
(
range
(
0
,
gbuf_world_numel
,
gbuf_local_numel
))
send_tensors
=
[
world_tensor
[
i
:
(
i
+
gbuf_local_numel
)]
for
i
in
gbuf_start_idxs
]
else
:
send_tensors
=
None
# Scatter.
torch
.
distributed
.
scatter
(
recv_tensor
,
send_tensors
,
data_parallel_global_ranks
[
0
],
data_parallel_group_gloo
,
)
# Copy local contiguous shards to param/optim shards.
for
model_param
,
param_range_map
in
gbuf_range_map
[
"param_map"
].
items
():
# Main param & optimizer states.
group_index
,
group_order
=
self
.
model_param_group_index_map
[
model_param
]
main_param
=
self
.
optimizer
.
param_groups
[
group_index
][
"params"
][
group_order
]
if
key
==
"param"
:
tensor_to_copy_into
=
main_param
else
:
optim_state
=
self
.
optimizer
.
state
[
main_param
]
tensor_to_copy_into
=
optim_state
[
key
]
# Copy states into contiguous shard.
gbuf_local_start
=
param_range_map
[
"gbuf_local"
].
start
gbuf_local_end
=
param_range_map
[
"gbuf_local"
].
end
tensor_to_copy_into
.
data
.
copy_
(
recv_tensor
[
gbuf_local_start
:
gbuf_local_end
]
)
def
split_state_dict_if_needed
(
self
,
state_dict
):
"""
When "--fp8-param-gather" is disabled, weights and biases are stored in the same
`ParamAndGradBuffer`. So, when saving a checkpoint, the optimizer's main parameters are
saved in a single continuous tensor (this also applies to "exp_avg" and "exp_avg_sq").
However, when "--fp8-param-gather" is enabled, weights(in fp8 dtype) and biases(in bf16/fp16
dtype) are stored in separate `ParamAndGradBuffer`. Therefore, when we enabled
"--fp8-param-gather", and want to load a checkpoint saved without "--fp8-param-gather", we
need to split the weights(fp8) and biases(bf16/fp16) in the static_dict into two separate
tensors.
"""
# Skip if there is no fp8 buffers.
fp8_gbuf_indices
=
[]
for
gbuf_idx
,
gbuf_range_maps
in
enumerate
(
self
.
gbuf_ranges
):
for
dtype
,
_
in
gbuf_range_maps
.
items
():
if
is_float8tensor
(
self
.
buffers
[
gbuf_idx
].
params
[
0
]):
fp8_gbuf_indices
.
append
(
gbuf_idx
)
if
len
(
fp8_gbuf_indices
)
==
0
:
return
dtype_to_gbuf_idx
=
{}
for
key
in
state_dict
.
keys
():
if
key
!=
'buckets_coalesced'
:
for
dtype
in
state_dict
[
key
].
keys
():
assert
dtype
not
in
dtype_to_gbuf_idx
if
dtype
[
0
]
==
torch
.
uint8
:
# If the `state_dict`` already contains a torch.uint8 buffer, we assumed
# that the fp8 weights and fp16/bf16 biases in the checkpoint are already
# separated. In this case, no action is required, so we can return directly.
return
dtype_to_gbuf_idx
[
dtype
]
=
key
# 1. Replace the gbuf_idx in the checkpoint with the new gbuf_idx.
# 2. Copy the non-tensor data (i.e., the "buckets_coalesced") to `new_state_dict`.
new_state_dict
=
{
'buckets_coalesced'
:
state_dict
[
'buckets_coalesced'
]}
for
gbuf_idx
,
gbuf_range_maps
in
enumerate
(
self
.
gbuf_ranges
):
for
dtype
,
_
in
gbuf_range_maps
.
items
():
if
not
is_float8tensor
(
self
.
buffers
[
gbuf_idx
].
params
[
0
]):
new_state_dict
[
gbuf_idx
]
=
state_dict
[
dtype_to_gbuf_idx
[
dtype
]]
for
fp8_gbuf_idx
in
fp8_gbuf_indices
:
# Note that `self.buffers[fp8_gbuf_idx].params[0].dtype` is the dummy dtype of
# `Float8Tensor`, not torch.uint8.
non_fp8_param_and_grad_dtype
=
(
self
.
buffers
[
fp8_gbuf_idx
].
params
[
0
].
dtype
,
self
.
buffers
[
fp8_gbuf_idx
].
grad_dtype
,
)
# Iterate through all buffers to find the one that needs to be split.
non_fp8_gbuf_idx
=
None
for
gbuf_idx
,
gbuf_range_maps
in
enumerate
(
self
.
gbuf_ranges
):
for
dtype
,
_
in
gbuf_range_maps
.
items
():
if
dtype
==
non_fp8_param_and_grad_dtype
:
non_fp8_gbuf_idx
=
gbuf_idx
assert
non_fp8_gbuf_idx
is
not
None
# We need the fp8_flags to determine the order of weight (fp8) and bias (fp16/bf16) in
# the buffer.
index_to_fp8_map
=
{}
for
index
in
self
.
buffers
[
fp8_gbuf_idx
].
param_indices
:
assert
index
not
in
index_to_fp8_map
index_to_fp8_map
[
index
]
=
True
for
index
in
self
.
buffers
[
non_fp8_gbuf_idx
].
param_indices
:
assert
index
not
in
index_to_fp8_map
index_to_fp8_map
[
index
]
=
False
param_indices
=
(
self
.
buffers
[
fp8_gbuf_idx
].
param_indices
+
self
.
buffers
[
non_fp8_gbuf_idx
].
param_indices
)
assert
min
(
param_indices
)
==
0
assert
max
(
param_indices
)
==
len
(
param_indices
)
-
1
fp8_flags
=
[]
for
i
in
range
(
len
(
param_indices
)):
fp8_flag
.
append
(
index_to_fp8_map
[
i
])
fp8_buffer
=
self
.
buffers
[
fp8_gbuf_idx
]
non_fp8_buffer
=
self
.
buffers
[
non_fp8_gbuf_idx
]
fp8_idx
=
len
(
fp8_buffer
.
params
)
-
1
non_fp8_idx
=
len
(
non_fp8_buffer
.
params
)
-
1
offsets
,
fp8_offsets
,
non_fp8_offsets
=
[
0
],
[
0
],
[
0
]
# Because the parameters in `ParamAndGradBuffer` are traversed in reverse order, the
# flag here also needs to be traversed in reverse order.
for
fp8_flag
in
fp8_flags
[::
-
1
]:
if
fp8_flag
:
numel
=
fp8_buffer
.
params
[
fp8_idx
].
nelement
()
fp8_idx
-=
1
offsets
.
append
(
offsets
[
-
1
]
+
numel
)
fp8_offsets
.
append
(
fp8_offsets
[
-
1
]
+
numel
)
else
:
numel
=
non_fp8_buffer
.
params
[
non_fp8_idx
].
nelement
()
non_fp8_idx
-=
1
offsets
.
append
(
offsets
[
-
1
]
+
numel
)
non_fp8_offsets
.
append
(
non_fp8_offsets
[
-
1
]
+
numel
)
# Split the target buffer into two separate buffers.
fp8_state_dict
,
non_fp8_state_dict
=
{},
{}
for
key
in
[
'param'
,
'exp_avg'
,
'exp_avg_sq'
]:
tensor
=
state_dict
[
non_fp8_gbuf_idx
][
non_fp8_param_and_grad_dtype
][
key
]
fp8_tensor
=
torch
.
empty
([
fp8_offsets
[
-
1
]],
dtype
=
tensor
.
dtype
)
non_fp8_tensor
=
torch
.
empty
([
non_fp8_offsets
[
-
1
]],
dtype
=
tensor
.
dtype
)
fp8_idx
,
non_fp8_idx
=
0
,
0
for
i
in
range
(
len
(
offsets
)
-
1
):
if
fp8_flags
[
-
(
i
+
1
)]:
fp8_tensor
[
fp8_offsets
[
fp8_idx
]
:
fp8_offsets
[
fp8_idx
+
1
]].
copy_
(
tensor
[
offsets
[
i
]
:
offsets
[
i
+
1
]]
)
fp8_idx
+=
1
else
:
non_fp8_tensor
[
non_fp8_offsets
[
non_fp8_idx
]
:
non_fp8_offsets
[
non_fp8_idx
+
1
]
].
copy_
(
tensor
[
offsets
[
i
]
:
offsets
[
i
+
1
]])
non_fp8_idx
+=
1
fp8_state_dict
[
key
]
=
fp8_tensor
non_fp8_state_dict
[
key
]
=
non_fp8_tensor
fp8_state_dict
[
'numel_unpadded'
]
=
fp8_offsets
[
-
1
]
non_fp8_state_dict
[
'numel_unpadded'
]
=
non_fp8_offsets
[
-
1
]
# Add the two separate buffers into `new_state_dict`.
new_state_dict
[
fp8_gbuf_idx
]
=
{}
new_state_dict
[
fp8_gbuf_idx
][(
torch
.
uint8
,
fp8_buffer
.
grad_dtype
)]
=
fp8_state_dict
new_state_dict
[
non_fp8_gbuf_idx
][
non_fp8_param_and_grad_dtype
]
=
non_fp8_state_dict
# Inplace update state_dict
state_dict
.
clear
()
for
key
,
value
in
new_state_dict
.
items
():
state_dict
[
key
]
=
value
def
load_parameter_state
(
self
,
filename
:
str
,
*
,
update_legacy_format
=
False
):
"""Load the distributed parameter state from disk.
Args:
filename (str): path to load parameter state from.
"""
state_dict
=
None
if
torch
.
distributed
.
get_rank
(
self
.
data_parallel_group
)
==
0
:
state_dict
=
torch
.
load
(
filename
)
self
.
load_parameter_state_from_dp_zero
(
state_dict
,
update_legacy_format
=
update_legacy_format
)
def
zero_grad
(
self
,
set_to_none
:
bool
=
True
):
"""
Zeroes grads for the model related parameters, i.e., model_float16_groups
and model_fp32_groups. We additionally zero the remaining groups as a
memory optimization to reduce fragmentation; in the case of
set_to_none==True, the space used by this field can be safely deallocated.
Args:
set_to_none (bool): if true, set grads to None.
"""
for
groups
in
(
self
.
model_float16_groups
,
self
.
model_fp32_groups
,
self
.
shard_float16_groups
,
# grad empty/unused here?
self
.
shard_fp32_groups
,
# throws grad-access warning
self
.
shard_fp32_from_float16_groups
,
):
for
group
in
groups
:
_zero_grad_group_helper
(
group
,
set_to_none
)
def
_collect_main_grad_data_for_unscaling
(
self
):
"""
Note: this should be equivalent to the float-16 optimizer's method,
but written differently, so the two should be combined.
"""
return
[
param
.
grad
.
data
for
group
in
self
.
optimizer
.
param_groups
for
param
in
group
[
"params"
]
]
def
_get_model_and_main_params_data_float16
(
self
):
"""
Get aligned list of model and main params.
"""
model_data
=
[]
main_data
=
[]
for
model_group
,
main_group
in
zip
(
self
.
shard_float16_groups
,
self
.
shard_fp32_from_float16_groups
):
for
model_param
,
main_param
in
zip
(
model_group
,
main_group
):
model_data
.
append
(
model_param
.
data
)
main_data
.
append
(
main_param
.
data
)
return
model_data
,
main_data
def
_copy_model_grads_to_main_grads
(
self
):
"""
Copy model grads to main grads.
Since this step follows a reduce-scatter through the DDP's grad
buffer, this method is responsible for copying the updated grads
from the grad buffer to the main shard's grad field.
"""
# Utility method for copying group grads.
def
copy_group_grads
(
model_groups
,
shard_main_groups
):
for
model_group
,
shard_main_group
in
zip
(
model_groups
,
shard_main_groups
):
for
model_param
,
shard_main_param
in
zip
(
model_group
,
shard_main_group
):
param_range_map
=
self
.
_get_model_param_range_map
(
model_param
)
param_range
=
param_range_map
[
"param"
]
assert
param_range
.
size
==
shard_main_param
.
nelement
()
model_grad
=
model_param
.
main_grad
shard_model_grad
=
model_grad
.
view
(
-
1
)[
param_range
.
start
:
param_range
.
end
]
shard_main_param
.
grad
=
shard_model_grad
.
float
()
# Copy model groups to shard groups.
copy_group_grads
(
self
.
model_float16_groups
,
self
.
shard_fp32_from_float16_groups
)
copy_group_grads
(
self
.
model_fp32_groups
,
self
.
shard_fp32_groups
)
def
_copy_main_params_to_model_params
(
self
):
"""
Copy main params to model params.
Since this step is followed by an all-gather through the DDP's grad
buffer, this method is responsible for copying the updated params
from the main shards into the correct position in the grad buffer.
"""
# Utility method for copying group params.
def
copy_group_params
(
shard_main_groups
,
model_groups
):
for
shard_main_group
,
model_group
in
zip
(
shard_main_groups
,
model_groups
):
for
shard_main_param
,
model_param
in
zip
(
shard_main_group
,
model_group
):
param_range_map
=
self
.
_get_model_param_range_map
(
model_param
)
world_range
=
param_range_map
[
"gbuf_world_in_bucket"
]
assert
world_range
.
size
==
shard_main_param
.
nelement
()
gbuf_index
,
_
,
bucket_id
=
self
.
model_param_gbuf_map
[
model_param
]
model_param_buffer
=
self
.
buffers
[
gbuf_index
].
buckets
[
bucket_id
].
param_data
shard_model_param
=
model_param_buffer
.
view
(
-
1
)[
world_range
.
start
:
world_range
.
end
]
if
is_float8tensor
(
model_param
):
# 1. When "--fp8-param-gather" is disabled, the main param is first cast to
# BF16/FP16, and then cast to FP8, so the amax_history is calculated
# using BF16/FP16 param.
# 2. When "--fp8-param-gather" is enabled, we can cast the FP32 main param
# to FP8 directly, which results in slightly different results with
# higher speed. In theory, this does not affect convergence.
# TODO: The following code maintains the logic of the point-1 above. It can
# be deleted if it is not necessary.
shard_main_param
=
shard_main_param
.
to
(
model_param
.
dtype
)
cast_to_fp8
(
shard_main_param
.
view
(
1
,
-
1
),
model_param
.
_fp8_meta
[
'scaling_fwd'
],
model_param
.
_fp8_meta_index
,
model_param
.
_fp8_dtype
,
out
=
shard_model_param
.
view
(
1
,
-
1
),
)
else
:
shard_model_param
.
data
.
copy_
(
shard_main_param
)
# Copy shard groups to model groups.
copy_group_params
(
self
.
shard_fp32_from_float16_groups
,
self
.
model_float16_groups
)
copy_group_params
(
self
.
shard_fp32_groups
,
self
.
model_fp32_groups
)
def
_copy_model_params_to_main_params
(
self
):
"""
Copy model params to main params.
During finetuning, this method is used to reload the main params from
the model params. This copy does not make use of the grad buffer as
an intermediary.
"""
# Utility method for copying group params.
def
copy_group_params
(
model_groups
,
shard_main_groups
):
for
model_group
,
shard_main_group
in
zip
(
model_groups
,
shard_main_groups
):
for
model_param
,
shard_main_param
in
zip
(
model_group
,
shard_main_group
):
param_range_map
=
self
.
_get_model_param_range_map
(
model_param
)
param_range
=
param_range_map
[
"param"
]
assert
param_range
.
size
==
shard_main_param
.
nelement
()
shard_model_param
=
model_param
.
view
(
-
1
)[
param_range
.
start
:
param_range
.
end
]
shard_main_param
.
data
.
copy_
(
shard_model_param
)
# Copy model groups to shard groups.
copy_group_params
(
self
.
model_float16_groups
,
self
.
shard_fp32_from_float16_groups
)
copy_group_params
(
self
.
model_fp32_groups
,
self
.
shard_fp32_groups
)
def
_update_fp8_scale_inv_and_amax
(
self
):
"""
If detect FP8 parameters, update their `_scale_inv` and do reduce-max for their
`amax_history`.
"""
amaxes
=
[]
scales
=
[]
scale_invs
=
[]
# Iterate over all parameters inside this optimizer to find FP8 parameters.
for
buffer
in
self
.
buffers
:
for
bucket
in
buffer
.
buckets
:
for
param
in
bucket
.
params_list
:
if
is_float8tensor
(
param
):
fp8_meta
=
param
.
_fp8_meta
[
'scaling_fwd'
]
fp8_meta_index
=
param
.
_fp8_meta_index
amaxes
.
append
(
fp8_meta
.
amax_history
[
0
][
fp8_meta_index
].
view
(
1
))
scales
.
append
(
fp8_meta
.
scale
[
fp8_meta_index
].
view
(
1
))
scale_invs
.
append
(
param
.
_scale_inv
.
view
(
1
))
# Reset transpose cache
param
.
_reset_caches
()
# If there is no FP8 parameters, skip all operations.
if
len
(
scales
)
>
0
:
dummy_overflow_buf
=
torch
.
tensor
([
0
],
dtype
=
torch
.
int
,
device
=
'cuda'
)
# Update scaling factors.
packed_scales
=
torch
.
empty
(
len
(
scales
),
dtype
=
torch
.
float32
,
device
=
scales
[
0
].
device
)
packed_scale_views
=
[
packed_scales
[
i
].
view
(
1
)
for
i
in
range
(
len
(
scales
))]
_multi_tensor_copy_this_to_that
(
scales
,
packed_scale_views
,
dummy_overflow_buf
)
torch
.
reciprocal
(
packed_scales
,
out
=
packed_scales
)
_multi_tensor_copy_this_to_that
(
packed_scale_views
,
scale_invs
,
dummy_overflow_buf
)
# Reduce amaxes.
# Note: Assume each param has a separate amax.
packed_amaxes
=
torch
.
empty
(
len
(
amaxes
),
dtype
=
torch
.
float32
,
device
=
amaxes
[
0
].
device
)
packed_amax_views
=
[
packed_amaxes
[
i
].
view
(
1
)
for
i
in
range
(
len
(
amaxes
))]
_multi_tensor_copy_this_to_that
(
amaxes
,
packed_amax_views
,
dummy_overflow_buf
)
torch
.
distributed
.
all_reduce
(
packed_amaxes
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
group
=
self
.
data_parallel_group
)
_multi_tensor_copy_this_to_that
(
packed_amax_views
,
amaxes
,
dummy_overflow_buf
)
@
torch
.
no_grad
()
def
step_with_ready_grads
(
self
)
->
bool
:
"""Step the optimizer with ready gradients, return successful.
Under the hood, either launch synchronous param all-gathers or get ready to launch
asynchorous all-gathers that get overlapped with the next forward pass.
"""
update_successful
=
super
().
step_with_ready_grads
()
# If there is no FP8 parameters, this will do nothing.
self
.
_update_fp8_scale_inv_and_amax
()
timers
=
self
.
config
.
timers
if
timers
is
not
None
:
timers
(
'params-all-gather'
,
log_level
=
1
).
start
(
barrier
=
self
.
config
.
barrier_with_L1_time
)
# If not overlapping all-gather for parameters, launch synchronous all-gather
# communication calls here. If overlapping all-gather for parameters, the following
# the first all-gather is launched asynchronously in the next optimizer.zero_grad()
# call and subsequent all-gathers are launched in the forward pre-hook.
if
not
self
.
ddp_config
.
overlap_param_gather
:
for
model_chunk
in
self
.
model_chunks
:
model_chunk
.
start_param_sync
()
if
timers
is
not
None
:
timers
(
'params-all-gather'
).
stop
()
return
update_successful
megatron/optimizer/grad_scaler.py
→
megatron/
core/
optimizer/grad_scaler.py
View file @
4b097dee
# Copyright (c) 202
2
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 202
4
, NVIDIA CORPORATION. All rights reserved.
"""Megatron grad scaler."""
from
abc
import
ABC
from
abc
import
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
typing
import
Dict
import
torch
class
MegatronGradScaler
(
ABC
):
def
__init__
(
self
,
initial_scale
):
def
__init__
(
self
,
initial_scale
:
float
):
"""Initialize scale value with the input initial scale."""
assert
initial_scale
>
0.0
self
.
_scale
=
torch
.
cuda
.
FloatT
ensor
([
initial_scale
])
self
.
_scale
=
torch
.
t
ensor
([
initial_scale
]
,
dtype
=
torch
.
float
,
device
=
'cuda'
)
@
property
def
scale
(
self
):
...
...
@@ -24,7 +23,7 @@ class MegatronGradScaler(ABC):
return
self
.
_scale
.
double
().
reciprocal
().
float
()
@
abstractmethod
def
update
(
self
,
found_inf
):
def
update
(
self
,
found_inf
:
bool
):
pass
@
abstractmethod
...
...
@@ -32,14 +31,16 @@ class MegatronGradScaler(ABC):
pass
@
abstractmethod
def
load_state_dict
(
self
,
state_dict
):
def
load_state_dict
(
self
,
state_dict
:
Dict
):
pass
class
ConstantGradScaler
(
MegatronGradScaler
):
"""
Constant grad scaler (loss scale is never adjusted regardless of NaNs seen in gradients).
"""
def
update
(
self
,
found_inf
):
def
update
(
self
,
found_inf
:
bool
):
pass
def
state_dict
(
self
):
...
...
@@ -49,26 +50,48 @@ class ConstantGradScaler(MegatronGradScaler):
pass
class
DynamicGradScaler
(
MegatronGradScaler
):
def
__init__
(
self
,
initial_scale
,
min_scale
,
growth_factor
,
backoff_factor
,
growth_interval
,
hysteresis
):
""""Grad scaler with dynamic scale that gets adjusted
during training."""
"""
Grad scaler with dynamic scale that gets adjusted during training.
Reduces loss scale by `backoff_factor` if `hysteresis` number of NaNs are seen in a row. Increases
loss scale by `growth_factor` if NaNs are not seen for `growth_interval` iterations.
"""
def
__init__
(
self
,
initial_scale
:
float
,
min_scale
:
float
,
growth_factor
:
float
,
backoff_factor
:
float
,
growth_interval
:
int
,
hysteresis
:
int
,
):
"""
Grad scaler with dynamic scale that gets adjusted during training.
Args:
initial_scale (float): Initial loss scale value.
min_scale (float): Minimum loss scale value.
growth_factor (float): Factor to grow loss scale by if NaNs are not seen in `growth_interval`
training iterations. Must be greater than 1.
backoff_factor (float): Factor to decrease loss scale by if NaNs are seen in `hysteresis`
consecutive training iterations. Must be between 0 and 1.
growth_interval (int): Number of training iterations of no NaNs before loss scale is increased.
hysteresis (int): Number of training iterations of consecutive NaNs before loss scale is decreased.
"""
super
(
DynamicGradScaler
,
self
).
__init__
(
initial_scale
)
# Lower bound on the scale.
assert
min_scale
>
0.0
assert
min_scale
<=
initial_scale
self
.
min_scale
=
torch
.
cuda
.
FloatT
ensor
([
min_scale
])
self
.
min_scale
=
torch
.
t
ensor
([
min_scale
]
,
dtype
=
torch
.
float
,
device
=
'cuda'
)
# Growth and backoff factors for the scale.
assert
growth_factor
>
1.0
self
.
growth_factor
=
torch
.
cuda
.
FloatT
ensor
([
growth_factor
])
self
.
growth_factor
=
torch
.
t
ensor
([
growth_factor
]
,
dtype
=
torch
.
float
,
device
=
'cuda'
)
assert
backoff_factor
<
1.0
assert
backoff_factor
>
0.0
self
.
backoff_factor
=
torch
.
cuda
.
FloatT
ensor
([
backoff_factor
])
self
.
backoff_factor
=
torch
.
t
ensor
([
backoff_factor
]
,
dtype
=
torch
.
float
,
device
=
'cuda'
)
# Interval over which if we don't see any inf/nan,
# we will scale the grad scale by the growth factor.
assert
growth_interval
>
0
...
...
@@ -82,8 +105,10 @@ class DynamicGradScaler(MegatronGradScaler):
self
.
_growth_tracker
=
0
self
.
_hysteresis_tracker
=
self
.
hysteresis
def
update
(
self
,
found_inf
):
def
update
(
self
,
found_inf
:
bool
):
"""
Updates internal state in grad scaler based on whether NaNs are seen in grads or not.
"""
# If we have an inf/nan, growth tracker is set to 0
# and hysterisis tracker is reduced by 1.
...
...
@@ -92,8 +117,7 @@ class DynamicGradScaler(MegatronGradScaler):
self
.
_hysteresis_tracker
-=
1
# Now if we are out of hysteresis count, scale down the loss.
if
self
.
_hysteresis_tracker
<=
0
:
self
.
_scale
=
torch
.
max
(
self
.
_scale
*
self
.
backoff_factor
,
self
.
min_scale
)
self
.
_scale
=
torch
.
max
(
self
.
_scale
*
self
.
backoff_factor
,
self
.
min_scale
)
else
:
# If there is no nan/inf, increment the growth tracker.
self
.
_growth_tracker
+=
1
...
...
@@ -105,7 +129,6 @@ class DynamicGradScaler(MegatronGradScaler):
# and scale up the loss scale.
self
.
_scale
=
self
.
_scale
*
self
.
growth_factor
def
state_dict
(
self
):
state_dict
=
{}
state_dict
[
'scale'
]
=
self
.
_scale
...
...
@@ -113,8 +136,7 @@ class DynamicGradScaler(MegatronGradScaler):
state_dict
[
'hysteresis_tracker'
]
=
self
.
_hysteresis_tracker
return
state_dict
def
load_state_dict
(
self
,
state_dict
):
def
load_state_dict
(
self
,
state_dict
:
Dict
):
self
.
_scale
=
state_dict
[
'scale'
].
cuda
(
torch
.
cuda
.
current_device
())
self
.
_growth_tracker
=
state_dict
[
'growth_tracker'
]
self
.
_hysteresis_tracker
=
state_dict
[
'hysteresis_tracker'
]
megatron/optimizer/optimizer.py
→
megatron/
core/
optimizer/optimizer.py
View file @
4b097dee
# Copyright (c) 202
3
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 202
4
, NVIDIA CORPORATION. All rights reserved.
"""Megatron optimizer."""
from
abc
import
ABC
from
abc
import
abstractmethod
from
apex.multi_tensor_apply
import
multi_tensor_applier
import
amp_C
import
torch
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
torch._utils
import
_flatten_dense_tensors
,
_unflatten_dense_tensors
from
megatron
import
get_timers
from
megatron
import
print_rank_0
from
megatron.core
import
mpu
,
tensor_parallel
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.model
import
Float16Module
from
megatron.model.module
import
param_is_not_shared
from
megatron.utils
import
unwrap_model
from
.clip_grads
import
clip_grad_norm_fp32
,
count_zeros_fp32
import
copy
import
math
import
warnings
from
abc
import
ABC
,
abstractmethod
from
itertools
import
chain
from
logging
import
getLogger
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
torch
def
_zero_grad_group_helper
(
group
,
set_to_none
):
"""Zero out the gradient for a group of parameters.
Note: copied from torch.optim.optimizer."""
try
:
from
transformer_engine.pytorch.optimizers
import
multi_tensor_applier
except
ImportError
:
try
:
from
apex.multi_tensor_apply
import
multi_tensor_applier
except
ImportError
:
from
megatron.core.utils
import
local_multi_tensor_applier
multi_tensor_applier
=
local_multi_tensor_applier
try
:
import
amp_C
l2_norm_impl
=
amp_C
.
multi_tensor_l2norm
multi_tensor_scale_impl
=
amp_C
.
multi_tensor_scale
except
ImportError
:
from
megatron.core.utils
import
local_multi_tensor_l2_norm
,
local_multi_tensor_scale
l2_norm_impl
=
local_multi_tensor_l2_norm
multi_tensor_scale_impl
=
local_multi_tensor_scale
from
..
import
parallel_state
,
tensor_parallel
from
..config_logger
import
has_config_logger_enabled
,
log_config_to_disk
from
..dist_checkpointing.mapping
import
ShardedStateDict
from
..dist_checkpointing.optimizer
import
(
get_param_id_to_sharded_param_map
,
make_sharded_optimizer_tensor
,
optim_state_to_sharding_state
,
)
from
..dist_checkpointing.utils
import
add_prefix_for_sharding
from
..transformer.module
import
param_is_not_shared
from
.clip_grads
import
clip_grad_by_total_norm_fp32
,
count_zeros_fp32
,
get_grad_norm_fp32
from
.grad_scaler
import
MegatronGradScaler
from
.optimizer_config
import
OptimizerConfig
logger
=
getLogger
(
__name__
)
def
_zero_grad_group_helper
(
group
:
List
[
torch
.
nn
.
Parameter
],
set_to_none
:
bool
):
"""
Zero out the gradient for a group of parameters.
Note: copied from torch.optim.optimizer.
"""
for
param
in
group
:
if
param
.
grad
is
not
None
:
if
set_to_none
:
...
...
@@ -36,65 +66,65 @@ def _zero_grad_group_helper(group, set_to_none):
param
.
grad
.
zero_
()
def
_multi_tensor_copy_this_to_that
(
this
,
that
,
overflow_buf
=
None
):
"""Use multi-tensor-applier to copy values from one list to another.
We don't have a blfoat16 implementation so for now if the overflow_buf
def
_multi_tensor_copy_this_to_that
(
this
:
List
[
torch
.
Tensor
],
that
:
List
[
torch
.
Tensor
],
overflow_buf
:
Optional
[
torch
.
Tensor
]
=
None
):
"""
Use multi-tensor-applier to copy values from one list to another.
We don't have a bfloat16 implementation so for now if the overflow_buf
is not provided, we default back to simple loop copy to be compatible
with bfloat16."""
with bfloat16.
"""
if
overflow_buf
:
overflow_buf
.
fill_
(
0
)
# Scaling with factor `1.0` is equivalent to copy.
multi_tensor_applier
(
amp_C
.
multi_tensor_scale
,
overflow_buf
,
[
this
,
that
],
1.0
)
multi_tensor_applier
(
multi_tensor_scale_impl
,
overflow_buf
,
[
this
,
that
],
1.0
)
else
:
for
this_
,
that_
in
zip
(
this
,
that
):
that_
.
copy_
(
this_
)
class
MegatronOptimizer
(
ABC
):
"""
Base class for all Megatron optimizers.
Args:
optimizer (torch.optim.Optimizer): base optimizer such as Adam or SGD.
config (OptimizerConfig): configuration object for optimizer.
init_state_fn (Callable, optional): function to initialize state in the optimizer.
"""
def
__init__
(
self
,
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
params_have_main_grad
,
use_contiguous_buffers_in_local_ddp
,
models
):
"""Input optimizer is the base optimizer
for example
Adam."""
def
__init__
(
self
,
optimizer
:
torch
.
optim
.
Optimizer
,
config
:
OptimizerConfig
,
init_state_fn
:
Callable
=
lambda
x
:
None
,
):
"""Input optimizer is the base optimizer
(e.g.,
Adam
)
."""
self
.
optimizer
=
optimizer
assert
self
.
optimizer
,
'no optimizer is provided.'
# Set gradient clipping and logging params.
self
.
clip_grad
=
clip_grad
self
.
log_num_zeros_in_grad
=
log_num_zeros_in_grad
self
.
params_have_main_grad
=
params_have_main_grad
self
.
use_contiguous_buffers_in_local_ddp
=
use_contiguous_buffers_in_local_ddp
# 'models' are retained for access to the contiguous grad buffers.
# (see distributed optimizer)
self
.
models
=
models
if
self
.
use_contiguous_buffers_in_local_ddp
:
assert
self
.
params_have_main_grad
,
\
"use of contiguous buffer requires that params have main grad"
self
.
config
=
config
self
.
init_state_fn
=
init_state_fn
def
get_parameters
(
self
):
def
get_parameters
(
self
)
->
List
[
torch
.
nn
.
Parameter
]:
"""
Get list of parameters wrapped in optimizer.
"""
params
=
[]
for
param_group
in
self
.
optimizer
.
param_groups
:
for
param
in
param_group
[
'params'
]:
params
.
append
(
param
)
return
params
def
get_main_grads_for_grad_norm
(
self
):
# Filter parameters based on:
# - grad should not be none
# - parameter should not be shared
# - should not be a replica due to tensor model parallelism
def
get_main_grads_for_grad_norm
(
self
)
->
List
[
torch
.
Tensor
]:
"""
Get main_grads that should be taken into account to compute the grad norm.
Filter parameters based on:
- grad should not be None.
- parameter should not be shared (i.e., grads shouldn't be double counted while
computing norms).
- should not be a replica due to tensor model parallelism.
"""
params
=
self
.
get_parameters
()
grads_for_norm
=
[]
for
param
in
params
:
...
...
@@ -107,41 +137,69 @@ class MegatronOptimizer(ABC):
return
grads_for_norm
def
get_model_parallel_group
(
self
):
def
get_model_parallel_group
(
self
)
->
torch
.
distributed
.
ProcessGroup
:
"""Default returned here, but the distributed optimizer overrides this."""
return
mpu
.
get_model_parallel_group
()
if
hasattr
(
self
,
'model_parallel_group'
):
return
self
.
model_parallel_group
return
parallel_state
.
get_model_parallel_group
()
@
abstractmethod
def
prepare_grads
(
self
)
->
bool
:
"""Pre-processing gradients before the optimizer step, returns whether inf/nan is found."""
return
False
def
clip_grad_norm
(
self
,
clip_grad
):
params
=
self
.
get_parameters
()
grads_for_norm
=
self
.
get_main_grads_for_grad_norm
()
return
clip_grad_norm_fp32
(
params
,
grads_for_norm
,
clip_grad
,
model_parallel_group
=
self
.
get_model_parallel_group
())
@
abstractmethod
def
step_with_ready_grads
(
self
)
->
bool
:
"""Step the optimizer with ready gradients, return successful."""
return
True
@
torch
.
no_grad
()
def
get_grad_norm
(
self
):
"""Compute and return grad norm."""
grads_for_norm
=
self
.
get_main_grads_for_grad_norm
()
total_norm
=
get_grad_norm_fp32
(
grads_for_norm
,
model_parallel_group
=
self
.
get_model_parallel_group
()
)
return
total_norm
def
count_zeros
(
self
):
def
clip_grad_norm
(
self
,
clip_grad
:
float
)
->
float
:
"""Compute and return grad norm, also clip grads."""
params
=
self
.
get_parameters
()
return
count_zeros_fp32
(
params
,
model_parallel_group
=
self
.
get_model_parallel_group
())
grads_for_norm
=
self
.
get_main_grads_for_grad_norm
()
grad_norm
=
get_grad_norm_fp32
(
grads_for_norm
,
model_parallel_group
=
self
.
get_model_parallel_group
()
)
clip_grad_by_total_norm_fp32
(
params
,
clip_grad
,
grad_norm
)
return
grad_norm
def
count_zeros
(
self
)
->
float
:
"""Count number of zeros in model's gradients."""
params
=
self
.
get_parameters
()
return
count_zeros_fp32
(
params
,
model_parallel_group
=
self
.
get_model_parallel_group
())
@
abstractmethod
def
zero_grad
(
self
,
set_to_none
=
True
):
def
zero_grad
(
self
,
set_to_none
:
bool
=
True
):
"""Zero gradients and prepare for next forward pass."""
pass
@
abstractmethod
def
get_loss_scale
(
self
):
"""The output should be a cuda tensor of size 1."""
def
get_loss_scale
(
self
)
->
torch
.
Tensor
:
"""
Get current loss scale factor.
NOTE: The output should be a CUDA tensor of size 1.
"""
pass
def
scale_loss
(
self
,
loss
):
def
scale_loss
(
self
,
loss
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Simple scaling."""
return
self
.
get_loss_scale
()
*
loss
def
start_param_sync
(
self
,
model_index
:
int
,
*
unused
):
"""
Start parameter synchronization for all optimizers.
This is a no-op for all non-distributed optimizers.
"""
pass
@
abstractmethod
def
reload_model_params
(
self
):
...
...
@@ -152,17 +210,16 @@ class MegatronOptimizer(ABC):
with main parameters, the main parameters need to also be updated."""
pass
@
abstractmethod
def
state_dict
(
self
):
"""Return state_dict."""
pass
@
abstractmethod
def
load_state_dict
(
self
,
state_dict
):
"""Load pass-in `state_dict`."""
pass
# Promote state so it can be retrieved or set via
# "optimizer_instance.state"
def
_get_state
(
self
):
...
...
@@ -173,7 +230,6 @@ class MegatronOptimizer(ABC):
state
=
property
(
_get_state
,
_set_state
)
# Promote param_groups so it can be retrieved or set via
# "optimizer_instance.param_groups"
# (for example, to adjust the learning rate)
...
...
@@ -185,200 +241,104 @@ class MegatronOptimizer(ABC):
param_groups
=
property
(
_get_param_groups
,
_set_param_groups
)
@
abstractmethod
def
step
(
self
,
args
,
timers
):
pass
def
gather_model_params
(
self
,
args
,
timers
):
"""
For the case of a non-distributed-optimizer, there is nothing to
do here.
"""
def
step
(
self
):
"""Step the optimizer."""
pass
@
abstractmethod
def
sharded_state_dict
(
self
,
model_sharded_state_dict
:
ShardedStateDict
,
is_loading
:
bool
=
False
)
->
ShardedStateDict
:
"""Builds sharded state dict for the optimizer, based on model's sharded state dict.
def
allreduce_word_embedding_grads
(
self
,
args
):
"""
All-reduce word embedding grads.
Args:
model_sharded_state_dict (ShardedStateDict): sharded state dict of the model
is_loading (bool, optional): flag indicating whether the state dict will be
used to save or load the optimizer state. Defaults to False.
Reduce grads across first and last stages to ensure that word_embeddings
parameters stay in sync. This should only run for models that support
pipelined model parallelism (BERT and GPT-2).
Returns: optimizer sharded state dict
"""
if
mpu
.
is_rank_in_embedding_group
(
ignore_virtual
=
True
)
and
\
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
if
mpu
.
is_pipeline_first_stage
(
ignore_virtual
=
True
):
unwrapped_model
=
self
.
models
[
0
]
elif
mpu
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
unwrapped_model
=
self
.
models
[
-
1
]
else
:
# We do not support the interleaved schedule for T5 yet.
unwrapped_model
=
self
.
models
[
0
]
unwrapped_model
=
unwrap_model
(
unwrapped_model
,
(
torchDDP
,
LocalDDP
,
Float16Module
))
if
unwrapped_model
.
share_embeddings_and_output_weights
:
weight
=
unwrapped_model
.
shared_embedding_or_output_weight
()
if
args
.
DDP_impl
==
'local'
:
grad
=
weight
.
main_grad
else
:
grad
=
weight
.
grad
torch
.
distributed
.
all_reduce
(
grad
,
group
=
mpu
.
get_embedding_group
())
def
allreduce_position_embedding_grads
(
self
,
args
):
"""
All-reduce position_embeddings grad across first (encoder) and
split (decoder) stages to ensure that position embeddings parameters
stay in sync. This should only run for T5 models with pipeline
parallelism.
"""
if
mpu
.
is_rank_in_position_embedding_group
()
and
\
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
and
\
args
.
pipeline_model_parallel_split_rank
is
not
None
:
unwrapped_model
=
self
.
models
[
0
]
unwrapped_model
=
unwrap_model
(
unwrapped_model
,
(
torchDDP
,
LocalDDP
,
Float16Module
))
assert
args
.
DDP_impl
==
'local'
,
\
'T5 model is only supported with local DDP mode'
grad
=
unwrapped_model
.
language_model
.
embedding
.
position_embeddings
.
weight
.
main_grad
torch
.
distributed
.
all_reduce
(
grad
,
group
=
mpu
.
get_position_embedding_group
())
def
allreduce_embedding_grads
(
self
,
args
):
"""All-reduce both word and position embeddings."""
self
.
allreduce_word_embedding_grads
(
args
)
self
.
allreduce_position_embedding_grads
(
args
)
def
allreduce_layernorm_grads
(
self
,
args
):
"""All-reduce layernorm grads (for sequence parallelism)."""
# All-reduce layernorm parameters across model parallel nodes
# when sequence parallelism is used
if
mpu
.
get_tensor_model_parallel_world_size
()
>
1
and
\
args
.
sequence_parallel
:
grads
=
[]
for
model_module
in
self
.
models
:
unwrapped_model
=
unwrap_model
(
model_module
,
(
torchDDP
,
LocalDDP
,
Float16Module
))
for
param
in
unwrapped_model
.
parameters
():
if
getattr
(
param
,
'sequence_parallel'
,
False
):
grad
=
param
.
main_grad
if
args
.
DDP_impl
==
'local'
else
param
.
grad
grads
.
append
(
grad
.
data
)
coalesced
=
_flatten_dense_tensors
(
grads
)
torch
.
distributed
.
all_reduce
(
coalesced
,
group
=
mpu
.
get_tensor_model_parallel_group
())
for
buf
,
synced
in
zip
(
grads
,
_unflatten_dense_tensors
(
coalesced
,
grads
)):
buf
.
copy_
(
synced
)
def
reduce_model_grads
(
self
,
args
,
timers
):
"""All-reduce all grads, and all-reduce embeddings."""
# All-reduce layer-norm grads (for sequence parallelism).
timers
(
'layernorm-grads-all-reduce'
,
log_level
=
1
).
start
(
barrier
=
args
.
barrier_with_L1_time
)
self
.
allreduce_layernorm_grads
(
args
)
timers
(
'layernorm-grads-all-reduce'
).
stop
()
# All-reduce if needed.
if
args
.
DDP_impl
==
'local'
:
timers
(
'grads-all-reduce'
,
log_level
=
1
).
start
(
barrier
=
args
.
barrier_with_L1_time
)
for
model
in
self
.
models
:
model
.
allreduce_gradients
()
timers
(
'grads-all-reduce'
).
stop
()
# All-reduce embedding grads.
timers
(
'embedding-grads-all-reduce'
,
log_level
=
1
).
start
(
barrier
=
args
.
barrier_with_L1_time
)
self
.
allreduce_embedding_grads
(
args
)
timers
(
'embedding-grads-all-reduce'
).
stop
()
@
staticmethod
def
_extract_common_per_param_step
(
state_dict
)
->
Union
[
int
,
torch
.
Tensor
]:
common_step
=
None
for
param_idx
,
param_state
in
state_dict
[
'state'
].
items
():
param_step
=
param_state
.
get
(
'step'
,
None
)
if
param_step
is
not
None
:
if
common_step
is
None
:
common_step
=
param_step
elif
common_step
!=
param_step
:
raise
ValueError
(
"The optimizer step differs per parameter. Mcore only supports "
"optimizers whose step is shared across all parameters."
)
return
common_step
@
staticmethod
def
_restore_common_per_param_step
(
state_dict
:
Dict
,
step
:
Union
[
int
,
torch
.
Tensor
]):
for
param_idx
,
param_state
in
state_dict
[
'state'
].
items
():
param_state
[
'step'
]
=
copy
.
deepcopy
(
step
)
class
MixedPrecisionOptimizer
(
MegatronOptimizer
):
"""Base class for both the float-16 and the distributed optimizer.
Arguments:
optimizer: base optimizer such as Adam or SGD
clip_grad: clip gradeints with this global L2 norm. Note
that clipping is ignored if clip_grad == 0
log_num_zeros_in_grad: return number of zeros in the gradients.
params_have_main_grad: flag indicating if parameters have
a `main_grad` field. If this is set, we are assuming
that the model parameters are store in the `main_grad`
field instead of the typical `grad` field. This happens
for the DDP cases where there is a continuous buffer
holding the gradients. For example for bfloat16, we want
to do gradient accumulation and all-reduces in float32
and as a result we store those gradients in the main_grad.
Note that main grad is not necessarily in float32.
use_contiguous_buffers_in_local_ddp: if true, the local DDP model
is using a contiguous buffer to hold the model grads.
fp16: if true, the model is running in fp16.
bf16: if true, the model is running in bfloat16.
params_dtype: used by distributed optimizer.
grad_scaler: used for scaling gradients. Note that this can be
None. This case happens when `bf16 = True` and we don't
Args:
optimizer (torch.optim.Optimizer): base optimizer such as Adam or SGD.
config (OptimizerConfig): configuration object for optimizer.
grad_scaler (MegatronGradScaler): used for scaling gradients. Note that
this can be None. This case happens when `bf16 = True` and we don't
use any loss scale. Note that for `bf16 = True`, we can have
a const
n
at gradient scaler. Also for `bf16 = False`, we
a consta
n
t gradient scaler. Also for `bf16 = False`, we
always require a grad scaler.
models: list of models (i.e., the virtual pipelining models). This
is used by the distributed optimizer for mapping parameters.
init_state_fn (Callable, optional): function to initialize state in the optimizer.
"""
def
__init__
(
self
,
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
params_have_main_grad
,
use_contiguous_buffers_in_local_ddp
,
fp16
,
bf16
,
params_dtype
,
grad_scaler
,
models
):
super
().
__init__
(
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
params_have_main_grad
,
use_contiguous_buffers_in_local_ddp
,
models
)
self
.
fp16
=
fp16
self
.
bf16
=
bf16
self
.
params_dtype
=
params_dtype
def
__init__
(
self
,
optimizer
:
torch
.
optim
.
Optimizer
,
config
:
OptimizerConfig
,
grad_scaler
:
Optional
[
MegatronGradScaler
],
init_state_fn
:
Callable
,
):
if
has_config_logger_enabled
(
config
):
log_config_to_disk
(
config
,
locals
(),
prefix
=
type
(
self
).
__name__
)
super
().
__init__
(
optimizer
,
config
,
init_state_fn
)
self
.
grad_scaler
=
grad_scaler
# None grad scaler is only supported for bf16.
if
self
.
grad_scaler
is
None
:
assert
not
self
.
fp16
,
'fp16 expects a grad scaler.'
assert
not
self
.
config
.
fp16
,
'fp16 expects a grad scaler.'
# Tensor used to determine if a nan/if has happend.
# Any non-zero value indicates inf/nan.
# Note that we keep this for the cases that grad scaler is none.
# We still record nan/inf if we have a bfloat16 with a grad scaler.
if
self
.
grad_scaler
:
self
.
found_inf
=
torch
.
cuda
.
FloatTensor
([
0.0
]
)
self
.
found_inf
=
torch
.
tensor
([
0.0
],
dtype
=
torch
.
float
,
device
=
'cuda'
)
# Dummy tensor needed for apex multi-apply tensor.
# For bfloat, we don't have multi-tensor apply and for now
# we set it to none so the multi-tensor apply gets ignored.
if
bf16
:
if
self
.
config
.
bf16
:
self
.
_dummy_overflow_buf
=
None
else
:
self
.
_dummy_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
]
)
self
.
_dummy_overflow_buf
=
torch
.
tensor
([
0
],
dtype
=
torch
.
int
,
device
=
'cuda'
)
# In case grad scaler is not passed, define the unity scale.
if
self
.
grad_scaler
is
None
:
self
.
_scale_one
=
torch
.
cuda
.
FloatTensor
([
1.0
])
self
.
_scale_one
=
torch
.
tensor
([
1.0
],
dtype
=
torch
.
float
,
device
=
'cuda'
)
def
get_loss_scale
(
self
):
if
self
.
grad_scaler
is
None
:
return
self
.
_scale_one
return
self
.
grad_scaler
.
scale
def
reload_model_params
(
self
):
self
.
_copy_model_params_to_main_params
()
def
_unscale_main_grads_and_check_for_nan
(
self
):
# Collect main grads.
...
...
@@ -389,119 +349,137 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
# Unscale and set found inf/nan
torch
.
_amp_foreach_non_finite_check_and_unscale_
(
main_grads
,
self
.
found_inf
,
self
.
grad_scaler
.
inv_scale
)
main_grads
,
self
.
found_inf
,
self
.
grad_scaler
.
inv_scale
)
# Update across all model parallel instances.
torch
.
distributed
.
all_reduce
(
self
.
found_inf
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
group
=
self
.
get_model_parallel_group
()
)
torch
.
distributed
.
all_reduce
(
self
.
found_inf
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
group
=
self
.
get_model_parallel_group
()
)
# Check for nan.
found_inf_flag
=
(
self
.
found_inf
.
item
()
>
0
)
found_inf_flag
=
self
.
found_inf
.
item
()
>
0
return
found_inf_flag
@
torch
.
no_grad
()
def
step
(
self
,
args
,
timers
):
def
prepare_grads
(
self
)
->
bool
:
"""Pre-processing gradients before the optimizer step, returns whether inf/nan is found."""
timers
=
self
.
config
.
timers
# Copy gradients from model params to main params.
timers
(
'optimizer-copy-to-main-grad'
,
log_level
=
1
).
start
(
barrier
=
args
.
barrier_with_L1_time
)
if
timers
is
not
None
:
timers
(
'optimizer-copy-to-main-grad'
,
log_level
=
1
).
start
(
barrier
=
self
.
config
.
barrier_with_L1_time
)
self
.
_copy_model_grads_to_main_grads
()
timers
(
'optimizer-copy-to-main-grad'
).
stop
()
if
timers
is
not
None
:
timers
(
'optimizer-copy-to-main-grad'
).
stop
()
# Do unscale, check for inf, and update grad scaler only for
# the case that grad scaler is provided.
if
self
.
grad_scaler
:
# Unscale and check for inf/nan.
timers
(
'optimizer-unscale-and-check-inf'
,
log_level
=
1
).
start
(
barrier
=
args
.
barrier_with_L1_time
)
if
timers
is
not
None
:
timers
(
'optimizer-unscale-and-check-inf'
,
log_level
=
1
).
start
(
barrier
=
self
.
config
.
barrier_with_L1_time
)
found_inf_flag
=
self
.
_unscale_main_grads_and_check_for_nan
()
timers
(
'optimizer-unscale-and-check-inf'
).
stop
()
if
timers
is
not
None
:
timers
(
'optimizer-unscale-and-check-inf'
).
stop
()
# We are done with scaling gradients
# so we can update the loss scale.
self
.
grad_scaler
.
update
(
found_inf_flag
)
# If we found inf/nan, skip the update.
if
found_inf_flag
:
return
False
,
None
,
None
# Clip the main gradients.
timers
(
'optimizer-clip-main-grad'
,
log_level
=
1
).
start
(
barrier
=
args
.
barrier_with_L1_time
)
grad_norm
=
None
if
self
.
clip_grad
>
0.0
:
grad_norm
=
self
.
clip_grad_norm
(
self
.
clip_grad
)
timers
(
'optimizer-clip-main-grad'
).
stop
()
return
found_inf_flag
# Count the zeros in the grads.
timers
(
'optimizer-count-zeros'
,
log_level
=
1
).
start
(
barrier
=
args
.
barrier_with_L1_time
)
num_zeros_in_grad
=
self
.
count_zeros
()
if
\
self
.
log_num_zeros_in_grad
else
None
timers
(
'optimizer-count-zeros'
).
stop
()
return
False
@
torch
.
no_grad
()
def
step_with_ready_grads
(
self
)
->
bool
:
"""Step the optimizer with ready gradients, return successful."""
timers
=
self
.
config
.
timers
# Step the optimizer.
timers
(
'optimizer-inner-step'
,
log_level
=
1
).
start
(
barrier
=
args
.
barrier_with_L1_time
)
if
timers
is
not
None
:
timers
(
'optimizer-inner-step'
,
log_level
=
1
).
start
(
barrier
=
self
.
config
.
barrier_with_L1_time
)
self
.
optimizer
.
step
()
timers
(
'optimizer-inner-step'
).
stop
()
if
timers
is
not
None
:
timers
(
'optimizer-inner-step'
).
stop
()
# Update params from main params.
timers
(
'optimizer-copy-main-to-model-params'
,
log_level
=
1
).
start
(
barrier
=
args
.
barrier_with_L1_time
)
if
timers
is
not
None
:
timers
(
'optimizer-copy-main-to-model-params'
,
log_level
=
1
).
start
(
barrier
=
self
.
config
.
barrier_with_L1_time
)
self
.
_copy_main_params_to_model_params
()
timers
(
'optimizer-copy-main-to-model-params'
).
stop
()
if
timers
is
not
None
:
timers
(
'optimizer-copy-main-to-model-params'
).
stop
()
return
True
@
torch
.
no_grad
()
def
step
(
self
):
timers
=
self
.
config
.
timers
found_inf_flag
=
self
.
prepare_grads
()
if
found_inf_flag
:
return
False
,
None
,
None
# Clip the main gradients.
if
timers
is
not
None
:
timers
(
'optimizer-clip-main-grad'
,
log_level
=
1
).
start
(
barrier
=
self
.
config
.
barrier_with_L1_time
)
grad_norm
=
None
if
self
.
config
.
clip_grad
>
0.0
:
grad_norm
=
self
.
clip_grad_norm
(
self
.
config
.
clip_grad
)
if
timers
is
not
None
:
timers
(
'optimizer-clip-main-grad'
).
stop
()
# Count the zeros in the grads.
if
timers
is
not
None
:
timers
(
'optimizer-count-zeros'
,
log_level
=
1
).
start
(
barrier
=
self
.
config
.
barrier_with_L1_time
)
num_zeros_in_grad
=
self
.
count_zeros
()
if
self
.
config
.
log_num_zeros_in_grad
else
None
if
timers
is
not
None
:
timers
(
'optimizer-count-zeros'
).
stop
()
success
=
self
.
step_with_ready_grads
()
# Successful update.
return
True
,
grad_norm
,
num_zeros_in_grad
return
success
,
grad_norm
,
num_zeros_in_grad
class
Float16OptimizerWithFloat16Params
(
MixedPrecisionOptimizer
):
"""Float16 optimizer for fp16 and bf16 data types.
Arguments:
optimizer: base optimizer such as Adam or SGD
clip_grad: clip gradeints with this global L2 norm. Note
that clipping is ignored if clip_grad == 0
log_num_zeros_in_grad: return number of zeros in the gradients.
params_have_main_grad: flag indicating if parameters have
a `main_grad` field. If this is set, we are assuming
that the model parameters are store in the `main_grad`
field instead of the typical `grad` field. This happens
for the DDP cases where there is a continuous buffer
holding the gradients. For example for bfloat16, we want
to do gradient accumulation and all-reduces in float32
and as a result we store those gradients in the main_grad.
Note that main grad is not necessarily in float32.
use_contiguous_buffers_in_local_ddp: if true, the local DDP model
is using a contiguous buffer to hold the model grads.
fp16: if true, the model is running in fp16.
bf16: if true, the model is running in bfloat16.
grad_scaler: used for scaling gradients. Note that this can be
None. This case happens when `bf16 = True` and we don't
Args:
optimizer (torch.optim.Optimizer): base optimizer such as Adam or SGD.
config (OptimizerConfig): configuration object for optimizer.
grad_scaler (MegatronGradScaler): used for scaling gradients. Note that
this can be None. This case happens when `bf16 = True` and we don't
use any loss scale. Note that for `bf16 = True`, we can have
a const
n
at gradient scaler. Also for `bf16 = False`, we
a consta
n
t gradient scaler. Also for `bf16 = False`, we
always require a grad scaler.
models: list of models (i.e., the virtual pipelining models). This
is used by the distributed optimizer for mapping parameters.
init_state_fn (Callable, optional): function to initialize state in the optimizer.
"""
def
__init__
(
self
,
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
params_have_main_grad
,
use_contiguous_buffers_in_local_ddp
,
fp16
,
bf16
,
params_dtype
,
grad_scaler
,
models
):
def
__init__
(
self
,
optimizer
:
torch
.
optim
.
Optimizer
,
config
:
OptimizerConfig
,
grad_scaler
:
MegatronGradScaler
,
init_state_fn
:
Callable
,
):
super
().
__init__
(
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
params_have_main_grad
,
use_contiguous_buffers_in_local_ddp
,
fp16
,
bf16
,
params_dtype
,
grad_scaler
,
models
)
super
().
__init__
(
optimizer
,
config
,
grad_scaler
,
init_state_fn
)
# ======================
# main parameter stuff
# ======================
# Handle main parameters.
# Three groups of parameters:
# float16_groups: original float16 parameters
...
...
@@ -521,14 +499,12 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
if
param
.
requires_grad
:
# float16 params:
if
param
.
type
()
in
[
'torch.cuda.HalfTensor'
,
'torch.cuda.BFloat16Tensor'
]:
if
param
.
type
()
in
[
'torch.cuda.HalfTensor'
,
'torch.cuda.BFloat16Tensor'
]:
float16_params_this_group
.
append
(
param
)
# Create a copy
main_param
=
param
.
detach
().
clone
().
float
()
# Copy tensor model parallel attributes.
tensor_parallel
.
copy_tensor_model_parallel_attributes
(
main_param
,
param
)
tensor_parallel
.
copy_tensor_model_parallel_attributes
(
main_param
,
param
)
if
hasattr
(
param
,
'shared'
):
main_param
.
shared
=
param
.
shared
# Replace the optimizer params with the new fp32 copy.
...
...
@@ -537,26 +513,25 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
fp32_from_float16_params_this_group
.
append
(
main_param
)
# Reset existing state dict key to the new main param.
if
param
in
self
.
optimizer
.
state
:
self
.
optimizer
.
state
[
main_param
]
\
=
self
.
optimizer
.
state
.
pop
(
param
)
self
.
optimizer
.
state
[
main_param
]
=
self
.
optimizer
.
state
.
pop
(
param
)
# fp32 params.
elif
param
.
type
()
==
'torch.cuda.FloatTensor'
:
fp32_params_this_group
.
append
(
param
)
param_group
[
'params'
][
i
]
=
param
else
:
raise
TypeError
(
'Wrapped parameters must be one of '
'torch.cuda.FloatTensor, '
'torch.cuda.HalfTensor, or '
'torch.cuda.BFloat16Tensor. '
'Received {}'
.
format
(
param
.
type
()))
raise
TypeError
(
'Wrapped parameters must be one of '
'torch.cuda.FloatTensor, '
'torch.cuda.HalfTensor, or '
'torch.cuda.BFloat16Tensor. '
'Received {}'
.
format
(
param
.
type
())
)
self
.
float16_groups
.
append
(
float16_params_this_group
)
self
.
fp32_from_float16_groups
.
append
(
fp32_from_float16_params_this_group
)
self
.
fp32_from_float16_groups
.
append
(
fp32_from_float16_params_this_group
)
self
.
fp32_from_fp32_groups
.
append
(
fp32_params_this_group
)
def
zero_grad
(
self
,
set_to_none
=
True
):
"""We only need to zero the model related parameters, i.e.,
float16_groups & fp32_from_fp32_groups. We additionally zero
...
...
@@ -570,7 +545,6 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
for
group
in
self
.
fp32_from_fp32_groups
:
_zero_grad_group_helper
(
group
,
set_to_none
)
def
_collect_main_grad_data_for_unscaling
(
self
):
main_grads
=
[]
...
...
@@ -586,27 +560,23 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
for
main_param
in
main_group
:
if
main_param
.
grad
is
not
None
:
main_grads
.
append
(
main_param
.
grad
.
data
)
return
main_grads
return
main_grads
def
_get_model_and_main_params_data_float16
(
self
):
model_data
=
[]
main_data
=
[]
for
model_group
,
main_group
in
zip
(
self
.
float16_groups
,
self
.
fp32_from_float16_groups
):
for
model_group
,
main_group
in
zip
(
self
.
float16_groups
,
self
.
fp32_from_float16_groups
):
for
model_param
,
main_param
in
zip
(
model_group
,
main_group
):
model_data
.
append
(
model_param
.
data
)
main_data
.
append
(
main_param
.
data
)
return
model_data
,
main_data
def
_copy_model_grads_to_main_grads
(
self
):
# This only needs to be done for the float16 group.
for
model_group
,
main_group
in
zip
(
self
.
float16_groups
,
self
.
fp32_from_float16_groups
):
for
model_group
,
main_group
in
zip
(
self
.
float16_groups
,
self
.
fp32_from_float16_groups
):
for
model_param
,
main_param
in
zip
(
model_group
,
main_group
):
if
self
.
params_have_main_grad
and
hasattr
(
model_param
,
'main_grad'
):
if
hasattr
(
model_param
,
'main_grad'
):
main_param
.
grad
=
model_param
.
main_grad
.
float
()
else
:
if
model_param
.
grad
is
not
None
:
...
...
@@ -616,36 +586,25 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
# (If using contiguous buffers, main_grad's memory should
# persist and therefore should not be deallocated.)
model_param
.
grad
=
None
if
self
.
params_have_main_grad
and
\
not
self
.
use_contiguous_buffers_in_local_ddp
:
model_param
.
main_grad
=
None
# For fp32 grads, we need to reset the grads to main grad.
if
self
.
params_have_main_grad
:
for
model_group
in
self
.
fp32_from_fp32_groups
:
for
model_param
in
model_group
:
model_param
.
grad
=
model_param
.
main_grad
# Safe to de-reference model's main_grad after copying.
# (If using contiguous buffers, main_grad's memory should
# persist and therefore should not be deallocated.)
if
not
self
.
use_contiguous_buffers_in_local_ddp
:
model_param
.
main_grad
=
None
for
model_group
in
self
.
fp32_from_fp32_groups
:
for
model_param
in
model_group
:
model_param
.
grad
=
model_param
.
main_grad
def
_copy_main_params_to_model_params
(
self
):
# Only needed for the float16 params.
model_data
,
main_data
=
self
.
_get_model_and_main_params_data_float16
()
_multi_tensor_copy_this_to_that
(
this
=
main_data
,
that
=
model_data
,
overflow_buf
=
self
.
_dummy_overflow_buf
)
_multi_tensor_copy_this_to_that
(
this
=
main_data
,
that
=
model_data
,
overflow_buf
=
self
.
_dummy_overflow_buf
)
def
_copy_model_params_to_main_params
(
self
):
# Only needed for the float16 params.
model_data
,
main_data
=
self
.
_get_model_and_main_params_data_float16
()
_multi_tensor_copy_this_to_that
(
this
=
model_data
,
that
=
main_data
,
overflow_buf
=
self
.
_dummy_overflow_buf
)
_multi_tensor_copy_this_to_that
(
this
=
model_data
,
that
=
main_data
,
overflow_buf
=
self
.
_dummy_overflow_buf
)
def
state_dict
(
self
):
state_dict
=
{}
...
...
@@ -655,119 +614,454 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
state_dict
[
'fp32_from_fp16_params'
]
=
self
.
fp32_from_float16_groups
return
state_dict
def
sharded_state_dict
(
self
,
model_sharded_state_dict
:
ShardedStateDict
,
is_loading
:
bool
=
False
):
if
is_loading
:
self
.
init_state_fn
(
self
.
optimizer
)
state_dict
=
self
.
state_dict
()
id_to_sharded_param_map
=
get_param_id_to_sharded_param_map
(
model_sharded_state_dict
,
chain
.
from_iterable
(
g
for
g
in
self
.
float16_groups
)
)
# Convert fp32_from_fp16_params
assert
len
(
state_dict
[
'fp32_from_fp16_params'
])
==
len
(
state_dict
[
'optimizer'
][
'param_groups'
]
)
state_dict
[
'fp32_from_fp16_params'
]
=
[
[
make_sharded_optimizer_tensor
(
id_to_sharded_param_map
[
param_id
],
fp32_param
,
prefix
=
f
'optimizer.state.fp32_param'
,
)
for
param_id
,
fp32_param
in
zip
(
state_group
[
'params'
],
fp32_group
)
]
for
fp32_group
,
state_group
in
zip
(
state_dict
[
'fp32_from_fp16_params'
],
state_dict
[
'optimizer'
][
'param_groups'
]
)
]
step
=
self
.
_extract_common_per_param_step
(
state_dict
[
'optimizer'
])
# Convert regular optimizer state
# all optimizer parameters passed to optim_state_to_sharding_state are
# expected to have the same shape as the model parameters,
# so we save the step separately and ignore it here
optim_state_to_sharding_state
(
state_dict
[
'optimizer'
],
id_to_sharded_param_map
,
exclude_keys
=
"step"
)
# save step as a shared step among all parameters. Separate per-parameter
# steps are not supported
state_dict
[
'optimizer'
][
'state'
][
'common_step'
]
=
step
return
state_dict
def
load_state_dict
(
self
,
state_dict
):
pipeline_parallel_size
=
parallel_state
.
get_pipeline_model_parallel_world_size
()
# Optimizer.
optimizer_key
=
'optimizer'
if
optimizer_key
not
in
state_dict
:
optimizer_key
=
'optimizer_state_dict'
print_rank_0
(
'***WARNING*** loading optimizer from '
'an old checkpoint ...'
)
logger
.
info
(
'***WARNING*** loading optimizer from '
'an old checkpoint ...'
)
if
'common_step'
in
state_dict
[
optimizer_key
][
'state'
]:
common_step
=
state_dict
[
optimizer_key
][
'state'
].
pop
(
'common_step'
)
self
.
_restore_common_per_param_step
(
state_dict
[
optimizer_key
],
common_step
)
self
.
optimizer
.
load_state_dict
(
state_dict
[
optimizer_key
])
# Grad scaler.
if
'grad_scaler'
not
in
state_dict
:
if
self
.
fp16
:
print_rank_0
(
'***WARNING*** found an old checkpoint, will not '
'load grad scaler ...'
)
if
self
.
config
.
fp16
:
logger
.
info
(
'***WARNING*** found an old checkpoint, will not '
'load grad scaler ...'
)
else
:
if
self
.
grad_scaler
:
self
.
grad_scaler
.
load_state_dict
(
state_dict
[
'grad_scaler'
])
else
:
print_rank_0
(
'***WARNING*** fould the grad scaler in the '
'checkpoint but it is None in the class. '
'Skipping loading grad scaler ...'
)
logger
.
info
(
'***WARNING*** fould the grad scaler in the '
'checkpoint but it is None in the class. '
'Skipping loading grad scaler ...'
)
# Copy data for the main params.
fp32_from_float16_params_key
=
'fp32_from_fp16_params'
if
fp32_from_float16_params_key
not
in
state_dict
:
fp32_from_float16_params_key
=
'fp32_from_fp16'
for
current_group
,
saved_group
in
zip
(
self
.
fp32_from_float16_groups
,
state_dict
[
fp32_from_float16_params_key
]
):
self
.
fp32_from_float16_groups
,
state_dict
[
fp32_from_float16_params_key
]
):
for
current_param
,
saved_param
in
zip
(
current_group
,
saved_group
):
current_param
.
data
.
copy_
(
saved_param
.
data
)
class
FP32Optimizer
(
MegatronOptimizer
):
"""Float32 optimizer.
def
__init__
(
self
,
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
params_have_main_grad
,
use_contiguous_buffers_in_local_ddp
,
models
):
Args:
optimizer (torch.optim.Optimizer): base optimizer such as Adam or SGD.
config (OptimizerConfig): configuration object for optimizer.
init_state_fn (Callable, optional): function to initialize state in the optimizer.
"""
super
(
FP32Optimizer
,
self
).
__init__
(
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
params_have_main_grad
,
use_contiguous_buffers_in_local_ddp
,
models
)
def
__init__
(
self
,
optimizer
:
torch
.
optim
.
Optimizer
,
config
:
OptimizerConfig
,
init_state_fn
:
Callable
):
if
has_config_logger_enabled
(
config
):
log_config_to_disk
(
config
,
locals
(),
prefix
=
type
(
self
).
__name__
)
self
.
_
scale
=
torch
.
cuda
.
FloatTensor
([
1.0
]
)
super
(
FP32Optimizer
,
self
)
.
_
_init__
(
optimizer
,
config
,
init_state_fn
)
self
.
_scale
=
torch
.
tensor
([
1.0
],
dtype
=
torch
.
float
,
device
=
'cuda'
)
def
zero_grad
(
self
,
set_to_none
=
True
):
"""Copied from torch.optim.optimizer"""
for
group
in
self
.
optimizer
.
param_groups
:
_zero_grad_group_helper
(
group
[
'params'
],
set_to_none
)
def
get_loss_scale
(
self
):
"""FP32 optimizer does not do any scaling."""
return
self
.
_scale
@
torch
.
no_grad
()
def
prepare_grads
(
self
)
->
bool
:
"""Pre-processing gradients before the optimizer step, returns whether inf/nan is found."""
timers
=
self
.
config
.
timers
# Copy main_grads to grads.
if
timers
is
not
None
:
timers
(
'optimizer-copy-to-main-grad'
,
log_level
=
1
).
start
(
barrier
=
self
.
config
.
barrier_with_L1_time
)
for
param_group
in
self
.
optimizer
.
param_groups
:
for
param
in
param_group
[
'params'
]:
param
.
grad
=
param
.
main_grad
if
timers
is
not
None
:
timers
(
'optimizer-copy-to-main-grad'
).
stop
()
return
False
@
torch
.
no_grad
()
def
step_with_ready_grads
(
self
)
->
bool
:
"""Step the optimizer with ready gradients, return successful."""
timers
=
self
.
config
.
timers
# Update parameters.
if
timers
is
not
None
:
timers
(
'optimizer-inner-step'
,
log_level
=
1
).
start
(
barrier
=
self
.
config
.
barrier_with_L1_time
)
self
.
optimizer
.
step
()
if
timers
is
not
None
:
timers
(
'optimizer-inner-step'
).
stop
()
return
True
@
torch
.
no_grad
()
def
step
(
self
,
args
,
timers
):
def
step
(
self
):
"""Clip gradients (if needed) and step the base optimizer.
Always return successful since there is no overflow."""
timers
=
self
.
config
.
timers
# Copy main_grads to grads.
timers
(
'optimizer-copy-to-main-grad'
,
log_level
=
1
).
start
(
barrier
=
args
.
barrier_with_L1_time
)
if
self
.
params_have_main_grad
:
for
param_group
in
self
.
optimizer
.
param_groups
:
for
param
in
param_group
[
'params'
]:
param
.
grad
=
param
.
main_grad
# Safe to de-reference model's main_grad after copying.
# (If using contiguous buffers, main_grad's memory should
# persist and therefore should not be deallocated.)
if
not
self
.
use_contiguous_buffers_in_local_ddp
:
param
.
main_grad
=
None
timers
(
'optimizer-copy-to-main-grad'
).
stop
()
found_inf_flag
=
self
.
prepare_grads
()
if
found_inf_flag
:
return
False
,
None
,
None
# Clip gradients.
timers
(
'optimizer-clip-main-grad'
,
log_level
=
1
).
start
(
barrier
=
args
.
barrier_with_L1_time
)
if
timers
is
not
None
:
timers
(
'optimizer-clip-main-grad'
,
log_level
=
1
).
start
(
barrier
=
self
.
config
.
barrier_with_L1_time
)
grad_norm
=
None
if
self
.
clip_grad
>
0.0
:
grad_norm
=
self
.
clip_grad_norm
(
self
.
clip_grad
)
timers
(
'optimizer-clip-main-grad'
).
stop
()
if
self
.
config
.
clip_grad
>
0.0
:
grad_norm
=
self
.
clip_grad_norm
(
self
.
config
.
clip_grad
)
if
timers
is
not
None
:
timers
(
'optimizer-clip-main-grad'
).
stop
()
# count the zeros in the grads
timers
(
'optimizer-count-zeros'
,
log_level
=
1
).
start
(
barrier
=
args
.
barrier_with_L1_time
)
num_zeros_in_grad
=
self
.
count_zeros
()
if
\
self
.
log_num_zeros_in_grad
else
None
timers
(
'optimizer-count-zeros'
).
stop
()
# Count the zeros in the grads.
if
timers
is
not
None
:
timers
(
'optimizer-count-zeros'
,
log_level
=
1
).
start
(
barrier
=
self
.
config
.
barrier_with_L1_time
)
num_zeros_in_grad
=
self
.
count_zeros
()
if
self
.
config
.
log_num_zeros_in_grad
else
None
if
timers
is
not
None
:
timers
(
'optimizer-count-zeros'
).
stop
()
# Update parameters.
timers
(
'optimizer-inner-step'
,
log_level
=
1
).
start
(
barrier
=
args
.
barrier_with_L1_time
)
self
.
optimizer
.
step
()
timers
(
'optimizer-inner-step'
).
stop
()
success
=
self
.
step_with_ready_grads
()
# No overflow for FP32 optimizer.
return
True
,
grad_norm
,
num_zeros_in_grad
return
success
,
grad_norm
,
num_zeros_in_grad
def
reload_model_params
(
self
):
pass
def
state_dict
(
self
):
return
self
.
optimizer
.
state_dict
()
def
load_state_dict
(
self
,
state_dict
):
pipeline_parallel_size
=
parallel_state
.
get_pipeline_model_parallel_world_size
()
if
'common_step'
in
state_dict
[
'state'
]:
common_step
=
state_dict
[
'state'
].
pop
(
'common_step'
)
self
.
_restore_common_per_param_step
(
state_dict
,
common_step
)
self
.
optimizer
.
load_state_dict
(
state_dict
)
def
sharded_state_dict
(
self
,
model_sharded_state_dict
:
ShardedStateDict
,
is_loading
:
bool
=
False
):
if
is_loading
:
self
.
init_state_fn
(
self
.
optimizer
)
state_dict
=
self
.
state_dict
()
id_to_sharded_param_map
=
get_param_id_to_sharded_param_map
(
model_sharded_state_dict
,
self
.
get_parameters
()
)
step
=
self
.
_extract_common_per_param_step
(
state_dict
)
# all optimizer parameters passed to optim_state_to_sharding_state are
# expected to have the same shape as the model parameters,
# so we save the step separately and ignore it here
optim_state_to_sharding_state
(
state_dict
,
id_to_sharded_param_map
,
exclude_keys
=
"step"
)
# save step as a shared step among all parameters. Separate per-parameter
# steps are not supported
state_dict
[
'state'
][
'common_step'
]
=
step
return
state_dict
class
ProxyDict
:
"""
A dictionary-like object that proxies to a list of dictionaries.
e.g., ProxyDict([{'a': 1}, {'b': 2}]) behaves like:
{
(0, 'a'): 1,
(1, 'b'): 2,
}
We use tuples as keys to avoid ambiguity with the keys of the inner dicts.
"""
def
__init__
(
self
,
inner_dicts
:
List
[
dict
]):
self
.
_inner_dicts
=
inner_dicts
def
__getitem__
(
self
,
key
:
Tuple
[
int
,
str
]):
idx
,
inner_key
=
key
return
self
.
_inner_dicts
[
idx
].
get
(
inner_key
)
def
__setitem__
(
self
,
key
:
Tuple
[
int
,
str
],
value
:
Any
):
idx
,
inner_key
=
key
self
.
_inner_dicts
[
idx
][
inner_key
]
=
value
def
__len__
(
self
)
->
int
:
return
sum
([
len
(
inner_dict
)
for
inner_dict
in
self
.
_inner_dicts
])
def
__iter__
(
self
):
for
idx
,
inner_dict
in
enumerate
(
self
.
_inner_dicts
):
for
inner_key
in
inner_dict
:
yield
(
idx
,
inner_key
)
def
items
(
self
):
"""Return generator over underlying items."""
for
idx
,
inner_dict
in
enumerate
(
self
.
_inner_dicts
):
for
inner_key
,
value
in
inner_dict
.
items
():
yield
(
idx
,
inner_key
),
value
class
ChainedOptimizer
(
MegatronOptimizer
):
"""ChainedOptimizer is designed for a collection of optimizers.
These optimizers are responsible for different parts of multiple models for
a training task and will be executed one-by-one when the model is updated.
Args:
chained_optimizers: a list of optimizers.
"""
def
__init__
(
self
,
chained_optimizers
:
List
[
MegatronOptimizer
]):
self
.
model_chunks
=
[]
self
.
config
=
getattr
(
chained_optimizers
[
0
],
'config'
,
None
)
for
optimizer
in
chained_optimizers
:
if
hasattr
(
optimizer
,
'model_chunks'
):
for
model_chunk
in
optimizer
.
model_chunks
:
if
model_chunk
not
in
self
.
model_chunks
:
self
.
model_chunks
.
append
(
model_chunk
)
assert
self
.
config
==
getattr
(
optimizer
,
'config'
,
None
)
self
.
chained_optimizers
=
chained_optimizers
@
property
def
param_groups
(
self
)
->
List
[
dict
]:
"""Get param_groups aggregated over underlying optimizers."""
param_groups
=
[]
for
optimizer
in
self
.
chained_optimizers
:
param_groups
+=
optimizer
.
param_groups
return
param_groups
@
property
def
state
(
self
)
->
ProxyDict
:
"""
Return optimizer state with tuple keys, where the first element is the
index of the optimizer in the list of chained optimizers.
"""
return
ProxyDict
([
opt
.
state
for
opt
in
self
.
chained_optimizers
])
def
zero_grad
(
self
,
set_to_none
=
True
):
for
optimizer
in
self
.
chained_optimizers
:
optimizer
.
zero_grad
(
set_to_none
)
def
get_loss_scale
(
self
):
return
self
.
chained_optimizers
[
0
].
get_loss_scale
()
def
reload_model_params
(
self
):
for
optimizer
in
self
.
chained_optimizers
:
optimizer
.
reload_model_params
()
def
state_dict
(
self
):
return
[
optimizer
.
state_dict
()
for
optimizer
in
self
.
chained_optimizers
]
def
sharded_state_dict
(
self
,
model_sharded_state_dict
:
ShardedStateDict
,
is_loading
:
bool
=
False
,
**
kwargs
):
sharded_state_dict
=
{}
for
optimizer_idx
,
optimizer
in
enumerate
(
self
.
chained_optimizers
):
optim_state_dict
=
optimizer
.
sharded_state_dict
(
model_sharded_state_dict
,
is_loading
,
**
kwargs
)
add_prefix_for_sharding
(
optim_state_dict
,
f
'chained_
{
optimizer_idx
}
.'
)
sharded_state_dict
[
optimizer_idx
]
=
optim_state_dict
return
sharded_state_dict
def
load_state_dict
(
self
,
state_dict
):
if
len
(
self
.
chained_optimizers
)
!=
len
(
state_dict
):
raise
RuntimeError
(
f
'Expected
{
len
(
self
.
chained_optimizers
)
}
entries'
f
' in state dict, but got
{
len
(
state_dict
)
}
.'
)
if
isinstance
(
state_dict
,
dict
):
state_dict
=
(
v
for
k
,
v
in
sorted
(
state_dict
.
items
()))
for
optimizer
,
state
in
zip
(
self
.
chained_optimizers
,
state_dict
):
optimizer
.
load_state_dict
(
state
)
@
torch
.
no_grad
()
def
prepare_grads
(
self
)
->
bool
:
"""Pre-processing gradients before the optimizer step, returns whether inf/nan is found."""
found_inf_flag
=
False
for
optimizer
in
self
.
chained_optimizers
:
found_inf_flag
|=
optimizer
.
prepare_grads
()
return
found_inf_flag
@
torch
.
no_grad
()
def
step_with_ready_grads
(
self
)
->
bool
:
"""Step the optimizer with ready gradients, return successful."""
success
=
True
for
optimizer_idx
,
optimizer
in
enumerate
(
self
.
chained_optimizers
):
success
&=
optimizer
.
step_with_ready_grads
()
if
self
.
config
.
overlap_param_gather_with_optimizer_step
and
optimizer_idx
==
0
:
assert
success
assert
len
(
optimizer
.
model_chunks
)
==
1
optimizer
.
model_chunks
[
0
].
start_param_sync
(
force_dispatch
=
True
)
return
success
def
disable_pre_hook
(
self
):
"""Disable pre-hooks for underlying distributed optimizers."""
warnings
.
warn
(
"`ChainedOptimizer.disable_pre_hook` will be deprecated in a future release. "
"Use `DistributedDataParallel.disable_forward_pre_hook` directly."
)
for
model_chunk
in
self
.
model_chunks
:
model_chunk
.
disable_forward_pre_hook
()
def
enable_pre_hook
(
self
):
"""Enable pre-hooks for underlying distributed optimizers."""
warnings
.
warn
(
"`ChainedOptimizer.enable_pre_hook` will be deprecated in a future release. "
"Use `DistributedDataParallel.enable_forward_pre_hook` directly."
)
for
model_chunk
in
self
.
model_chunks
:
model_chunk
.
enable_forward_pre_hook
()
@
torch
.
no_grad
()
def
step
(
self
):
"""ChainedOptimizer will step all optimizers one by one."""
found_inf_flag
=
self
.
prepare_grads
()
if
found_inf_flag
:
return
False
,
None
,
None
# Get grad norm.
grad_norms
=
[]
for
optimizer
in
self
.
chained_optimizers
:
_grad_norm
=
optimizer
.
get_grad_norm
()
grad_norms
+=
[
_grad_norm
if
_grad_norm
else
0.0
]
grad_norm
=
math
.
sqrt
(
sum
([
x
**
2
for
x
in
grad_norms
]))
# Clip gradients.
for
optimizer
in
self
.
chained_optimizers
:
if
optimizer
.
config
.
clip_grad
>
0.0
:
clip_grad_by_total_norm_fp32
(
optimizer
.
get_parameters
(),
max_norm
=
optimizer
.
config
.
clip_grad
,
total_norm
=
grad_norm
,
)
# Count the zeros in the grads.
num_zeros_in_grad
=
0
for
optimizer
in
self
.
chained_optimizers
:
num_zeros_in_grad
+=
(
optimizer
.
count_zeros
()
if
optimizer
.
config
.
log_num_zeros_in_grad
else
0
)
update_successful
=
self
.
step_with_ready_grads
()
return
update_successful
,
grad_norm
,
num_zeros_in_grad
def
save_parameter_state
(
self
,
filename
:
str
):
"""Save the distributed parameter states of all optimizers to a file.
Args:
filename (str): path to save parameter state to.
"""
save_states
=
False
states
=
[]
for
optimizer
in
self
.
chained_optimizers
:
if
hasattr
(
optimizer
,
'get_parameter_state_dp_zero'
):
state_dict
=
optimizer
.
get_parameter_state_dp_zero
()
# Save checkpoint economically, only when DP rank = 0, state dict
# needs to be saved.
if
torch
.
distributed
.
get_rank
(
optimizer
.
data_parallel_group
)
==
0
:
states
.
append
(
state_dict
)
save_states
=
True
else
:
states
.
append
(
None
)
else
:
states
.
append
(
None
)
if
save_states
:
torch
.
save
(
states
,
filename
)
def
load_parameter_state
(
self
,
filename
:
str
,
*
,
update_legacy_format
:
bool
=
False
):
"""Load the distributed parameter states of all optimizers from a file.
Args:
filename (str): path to load parameter state from.
"""
states
=
None
for
idx
,
optimizer
in
enumerate
(
self
.
chained_optimizers
):
if
not
hasattr
(
optimizer
,
'load_parameter_state_from_dp_zero'
):
continue
# Lazy loading checkpoint, state dict is needed only when DP rank = 0.
if
torch
.
distributed
.
get_rank
(
optimizer
.
data_parallel_group
)
==
0
and
states
is
None
:
states
=
torch
.
load
(
filename
)
state_dict
=
states
[
idx
]
if
states
else
None
optimizer
.
load_parameter_state_from_dp_zero
(
state_dict
,
update_legacy_format
=
update_legacy_format
)
def
start_param_sync
(
self
,
model_index
:
int
,
*
unused
):
"""Start parameter synchronization for all optimizers."""
for
optimizer
in
self
.
chained_optimizers
:
optimizer
.
start_param_sync
(
model_index
,
*
unused
)
megatron/core/optimizer/optimizer_config.py
0 → 100644
View file @
4b097dee
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from
dataclasses
import
dataclass
from
typing
import
Callable
,
Optional
import
torch
@
dataclass
class
OptimizerConfig
:
"""Configuration for optimizer."""
##############
# General
##############
optimizer
:
str
=
'adam'
"""Optimizer to use (one of Adam or SGD)."""
lr
:
Optional
[
float
]
=
None
"""Initial learning rate. Depending on decay style and initial warmup, the learning rate at each
iteration would be different.
"""
min_lr
:
Optional
[
float
]
=
None
"""Minumum value for learning rate. The scheduler clip values below this threshold."""
decoupled_lr
:
Optional
[
float
]
=
None
"""Separate learning rate for the input and output layer."""
decoupled_min_lr
:
Optional
[
float
]
=
None
"""Minimum value for learning rate for the input and output layer. The scheduler clip values
below this threshold.
"""
weight_decay
:
float
=
0.01
"""Weight decay coefficient for L2 regularization."""
##############
# Precision
##############
fp16
:
bool
=
False
"""If true, train with fp16 mixed precision training. Defaults to False."""
bf16
:
bool
=
False
"""If true, train with bf16 mixed precision training. Defaults to False."""
params_dtype
:
torch
.
dtype
=
torch
.
float32
"""dtype used when intializing the weights. Defaults to torch.float32."""
###############
# Loss scaling
###############
loss_scale
:
Optional
[
float
]
=
None
"""Static loss scaling, positive power of 2 values can improve fp16 convergence. If None,
dynamic loss scaling is used.
"""
initial_loss_scale
:
float
=
2
**
32
"""Initial loss-scale for dynamic loss scaling."""
min_loss_scale
:
float
=
1.0
"""Minimum loss scale for dynamic loss scaling."""
loss_scale_window
:
float
=
1000
"""Window over which to raise/lower dynamic scale."""
hysteresis
:
int
=
2
"""Hysteresis for dynamic loss scaling."""
##############
# Optimizer
##############
# Adam
adam_beta1
:
float
=
0.9
"""First coefficient for computing running averages of gradient and its square in Adam
optimizer.
"""
adam_beta2
:
float
=
0.999
"""Second coefficient for computing running averages of gradient and its square in Adam
optimizer.
"""
adam_eps
:
float
=
1e-08
"""Term added to the denominator to improve numerical stability in Adam optimizer."""
# SGD.
sgd_momentum
:
float
=
0.9
"""Momentum factor for SGD optimizer."""
#######################
# Distributed optimizer
#######################
use_distributed_optimizer
:
bool
=
False
"""Distribute optimizer state over data-parallel replicas."""
overlap_grad_reduce
:
bool
=
False
"""If true, overlap grad reduce-scatter with backward compute in distributed optimizer.
NOTE: This parameter will be deprecated in a future release. Use `overlap_grad_reduce`
in `megatron/core/distributed/distributed_data_parallel_config.py` instead."""
overlap_param_gather
:
bool
=
False
"""If true, overlap param all-gather with forward compute in distributed optimizer.
NOTE: This parameter will be deprecated in a future release. Use `overlap_param_gather`
in `megatron/core/distributed/distributed_data_parallel_config.py` instead."""
overlap_param_gather_with_optimizer_step
:
bool
=
False
"""If true, overlap param all-gather of first bucket with optimizer step."""
################
# Miscellaneous
################
clip_grad
:
float
=
1.0
"""Gradient clipping based on global L2 norm."""
log_num_zeros_in_grad
:
bool
=
False
"""If true, calculate and log the number of zeros in gradient."""
barrier_with_L1_time
:
bool
=
False
"""If true, use barrier with level 1 time measurements."""
timers
:
Callable
=
None
"""Function to get timers."""
config_logger_dir
:
str
=
""
"""When non-empty, dumps entry-point configs to config_logger_dir"""
megatron/optimizer_param_scheduler.py
→
megatron/
core/
optimizer_param_scheduler.py
View file @
4b097dee
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Learning rate decay and weight decay incr functions."""
import
logging
import
math
from
megatron
import
print_rank_0
class
OptimizerParamScheduler
(
object
):
"""Anneals learning rate and weight decay"""
def
__init__
(
self
,
optimizer
,
init_lr
,
max_lr
,
min_lr
,
lr_warmup_steps
,
lr_decay_steps
,
lr_decay_style
,
start_wd
,
end_wd
,
wd_incr_steps
,
wd_incr_style
,
use_checkpoint_opt_param_scheduler
=
True
,
override_opt_param_scheduler
=
False
):
from
typing
import
Optional
from
megatron.core.optimizer
import
MegatronOptimizer
from
megatron.core.utils
import
log_single_rank
logger
=
logging
.
getLogger
(
__name__
)
class
OptimizerParamScheduler
:
"""Anneals learning rate and weight decay
Args:
optimizer (MegatronOptimizer): the optimizer to be used
init_lr (float): initial learning rate
max_lr (float): maximum learning rate
min_lr (float): minimum learning rate
lr_warmup_steps (int): number of warmup steps
lr_decay_steps (int): number of decay steps
lr_decay_style (str): decay style for learning rate
start_wd (float): initial weight decay
end_wd (float): final weight decay
wd_incr_steps (int): number of weight decay increment steps
wd_incr_style (str): weight decay increment style
use_checkpoint_opt_param_scheduler (bool, optional): whether to use the checkpoint values
for the optimizer param scheduler
override_opt_param_scheduler (bool, optional): whether to override the optimizer param
scheduler values with the class values
wsd_decay_steps (int, optional): number of weight decay decay steps
lr_wsd_decay_style (str, optional): decay style for learning rate during weight decay decay
steps
"""
def
__init__
(
self
,
optimizer
:
MegatronOptimizer
,
init_lr
:
float
,
max_lr
:
float
,
min_lr
:
float
,
lr_warmup_steps
:
int
,
lr_decay_steps
:
int
,
lr_decay_style
:
str
,
start_wd
:
float
,
end_wd
:
float
,
wd_incr_steps
:
int
,
wd_incr_style
:
str
,
use_checkpoint_opt_param_scheduler
:
Optional
[
bool
]
=
True
,
override_opt_param_scheduler
:
Optional
[
bool
]
=
False
,
wsd_decay_steps
:
Optional
[
int
]
=
None
,
lr_wsd_decay_style
:
Optional
[
str
]
=
None
,
)
->
None
:
# Class values.
self
.
optimizer
=
optimizer
...
...
@@ -28,10 +68,14 @@ class OptimizerParamScheduler(object):
self
.
lr_warmup_steps
=
lr_warmup_steps
self
.
num_steps
=
0
self
.
lr_decay_steps
=
lr_decay_steps
self
.
wsd_decay_steps
=
wsd_decay_steps
self
.
lr_wsd_decay_style
=
lr_wsd_decay_style
assert
self
.
lr_decay_steps
>
0
assert
self
.
lr_warmup_steps
<
self
.
lr_decay_steps
self
.
lr_decay_style
=
lr_decay_style
if
self
.
lr_decay_style
==
"WSD"
:
assert
self
.
wsd_decay_steps
is
not
None
self
.
start_wd
=
start_wd
self
.
end_wd
=
end_wd
...
...
@@ -43,16 +87,16 @@ class OptimizerParamScheduler(object):
self
.
override_opt_param_scheduler
=
override_opt_param_scheduler
self
.
use_checkpoint_opt_param_scheduler
=
use_checkpoint_opt_param_scheduler
if
self
.
override_opt_param_scheduler
:
assert
not
self
.
use_checkpoint_opt_param_scheduler
,
'both override and '
\
'use-checkpoint are set.'
assert
not
self
.
use_checkpoint_opt_param_scheduler
,
(
'both override and '
'use-checkpoint are set.'
)
# Set the learning rate
self
.
step
(
0
)
print_rank_0
(
'> learning rate decay style: {}'
.
format
(
self
.
lr_decay_style
))
log_single_rank
(
logger
,
logging
.
INFO
,
f
"> learning rate decay style:
{
self
.
lr_decay_style
}
"
)
def
get_wd
(
self
):
"""
Weight decay incr functions"""
def
get_wd
(
self
)
->
float
:
"""Weight decay incr functions"""
if
self
.
num_steps
>
self
.
wd_incr_steps
:
return
self
.
end_wd
...
...
@@ -70,71 +114,86 @@ class OptimizerParamScheduler(object):
elif
self
.
wd_incr_style
==
'cosine'
:
coeff
=
0.5
*
(
math
.
cos
(
math
.
pi
*
(
1
-
incr_ratio
))
+
1.0
)
else
:
raise
Exception
(
'{} weight decay increment style is not supported.'
.
format
(
self
.
wd_incr_style
))
raise
Exception
(
f
'
{
self
.
wd_incr_style
}
weight decay increment style is not supported.'
)
return
self
.
start_wd
+
coeff
*
delta_wd
def
get_lr
(
self
):
def
get_lr
(
self
,
param_group
:
dict
)
->
float
:
"""Learning rate decay functions from:
https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""
https://openreview.net/pdf?id=BJYwwY9ll pg. 4
Args:
param_group (dict): parameter group from the optimizer.
"""
max_lr
=
param_group
.
get
(
'max_lr'
,
self
.
max_lr
)
min_lr
=
param_group
.
get
(
'min_lr'
,
self
.
min_lr
)
# Use linear warmup for the initial part.
if
self
.
lr_warmup_steps
>
0
and
self
.
num_steps
<=
self
.
lr_warmup_steps
:
return
(
self
.
init_lr
+
(
(
self
.
max_lr
-
self
.
init_lr
)
*
float
(
self
.
num_steps
)
/
float
(
self
.
lr_warmup_steps
)
)
return
self
.
init_lr
+
(
(
max_lr
-
self
.
init_lr
)
*
float
(
self
.
num_steps
)
/
float
(
self
.
lr_warmup_steps
)
)
# If the learning rate is constant, just return the initial value.
if
self
.
lr_decay_style
==
'constant'
:
return
self
.
max_lr
return
max_lr
# For any steps larger than `self.lr_decay_steps`, use `
self.
min_lr`.
# For any steps larger than `self.lr_decay_steps`, use `min_lr`.
if
self
.
num_steps
>
self
.
lr_decay_steps
:
return
self
.
min_lr
return
min_lr
# If we are done with the warmup period, use the decay style.
if
self
.
lr_decay_style
==
'inverse-square-root'
:
warmup_steps
=
max
(
self
.
lr_warmup_steps
,
1
)
num_steps
=
max
(
self
.
num_steps
,
1
)
lr
=
self
.
max_lr
*
warmup_steps
**
0.5
/
(
num_steps
**
0.5
)
return
max
(
self
.
min_lr
,
lr
)
lr
=
max_lr
*
warmup_steps
**
0.5
/
(
num_steps
**
0.5
)
return
max
(
min_lr
,
lr
)
num_steps_
=
self
.
num_steps
-
self
.
lr_warmup_steps
decay_steps_
=
self
.
lr_decay_steps
-
self
.
lr_warmup_steps
decay_ratio
=
float
(
num_steps_
)
/
float
(
decay_steps_
)
assert
decay_ratio
>=
0.0
assert
decay_ratio
<=
1.0
delta_lr
=
self
.
max_lr
-
self
.
min_lr
delta_lr
=
max_lr
-
min_lr
if
self
.
lr_decay_style
==
'linear'
:
coeff
=
(
1.0
-
decay_ratio
)
coeff
=
1.0
-
decay_ratio
elif
self
.
lr_decay_style
==
'cosine'
:
coeff
=
0.5
*
(
math
.
cos
(
math
.
pi
*
decay_ratio
)
+
1.0
)
elif
self
.
lr_decay_style
==
'WSD'
:
wsd_anneal_start_
=
self
.
lr_decay_steps
-
self
.
wsd_decay_steps
if
self
.
num_steps
<=
wsd_anneal_start_
:
coeff
=
1.0
else
:
wsd_steps
=
self
.
num_steps
-
wsd_anneal_start_
wsd_decay_ratio
=
float
(
wsd_steps
)
/
float
(
self
.
wsd_decay_steps
)
if
self
.
lr_wsd_decay_style
==
"linear"
:
coeff
=
1.0
-
wsd_decay_ratio
elif
self
.
lr_wsd_decay_style
==
"cosine"
:
coeff
=
0.5
*
(
math
.
cos
(
math
.
pi
*
wsd_decay_ratio
)
+
1.0
)
elif
self
.
lr_wsd_decay_style
==
"exponential"
:
coeff
=
(
2.0
*
math
.
pow
(
0.5
,
wsd_decay_ratio
))
-
1.0
else
:
raise
Exception
(
'{} decay style is not supported.'
.
format
(
self
.
lr_decay_style
))
raise
Exception
(
f
'
{
self
.
lr_decay_style
}
decay style is not supported.'
)
return
self
.
min_lr
+
coeff
*
delta_lr
return
min_lr
+
coeff
*
delta_lr
def
step
(
self
,
increment
:
int
)
->
None
:
"""Set lr for all parameters groups.
def
step
(
self
,
increment
):
"""Set lr for all parameters groups."""
Args:
increment (int): number of steps to increment
"""
self
.
num_steps
+=
increment
new_lr
=
self
.
get_lr
()
new_wd
=
self
.
get_wd
()
for
group
in
self
.
optimizer
.
param_groups
:
group
[
'lr'
]
=
new_lr
*
group
.
get
(
'lr_mult'
,
1.0
)
group
[
'weight_decay'
]
=
new_wd
*
group
.
get
(
'wd_mult'
,
1.0
)
for
param_group
in
self
.
optimizer
.
param_groups
:
new_lr
=
self
.
get_lr
(
param_group
)
param_group
[
'lr'
]
=
new_lr
*
param_group
.
get
(
'lr_mult'
,
1.0
)
param_group
[
'weight_decay'
]
=
new_wd
*
param_group
.
get
(
'wd_mult'
,
1.0
)
def
state
_
dict
(
self
):
def
state_dict
(
self
)
->
dict
:
"""Return the
state
dict
."""
state_dict
=
{
'max_lr'
:
self
.
max_lr
,
'lr_warmup_steps'
:
self
.
lr_warmup_steps
,
...
...
@@ -145,91 +204,94 @@ class OptimizerParamScheduler(object):
'start_wd'
:
self
.
start_wd
,
'end_wd'
:
self
.
end_wd
,
'wd_incr_style'
:
self
.
wd_incr_style
,
'wd_incr_steps'
:
self
.
wd_incr_steps
'wd_incr_steps'
:
self
.
wd_incr_steps
,
}
return
state_dict
def
_check_and_set
(
self
,
cls_value
,
sd_value
,
name
):
def
_check_and_set
(
self
,
cls_value
:
float
,
sd_value
:
float
,
name
:
str
)
->
float
:
"""Auxiliary function for checking the values in the checkpoint and
setting them."""
setting them.
Args:
cls_value (float): class value
sd_value (float): checkpoint value
name (str): name of the parameter
"""
if
self
.
override_opt_param_scheduler
:
print_rank_0
(
'
> overriding {} value to {
}'
.
format
(
name
,
cls_value
)
)
log_single_rank
(
logger
,
logging
.
INFO
,
f
"
> overriding
{
name
}
value to
{
cls_value
}
"
)
return
cls_value
if
not
self
.
use_checkpoint_opt_param_scheduler
:
assert
cls_value
==
sd_value
,
\
f
'OptimizerParamScheduler: class input value
{
cls_value
}
and checkpoint'
\
assert
cls_value
==
sd_value
,
(
f
'OptimizerParamScheduler: class input value
{
cls_value
}
and checkpoint'
f
'value
{
sd_value
}
for
{
name
}
do not match'
print_rank_0
(
' > using checkpoint value {} for {}'
.
format
(
sd_value
,
name
))
)
log_single_rank
(
logger
,
logging
.
INFO
,
f
" > using checkpoint value
{
sd_value
}
for
{
name
}
"
)
return
sd_value
def
load_state_dict
(
self
,
state_dict
:
dict
)
->
None
:
"""Load the state dict.
def
load_state_dict
(
self
,
sd
):
Args:
state_dict (dict): state dict to be load
"""
if
'start_lr'
in
s
d
:
max_lr_
=
s
d
[
'start_lr'
]
if
'start_lr'
in
s
tate_dict
:
max_lr_
=
s
tate_dict
[
'start_lr'
]
else
:
max_lr_
=
s
d
[
'max_lr'
]
self
.
max_lr
=
self
.
_check_and_set
(
self
.
max_lr
,
max_lr_
,
'learning rate'
)
self
.
min_lr
=
self
.
_check_and_set
(
self
.
min_lr
,
sd
[
'min_lr'
],
'minimum learning rate'
)
if
'warmup_iter'
in
s
d
:
lr_warmup_steps_
=
s
d
[
'warmup_iter'
]
elif
'warmup_steps'
in
s
d
:
lr_warmup_steps_
=
s
d
[
'warmup_steps'
]
max_lr_
=
s
tate_dict
[
'max_lr'
]
self
.
max_lr
=
self
.
_check_and_set
(
self
.
max_lr
,
max_lr_
,
'learning rate'
)
self
.
min_lr
=
self
.
_check_and_set
(
self
.
min_lr
,
state_dict
[
'min_lr'
],
'minimum learning rate'
)
if
'warmup_iter'
in
s
tate_dict
:
lr_warmup_steps_
=
s
tate_dict
[
'warmup_iter'
]
elif
'warmup_steps'
in
s
tate_dict
:
lr_warmup_steps_
=
s
tate_dict
[
'warmup_steps'
]
else
:
lr_warmup_steps_
=
s
d
[
'lr_warmup_steps'
]
self
.
lr_warmup_steps
=
self
.
_check_and_set
(
self
.
lr_warmup_steps
,
lr_warmup_steps_
,
'warmup iterations'
)
if
'end_iter'
in
s
d
:
lr_decay_steps_
=
s
d
[
'end_iter'
]
elif
'decay_steps'
in
s
d
:
lr_decay_steps_
=
s
d
[
'decay_steps'
]
lr_warmup_steps_
=
s
tate_dict
[
'lr_warmup_steps'
]
self
.
lr_warmup_steps
=
self
.
_check_and_set
(
self
.
lr_warmup_steps
,
lr_warmup_steps_
,
'warmup iterations'
)
if
'end_iter'
in
s
tate_dict
:
lr_decay_steps_
=
s
tate_dict
[
'end_iter'
]
elif
'decay_steps'
in
s
tate_dict
:
lr_decay_steps_
=
s
tate_dict
[
'decay_steps'
]
else
:
lr_decay_steps_
=
sd
[
'lr_decay_steps'
]
self
.
lr_decay_steps
=
self
.
_check_and_set
(
self
.
lr_decay_steps
,
lr_decay_steps_
,
'total number of iterations'
)
lr_decay_steps_
=
state_dict
[
'lr_decay_steps'
]
self
.
lr_decay_steps
=
self
.
_check_and_set
(
self
.
lr_decay_steps
,
lr_decay_steps_
,
'total number of iterations'
)
if
'decay_style'
in
s
d
:
lr_decay_style_
=
s
d
[
'decay_style'
]
if
'decay_style'
in
s
tate_dict
:
lr_decay_style_
=
s
tate_dict
[
'decay_style'
]
else
:
lr_decay_style_
=
s
d
[
'lr_decay_style'
]
self
.
lr_decay_style
=
self
.
_check_and_set
(
self
.
lr_decay_style
,
lr_decay_style_
,
'learning rate decay style'
)
lr_decay_style_
=
s
tate_dict
[
'lr_decay_style'
]
self
.
lr_decay_style
=
self
.
_check_and_set
(
self
.
lr_decay_style
,
lr_decay_style_
,
'learning rate decay style'
)
if
'num_iters'
in
s
d
:
num_steps
=
s
d
[
'num_iters'
]
if
'num_iters'
in
s
tate_dict
:
num_steps
=
s
tate_dict
[
'num_iters'
]
else
:
num_steps
=
s
d
[
'num_steps'
]
num_steps
=
s
tate_dict
[
'num_steps'
]
self
.
step
(
increment
=
num_steps
)
if
'start_wd'
in
sd
:
self
.
start_wd
=
self
.
_check_and_set
(
self
.
start_wd
,
sd
[
'start_wd'
],
"start weight decay"
)
self
.
end_wd
=
self
.
_check_and_set
(
self
.
end_wd
,
sd
[
'end_wd'
],
"end weight decay"
)
self
.
wd_incr_steps
=
self
.
_check_and_set
(
self
.
wd_incr_steps
,
sd
[
'wd_incr_steps'
],
"total number of weight decay iterations"
)
self
.
wd_incr_style
=
self
.
_check_and_set
(
self
.
wd_incr_style
,
sd
[
'wd_incr_style'
],
"weight decay incr style"
)
if
'start_wd'
in
state_dict
:
self
.
start_wd
=
self
.
_check_and_set
(
self
.
start_wd
,
state_dict
[
'start_wd'
],
"start weight decay"
)
self
.
end_wd
=
self
.
_check_and_set
(
self
.
end_wd
,
state_dict
[
'end_wd'
],
"end weight decay"
)
self
.
wd_incr_steps
=
self
.
_check_and_set
(
self
.
wd_incr_steps
,
state_dict
[
'wd_incr_steps'
],
"total number of weight decay iterations"
,
)
self
.
wd_incr_style
=
self
.
_check_and_set
(
self
.
wd_incr_style
,
state_dict
[
'wd_incr_style'
],
"weight decay incr style"
)
megatron/core/package_info.py
View file @
4b097dee
...
...
@@ -2,7 +2,7 @@
MAJOR
=
0
MINOR
=
3
MINOR
=
9
PATCH
=
0
PRE_RELEASE
=
''
...
...
megatron/core/packed_seq_params.py
0 → 100644
View file @
4b097dee
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from
dataclasses
import
dataclass
from
torch
import
Tensor
@
dataclass
class
PackedSeqParams
:
# parameters to TEDotProductAttention and fused rope kernels for the `thd` (packed) sequence format,
qkv_format
:
str
=
None
cu_seqlens_q
:
Tensor
=
None
cu_seqlens_kv
:
Tensor
=
None
max_seqlen_q
:
Tensor
=
None
max_seqlen_kv
:
Tensor
=
None
megatron/core/parallel_state.py
View file @
4b097dee
...
...
@@ -3,7 +3,11 @@
"""Model and data parallel groups."""
import
os
from
typing
import
Optional
import
warnings
from
datetime
import
timedelta
from
functools
import
partial
from
itertools
import
cycle
from
typing
import
Callable
,
List
,
Optional
import
torch
...
...
@@ -15,6 +19,8 @@ _TENSOR_MODEL_PARALLEL_GROUP = None
_PIPELINE_MODEL_PARALLEL_GROUP
=
None
# Model parallel group (both intra- and pipeline) that the current rank belongs to.
_MODEL_PARALLEL_GROUP
=
None
# Model parallel group (both intra-, pipeline, and expert) that the current rank belongs to.
_MODEL_AND_EXPERT_PARALLEL_GROUP
=
None
# Embedding group.
_EMBEDDING_GROUP
=
None
# Position embedding group.
...
...
@@ -22,18 +28,33 @@ _POSITION_EMBEDDING_GROUP = None
# Data parallel group that the current rank belongs to.
_DATA_PARALLEL_GROUP
=
None
_DATA_PARALLEL_GROUP_GLOO
=
None
# FP8 amax reduction group.
_AMAX_REDUCTION_GROUP
=
None
# tensor model parallel group and data parallel group combined
# used for fp8 and moe training
_TENSOR_AND_DATA_PARALLEL_GROUP
=
None
# Expert parallel group that the current rank belongs to.
_EXPERT_MODEL_PARALLEL_GROUP
=
None
_TENSOR_AND_EXPERT_PARALLEL_GROUP
=
None
_DATA_MODULO_EXPERT_PARALLEL_GROUP
=
None
_DATA_MODULO_EXPERT_PARALLEL_GROUP_GLOO
=
None
_DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP
=
None
_DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP_GLOO
=
None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
=
None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
=
None
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
=
None
_PIPELINE_MODEL_PARALLEL_DECODER_START
=
None
# These values enable us to change the mpu sizes on the fly.
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
=
None
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
=
None
_MPU_EXPERT_MODEL_PARALLEL_WORLD_SIZE
=
None
_MPU_DATA_PARALLEL_WORLD_SIZE
=
None
_MPU_DATA_PARALLEL_RANK
=
None
_MPU_TENSOR_MODEL_PARALLEL_RANK
=
None
_MPU_PIPELINE_MODEL_PARALLEL_RANK
=
None
_MPU_EXPERT_MODEL_PARALLEL_RANK
=
None
# A list of ranks that have a copy of the embedding.
_EMBEDDING_GLOBAL_RANKS
=
None
...
...
@@ -49,21 +70,305 @@ _PIPELINE_GLOBAL_RANKS = None
# rank when broadcasting weights from src to all other data parallel ranks
_DATA_PARALLEL_GLOBAL_RANKS
=
None
# A list of global ranks for each tensor model parallel group to ease calculation of
# the first local rank in the tensor model parallel group
_TENSOR_MODEL_PARALLEL_GLOBAL_RANKS
=
None
# Context parallel group that the current rank belongs to
_CONTEXT_PARALLEL_GROUP
=
None
# A list of global ranks for each context parallel group to ease calculation of the
# destination rank when exchanging KV/dKV between context parallel_ranks
_CONTEXT_PARALLEL_GLOBAL_RANKS
=
None
# Data parallel group information with context parallel combined.
_DATA_PARALLEL_GROUP_WITH_CP
=
None
_DATA_PARALLEL_GROUP_WITH_CP_GLOO
=
None
_DATA_PARALLEL_GLOBAL_RANKS_WITH_CP
=
None
# combined parallel group of TP and CP
_TENSOR_AND_CONTEXT_PARALLEL_GROUP
=
None
# combined parallel group of TP, DP, and CP used for fp8
_TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP
=
None
# Memory buffers to avoid dynamic memory allocation
_GLOBAL_MEMORY_BUFFER
=
None
# MOE logging
_MOE_LAYER_WISE_LOGGING_TRACKER
=
{}
def
get_nccl_options
(
pg_name
,
nccl_comm_cfgs
):
"""Set the NCCL process group options.
Args:
pg_name (str): process group name
nccl_comm_cfgs (dict): nccl communicator configurations
When an option (e.g., max_ctas) is not found in the config, use the NCCL default setting.
"""
if
pg_name
in
nccl_comm_cfgs
:
nccl_options
=
torch
.
distributed
.
ProcessGroupNCCL
.
Options
()
nccl_options
.
config
.
cga_cluster_size
=
nccl_comm_cfgs
[
pg_name
].
get
(
'cga_cluster_size'
,
4
)
nccl_options
.
config
.
max_ctas
=
nccl_comm_cfgs
[
pg_name
].
get
(
'max_ctas'
,
32
)
nccl_options
.
config
.
min_ctas
=
nccl_comm_cfgs
[
pg_name
].
get
(
'min_ctas'
,
1
)
return
nccl_options
else
:
return
None
def
generate_masked_orthogonal_rank_groups
(
world_size
:
int
,
parallel_size
:
List
[
int
],
mask
:
List
[
bool
]
)
->
List
[
List
[
int
]]:
"""Generate orthogonal parallel groups based on the parallel size and mask.
Arguments:
world_size (int): world size
parallel_size (List[int]):
The parallel size of each orthogonal parallel type. For example, if
tensor_parallel_size = 2, pipeline_model_parallel_group = 3, data_parallel_size = 4,
and the parallel mapping order is tp-pp-dp, then the parallel_size = [2, 3, 4].
mask (List[bool]):
The mask controls which parallel methods the generated groups represent. If mask[i] is
True, it means the generated group contains the i-th parallelism method. For example,
if parallel_size = [tp_size, pp_size, dp_size], and mask = [True, False , True], then
the generated group is the `tp-dp` group, if the mask = [False, True, False], then the
generated group is the `pp` group.
Algorithm:
For orthogonal parallelism, such as tp/dp/pp/cp, the global_rank and
local_rank satisfy the following equation:
global_rank = tp_rank + dp_rank * tp_size + pp_rank * tp_size * dp_size (1)
tp_rank \in [0, tp_size)
dp_rank \in [0, dp_size)
pp_rank \in [0, pp_size)
If we want to get the `dp_group` (tp_size * pp_size groups of dp_size ranks each.
For example, if the gpu size is 8 and order is 'tp-pp-dp', size is '2-2-2', and the
dp_group here is [[0, 4], [1, 5], [2, 6], [3, 7]].)
The tp_rank and pp_rank will be combined to form the `dp_group_index`.
dp_group_index = tp_rank + pp_rank * tp_size (2)
So, Given that tp_rank and pp_rank satisfy equation (2), and dp_rank in
range(0, dp_size), the ranks in dp_group[dp_group_index] satisfies the
equation (1).
This function solve this math problem.
For example, if the parallel_size = [tp_size, dp_size, pp_size] = [2, 3, 4],
and the mask = [False, True, False]. Then,
dp_group_index(0) = tp_rank(0) + pp_rank(0) * 2
dp_group_index(1) = tp_rank(1) + pp_rank(0) * 2
...
dp_group_index(7) = tp_rank(1) + pp_rank(3) * 2
dp_group[0] = 0 + range(0, 3) * 2 + 0 = [0, 2, 4]
dp_group[1] = 1 + range(0, 3) * 2 + 0 = [1, 3, 5]
...
dp_group[7] = 1 + range(0, 3) * 2 + 3 * 2 * 3 = [19, 21, 23]
"""
def
prefix_product
(
a
:
List
[
int
],
init
=
1
)
->
List
[
int
]:
r
=
[
init
]
for
v
in
a
:
init
=
init
*
v
r
.
append
(
init
)
return
r
def
inner_product
(
a
:
List
[
int
],
b
:
List
[
int
])
->
int
:
return
sum
([
x
*
y
for
x
,
y
in
zip
(
a
,
b
)])
def
decompose
(
index
,
shape
,
stride
=
None
):
'''
This function solve the math problem below:
There is an equation:
index = sum(idx[i] * stride[i])
And given the value of index, stride.
Return the idx.
This function will used to get the pp/dp/pp_rank
from group_index and rank_in_group.
'''
if
stride
is
None
:
stride
=
prefix_product
(
shape
)
idx
=
[(
index
//
d
)
%
s
for
s
,
d
in
zip
(
shape
,
stride
)]
# stride is a prefix_product result. And the value of stride[-1]
# is not used.
assert
(
sum
([
x
*
y
for
x
,
y
in
zip
(
idx
,
stride
[:
-
1
])])
==
index
),
"idx {} with shape {} mismatch the return idx {}"
.
format
(
index
,
shape
,
idx
)
return
idx
masked_shape
=
[
s
for
s
,
m
in
zip
(
parallel_size
,
mask
)
if
m
]
unmasked_shape
=
[
s
for
s
,
m
in
zip
(
parallel_size
,
mask
)
if
not
m
]
global_stride
=
prefix_product
(
parallel_size
)
masked_stride
=
[
d
for
d
,
m
in
zip
(
global_stride
,
mask
)
if
m
]
unmasked_stride
=
[
d
for
d
,
m
in
zip
(
global_stride
,
mask
)
if
not
m
]
group_size
=
prefix_product
(
masked_shape
)[
-
1
]
num_of_group
=
world_size
//
group_size
ranks
=
[]
for
group_index
in
range
(
num_of_group
):
# get indices from unmaksed for group_index.
decomposed_group_idx
=
decompose
(
group_index
,
unmasked_shape
)
rank
=
[]
for
rank_in_group
in
range
(
group_size
):
# get indices from masked for rank_in_group.
decomposed_rank_idx
=
decompose
(
rank_in_group
,
masked_shape
)
rank
.
append
(
inner_product
(
decomposed_rank_idx
,
masked_stride
)
+
inner_product
(
decomposed_group_idx
,
unmasked_stride
)
)
ranks
.
append
(
rank
)
return
ranks
class
RankGenerator
(
object
):
"""A class for generating rank groups for different modes of parallelism."""
def
__init__
(
self
,
tp
:
int
,
ep
:
int
,
dp
:
int
,
pp
:
int
,
cp
:
int
,
order
:
str
,
rank_offset
:
int
=
0
)
->
None
:
self
.
tp
=
tp
self
.
ep
=
ep
self
.
dp
=
dp
self
.
pp
=
pp
self
.
cp
=
cp
self
.
rank_offset
=
rank_offset
self
.
world_size
=
tp
*
dp
*
pp
*
cp
self
.
name_to_size
=
{
"tp"
:
self
.
tp
,
"pp"
:
self
.
pp
,
"dp"
:
self
.
dp
,
"ep"
:
self
.
ep
,
"cp"
:
self
.
cp
,
}
self
.
order
=
order
order
=
order
.
lower
()
if
'ep'
in
order
:
if
'ep-dp'
not
in
order
and
'dp-ep'
not
in
order
:
raise
RuntimeError
(
f
"The ep and dp must be adjacent in order (
{
self
.
order
}
)."
)
for
name
in
self
.
name_to_size
.
keys
():
if
name
not
in
order
and
self
.
name_to_size
[
name
]
!=
1
:
raise
RuntimeError
(
f
"The size of (
{
name
}
) is (
{
self
.
name_to_size
[
name
]
}
), but you haven't"
f
"specified the order (
{
self
.
order
}
)."
)
elif
name
not
in
order
:
order
=
order
+
'-'
+
name
self
.
order_w_ep
=
order
self
.
order_wo_ep
=
'-'
.
join
([
token
for
token
in
order
.
split
(
'-'
)
if
token
!=
'ep'
])
self
.
ordered_size_wo_ep
=
[]
self
.
ordered_size_w_ep
=
[]
for
token
in
order
.
split
(
'-'
):
if
token
==
'dp'
:
self
.
ordered_size_w_ep
.
append
(
self
.
dp
//
self
.
ep
)
self
.
ordered_size_wo_ep
.
append
(
self
.
dp
)
elif
token
==
'ep'
:
self
.
ordered_size_w_ep
.
append
(
self
.
ep
)
else
:
self
.
ordered_size_w_ep
.
append
(
self
.
name_to_size
[
token
])
self
.
ordered_size_wo_ep
.
append
(
self
.
name_to_size
[
token
])
def
get_mask
(
self
,
order
:
str
,
token
:
str
):
"""Create a mask for the specified tokens based on the given order.
Args:
order (str): The order of parallelism types (e.g., 'tp-dp-pp').
token (str): The specific parallelism types to include in the mask,
separated by hyphens (e.g., 'tp-dp').
"""
ordered_token
=
order
.
split
(
'-'
)
token
=
token
.
split
(
'-'
)
mask
=
[
False
]
*
len
(
ordered_token
)
for
t
in
token
:
mask
[
ordered_token
.
index
(
t
)]
=
True
return
mask
def
get_ranks
(
self
,
token
,
independent_ep
=
False
):
"""Get rank group by input token.
Args:
token (str):
Specify the ranks type that want to get. If we want
to obtain multiple parallel types, we can use a hyphen
'-' to separate them. For example, if we want to obtain
the TP_DP group, the token should be 'tp-dp'.
independent_ep (bool: True):
This flag controls whether we treat EP and DP independently.
EP shares ranks with DP, if we want to get ranks related to
EP, we should set the flag. For example, get_ranks('dp', True)
will get DP modulo EP group, and get_ranks('dp', False) will
get full DP group.
"""
if
independent_ep
:
parallel_size
=
self
.
ordered_size_w_ep
order
=
self
.
order_w_ep
else
:
parallel_size
=
self
.
ordered_size_wo_ep
order
=
self
.
order_wo_ep
mask
=
self
.
get_mask
(
order
,
token
)
ranks
=
generate_masked_orthogonal_rank_groups
(
self
.
world_size
,
parallel_size
,
mask
)
if
self
.
rank_offset
>
0
:
for
rank_group
in
ranks
:
for
i
in
range
(
len
(
rank_group
)):
rank_group
[
i
]
+=
self
.
rank_offset
return
ranks
def
default_embedding_ranks
(
pp_ranks
,
split_rank
=
None
):
"""Return the default ranks that constitute the stages on which the word embeddings live.
For most models, these are the first and last pipeline stages.
We also support the deprecated split rank argument for backwards compatibility."""
if
len
(
pp_ranks
)
==
1
:
return
[
pp_ranks
[
0
]]
elif
split_rank
is
not
None
and
pp_ranks
[
split_rank
]
not
in
(
pp_ranks
[
0
],
pp_ranks
[
-
1
]):
return
[
pp_ranks
[
0
],
pp_ranks
[
split_rank
],
pp_ranks
[
-
1
]]
else
:
return
[
pp_ranks
[
0
],
pp_ranks
[
-
1
]]
def
default_position_embedding_ranks
(
pp_ranks
,
split_rank
=
None
):
"""Return the default ranks that constitute the stages on which the position embeddings live.
For most models, this is only the first pipeline stage.
We also support the deprecated split rank argument for backwards compatibility."""
if
split_rank
is
not
None
and
pp_ranks
[
0
]
!=
pp_ranks
[
split_rank
]:
return
[
pp_ranks
[
0
],
pp_ranks
[
split_rank
]]
else
:
return
[
pp_ranks
[
0
]]
def
initialize_model_parallel
(
tensor_model_parallel_size
:
int
=
1
,
pipeline_model_parallel_size
:
int
=
1
,
virtual_pipeline_model_parallel_size
:
Optional
[
int
]
=
None
,
pipeline_model_parallel_split_rank
:
Optional
[
int
]
=
None
,
use_fp8
:
bool
=
False
,
use_sharp
:
bool
=
False
,
context_parallel_size
:
int
=
1
,
expert_model_parallel_size
:
int
=
1
,
nccl_communicator_config_path
:
Optional
[
str
]
=
None
,
distributed_timeout_minutes
:
int
=
30
,
order
:
str
=
"tp-cp-ep-dp-pp"
,
encoder_tensor_model_parallel_size
:
Optional
[
int
]
=
0
,
encoder_pipeline_model_parallel_size
:
Optional
[
int
]
=
0
,
get_embedding_ranks
:
Optional
[
Callable
[[
List
[
int
],
Optional
[
int
]],
List
[
int
]]]
=
None
,
get_position_embedding_ranks
:
Optional
[
Callable
[[
List
[
int
],
Optional
[
int
]],
List
[
int
]]]
=
None
,
)
->
None
:
# pylint: disable=line-too-long
"""Initialize model data parallel groups.
Arg
ument
s:
Args:
tensor_model_parallel_size (int, default = 1):
The number of GPUs to split individual tensors across.
...
...
@@ -90,7 +395,7 @@ def initialize_model_parallel(
GPU 3: [7, 8] [15, 16]
pipeline_model_parallel_split_rank (int, optional):
For models with both an encoder and decoder, the rank in
DEPRECATED.
For models with both an encoder and decoder, the rank in
pipeline to switch between encoder and decoder (i.e. the
first rank of the decoder). This allows the user to set
the pipeline parallel size of the encoder and decoder
...
...
@@ -99,17 +404,73 @@ def initialize_model_parallel(
pipeline_model_parallel_split_rank is 3, then ranks 0-2
will be the encoder and ranks 3-7 will be the decoder.
use_fp8 (bool, default = False):
Construct GPU groups needed for FP8 training, namely for
amax reduction across the product of the data-parallel and
tensor-parallel groups.
use_sharp (bool, default = False):
Set the use of SHARP for the collective communications of
data-parallel process groups. When `True`, run barrier
within each data-parallel process group, which specifies
the SHARP application target groups.
context_parallel_size (int, default = 1):
The number of tensor parallel GPU groups to split the
network input sequence length across. Compute of attention
module requires tokens of full sequence length, so GPUs
in a context parallel group need to communicate with each
other to exchange information of other sequence chunks.
Each GPU and its counterparts in other tensor parallel
groups compose a context parallel group.
For example, assume we have 8 GPUs, if tensor model parallel
size is 4 and context parallel size is 2, the network input
will be split into two sequence chunks, which are processed
by 2 different groups of 4 GPUs. One chunk is processed by
GPU0-3, the other chunk is processed by GPU4-7. Four groups
are build to do context parallel communications: [GPU0, GPU4],
[GPU1, GPU5], [GPU2, GPU6], and [GPU3, GPU7].
Context parallelism partitions sequence length, so it has no
impact on weights, which means weights are duplicated among
GPUs in a context parallel group. Hence, weight gradients
all-reduce is required in backward. For simplicity, we piggyback
GPUs of context parallelism on data parallel group for
weight gradient all-reduce.
expert_model_parallel_size (int, default = 1):
The number of Mixture of Experts parallel GPUs in each expert
parallel group.
nccl_communicator_config_path (str, default = None):
Path to the yaml file of NCCL communicator configurations.
`min_ctas`, `max_ctas`, and `cga_cluster_size` can be set
for each communicator.
distributed_timeout_minutes (int, default = 30): Timeout, in
minutes,for operations executed against distributed
process groups. See PyTorch documentation at
https://pytorch.org/docs/stable/distributed.html for
caveats.
order (str, default=tp-dp-pp):
The rank initialization order of parallelism. Now we support
tp-dp-pp and tp-pp-dp orders.
encoder_tensor_model_parallel_size (int, default = 0):
The number of GPUs to split individual tensors across in the encoder. If 0,
then we use the default, decoder's tensor model parallel size.
encoder_pipeline_model_parallel_size (int, default = 0):
The number of tensor parallel GPU groups to allocate to the encoder. As an example,
if pipeline_model_parallel_size is 4 and encoder_pipeline_model_parallel_size is 2,
then the encoder will use the first two pipeline stages for its layers, and the total
amount of pipelineing is 6.
get_embedding_ranks (Callable[[List[int], Optional[int]], List[int]], optional, default=None):
A function that takes in a list of ranks for a pipeline group and returns
those ranks that should have embeddings.
get_position_embedding_ranks (Callable[[List[int], Optional[int]], List[int]], optional, default=None):
A function that takes in a list of ranks for a pipeline group, and returns
those ranks that should have position embeddings.
Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
the model pipeline. The present function will
...
...
@@ -127,28 +488,68 @@ def initialize_model_parallel(
ranks 8 to 15 belong to the second box.
"""
if
encoder_pipeline_model_parallel_size
is
None
:
encoder_pipeline_model_parallel_size
=
0
if
encoder_tensor_model_parallel_size
==
0
and
encoder_pipeline_model_parallel_size
>
0
:
encoder_tensor_model_parallel_size
=
tensor_model_parallel_size
if
get_embedding_ranks
is
None
:
get_embedding_ranks
=
partial
(
default_embedding_ranks
,
split_rank
=
pipeline_model_parallel_split_rank
)
if
get_position_embedding_ranks
is
None
:
get_position_embedding_ranks
=
partial
(
default_position_embedding_ranks
,
split_rank
=
pipeline_model_parallel_split_rank
)
if
encoder_pipeline_model_parallel_size
>
0
:
global
_PIPELINE_MODEL_PARALLEL_DECODER_START
_PIPELINE_MODEL_PARALLEL_DECODER_START
=
encoder_pipeline_model_parallel_size
# Get world size and rank. Ensure some consistencies.
assert
torch
.
distributed
.
is_initialized
()
world_size
:
int
=
torch
.
distributed
.
get_world_size
()
if
world_size
%
(
tensor_model_parallel_size
*
pipeline_model_parallel_size
)
!=
0
:
if
encoder_tensor_model_parallel_size
>
0
:
assert
encoder_pipeline_model_parallel_size
>
0
assert
(
encoder_tensor_model_parallel_size
<=
tensor_model_parallel_size
),
"We do not support encoders with more TP than the decoder."
encoder_model_size
=
(
encoder_tensor_model_parallel_size
*
encoder_pipeline_model_parallel_size
*
context_parallel_size
)
decoder_model_size
=
(
tensor_model_parallel_size
*
pipeline_model_parallel_size
*
context_parallel_size
)
total_model_size
=
encoder_model_size
+
decoder_model_size
if
world_size
%
total_model_size
!=
0
:
raise
RuntimeError
(
f
"world_size (
{
world_size
}
) is not divisible by
{
total_model_size
}
"
)
data_parallel_size
:
int
=
world_size
//
total_model_size
if
data_parallel_size
%
expert_model_parallel_size
!=
0
:
raise
RuntimeError
(
f
"
world_size (
{
world_size
}
) is not divisible by tensor_model_parallel_size
"
f
"(
{
tensor_model_parallel_size
}
) x pipeline_model_parallel_size (
{
pipeline
_model_parallel_size
}
)
"
f
"
data_parallel_size (
{
data_parallel_size
}
) is not divisible by
"
"expert
_model_parallel_size
"
)
data_parallel_size
:
int
=
world_size
//
(
tensor_model_parallel_size
*
pipeline_model_parallel_size
)
encoder_world_size
=
encoder_model_size
*
data_parallel_size
decoder_world_size
=
decoder_model_size
*
data_parallel_size
num_tensor_model_parallel_groups
:
int
=
world_size
//
tensor_model_parallel_size
num_pipeline_model_parallel_groups
:
int
=
world_size
//
pipeline_model_parallel
_size
num_data_parallel_groups
:
int
=
world_size
//
data_parallel
_size
assert
(
encoder_world_size
+
decoder_world_size
==
world
_size
),
f
"
{
encoder_world_size
=
}
+
{
decoder_world_size
=
}
!=
{
world
_size
=
}
"
if
virtual_pipeline_model_parallel_size
is
not
None
:
if
not
pipeline_model_parallel_size
>
2
:
if
not
pipeline_model_parallel_size
>
1
:
raise
RuntimeError
(
"pipeline-model-parallel size should be greater than
2
with interleaved schedule"
"pipeline-model-parallel size should be greater than
1
with interleaved schedule"
)
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
...
...
@@ -161,24 +562,103 @@ def initialize_model_parallel(
rank
=
torch
.
distributed
.
get_rank
()
nccl_comm_cfgs
=
{}
if
nccl_communicator_config_path
is
not
None
:
try
:
import
yaml
except
ImportError
:
raise
RuntimeError
(
"Cannot import `yaml`. Setting custom nccl communicator configs "
"requires the yaml package."
)
with
open
(
nccl_communicator_config_path
,
"r"
)
as
stream
:
nccl_comm_cfgs
=
yaml
.
safe_load
(
stream
)
if
encoder_world_size
>
0
:
encoder_rank_generator
=
RankGenerator
(
tp
=
encoder_tensor_model_parallel_size
,
ep
=
1
,
dp
=
data_parallel_size
,
pp
=
encoder_pipeline_model_parallel_size
,
cp
=
context_parallel_size
,
order
=
order
,
rank_offset
=
0
,
)
else
:
encoder_rank_generator
=
None
decoder_rank_generator
=
RankGenerator
(
tp
=
tensor_model_parallel_size
,
ep
=
expert_model_parallel_size
,
dp
=
data_parallel_size
,
pp
=
pipeline_model_parallel_size
,
cp
=
context_parallel_size
,
order
=
order
,
rank_offset
=
encoder_world_size
,
)
def
generator_wrapper
(
group_type
,
**
kwargs
):
"""The `RankGenerator` class produces a hyper-rectangle for a given set of
tensor, pipeline, data, expert, and context parallelism. If we have an encoder,
in addition to the default decoder, we essentially instantiate two `RankGenerator`
classes to construct the parallelism for each module separately, and we then have
to stitch them together for the right groups. For now, this means pp and tp-pp."""
d_ranks
=
decoder_rank_generator
.
get_ranks
(
group_type
,
**
kwargs
)
if
encoder_rank_generator
is
None
:
for
x
in
d_ranks
:
yield
x
return
e_ranks
=
encoder_rank_generator
.
get_ranks
(
group_type
,
**
kwargs
)
if
group_type
==
'pp'
:
# Map 1 encoder tp rank to several decoder tp ranks, because
# these won't be the same size.
for
x
,
y
in
zip
(
cycle
(
e_ranks
),
d_ranks
):
yield
x
+
y
elif
group_type
==
'tp-pp'
:
# For this group, we can just return the concatenated
# groups together, because their sizes are the same.
assert
len
(
e_ranks
)
==
len
(
d_ranks
)
for
x
,
y
in
zip
(
e_ranks
,
d_ranks
):
yield
x
+
y
else
:
for
x
in
e_ranks
:
yield
x
for
x
in
d_ranks
:
yield
x
timeout
=
timedelta
(
minutes
=
distributed_timeout_minutes
)
# Build the data-parallel groups.
global
_DATA_PARALLEL_GROUP
global
_DATA_PARALLEL_GROUP_GLOO
global
_DATA_PARALLEL_GLOBAL_RANKS
global
_DATA_PARALLEL_GROUP_WITH_CP
global
_DATA_PARALLEL_GROUP_WITH_CP_GLOO
global
_DATA_PARALLEL_GLOBAL_RANKS_WITH_CP
assert
_DATA_PARALLEL_GROUP
is
None
,
'data parallel group is already initialized'
all_data_parallel_group_ranks
=
[]
for
i
in
range
(
pipeline_model_parallel_size
):
start_rank
=
i
*
num_pipeline_model_parallel_groups
end_rank
=
(
i
+
1
)
*
num_pipeline_model_parallel_groups
for
j
in
range
(
tensor_model_parallel_size
):
ranks
=
range
(
start_rank
+
j
,
end_rank
,
tensor_model_parallel_size
)
all_data_parallel_group_ranks
.
append
(
list
(
ranks
))
group
=
torch
.
distributed
.
new_group
(
ranks
)
group_gloo
=
torch
.
distributed
.
new_group
(
ranks
,
backend
=
"gloo"
)
if
rank
in
ranks
:
_DATA_PARALLEL_GROUP
=
group
_DATA_PARALLEL_GROUP_GLOO
=
group_gloo
_DATA_PARALLEL_GLOBAL_RANKS
=
ranks
for
ranks
in
generator_wrapper
(
'dp'
):
group
=
torch
.
distributed
.
new_group
(
ranks
,
timeout
=
timeout
,
pg_options
=
get_nccl_options
(
'dp'
,
nccl_comm_cfgs
)
)
group_gloo
=
torch
.
distributed
.
new_group
(
ranks
,
timeout
=
timeout
,
backend
=
"gloo"
)
if
rank
in
ranks
:
_DATA_PARALLEL_GROUP
=
group
_DATA_PARALLEL_GROUP_GLOO
=
group_gloo
_DATA_PARALLEL_GLOBAL_RANKS
=
ranks
for
ranks_with_cp
in
generator_wrapper
(
'dp-cp'
):
group_with_cp
=
torch
.
distributed
.
new_group
(
ranks_with_cp
,
timeout
=
timeout
,
pg_options
=
get_nccl_options
(
'dp_cp'
,
nccl_comm_cfgs
)
)
group_with_cp_gloo
=
torch
.
distributed
.
new_group
(
ranks_with_cp
,
timeout
=
timeout
,
backend
=
"gloo"
)
if
rank
in
ranks_with_cp
:
_DATA_PARALLEL_GROUP_WITH_CP
=
group_with_cp
_DATA_PARALLEL_GROUP_WITH_CP_GLOO
=
group_with_cp_gloo
_DATA_PARALLEL_GLOBAL_RANKS_WITH_CP
=
ranks_with_cp
# Apply SHARP to DP process groups
if
use_sharp
:
...
...
@@ -194,33 +674,59 @@ def initialize_model_parallel(
"`#SBATCH_NETWORK=sharp` should be set in the sbatch script."
)
torch
.
distributed
.
barrier
(
group
=
get_data_parallel_group
(),
device_ids
=
[
torch
.
cuda
.
current_device
()]
group
=
get_data_parallel_group
(
with_context_parallel
=
True
),
device_ids
=
[
torch
.
cuda
.
current_device
()],
)
# Set `NCCL_SHARP_DISABLE=1` to restrict SHARP application to DP process groups
os
.
environ
[
"NCCL_SHARP_DISABLE"
]
=
"1"
# Set `NCCL_COLLNET_ENABLE=0` to restrict SHARP application to DP process groups
os
.
environ
[
"NCCL_COLLNET_ENABLE"
]
=
"0"
# Build the context-parallel groups.
global
_CONTEXT_PARALLEL_GROUP
global
_CONTEXT_PARALLEL_GLOBAL_RANKS
assert
_CONTEXT_PARALLEL_GROUP
is
None
,
'context parallel group is already initialized'
for
ranks
in
generator_wrapper
(
'cp'
):
group
=
torch
.
distributed
.
new_group
(
ranks
,
timeout
=
timeout
,
pg_options
=
get_nccl_options
(
'cp'
,
nccl_comm_cfgs
)
)
if
rank
in
ranks
:
_CONTEXT_PARALLEL_GROUP
=
group
_CONTEXT_PARALLEL_GLOBAL_RANKS
=
ranks
# Build the model-parallel groups.
global
_MODEL_PARALLEL_GROUP
assert
_MODEL_PARALLEL_GROUP
is
None
,
'model parallel group is already initialized'
for
i
in
range
(
data_parallel_size
):
ranks
=
[
data_parallel_group_ranks
[
i
]
for
data_parallel_group_ranks
in
all_data_parallel_group_ranks
]
group
=
torch
.
distributed
.
new_group
(
ranks
)
for
ranks
in
generator_wrapper
(
'tp-pp'
):
group
=
torch
.
distributed
.
new_group
(
ranks
,
timeout
=
timeout
,
pg_options
=
get_nccl_options
(
'mp'
,
nccl_comm_cfgs
)
)
if
rank
in
ranks
:
_MODEL_PARALLEL_GROUP
=
group
# Build the model-parallel groups with expert parallel
global
_MODEL_AND_EXPERT_PARALLEL_GROUP
assert
(
_MODEL_AND_EXPERT_PARALLEL_GROUP
is
None
),
'model and expert parallel group is already initialized'
for
ranks
in
generator_wrapper
(
'tp-ep-pp'
,
independent_ep
=
True
):
group
=
torch
.
distributed
.
new_group
(
ranks
,
timeout
=
timeout
,
pg_options
=
get_nccl_options
(
'mp_exp'
,
nccl_comm_cfgs
)
)
if
rank
in
ranks
:
_MODEL_AND_EXPERT_PARALLEL_GROUP
=
group
# Build the tensor model-parallel groups.
global
_TENSOR_MODEL_PARALLEL_GROUP
global
_TENSOR_MODEL_PARALLEL_GLOBAL_RANKS
assert
(
_TENSOR_MODEL_PARALLEL_GROUP
is
None
),
'tensor model parallel group is already initialized'
for
i
in
range
(
num_tensor_model_parallel_groups
):
ranks
=
range
(
i
*
tensor_model_parallel_size
,
(
i
+
1
)
*
tensor_model_parallel_size
)
group
=
torch
.
distributed
.
new_group
(
ranks
)
for
ranks
in
generator_wrapper
(
'tp'
):
group
=
torch
.
distributed
.
new_group
(
ranks
,
timeout
=
timeout
,
pg_options
=
get_nccl_options
(
'tp'
,
nccl_comm_cfgs
)
)
if
rank
in
ranks
:
_TENSOR_MODEL_PARALLEL_GROUP
=
group
_TENSOR_MODEL_PARALLEL_GLOBAL_RANKS
=
ranks
# Build the pipeline model-parallel groups and embedding groups
# (first and last rank in each pipeline model-parallel group).
...
...
@@ -235,55 +741,125 @@ def initialize_model_parallel(
global
_POSITION_EMBEDDING_GROUP
global
_POSITION_EMBEDDING_GLOBAL_RANKS
assert
_POSITION_EMBEDDING_GROUP
is
None
,
'position embedding group is already initialized'
for
i
in
range
(
num_pipeline_model_parallel_groups
):
ranks
=
range
(
i
,
world_size
,
num_pipeline_model_parallel_groups
)
group
=
torch
.
distributed
.
new_group
(
ranks
)
for
ranks
in
generator_wrapper
(
'pp'
):
group
=
torch
.
distributed
.
new_group
(
ranks
,
timeout
=
timeout
,
pg_options
=
get_nccl_options
(
'pp'
,
nccl_comm_cfgs
)
)
if
rank
in
ranks
:
_PIPELINE_MODEL_PARALLEL_GROUP
=
group
_PIPELINE_GLOBAL_RANKS
=
ranks
# Setup embedding group (to exchange gradients between
# first and last stages).
if
len
(
ranks
)
>
1
:
embedding_ranks
=
[
ranks
[
0
],
ranks
[
-
1
]]
position_embedding_ranks
=
[
ranks
[
0
]]
if
pipeline_model_parallel_split_rank
is
not
None
:
if
ranks
[
pipeline_model_parallel_split_rank
]
not
in
embedding_ranks
:
embedding_ranks
=
[
ranks
[
0
],
ranks
[
pipeline_model_parallel_split_rank
],
ranks
[
-
1
],
]
if
ranks
[
pipeline_model_parallel_split_rank
]
not
in
position_embedding_ranks
:
position_embedding_ranks
=
[
ranks
[
0
],
ranks
[
pipeline_model_parallel_split_rank
]]
else
:
embedding_ranks
=
ranks
position_embedding_ranks
=
ranks
group
=
torch
.
distributed
.
new_group
(
embedding_ranks
)
if
_PIPELINE_MODEL_PARALLEL_GROUP
is
None
:
_PIPELINE_MODEL_PARALLEL_GROUP
=
group
_PIPELINE_GLOBAL_RANKS
=
ranks
elif
isinstance
(
_PIPELINE_GLOBAL_RANKS
[
0
],
list
):
_PIPELINE_MODEL_PARALLEL_GROUP
.
append
(
group
)
_PIPELINE_GLOBAL_RANKS
.
append
(
ranks
)
else
:
_PIPELINE_MODEL_PARALLEL_GROUP
=
[
_PIPELINE_MODEL_PARALLEL_GROUP
,
group
]
_PIPELINE_GLOBAL_RANKS
=
[
_PIPELINE_GLOBAL_RANKS
,
ranks
]
embedding_ranks
=
get_embedding_ranks
(
ranks
)
group
=
torch
.
distributed
.
new_group
(
embedding_ranks
,
timeout
=
timeout
,
pg_options
=
get_nccl_options
(
'embd'
,
nccl_comm_cfgs
)
)
if
rank
in
embedding_ranks
:
_EMBEDDING_GROUP
=
group
if
rank
in
ranks
:
_EMBEDDING_GLOBAL_RANKS
=
embedding_ranks
group
=
torch
.
distributed
.
new_group
(
position_embedding_ranks
)
position_embedding_ranks
=
get_position_embedding_ranks
(
ranks
)
group
=
torch
.
distributed
.
new_group
(
position_embedding_ranks
,
timeout
=
timeout
,
pg_options
=
get_nccl_options
(
'embd'
,
nccl_comm_cfgs
),
)
if
rank
in
position_embedding_ranks
:
_POSITION_EMBEDDING_GROUP
=
group
if
rank
in
ranks
:
_POSITION_EMBEDDING_GLOBAL_RANKS
=
position_embedding_ranks
# Build the FP8 groups.
global
_AMAX_REDUCTION_GROUP
assert
_AMAX_REDUCTION_GROUP
is
None
,
'FP8 amax reduction group is already initialized'
if
use_fp8
:
amax_group_size
:
int
=
tensor_model_parallel_size
*
data_parallel_size
num_amax_groups
:
int
=
world_size
//
amax_group_size
for
i
in
range
(
num_amax_groups
):
start_rank
=
i
*
amax_group_size
end_rank
=
(
i
+
1
)
*
amax_group_size
ranks
=
range
(
start_rank
,
end_rank
)
group
=
torch
.
distributed
.
new_group
(
ranks
)
if
rank
in
ranks
:
_AMAX_REDUCTION_GROUP
=
group
# Build the tensor + data parallel groups.
global
_TENSOR_AND_DATA_PARALLEL_GROUP
global
_TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP
assert
(
_TENSOR_AND_DATA_PARALLEL_GROUP
is
None
),
'Tensor + data parallel group is already initialized'
for
ranks
in
generator_wrapper
(
'tp-dp-cp'
):
group
=
torch
.
distributed
.
new_group
(
ranks
,
timeout
=
timeout
,
pg_options
=
get_nccl_options
(
'tp_dp_cp'
,
nccl_comm_cfgs
)
)
if
rank
in
ranks
:
_TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP
=
group
for
ranks
in
generator_wrapper
(
'tp-dp'
):
group
=
torch
.
distributed
.
new_group
(
ranks
,
timeout
=
timeout
,
pg_options
=
get_nccl_options
(
'tp_dp'
,
nccl_comm_cfgs
)
)
if
rank
in
ranks
:
_TENSOR_AND_DATA_PARALLEL_GROUP
=
group
global
_TENSOR_AND_CONTEXT_PARALLEL_GROUP
assert
(
_TENSOR_AND_CONTEXT_PARALLEL_GROUP
is
None
),
'Tensor + context parallel group is already initialized'
for
ranks
in
generator_wrapper
(
'tp-cp'
):
group
=
torch
.
distributed
.
new_group
(
ranks
,
timeout
=
timeout
,
pg_options
=
get_nccl_options
(
'tp_cp'
,
nccl_comm_cfgs
)
)
if
rank
in
ranks
:
_TENSOR_AND_CONTEXT_PARALLEL_GROUP
=
group
# Build the tensor + expert parallel groups
global
_EXPERT_MODEL_PARALLEL_GROUP
assert
_EXPERT_MODEL_PARALLEL_GROUP
is
None
,
'Expert parallel group is already initialized'
global
_TENSOR_AND_EXPERT_PARALLEL_GROUP
assert
(
_TENSOR_AND_EXPERT_PARALLEL_GROUP
is
None
),
'Tensor + expert parallel group is already initialized'
global
_DATA_MODULO_EXPERT_PARALLEL_GROUP
assert
(
_DATA_MODULO_EXPERT_PARALLEL_GROUP
is
None
),
'Data modulo expert group is already initialized'
global
_DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP
assert
(
_DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP
is
None
),
'Data modulo expert group with context parallel is already initialized'
global
_DATA_MODULO_EXPERT_PARALLEL_GROUP_GLOO
global
_DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP_GLOO
for
ranks
in
generator_wrapper
(
'tp-ep'
,
independent_ep
=
True
):
group
=
torch
.
distributed
.
new_group
(
ranks
,
timeout
=
timeout
,
pg_options
=
get_nccl_options
(
'tp_exp'
,
nccl_comm_cfgs
)
)
if
rank
in
ranks
:
_TENSOR_AND_EXPERT_PARALLEL_GROUP
=
group
for
ranks
in
generator_wrapper
(
'ep'
,
independent_ep
=
True
):
group
=
torch
.
distributed
.
new_group
(
ranks
,
pg_options
=
get_nccl_options
(
'exp'
,
nccl_comm_cfgs
)
)
if
rank
in
ranks
:
_EXPERT_MODEL_PARALLEL_GROUP
=
group
for
ranks
in
generator_wrapper
(
'dp'
,
independent_ep
=
True
):
group
=
torch
.
distributed
.
new_group
(
ranks
,
timeout
=
timeout
,
pg_options
=
get_nccl_options
(
'dp_modulo_exp'
,
nccl_comm_cfgs
)
)
group_gloo
=
torch
.
distributed
.
new_group
(
ranks
,
backend
=
"gloo"
)
if
rank
in
ranks
:
_DATA_MODULO_EXPERT_PARALLEL_GROUP
=
group
_DATA_MODULO_EXPERT_PARALLEL_GROUP_GLOO
=
group_gloo
for
ranks
in
generator_wrapper
(
'dp-cp'
,
independent_ep
=
True
):
# Lazy initialization of the group
if
get_context_parallel_world_size
()
>
1
:
group
=
torch
.
distributed
.
new_group
(
ranks
,
timeout
=
timeout
,
pg_options
=
get_nccl_options
(
'dp_modulo_exp_cp'
,
nccl_comm_cfgs
),
)
group_gloo
=
torch
.
distributed
.
new_group
(
ranks
,
backend
=
"gloo"
)
else
:
group
=
_DATA_MODULO_EXPERT_PARALLEL_GROUP
group_gloo
=
_DATA_MODULO_EXPERT_PARALLEL_GROUP_GLOO
if
rank
in
ranks
:
_DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP
=
group
_DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP_GLOO
=
group_gloo
# Initialize global memory buffer
# This isn't really "parallel state" but there isn't another good place to
...
...
@@ -292,13 +868,23 @@ def initialize_model_parallel(
_set_global_memory_buffer
()
def
is_
u
nitialized
():
def
is_
i
nitialized
():
"""Useful for code segments that may be accessed with or without mpu initialization"""
return
_DATA_PARALLEL_GROUP
is
None
return
_DATA_PARALLEL_GROUP
is
not
None
def
is_unitialized
()
->
bool
:
"""Check if parallel state has been initialized
Deprecated. Use is_initialized instead.
"""
warnings
.
warn
(
"is_unitialized is deprecated, use is_initialized instead"
,
DeprecationWarning
)
return
not
is_initialized
()
def
model_parallel_is_initialized
():
"""Check if model and data
parallel groups are initialized."""
"""Check if model
-
and data
-
parallel groups are initialized."""
if
(
_TENSOR_MODEL_PARALLEL_GROUP
is
None
or
_PIPELINE_MODEL_PARALLEL_GROUP
is
None
...
...
@@ -308,14 +894,19 @@ def model_parallel_is_initialized():
return
True
def
get_model_parallel_group
():
"""Get the model parallel group the caller rank belongs to."""
def
get_model_parallel_group
(
with_expert_parallel
=
False
):
"""Get the model-parallel group the caller rank belongs to."""
if
with_expert_parallel
:
assert
(
_MODEL_AND_EXPERT_PARALLEL_GROUP
is
not
None
),
'model parallel group is not initialized'
return
_MODEL_AND_EXPERT_PARALLEL_GROUP
assert
_MODEL_PARALLEL_GROUP
is
not
None
,
'model parallel group is not initialized'
return
_MODEL_PARALLEL_GROUP
def
get_tensor_model_parallel_group
(
check_initialized
=
True
):
"""Get the tensor
model
parallel group the caller rank belongs to."""
"""Get the tensor
-
model
-
parallel group the caller rank belongs to."""
if
check_initialized
:
assert
(
_TENSOR_MODEL_PARALLEL_GROUP
is
not
None
...
...
@@ -324,23 +915,51 @@ def get_tensor_model_parallel_group(check_initialized=True):
def
get_pipeline_model_parallel_group
():
"""Get the pipeline
model
parallel group the caller rank belongs to."""
"""Get the pipeline
-
model
-
parallel group the caller rank belongs to."""
assert
(
_PIPELINE_MODEL_PARALLEL_GROUP
is
not
None
),
'pipeline_model parallel group is not initialized'
return
_PIPELINE_MODEL_PARALLEL_GROUP
def
get_data_parallel_group
():
"""Get the data parallel group the caller rank belongs to."""
assert
_DATA_PARALLEL_GROUP
is
not
None
,
'data parallel group is not initialized'
return
_DATA_PARALLEL_GROUP
def
get_data_parallel_group
(
with_context_parallel
=
False
):
"""Get the data-parallel group the caller rank belongs to."""
if
with_context_parallel
:
assert
(
_DATA_PARALLEL_GROUP_WITH_CP
is
not
None
),
'data parallel group with context parallel combined is not initialized'
return
_DATA_PARALLEL_GROUP_WITH_CP
else
:
assert
_DATA_PARALLEL_GROUP
is
not
None
,
'data parallel group is not initialized'
return
_DATA_PARALLEL_GROUP
def
get_data_parallel_group_gloo
(
with_context_parallel
=
False
):
"""Get the Gloo data-parallel group the caller rank belongs to."""
if
with_context_parallel
:
assert
(
_DATA_PARALLEL_GROUP_WITH_CP_GLOO
is
not
None
),
'data parallel group-gloo with context parallel combined is not initialized'
return
_DATA_PARALLEL_GROUP_WITH_CP_GLOO
else
:
assert
_DATA_PARALLEL_GROUP_GLOO
is
not
None
,
'data parallel group-gloo is not initialized'
return
_DATA_PARALLEL_GROUP_GLOO
def
get_data_parallel_group_gloo
():
"""Get the data parallel group-gloo the caller rank belongs to."""
assert
_DATA_PARALLEL_GROUP_GLOO
is
not
None
,
'data parallel group-gloo is not initialized'
return
_DATA_PARALLEL_GROUP_GLOO
def
get_context_parallel_group
(
check_initialized
=
True
):
"""Get the context-parallel group the caller rank belongs to."""
if
check_initialized
:
assert
_CONTEXT_PARALLEL_GROUP
is
not
None
,
'context parallel group is not initialized'
return
_CONTEXT_PARALLEL_GROUP
def
get_context_parallel_global_ranks
(
check_initialized
=
True
):
"""Get all global ranks of the context-parallel group that the caller rank belongs to."""
if
check_initialized
:
assert
(
_CONTEXT_PARALLEL_GLOBAL_RANKS
is
not
None
),
'context parallel group is not initialized'
return
_CONTEXT_PARALLEL_GLOBAL_RANKS
def
get_embedding_group
():
...
...
@@ -355,32 +974,124 @@ def get_position_embedding_group():
return
_POSITION_EMBEDDING_GROUP
def
get_amax_reduction_group
():
def
get_amax_reduction_group
(
with_context_parallel
=
False
,
tp_only_amax_red
=
False
):
"""Get the FP8 amax reduction group the caller rank belongs to."""
assert
_AMAX_REDUCTION_GROUP
is
not
None
,
'FP8 amax reduction group is not initialized'
return
_AMAX_REDUCTION_GROUP
if
with_context_parallel
:
if
not
tp_only_amax_red
:
assert
(
_TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP
is
not
None
),
'FP8 amax reduction group is not initialized'
return
_TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP
else
:
assert
(
_TENSOR_AND_CONTEXT_PARALLEL_GROUP
is
not
None
),
'FP8 amax reduction group is not initialized'
return
_TENSOR_AND_CONTEXT_PARALLEL_GROUP
else
:
if
not
tp_only_amax_red
:
assert
(
_TENSOR_AND_DATA_PARALLEL_GROUP
is
not
None
),
'FP8 amax reduction group is not initialized'
return
_TENSOR_AND_DATA_PARALLEL_GROUP
else
:
assert
(
_TENSOR_MODEL_PARALLEL_GROUP
is
not
None
),
'FP8 amax reduction group is not initialized'
return
_TENSOR_MODEL_PARALLEL_GROUP
def
get_tensor_and_data_parallel_group
(
with_context_parallel
=
False
):
"""Get the tensor- and data-parallel group the caller rank belongs to."""
if
with_context_parallel
:
assert
(
_TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP
is
not
None
),
'tensor and data parallel group is not initialized'
return
_TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP
else
:
assert
(
_TENSOR_AND_DATA_PARALLEL_GROUP
is
not
None
),
'tensor and data parallel group is not initialized'
return
_TENSOR_AND_DATA_PARALLEL_GROUP
def
get_tensor_and_context_parallel_group
():
"""Get the tensor- and context-parallel group the caller rank belongs to."""
assert
(
_TENSOR_AND_CONTEXT_PARALLEL_GROUP
is
not
None
),
'tensor and context parallel group is not initialized'
return
_TENSOR_AND_CONTEXT_PARALLEL_GROUP
def
get_expert_model_parallel_group
():
"""Get the expert-model-parallel group the caller rank belongs to."""
assert
(
_EXPERT_MODEL_PARALLEL_GROUP
is
not
None
),
'expert model parallel group is not initialized'
return
_EXPERT_MODEL_PARALLEL_GROUP
def
get_tensor_and_expert_parallel_group
():
"""Get the tensor- and expert-parallel group the caller rank belongs to."""
assert
(
_TENSOR_AND_EXPERT_PARALLEL_GROUP
is
not
None
),
'tensor and expert parallel group is not initialized'
return
_TENSOR_AND_EXPERT_PARALLEL_GROUP
def
get_data_modulo_expert_parallel_group
(
with_context_parallel
=
False
):
"""Get the data-modulo-expert-parallel group the caller rank belongs to."""
if
with_context_parallel
:
assert
(
_DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP
is
not
None
),
'data modulo expert parallel group with context parallel is not initialized'
return
_DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP
else
:
assert
(
_DATA_MODULO_EXPERT_PARALLEL_GROUP
is
not
None
),
'data modulo expert parallel group is not initialized'
return
_DATA_MODULO_EXPERT_PARALLEL_GROUP
def
get_data_modulo_expert_parallel_group_gloo
(
with_context_parallel
=
False
):
"""Get the Gloo data-modulo-expert-parallel group the caller rank belongs to."""
if
with_context_parallel
:
assert
(
_DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP_GLOO
is
not
None
),
'data modulo expert parallel group-gloo with context parallel is not initialized'
return
_DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP_GLOO
else
:
assert
(
_DATA_MODULO_EXPERT_PARALLEL_GROUP_GLOO
is
not
None
),
'data modulo expert parallel group-gloo is not initialized'
return
_DATA_MODULO_EXPERT_PARALLEL_GROUP_GLOO
def
set_expert_model_parallel_world_size
(
world_size
):
"""Sets the expert-model-parallel world size."""
global
_MPU_EXPERT_MODEL_PARALLEL_WORLD_SIZE
_MPU_EXPERT_MODEL_PARALLEL_WORLD_SIZE
=
world_size
def
set_tensor_model_parallel_world_size
(
world_size
):
"""Set the tensor
model
parallel size"""
"""Set the tensor
-
model
-
parallel size"""
global
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
=
world_size
def
set_pipeline_model_parallel_world_size
(
world_size
):
"""Set the pipeline
model
parallel size"""
"""Set the pipeline
-
model
-
parallel size"""
global
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
=
world_size
def
set_virtual_pipeline_model_parallel_world_size
(
world_size
):
"""Set the pipeline
model
parallel size"""
"""Set the pipeline
-
model
-
parallel size"""
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
=
world_size
def
get_tensor_model_parallel_world_size
():
"""Return world size for the tensor
model
parallel group."""
"""Return world size for the tensor
-
model
-
parallel group."""
global
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
if
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
is
not
None
:
return
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
...
...
@@ -388,33 +1099,49 @@ def get_tensor_model_parallel_world_size():
def
get_pipeline_model_parallel_world_size
():
"""Return world size for the pipeline
model
parallel group."""
"""Return world size for the pipeline
-
model
-
parallel group."""
global
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
if
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
is
not
None
:
return
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
return
torch
.
distributed
.
get_world_size
(
group
=
get_pipeline_model_parallel_group
())
pp_group
=
get_pipeline_model_parallel_group
()
if
isinstance
(
pp_group
,
list
):
# Implicit assumption that each PP group is the same size.
sizes
=
[]
for
group
in
_PIPELINE_GLOBAL_RANKS
:
sizes
.
append
(
len
(
group
))
assert
all
(
x
==
sizes
[
0
]
for
x
in
sizes
)
return
torch
.
distributed
.
get_world_size
(
group
=
pp_group
[
0
])
else
:
return
torch
.
distributed
.
get_world_size
(
group
=
pp_group
)
def
set_expert_model_parallel_rank
(
rank
):
"""Set expert-model-parallel rank."""
global
_MPU_EXPERT_MODEL_PARALLEL_RANK
_MPU_EXPERT_MODEL_PARALLEL_RANK
=
rank
def
set_tensor_model_parallel_rank
(
rank
):
"""Set tensor
model
parallel rank."""
"""Set tensor
-
model
-
parallel rank."""
global
_MPU_TENSOR_MODEL_PARALLEL_RANK
_MPU_TENSOR_MODEL_PARALLEL_RANK
=
rank
def
set_pipeline_model_parallel_rank
(
rank
):
"""Set pipeline
model
parallel rank."""
"""Set pipeline
-
model
-
parallel rank."""
global
_MPU_PIPELINE_MODEL_PARALLEL_RANK
_MPU_PIPELINE_MODEL_PARALLEL_RANK
=
rank
def
set_pipeline_model_parallel_split_rank
(
rank
):
"""Set pipeline
model
parallel split rank."""
"""Set pipeline
-
model
-
parallel split rank.
DEPRECATED.
"""
global
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
=
rank
def
get_tensor_model_parallel_rank
():
"""Return
my
rank for the tensor
model
parallel group."""
"""Return
caller's
rank for the tensor
-
model
-
parallel group."""
global
_MPU_TENSOR_MODEL_PARALLEL_RANK
if
_MPU_TENSOR_MODEL_PARALLEL_RANK
is
not
None
:
return
_MPU_TENSOR_MODEL_PARALLEL_RANK
...
...
@@ -422,15 +1149,27 @@ def get_tensor_model_parallel_rank():
def
get_pipeline_model_parallel_rank
():
"""Return
my
rank for the pipeline
model
parallel group."""
"""Return
caller's
rank for the pipeline
-
model
-
parallel group."""
global
_MPU_PIPELINE_MODEL_PARALLEL_RANK
if
_MPU_PIPELINE_MODEL_PARALLEL_RANK
is
not
None
:
return
_MPU_PIPELINE_MODEL_PARALLEL_RANK
return
torch
.
distributed
.
get_rank
(
group
=
get_pipeline_model_parallel_group
())
rank
=
torch
.
distributed
.
get_rank
()
pp_group
=
get_pipeline_model_parallel_group
()
if
isinstance
(
pp_group
,
list
):
# Assume that if the caller exist in multiple PP groups, then it has the same index.
indices
=
[]
for
group
in
_PIPELINE_GLOBAL_RANKS
:
for
i
,
r
in
enumerate
(
group
):
if
r
==
rank
:
indices
.
append
(
i
)
assert
all
(
x
==
indices
[
0
]
for
x
in
indices
)
return
torch
.
distributed
.
get_rank
(
group
=
pp_group
[
0
])
else
:
return
torch
.
distributed
.
get_rank
(
group
=
pp_group
)
def
get_pipeline_model_parallel_split_rank
():
"""Return pipeline
model
parallel split rank."""
"""Return pipeline
-
model
-
parallel split rank."""
global
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
return
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
...
...
@@ -447,13 +1186,15 @@ def is_pipeline_first_stage(ignore_virtual=False):
def
is_pipeline_last_stage
(
ignore_virtual
=
False
):
"""Return True if in the last pipeline
model-parallel stage, False otherwise."""
"""Return True if in the last pipeline
-
model-parallel stage, False otherwise."""
if
not
ignore_virtual
:
virtual_pipeline_model_parallel_world_size
=
(
get_virtual_pipeline_model_parallel_world_size
()
)
if
virtual_pipeline_model_parallel_world_size
is
not
None
and
get_virtual_pipeline_model_parallel_rank
()
!=
(
virtual_pipeline_model_parallel_world_size
-
1
if
(
virtual_pipeline_model_parallel_world_size
is
not
None
and
get_virtual_pipeline_model_parallel_rank
()
!=
(
virtual_pipeline_model_parallel_world_size
-
1
)
):
return
False
return
get_pipeline_model_parallel_rank
()
==
(
get_pipeline_model_parallel_world_size
()
-
1
)
...
...
@@ -463,6 +1204,8 @@ def is_rank_in_embedding_group(ignore_virtual=False):
"""Return true if current rank is in embedding group, False otherwise."""
rank
=
torch
.
distributed
.
get_rank
()
global
_EMBEDDING_GLOBAL_RANKS
if
_EMBEDDING_GLOBAL_RANKS
is
None
:
return
False
if
ignore_virtual
:
return
rank
in
_EMBEDDING_GLOBAL_RANKS
if
rank
in
_EMBEDDING_GLOBAL_RANKS
:
...
...
@@ -479,7 +1222,7 @@ def is_rank_in_position_embedding_group():
"""Return true if current rank is in position embedding group, False otherwise."""
rank
=
torch
.
distributed
.
get_rank
()
global
_POSITION_EMBEDDING_GLOBAL_RANKS
return
rank
in
_POSITION_EMBEDDING_GLOBAL_RANKS
return
_POSITION_EMBEDDING_GLOBAL_RANKS
is
not
None
and
rank
in
_POSITION_EMBEDDING_GLOBAL_RANKS
def
is_pipeline_stage_before_split
(
rank
=
None
):
...
...
@@ -512,6 +1255,36 @@ def is_pipeline_stage_after_split(rank=None):
return
False
def
is_inside_encoder
(
rank
=
None
):
"""Return True if pipeline stage executes encoder block for a model
with both encoder and decoder."""
if
get_pipeline_model_parallel_world_size
()
==
1
:
return
True
if
rank
is
None
:
rank
=
get_pipeline_model_parallel_rank
()
global
_PIPELINE_MODEL_PARALLEL_DECODER_START
if
_PIPELINE_MODEL_PARALLEL_DECODER_START
is
None
:
return
True
if
rank
<
_PIPELINE_MODEL_PARALLEL_DECODER_START
:
return
True
return
False
def
is_inside_decoder
(
rank
=
None
):
"""Return True if pipeline stage executes decoder block for a model
with both encoder and decoder."""
if
get_pipeline_model_parallel_world_size
()
==
1
:
return
True
if
rank
is
None
:
rank
=
get_pipeline_model_parallel_rank
()
global
_PIPELINE_MODEL_PARALLEL_DECODER_START
if
_PIPELINE_MODEL_PARALLEL_DECODER_START
is
None
:
return
True
if
rank
>=
_PIPELINE_MODEL_PARALLEL_DECODER_START
:
return
True
return
False
def
is_pipeline_stage_at_split
():
"""Return true if pipeline stage executes decoder block and next
stage executes encoder block for a model with both encoder and
...
...
@@ -541,67 +1314,203 @@ def get_virtual_pipeline_model_parallel_world_size():
def
get_tensor_model_parallel_src_rank
():
"""Calculate the global rank corresponding to the first local rank
in the tensor model parallel group."""
global_rank
=
torch
.
distributed
.
get_rank
()
local_world_size
=
get_tensor_model_parallel_world_size
()
return
(
global_rank
//
local_world_size
)
*
local_world_size
assert
(
_TENSOR_MODEL_PARALLEL_GLOBAL_RANKS
is
not
None
),
"Tensor model parallel group is not initialized"
return
_TENSOR_MODEL_PARALLEL_GLOBAL_RANKS
[
0
]
def
get_data_parallel_src_rank
():
def
get_data_parallel_src_rank
(
with_context_parallel
=
False
):
"""Calculate the global rank corresponding to the first local rank
in the data parallel group."""
assert
_DATA_PARALLEL_GLOBAL_RANKS
is
not
None
,
"Data parallel group is not initialized"
return
_DATA_PARALLEL_GLOBAL_RANKS
[
0
]
if
with_context_parallel
:
assert
(
_DATA_PARALLEL_GLOBAL_RANKS_WITH_CP
is
not
None
),
"Data parallel group with context parallel combined is not initialized"
return
_DATA_PARALLEL_GLOBAL_RANKS_WITH_CP
[
0
]
else
:
assert
_DATA_PARALLEL_GLOBAL_RANKS
is
not
None
,
"Data parallel group is not initialized"
return
_DATA_PARALLEL_GLOBAL_RANKS
[
0
]
def
get_pipeline_model_parallel_first_rank
():
"""Return the global rank of the first process in the pipeline for the
current tensor parallel group"""
"""Return the global rank of the first stage in the current rank's pipeline."""
assert
_PIPELINE_GLOBAL_RANKS
is
not
None
,
"Pipeline parallel group is not initialized"
return
_PIPELINE_GLOBAL_RANKS
[
0
]
if
isinstance
(
_PIPELINE_GLOBAL_RANKS
[
0
],
list
):
# I assume the first rank is the same for all pp groups right now.
for
rank_group
in
_PIPELINE_GLOBAL_RANKS
:
assert
rank_group
[
0
]
==
_PIPELINE_GLOBAL_RANKS
[
0
][
0
]
return
_PIPELINE_GLOBAL_RANKS
[
0
][
0
]
else
:
return
_PIPELINE_GLOBAL_RANKS
[
0
]
def
get_pipeline_model_parallel_last_rank
():
"""Return the global rank of the last process in the pipeline for the
current tensor parallel group"""
"""Return the global rank of the last stage in the current rank's pipeline."""
assert
_PIPELINE_GLOBAL_RANKS
is
not
None
,
"Pipeline parallel group is not initialized"
last_rank_local
=
get_pipeline_model_parallel_world_size
()
-
1
return
_PIPELINE_GLOBAL_RANKS
[
last_rank_local
]
def
get_pipeline_model_parallel_next_rank
():
"""Return the global rank that follows the caller in the pipeline"""
"""Return the global rank that follows the caller in the pipeline, for each
pipeline-parallel group that the rank is part of.
If it is just part of one group, an int is returned, otherwise a list of ints.
"""
assert
_PIPELINE_GLOBAL_RANKS
is
not
None
,
"Pipeline parallel group is not initialized"
rank_in_pipeline
=
get_pipeline_model_parallel_rank
()
world_size
=
get_pipeline_model_parallel_world_size
()
return
_PIPELINE_GLOBAL_RANKS
[(
rank_in_pipeline
+
1
)
%
world_size
]
if
isinstance
(
_PIPELINE_GLOBAL_RANKS
[
0
],
list
):
to_return
=
[]
for
group
in
_PIPELINE_GLOBAL_RANKS
:
to_return
.
append
(
group
[(
rank_in_pipeline
+
1
)
%
world_size
])
return
to_return
else
:
return
_PIPELINE_GLOBAL_RANKS
[(
rank_in_pipeline
+
1
)
%
world_size
]
def
get_pipeline_model_parallel_prev_rank
():
"""Return the global rank that preceeds the caller in the pipeline"""
"""Return the global rank that precedes the caller in the pipeline, for each
pipeline-parallel group that the rank is part of.
If it is just part of one group, an int is returned, otherwise a list of ints.
"""
assert
_PIPELINE_GLOBAL_RANKS
is
not
None
,
"Pipeline parallel group is not initialized"
rank_in_pipeline
=
get_pipeline_model_parallel_rank
()
world_size
=
get_pipeline_model_parallel_world_size
()
return
_PIPELINE_GLOBAL_RANKS
[(
rank_in_pipeline
-
1
)
%
world_size
]
if
isinstance
(
_PIPELINE_GLOBAL_RANKS
[
0
],
list
):
to_return
=
[]
for
group
in
_PIPELINE_GLOBAL_RANKS
:
to_return
.
append
(
group
[(
rank_in_pipeline
-
1
)
%
world_size
])
return
to_return
else
:
return
_PIPELINE_GLOBAL_RANKS
[(
rank_in_pipeline
-
1
)
%
world_size
]
def
get_data_parallel_world_size
():
def
get_data_parallel_world_size
(
with_context_parallel
=
False
):
"""Return world size for the data parallel group."""
global
_MPU_DATA_PARALLEL_WORLD_SIZE
if
_MPU_DATA_PARALLEL_WORLD_SIZE
is
not
None
:
return
_MPU_DATA_PARALLEL_WORLD_SIZE
if
torch
.
distributed
.
is_available
()
and
torch
.
distributed
.
is_initialized
():
return
torch
.
distributed
.
get_world_size
(
group
=
get_data_parallel_group
(
with_context_parallel
=
with_context_parallel
)
)
else
:
return
0
def
set_data_parallel_rank
(
rank
):
"""Return world size for the data parallel group."""
global
_MPU_DATA_PARALLEL_RANK
_MPU_DATA_PARALLEL_RANK
=
rank
def
get_data_parallel_rank
(
with_context_parallel
=
False
):
"""Return caller's rank in the data-parallel group."""
global
_MPU_DATA_PARALLEL_RANK
if
_MPU_DATA_PARALLEL_RANK
is
not
None
:
return
_MPU_DATA_PARALLEL_RANK
if
torch
.
distributed
.
is_available
()
and
torch
.
distributed
.
is_initialized
():
return
torch
.
distributed
.
get_rank
(
group
=
get_data_parallel_group
(
with_context_parallel
=
with_context_parallel
)
)
else
:
return
0
def
get_context_parallel_world_size
():
"""Return world size for the context parallel group."""
if
torch
.
distributed
.
is_available
()
and
torch
.
distributed
.
is_initialized
():
return
torch
.
distributed
.
get_world_size
(
group
=
get_context_parallel_group
())
else
:
return
0
def
get_context_parallel_rank
():
"""Return caller's rank in the context-parallel group."""
if
torch
.
distributed
.
is_available
()
and
torch
.
distributed
.
is_initialized
():
return
torch
.
distributed
.
get_
world_size
(
group
=
get_
data
_parallel_group
())
return
torch
.
distributed
.
get_
rank
(
group
=
get_
context
_parallel_group
())
else
:
return
0
def
get_
data_parallel_rank
():
"""Return
my rank for the data
parallel group."""
def
get_
tensor_and_context_parallel_world_size
():
"""Return
world size for the tensor and context-
parallel group."""
if
torch
.
distributed
.
is_available
()
and
torch
.
distributed
.
is_initialized
():
return
torch
.
distributed
.
get_rank
(
group
=
get_data_parallel_group
())
return
torch
.
distributed
.
get_world_size
(
group
=
get_tensor_and_context_parallel_group
())
else
:
return
0
def
get_tensor_and_context_parallel_rank
():
"""Return caller's rank in the joint tensor-model-parallel and context-parallel group."""
if
torch
.
distributed
.
is_available
()
and
torch
.
distributed
.
is_initialized
():
return
torch
.
distributed
.
get_rank
(
group
=
get_tensor_and_context_parallel_group
())
else
:
return
0
def
get_expert_model_parallel_world_size
():
"""Return world size for the expert-model-parallel group."""
if
_MPU_EXPERT_MODEL_PARALLEL_WORLD_SIZE
is
not
None
:
return
_MPU_EXPERT_MODEL_PARALLEL_WORLD_SIZE
if
torch
.
distributed
.
is_available
()
and
torch
.
distributed
.
is_initialized
():
tensor_and_expert_parallel_world_size
=
torch
.
distributed
.
get_world_size
(
group
=
get_tensor_and_expert_parallel_group
()
)
return
tensor_and_expert_parallel_world_size
//
get_tensor_model_parallel_world_size
()
else
:
return
0
def
get_tensor_and_expert_parallel_world_size
():
"""Return world size for the expert model parallel group times model parallel group.
Currently, each expert will also be distributed across TP group by default.
"""
if
torch
.
distributed
.
is_available
()
and
torch
.
distributed
.
is_initialized
():
tensor_and_expert_parallel_world_size
=
torch
.
distributed
.
get_world_size
(
group
=
get_tensor_and_expert_parallel_group
()
)
return
tensor_and_expert_parallel_world_size
else
:
return
0
def
get_expert_model_parallel_rank
():
"""Return caller's rank in the expert-model-parallel group."""
if
_MPU_EXPERT_MODEL_PARALLEL_RANK
is
not
None
:
return
_MPU_EXPERT_MODEL_PARALLEL_RANK
if
torch
.
distributed
.
is_available
()
and
torch
.
distributed
.
is_initialized
():
tensor_and_expert_parallel_rank
=
torch
.
distributed
.
get_rank
(
group
=
get_tensor_and_expert_parallel_group
()
)
return
tensor_and_expert_parallel_rank
//
get_tensor_model_parallel_world_size
()
else
:
return
0
def
get_data_modulo_expert_parallel_rank
(
with_context_parallel
=
False
):
"""Return caller's rank in the context-parallel group."""
if
torch
.
distributed
.
is_available
()
and
torch
.
distributed
.
is_initialized
():
return
torch
.
distributed
.
get_rank
(
group
=
get_data_modulo_expert_parallel_group
(
with_context_parallel
=
with_context_parallel
)
)
else
:
return
0
def
get_tensor_and_expert_parallel_rank
():
"""Return caller's rank in the joint tensor- and expert-model-parallel group."""
if
torch
.
distributed
.
is_available
()
and
torch
.
distributed
.
is_initialized
():
return
torch
.
distributed
.
get_rank
(
group
=
get_tensor_and_expert_parallel_group
())
else
:
return
0
def
_set_global_memory_buffer
():
"""Initialize global buffer"""
"""Initialize global buffer
.
"""
global
_GLOBAL_MEMORY_BUFFER
assert
_GLOBAL_MEMORY_BUFFER
is
None
,
'global memory buffer is already initialized'
_GLOBAL_MEMORY_BUFFER
=
GlobalMemoryBuffer
()
...
...
@@ -619,33 +1528,120 @@ def destroy_global_memory_buffer():
_GLOBAL_MEMORY_BUFFER
=
None
def
get_all_ranks
():
"""Get caller's rank in tensor-model-parallel, data-parallel, context-parallel,
pipeline-model-parallel and expert-model-parallel groups."""
ranks
=
[
get_tensor_model_parallel_rank
(),
get_data_parallel_rank
(),
get_context_parallel_rank
(),
get_pipeline_model_parallel_rank
(),
get_expert_model_parallel_rank
(),
]
return
'_'
.
join
(
map
(
lambda
x
:
str
(
x
or
0
),
ranks
))
def
get_moe_layer_wise_logging_tracker
():
"""Return the moe layer wise tracker."""
global
_MOE_LAYER_WISE_LOGGING_TRACKER
return
_MOE_LAYER_WISE_LOGGING_TRACKER
def
destroy_model_parallel
():
"""Set the groups to none."""
global
_MODEL_PARALLEL_GROUP
_MODEL_PARALLEL_GROUP
=
None
global
_MODEL_AND_EXPERT_PARALLEL_GROUP
_MODEL_AND_EXPERT_PARALLEL_GROUP
=
None
global
_TENSOR_MODEL_PARALLEL_GROUP
_TENSOR_MODEL_PARALLEL_GROUP
=
None
global
_PIPELINE_MODEL_PARALLEL_GROUP
_PIPELINE_MODEL_PARALLEL_GROUP
=
None
global
_DATA_PARALLEL_GROUP
_DATA_PARALLEL_GROUP
=
None
global
_DATA_PARALLEL_GROUP_WITH_CP
_DATA_PARALLEL_GROUP_WITH_CP
=
None
global
_CONTEXT_PARALLEL_GROUP
_CONTEXT_PARALLEL_GROUP
=
None
global
_CONTEXT_PARALLEL_GLOBAL_RANKS
_CONTEXT_PARALLEL_GLOBAL_RANKS
=
None
global
_EMBEDDING_GROUP
_EMBEDDING_GROUP
=
None
global
_POSITION_EMBEDDING_GROUP
_POSITION_EMBEDDING_GROUP
=
None
global
_AMAX_REDUCTION_GROUP
_AMAX_REDUCTION_GROUP
=
None
global
_TENSOR_AND_DATA_PARALLEL_GROUP
_TENSOR_AND_DATA_PARALLEL_GROUP
=
None
global
_TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP
_TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP
=
None
global
_TENSOR_AND_CONTEXT_PARALLEL_GROUP
_TENSOR_AND_CONTEXT_PARALLEL_GROUP
=
None
global
_EXPERT_MODEL_PARALLEL_GROUP
_EXPERT_MODEL_PARALLEL_GROUP
=
None
global
_TENSOR_AND_EXPERT_PARALLEL_GROUP
_TENSOR_AND_EXPERT_PARALLEL_GROUP
=
None
global
_DATA_MODULO_EXPERT_PARALLEL_GROUP
_DATA_MODULO_EXPERT_PARALLEL_GROUP
=
None
global
_DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP
_DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP
=
None
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
=
None
global
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
=
None
global
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
=
None
global
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
=
None
global
_MPU_TENSOR_MODEL_PARALLEL_RANK
_MPU_TENSOR_MODEL_PARALLEL_RANK
=
None
global
_MPU_PIPELINE_MODEL_PARALLEL_RANK
_MPU_PIPELINE_MODEL_PARALLEL_RANK
=
None
global
_GLOBAL_MEMORY_BUFFER
_GLOBAL_MEMORY_BUFFER
=
None
global
_MPU_EXPERT_MODEL_PARALLEL_WORLD_SIZE
_MPU_EXPERT_MODEL_PARALLEL_WORLD_SIZE
=
None
global
_MPU_EXPERT_MODEL_PARALLEL_RANK
_MPU_EXPERT_MODEL_PARALLEL_RANK
=
None
global
_DATA_PARALLEL_GROUP_GLOO
if
_DATA_PARALLEL_GROUP_GLOO
is
not
None
:
torch
.
distributed
.
destroy_process_group
(
_DATA_PARALLEL_GROUP_GLOO
)
_DATA_PARALLEL_GROUP_GLOO
=
None
global
_DATA_PARALLEL_GROUP_WITH_CP_GLOO
_DATA_PARALLEL_GROUP_WITH_CP_GLOO
=
None
global
_DATA_MODULO_EXPERT_PARALLEL_GROUP_GLOO
if
_DATA_MODULO_EXPERT_PARALLEL_GROUP_GLOO
is
not
None
:
torch
.
distributed
.
destroy_process_group
(
_DATA_MODULO_EXPERT_PARALLEL_GROUP_GLOO
)
_DATA_MODULO_EXPERT_PARALLEL_GROUP_GLOO
=
None
global
_DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP_GLOO
_DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP_GLOO
=
None
global
_MOE_LAYER_WISE_LOGGING_TRACKER
_MOE_LAYER_WISE_LOGGING_TRACKER
=
{}
megatron/core/pipeline_parallel/__init__.py
View file @
4b097dee
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from
.schedules
import
get_forward_backward_func
megatron/core/pipeline_parallel/p2p_communication.py
View file @
4b097dee
...
...
@@ -13,6 +13,7 @@ from megatron.core.parallel_state import (
get_pipeline_model_parallel_next_rank
,
get_pipeline_model_parallel_prev_rank
,
get_pipeline_model_parallel_rank
,
get_pipeline_model_parallel_world_size
,
)
# Types
...
...
@@ -25,7 +26,7 @@ def _communicate_shapes(tensor_send_next, tensor_send_prev, recv_prev, recv_next
This is required when the sequence lengths across micro batches
are not uniform.
Takes the following argument
s:
Arg
s:
tensor_send_next: tensor to send to next rank (no tensor sent if
set to None).
tensor_send_prev: tensor to send to prev rank (no tensor sent if
...
...
@@ -123,39 +124,29 @@ def _batched_p2p_ops(
tensor_recv_prev
:
Optional
[
torch
.
Tensor
],
tensor_send_next
:
Optional
[
torch
.
Tensor
],
tensor_recv_next
:
Optional
[
torch
.
Tensor
],
group
:
torch
.
distributed
.
ProcessGroup
group
:
torch
.
distributed
.
ProcessGroup
,
prev_pipeline_rank
:
int
,
next_pipeline_rank
:
int
,
):
ops
=
[]
if
tensor_send_prev
is
not
None
:
send_prev_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
isend
,
tensor_send_prev
,
get_pipeline_model_parallel_prev_rank
(),
group
,
torch
.
distributed
.
isend
,
tensor_send_prev
,
prev_pipeline_rank
,
group
)
ops
.
append
(
send_prev_op
)
if
tensor_recv_prev
is
not
None
:
recv_prev_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
irecv
,
tensor_recv_prev
,
get_pipeline_model_parallel_prev_rank
(),
group
,
torch
.
distributed
.
irecv
,
tensor_recv_prev
,
prev_pipeline_rank
,
group
)
ops
.
append
(
recv_prev_op
)
if
tensor_send_next
is
not
None
:
send_next_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
isend
,
tensor_send_next
,
get_pipeline_model_parallel_next_rank
(),
group
,
torch
.
distributed
.
isend
,
tensor_send_next
,
next_pipeline_rank
,
group
)
ops
.
append
(
send_next_op
)
if
tensor_recv_next
is
not
None
:
recv_next_op
=
torch
.
distributed
.
P2POp
(
torch
.
distributed
.
irecv
,
tensor_recv_next
,
get_pipeline_model_parallel_next_rank
(),
group
,
torch
.
distributed
.
irecv
,
tensor_recv_next
,
next_pipeline_rank
,
group
)
ops
.
append
(
recv_next_op
)
if
len
(
ops
)
>
0
:
...
...
@@ -171,57 +162,69 @@ def _p2p_ops(
tensor_recv_prev
:
Optional
[
torch
.
Tensor
],
tensor_send_next
:
Optional
[
torch
.
Tensor
],
tensor_recv_next
:
Optional
[
torch
.
Tensor
],
group
:
torch
.
distributed
.
ProcessGroup
group
:
torch
.
distributed
.
ProcessGroup
,
prev_pipeline_rank
:
int
,
next_pipeline_rank
:
int
,
):
reqs
=
[]
rank
=
get_pipeline_model_parallel_rank
()
even_send_odd_recv_group
=
group
if
get_pipeline_model_parallel_world_size
()
==
2
:
# Use the global process group for one of the two p2p communications
# to allow the overlap of the independent communications.
# Using the global process group is compatible because the pipeline-parallel
# communications set the source and destination by global rank.
even_recv_odd_send_group
=
torch
.
distributed
.
group
.
WORLD
else
:
even_recv_odd_send_group
=
group
if
get_pipeline_model_parallel_rank
()
%
2
==
0
:
if
tensor_send_next
is
not
None
:
send_next_req
=
torch
.
distributed
.
isend
(
tensor
=
tensor_send_next
,
dst
=
ge
t_pipeline_
model_parallel_next_
rank
()
,
group
=
group
,
tensor
=
tensor_send_next
,
dst
=
nex
t_pipeline_rank
,
group
=
even_send_odd_recv_
group
)
reqs
.
append
(
send_next_req
)
if
tensor_recv_prev
is
not
None
:
recv_prev_req
=
torch
.
distributed
.
irecv
(
tensor
=
tensor_recv_prev
,
src
=
get
_pipeline_
model_parallel_prev_
rank
()
,
group
=
group
,
tensor
=
tensor_recv_prev
,
src
=
prev
_pipeline_rank
,
group
=
even_recv_odd_send_
group
)
reqs
.
append
(
recv_prev_req
)
if
tensor_send_prev
is
not
None
:
send_prev_req
=
torch
.
distributed
.
isend
(
tensor
=
tensor_send_prev
,
dst
=
get
_pipeline_
model_parallel_prev_
rank
()
,
group
=
group
,
tensor
=
tensor_send_prev
,
dst
=
prev
_pipeline_rank
,
group
=
even_send_odd_recv_
group
)
reqs
.
append
(
send_prev_req
)
if
tensor_recv_next
is
not
None
:
recv_next_req
=
torch
.
distributed
.
irecv
(
tensor
=
tensor_recv_next
,
src
=
ge
t_pipeline_
model_parallel_next_
rank
()
,
group
=
group
,
tensor
=
tensor_recv_next
,
src
=
nex
t_pipeline_rank
,
group
=
even_recv_odd_send_
group
)
reqs
.
append
(
recv_next_req
)
else
:
if
tensor_recv_prev
is
not
None
:
recv_prev_req
=
torch
.
distributed
.
irecv
(
tensor
=
tensor_recv_prev
,
src
=
get
_pipeline_
model_parallel_prev_
rank
()
,
group
=
group
,
tensor
=
tensor_recv_prev
,
src
=
prev
_pipeline_rank
,
group
=
even_send_odd_recv_
group
)
reqs
.
append
(
recv_prev_req
)
if
tensor_send_next
is
not
None
:
send_next_req
=
torch
.
distributed
.
isend
(
tensor
=
tensor_send_next
,
dst
=
ge
t_pipeline_
model_parallel_next_
rank
()
,
group
=
group
,
tensor
=
tensor_send_next
,
dst
=
nex
t_pipeline_rank
,
group
=
even_recv_odd_send_
group
)
reqs
.
append
(
send_next_req
)
if
tensor_recv_next
is
not
None
:
recv_next_req
=
torch
.
distributed
.
irecv
(
tensor
=
tensor_recv_next
,
src
=
ge
t_pipeline_
model_parallel_next_
rank
()
,
group
=
group
,
tensor
=
tensor_recv_next
,
src
=
nex
t_pipeline_rank
,
group
=
even_send_odd_recv_
group
)
reqs
.
append
(
recv_next_req
)
if
tensor_send_prev
is
not
None
:
send_prev_req
=
torch
.
distributed
.
isend
(
tensor
=
tensor_send_prev
,
dst
=
get
_pipeline_
model_parallel_prev_
rank
()
,
group
=
group
,
tensor
=
tensor_send_prev
,
dst
=
prev
_pipeline_rank
,
group
=
even_recv_odd_send_
group
)
reqs
.
append
(
send_prev_req
)
return
reqs
...
...
@@ -235,12 +238,12 @@ def _communicate(
recv_next
:
bool
,
tensor_shape
:
Shape
,
config
:
ModelParallelConfig
,
wait_on_reqs
:
bool
=
True
wait_on_reqs
:
bool
=
True
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Communicate tensors between stages. Used as helper method in other
communication methods that are used in megatron/schedules.py.
Arg
ument
s:
Args:
tensor_send_next (torch.Tensor, optional):
Tensor to send to next rank (no tensor sent if None)
...
...
@@ -270,10 +273,8 @@ def _communicate(
"""
# Create placeholder tensors for receive in forward and backward directions
# if needed.
tensor_recv_prev
=
None
tensor_recv_next
=
None
tensor_recv_prev_func
=
None
tensor_recv_next_func
=
None
if
not
config
.
variable_seq_lengths
:
recv_prev_shape
=
tensor_shape
...
...
@@ -283,6 +284,22 @@ def _communicate(
tensor_send_next
,
tensor_send_prev
,
recv_prev
,
recv_next
,
config
)
def
create_tensor_recv_prev
():
return
torch
.
empty
(
recv_prev_shape
,
requires_grad
=
True
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
config
.
pipeline_dtype
,
)
def
create_tensor_recv_next
():
return
torch
.
empty
(
recv_next_shape
,
requires_grad
=
True
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
config
.
pipeline_dtype
,
)
if
recv_prev
:
if
config
.
pipeline_dtype
is
None
:
raise
RuntimeError
(
"pipeline_dtype must be provided if recv_prev is True"
)
...
...
@@ -291,12 +308,8 @@ def _communicate(
"tensor_shape must be specified if recv_prev is True. "
"Common tensor_shape is (seq_length, micro_batch_size, hidden_size)"
)
tensor_recv_prev
=
torch
.
empty
(
recv_prev_shape
,
requires_grad
=
True
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
config
.
pipeline_dtype
,
)
tensor_recv_prev_func
=
create_tensor_recv_prev
if
recv_next
:
if
config
.
pipeline_dtype
is
None
:
raise
RuntimeError
(
"dtype must be provided if recv_next is True"
)
...
...
@@ -305,12 +318,7 @@ def _communicate(
"tensor_shape must be specified if recv_next is True. "
"Common tensor_shape is (seq_length, micro_batch_size, hidden_size)"
)
tensor_recv_next
=
torch
.
empty
(
recv_next_shape
,
requires_grad
=
True
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
config
.
pipeline_dtype
,
)
tensor_recv_next_func
=
create_tensor_recv_next
# Send tensors in both the forward and backward directions as appropriate.
if
config
.
use_ring_exchange_p2p
:
...
...
@@ -326,13 +334,49 @@ def _communicate(
else
:
p2p_func
=
_p2p_ops
reqs
=
p2p_func
(
tensor_send_prev
=
tensor_send_prev
,
tensor_recv_prev
=
tensor_recv_prev
,
tensor_send_next
=
tensor_send_next
,
tensor_recv_next
=
tensor_recv_next
,
group
=
get_pipeline_model_parallel_group
(),
)
# Each rank can now be part of several different pipeline parallel groups
# (specifically, this can occur when encoder tensor parallelism != decoder
# tensor parallelism, and hence a rank in the encoder is going to feed
# several different decoder ranks. We therefore have to receive or send tensors
# from several groups. For convenience, I wrap everything into lists.
pp_group
=
get_pipeline_model_parallel_group
()
next_rank
=
get_pipeline_model_parallel_next_rank
()
prev_rank
=
get_pipeline_model_parallel_prev_rank
()
if
not
isinstance
(
pp_group
,
list
):
pp_group
=
[
pp_group
]
assert
not
isinstance
(
next_rank
,
list
)
next_rank
=
[
next_rank
]
assert
not
isinstance
(
prev_rank
,
list
)
prev_rank
=
[
prev_rank
]
reqs
=
[]
tensor_recv_prev_list
=
[]
tensor_recv_next_list
=
[]
for
group
,
nr
,
pr
in
zip
(
pp_group
,
next_rank
,
prev_rank
):
if
tensor_recv_prev_func
is
not
None
:
tensor_recv_prev
=
tensor_recv_prev_func
()
tensor_recv_prev_list
.
append
(
tensor_recv_prev
)
else
:
tensor_recv_prev
=
None
if
tensor_recv_next_func
is
not
None
:
tensor_recv_next
=
tensor_recv_next_func
()
tensor_recv_next_list
.
append
(
tensor_recv_next
)
else
:
tensor_recv_next
=
None
reqs
.
extend
(
p2p_func
(
tensor_send_prev
=
tensor_send_prev
,
tensor_recv_prev
=
tensor_recv_prev
,
tensor_send_next
=
tensor_send_next
,
tensor_recv_next
=
tensor_recv_next
,
group
=
group
,
prev_pipeline_rank
=
pr
,
next_pipeline_rank
=
nr
,
)
)
if
wait_on_reqs
and
len
(
reqs
)
>
0
:
for
req
in
reqs
:
...
...
@@ -344,12 +388,27 @@ def _communicate(
# User should assert that we have a modern enough PyTorch to not need this
torch
.
cuda
.
synchronize
()
def
_handle_tensor_list
(
x
):
"""This basically handles all the cases that we expect to see. Either the list None,
or it's a singleton (the usual cases, since most ranks only belong to one pipeline group),
or everything returned is None, or everything returned is not None, and it has to be summed
together."""
if
len
(
x
)
==
0
:
return
None
if
len
(
x
)
==
1
:
return
x
[
0
]
if
all
(
xx
is
None
for
xx
in
x
):
return
None
return
torch
.
stack
(
x
,
dim
=
0
).
sum
(
dim
=
0
,
dtype
=
torch
.
float32
).
to
(
x
[
0
].
dtype
)
tensor_recv_prev
=
_handle_tensor_list
(
tensor_recv_prev_list
)
tensor_recv_next
=
_handle_tensor_list
(
tensor_recv_next_list
)
return
tensor_recv_prev
,
tensor_recv_next
,
reqs
def
recv_forward
(
tensor_shape
:
Shape
,
config
:
ModelParallelConfig
)
->
torch
.
Tensor
:
""" Receive tensor from previous rank in pipeline (forward receive).
"""Receive tensor from previous rank in pipeline (forward receive).
See _communicate for argument details.
"""
...
...
megatron/core/pipeline_parallel/schedules.py
View file @
4b097dee
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import
contextlib
from
typing
import
Callable
,
Iterator
,
List
,
Optional
,
Union
from
typing
import
Iterator
,
List
,
Union
import
torch
from
torch.autograd.variable
import
Variable
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
megatron
import
core
from
megatron.core
import
parallel_state
from
megatron.core.enums
import
ModelType
from
megatron.core.pipeline_parallel
import
p2p_communication
from
megatron.core.utils
import
get_attr_wrapped_model
,
get_model_config
,
get_model_type
from
megatron.core.transformer.moe.router
import
MoEAuxLossAutoScaler
from
megatron.core.utils
import
(
drain_embedding_wgrad_compute
,
get_attr_wrapped_model
,
get_model_config
,
get_model_type
,
get_model_xattn
,
)
# Types
Shape
=
Union
[
List
[
int
],
torch
.
Size
]
...
...
@@ -90,6 +95,10 @@ def get_forward_backward_func():
collect_non_loss_data (optional, bool, default=False): TODO
first_val_step (bool, optional): Is the first step of the validation phase. Used by
Transformer Engine modules to only update their fp8 weights only on the first validation
step.
"""
pipeline_model_parallel_size
=
parallel_state
.
get_pipeline_model_parallel_world_size
()
if
pipeline_model_parallel_size
>
1
:
...
...
@@ -113,7 +122,7 @@ def deallocate_output_tensor(out, deallocate_pipeline_outputs=False):
return
assert
isinstance
(
out
,
torch
.
Tensor
),
"expected Tensor, found %s."
%
type
(
out
).
__name__
assert
out
.
_base
is
None
,
"counter-productive to free a view of another tensor."
out
.
data
=
torch
.
empty
((
1
,),
device
=
out
.
device
,
dtype
=
out
.
dtype
,
)
out
.
data
=
torch
.
empty
((
1
,),
device
=
out
.
device
,
dtype
=
out
.
dtype
)
def
custom_backward
(
output
,
grad_output
):
...
...
@@ -134,7 +143,7 @@ def custom_backward(output, grad_output):
# Handle scalar output
if
grad_output
is
None
:
assert
output
.
numel
()
==
1
,
"implicit grad requires scalar output."
grad_output
=
torch
.
ones_like
(
output
,
memory_format
=
torch
.
preserve_format
,
)
grad_output
=
torch
.
ones_like
(
output
,
memory_format
=
torch
.
preserve_format
)
# Call c++ engine [ see torch/csrc/autograd/python_engine.cpp ]
Variable
.
_execution_engine
.
run_backward
(
...
...
@@ -148,6 +157,17 @@ def custom_backward(output, grad_output):
)
def
set_current_microbatch
(
model
,
microbatch_id
):
decoder_exists
=
True
decoder
=
None
try
:
decoder
=
get_attr_wrapped_model
(
model
,
"decoder"
)
except
RuntimeError
:
decoder_exists
=
False
if
decoder_exists
and
decoder
is
not
None
:
decoder
.
current_microbatch
=
microbatch_id
def
forward_step
(
forward_step_func
,
data_iterator
,
...
...
@@ -158,16 +178,84 @@ def forward_step(
config
,
collect_non_loss_data
=
False
,
checkpoint_activations_microbatch
=
None
,
is_first_microbatch
=
False
,
current_microbatch
=
None
,
encoder_decoder_xattn
=
False
,
):
"""Forward step for passed-in model.
If first stage, input tensor is obtained from data_iterator, otherwise
passed-in input_tensor is used.
Returns output tensor."""
If it is the first stage, the input tensor is obtained from the data_iterator.
Otherwise, the passed-in input_tensor is used.
Args:
forward_step_func (callable):
The forward step function for the model that takes the
data iterator as the first argument, and model as the second.
This user's forward step is expected to output a tuple of two elements:
1. The output object from the forward step. This output object needs to be a
tensor or some kind of collection of tensors. The only hard requirement
for this object is that it needs to be acceptible as input into the second
function.
2. A function to reduce (optionally) the output from the forward step. This
could be a reduction over the loss from the model, it could be a function that
grabs the output from the model and reformats, it could be a function that just
passes through the model output. This function must have one of the following
patterns, and depending on the pattern different things happen internally:
a. A tuple of reduced loss and some other data. Note that in this case
the first argument is divided by the number of global microbatches,
assuming it is a loss, so that the loss is stable as a function of
the number of devices the step is split across.
b. A triple of reduced loss, number of tokens, and some other data. This
is similar to case (a), but the loss is further averaged across the
number of tokens in the batch. If the user is not already averaging
across the number of tokens, this pattern is useful to use.
c. Any arbitrary data the user wants (eg a dictionary of tensors, a list
of tensors, etc in the case of inference). To trigger case 3 you need
to specify `collect_non_loss_data=True` and you may also want to
specify `forward_only=True` in the call to the parent forward_backward
function.
data_iterator (iterator):
The data iterator.
model (nn.Module):
The model to perform the forward step on.
num_microbatches (int):
The number of microbatches.
input_tensor (Tensor or list[Tensor]):
The input tensor(s) for the forward step.
forward_data_store (list):
The list to store the forward data. If you go down path 2.a or
2.b for the return of your forward reduction function then this will store only the
final dimension of the output, for example the metadata output by the loss function.
If you go down the path of 2.c then this will store the entire output of the forward
reduction function applied to the model output.
config (object):
The configuration object.
collect_non_loss_data (bool, optional):
Whether to collect non-loss data. Defaults to False.
This is the path to use if you want to collect arbitrary output from the model forward,
such as with inference use cases. Defaults to False.
checkpoint_activations_microbatch (int, optional):
The microbatch to checkpoint activations.
Defaults to None.
is_first_microbatch (bool, optional):
Whether it is the first microbatch. Defaults to False.
current_microbatch (int, optional):
The current microbatch. Defaults to None.
Returns:
Tensor or list[Tensor]: The output object(s) from the forward step.
Tensor: The number of tokens.
"""
if
config
.
timers
is
not
None
:
config
.
timers
(
'forward-compute'
,
log_level
=
2
).
start
()
if
is_first_microbatch
and
hasattr
(
model
,
'set_is_first_microbatch'
):
model
.
set_is_first_microbatch
()
if
current_microbatch
is
not
None
:
set_current_microbatch
(
model
,
current_microbatch
)
unwrap_output_tensor
=
False
if
not
isinstance
(
input_tensor
,
list
):
input_tensor
=
[
input_tensor
]
...
...
@@ -188,11 +276,20 @@ def forward_step(
data_iterator
,
model
,
checkpoint_activations_microbatch
)
num_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
int
)
if
parallel_state
.
is_pipeline_last_stage
():
if
not
collect_non_loss_data
:
output_tensor
=
loss_func
(
output_tensor
)
loss
,
loss_reduced
=
output_tensor
output_tensor
=
loss
/
num_microbatches
outputs
=
loss_func
(
output_tensor
)
if
len
(
outputs
)
==
3
:
output_tensor
,
num_tokens
,
loss_reduced
=
outputs
if
not
config
.
calculate_per_token_loss
:
output_tensor
/=
num_tokens
output_tensor
/=
num_microbatches
else
:
# preserve legacy loss averaging behavior (ie, over the number of microbatches)
assert
len
(
outputs
)
==
2
output_tensor
,
loss_reduced
=
outputs
output_tensor
/=
num_microbatches
forward_data_store
.
append
(
loss_reduced
)
else
:
data
=
loss_func
(
output_tensor
,
non_loss_data
=
True
)
...
...
@@ -201,18 +298,32 @@ def forward_step(
if
config
.
timers
is
not
None
:
config
.
timers
(
'forward-compute'
).
stop
()
# If T5 model (or other model with encoder and decoder)
# and in decoder stack, then send encoder_hidden_state
# Set the loss scale for the auxiliary loss of the MoE layer.
# Since we use a trick to do backward on the auxiliary loss, we need to set the scale
# explicitly.
if
hasattr
(
config
,
'num_moe_experts'
)
and
config
.
num_moe_experts
is
not
None
:
# Calculate the loss scale based on the grad_scale_func if available, else default to 1.
loss_scale
=
(
config
.
grad_scale_func
(
torch
.
ones
(
1
,
device
=
output_tensor
.
device
))
if
config
.
grad_scale_func
is
not
None
else
torch
.
tensor
(
1.0
)
)
# Set the loss scale
MoEAuxLossAutoScaler
.
set_loss_scale
(
loss_scale
/
num_microbatches
)
# If T5 model and in decoder stack, then send encoder_hidden_state
# downstream as well.
model_type
=
get_model_type
(
model
)
if
(
parallel_state
.
is_pipeline_stage_after_split
()
and
model_type
==
ModelType
.
encoder_and_decoder
model_type
==
ModelType
.
encoder_and_decoder
and
encoder_decoder_xattn
and
parallel_state
.
is_inside_decoder
()
):
return
[
output_tensor
,
input_tensor
[
-
1
]]
return
[
output_tensor
,
input_tensor
[
-
1
]],
num_tokens
if
unwrap_output_tensor
:
return
output_tensor
return
[
output_tensor
]
return
output_tensor
,
num_tokens
return
[
output_tensor
]
,
num_tokens
def
backward_step
(
input_tensor
,
output_tensor
,
output_tensor_grad
,
model_type
,
config
):
...
...
@@ -268,10 +379,11 @@ def backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, c
# model with encoder and decoder).
if
(
parallel_state
.
get_pipeline_model_parallel_world_size
()
>
1
and
parallel_state
.
is_pipeline_stage_after_split
()
and
model_type
==
ModelType
.
encoder_and_decoder
and
len
(
output_tensor_grad
)
>
1
# excludes models that lack a skip connection.
):
if
output_tensor_grad
[
1
]
is
not
None
:
assert
input_tensor_grad
[
-
1
]
is
not
None
input_tensor_grad
[
-
1
].
add_
(
output_tensor_grad
[
1
])
if
unwrap_input_tensor_grad
:
input_tensor_grad
=
input_tensor_grad
[
0
]
...
...
@@ -282,6 +394,13 @@ def backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, c
return
input_tensor_grad
def
check_first_val_step
(
first_val_step
,
forward_only
,
cond
):
if
(
first_val_step
is
not
None
)
and
forward_only
:
return
first_val_step
and
cond
else
:
return
cond
def
forward_backward_no_pipelining
(
*
,
forward_step_func
,
...
...
@@ -293,6 +412,7 @@ def forward_backward_no_pipelining(
decoder_seq_length
:
int
=
None
,
# unused
forward_only
:
bool
=
False
,
collect_non_loss_data
:
bool
=
False
,
first_val_step
:
bool
=
None
,
):
"""Run forward and backward passes with no pipeline parallelism
(no inter-stage communication).
...
...
@@ -313,10 +433,10 @@ def forward_backward_no_pipelining(
data_iterator
=
data_iterator
[
0
]
config
=
get_model_config
(
model
)
if
config
.
timers
is
not
None
:
config
.
timers
(
'forward-backward'
,
log_level
=
1
).
start
(
barrier
=
config
.
barrier_with_L1_time
)
no_sync_func
=
config
.
no_sync_func
if
no_sync_func
is
None
and
isinstance
(
model
,
torchDDP
):
no_sync_func
=
model
.
no_sync
if
no_sync_func
is
None
:
no_sync_func
=
contextlib
.
nullcontext
...
...
@@ -324,9 +444,10 @@ def forward_backward_no_pipelining(
forward_data_store
=
[]
input_tensor
,
output_tensor_grad
=
None
,
None
total_num_tokens
=
torch
.
zeros
([],
dtype
=
torch
.
int
,
device
=
"cuda"
)
with
no_sync_func
():
for
i
in
range
(
num_microbatches
-
1
):
output_tensor
=
forward_step
(
output_tensor
,
num_tokens
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
...
...
@@ -335,13 +456,16 @@ def forward_backward_no_pipelining(
forward_data_store
,
config
,
collect_non_loss_data
,
is_first_microbatch
=
check_first_val_step
(
first_val_step
,
forward_only
,
i
==
0
),
current_microbatch
=
i
,
)
total_num_tokens
+=
num_tokens
.
item
()
if
not
forward_only
:
backward_step
(
input_tensor
,
output_tensor
,
output_tensor_grad
,
model_type
,
config
)
# Run computation for last microbatch out of context handler (want to
# synchronize gradients).
output_tensor
=
forward_step
(
output_tensor
,
num_tokens
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
...
...
@@ -350,14 +474,68 @@ def forward_backward_no_pipelining(
forward_data_store
,
config
,
collect_non_loss_data
,
is_first_microbatch
=
check_first_val_step
(
first_val_step
,
forward_only
,
num_microbatches
==
1
),
current_microbatch
=
num_microbatches
-
1
,
)
total_num_tokens
+=
num_tokens
.
item
()
if
not
forward_only
:
backward_step
(
input_tensor
,
output_tensor
,
output_tensor_grad
,
model_type
,
config
)
if
config
.
finalize_model_grads_func
is
not
None
and
not
forward_only
:
# Finalize model grads (perform full grad all-reduce / reduce-scatter for
# data parallelism and layernorm all-reduce for sequence parallelism).
config
.
finalize_model_grads_func
(
[
model
],
total_num_tokens
if
config
.
calculate_per_token_loss
else
None
)
if
config
.
timers
is
not
None
:
config
.
timers
(
'forward-backward'
).
stop
()
return
forward_data_store
def
clear_embedding_activation_buffer
(
config
,
model
):
if
(
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
)
and
config
.
defer_embedding_wgrad_compute
):
if
isinstance
(
model
,
list
):
embedding_module
=
get_attr_wrapped_model
(
model
[
-
1
],
'post_process'
,
return_model_obj
=
True
)
else
:
embedding_module
=
get_attr_wrapped_model
(
model
,
'post_process'
,
return_model_obj
=
True
)
# Need to ensure no stray activations exists in this buffer
embedding_module
.
embedding_activation_buffer
.
clear
()
return
embedding_module
else
:
return
None
def
finish_embedding_wgrad_compute
(
config
,
embedding_module
):
if
(
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
)
and
config
.
defer_embedding_wgrad_compute
):
embedding_activation_buffer
=
embedding_module
.
embedding_activation_buffer
grad_output_buffer
=
embedding_module
.
grad_output_buffer
weight
=
(
embedding_module
.
output_layer
.
weight
if
embedding_module
.
share_embeddings_and_output_weights
else
embedding_module
.
shared_embedding_or_output_weight
()
)
drain_embedding_wgrad_compute
(
config
,
embedding_activation_buffer
,
grad_output_buffer
,
weight
)
def
forward_backward_pipelining_with_interleaving
(
*
,
forward_step_func
,
...
...
@@ -369,6 +547,7 @@ def forward_backward_pipelining_with_interleaving(
decoder_seq_length
:
int
=
None
,
forward_only
:
bool
=
False
,
collect_non_loss_data
:
bool
=
False
,
first_val_step
:
bool
=
None
,
):
"""Run interleaved 1F1B schedule (model split into model chunks), with
communication between pipeline stages as needed.
...
...
@@ -384,14 +563,21 @@ def forward_backward_pipelining_with_interleaving(
if
config
.
overlap_p2p_comm
and
config
.
batch_p2p_comm
:
raise
ValueError
(
"Can not use both overlap_p2p_comm and batch_p2p_comm"
)
# Needed only when gradients are finalized in M-Core
if
config
.
finalize_model_grads_func
is
not
None
and
not
forward_only
:
embedding_module
=
clear_embedding_activation_buffer
(
config
,
model
)
if
config
.
timers
is
not
None
:
config
.
timers
(
'forward-backward'
,
log_level
=
1
).
start
(
barrier
=
config
.
barrier_with_L1_time
)
# Disable async grad reductions
no_sync_func
=
config
.
no_sync_func
if
no_sync_func
is
None
and
all
(
isinstance
(
chunk
,
torchDDP
)
for
chunk
in
model
):
if
isinstance
(
no_sync_func
,
list
):
def
multi_no_sync
():
stack
=
contextlib
.
ExitStack
()
for
chunk
in
model
:
stack
.
enter_context
(
chunk
.
no_sync
())
for
model_chunk_no_sync_func
in
config
.
no_sync_func
:
stack
.
enter_context
(
model_
chunk
_
no_sync
_func
())
return
stack
no_sync_func
=
multi_no_sync
...
...
@@ -399,6 +585,19 @@ def forward_backward_pipelining_with_interleaving(
no_sync_func
=
contextlib
.
nullcontext
no_sync_context
=
None
if
config
.
grad_sync_func
is
not
None
and
not
isinstance
(
config
.
grad_sync_func
,
list
):
config
.
grad_sync_func
=
[
config
.
grad_sync_func
for
_
in
model
]
if
config
.
param_sync_func
is
not
None
and
not
isinstance
(
config
.
param_sync_func
,
list
):
config
.
param_sync_func
=
[
config
.
param_sync_func
for
_
in
model
]
# Disable config.grad_sync_func and config.param_sync_func if only running forward passes.
# They will be re-enabled at the end of this function.
grad_sync_func
,
param_sync_func
=
None
,
None
if
forward_only
:
grad_sync_func
,
param_sync_func
=
config
.
grad_sync_func
,
config
.
param_sync_func
config
.
grad_sync_func
,
config
.
param_sync_func
=
None
,
None
def
disable_grad_sync
():
"""Disable asynchronous grad reductions"""
nonlocal
no_sync_context
...
...
@@ -420,6 +619,8 @@ def forward_backward_pipelining_with_interleaving(
input_tensors
=
[[]
for
_
in
range
(
len
(
model
))]
output_tensors
=
[[]
for
_
in
range
(
len
(
model
))]
total_num_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
int
).
cuda
()
forward_data_store
=
[]
if
not
forward_only
:
output_tensor_grads
=
[[]
for
_
in
range
(
len
(
model
))]
...
...
@@ -443,6 +644,7 @@ def forward_backward_pipelining_with_interleaving(
)
tensor_shape
=
[
seq_length
,
micro_batch_size
,
config
.
hidden_size
]
tensor_shape
[
0
]
=
tensor_shape
[
0
]
//
parallel_state
.
get_context_parallel_world_size
()
if
config
.
sequence_parallel
:
tensor_shape
[
0
]
=
tensor_shape
[
0
]
//
parallel_state
.
get_tensor_model_parallel_world_size
()
...
...
@@ -482,8 +684,8 @@ def forward_backward_pipelining_with_interleaving(
# Synchronize params for first two model chunks
if
config
.
param_sync_func
is
not
None
:
config
.
param_sync_func
(
model
[
0
].
parameters
())
config
.
param_sync_func
(
model
[
1
].
parameters
())
config
.
param_sync_func
[
0
]
(
model
[
0
].
parameters
())
config
.
param_sync_func
[
1
]
(
model
[
1
].
parameters
())
def
get_model_chunk_id
(
microbatch_id
,
forward
):
"""Helper method to get the model chunk ID given the iteration number."""
...
...
@@ -493,10 +695,18 @@ def forward_backward_pipelining_with_interleaving(
model_chunk_id
=
num_model_chunks
-
model_chunk_id
-
1
return
model_chunk_id
def
get_microbatch_id_in_model_chunk
(
iteration_id
,
forward
):
"""Helper method to get the microbatch_id within model chunk given the iteration number."""
assert
forward
iteration_group_id
=
iteration_id
//
(
pipeline_parallel_size
*
num_model_chunks
)
microbatch_id_in_model_chunk
=
(
iteration_group_id
*
pipeline_parallel_size
)
+
(
iteration_id
%
pipeline_parallel_size
)
return
microbatch_id_in_model_chunk
def
is_first_microbatch_for_model_chunk
(
microbatch_id
:
int
)
->
bool
:
"""Check if an iteration is the first for a model chunk."""
microbatch_group_size
=
pipeline_parallel_size
*
num_model_chunks
num_microbatch_groups
=
total_num_microbatches
//
microbatch_group_size
microbatch_group_id
=
microbatch_id
//
microbatch_group_size
microbatch_id_in_group
=
microbatch_id
%
microbatch_group_size
if
microbatch_group_id
==
0
:
...
...
@@ -515,7 +725,7 @@ def forward_backward_pipelining_with_interleaving(
else
:
return
False
def
forward_step_helper
(
microbatch_id
,
checkpoint_activations_microbatch
):
def
forward_step_helper
(
microbatch_id
,
current_microbatch
,
checkpoint_activations_microbatch
):
"""Helper method to run forward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
forward_step())."""
...
...
@@ -535,14 +745,17 @@ def forward_backward_pipelining_with_interleaving(
):
param_sync_chunk_id
=
get_model_chunk_id
(
param_sync_microbatch_id
,
forward
=
True
)
+
1
if
1
<
param_sync_chunk_id
<
num_model_chunks
:
config
.
param_sync_func
(
model
[
param_sync_chunk_id
].
parameters
())
config
.
param_sync_func
[
param_sync_chunk_id
](
model
[
param_sync_chunk_id
].
parameters
()
)
# forward step
if
parallel_state
.
is_pipeline_first_stage
():
if
len
(
input_tensors
[
model_chunk_id
])
==
len
(
output_tensors
[
model_chunk_id
]):
input_tensors
[
model_chunk_id
].
append
(
None
)
input_tensor
=
input_tensors
[
model_chunk_id
][
-
1
]
output_tensor
=
forward_step
(
output_tensor
,
num_tokens
=
forward_step
(
forward_step_func
,
data_iterator
[
model_chunk_id
],
model
[
model_chunk_id
],
...
...
@@ -552,9 +765,16 @@ def forward_backward_pipelining_with_interleaving(
config
,
collect_non_loss_data
,
checkpoint_activations_microbatch
,
check_first_val_step
(
first_val_step
,
forward_only
,
is_first_microbatch_for_model_chunk
(
microbatch_id
)
),
current_microbatch
=
current_microbatch
,
)
output_tensors
[
model_chunk_id
].
append
(
output_tensor
)
nonlocal
total_num_tokens
total_num_tokens
+=
num_tokens
.
item
()
# if forward-only, no need to save tensors for a backward pass
if
forward_only
:
input_tensors
[
model_chunk_id
].
pop
()
...
...
@@ -596,7 +816,7 @@ def forward_backward_pipelining_with_interleaving(
):
grad_sync_chunk_id
=
get_model_chunk_id
(
grad_sync_microbatch_id
,
forward
=
False
)
enable_grad_sync
()
config
.
grad_sync_func
(
model
[
grad_sync_chunk_id
].
parameters
())
config
.
grad_sync_func
[
grad_sync_chunk_id
]
(
model
[
grad_sync_chunk_id
].
parameters
())
synchronized_model_chunks
.
add
(
grad_sync_chunk_id
)
disable_grad_sync
()
...
...
@@ -624,7 +844,10 @@ def forward_backward_pipelining_with_interleaving(
else
:
checkpoint_activations_microbatch
=
None
output_tensor
=
forward_step_helper
(
k
,
checkpoint_activations_microbatch
)
current_microbatch
=
get_microbatch_id_in_model_chunk
(
k
,
forward
=
True
)
output_tensor
=
forward_step_helper
(
k
,
current_microbatch
,
checkpoint_activations_microbatch
)
# Determine if tensor should be received from previous stage.
next_forward_model_chunk_id
=
get_model_chunk_id
(
k
+
1
,
forward
=
True
)
...
...
@@ -651,16 +874,15 @@ def forward_backward_pipelining_with_interleaving(
recv_next
=
True
if
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
recv_next
=
False
(
input_tensor
,
output_tensor_grad
,
)
=
p2p_communication
.
send_forward_backward_recv_forward_backward
(
output_tensor
,
input_tensor_grad
,
recv_prev
=
recv_prev
,
recv_next
=
recv_next
,
tensor_shape
=
tensor_shape
,
config
=
config
,
(
input_tensor
,
output_tensor_grad
)
=
(
p2p_communication
.
send_forward_backward_recv_forward_backward
(
output_tensor
,
input_tensor_grad
,
recv_prev
=
recv_prev
,
recv_next
=
recv_next
,
tensor_shape
=
tensor_shape
,
config
=
config
,
)
)
output_tensor_grads
[
num_model_chunks
-
1
].
append
(
output_tensor_grad
)
else
:
...
...
@@ -687,15 +909,14 @@ def forward_backward_pipelining_with_interleaving(
if
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
recv_next
=
False
(
output_tensor_grad
,
bwd_wait_handles
,
)
=
p2p_communication
.
send_backward_recv_backward
(
input_tensor_grad
,
recv_next
=
recv_next
,
tensor_shape
=
tensor_shape
,
config
=
config
,
overlap_p2p_comm
=
True
,
(
output_tensor_grad
,
bwd_wait_handles
)
=
(
p2p_communication
.
send_backward_recv_backward
(
input_tensor_grad
,
recv_next
=
recv_next
,
tensor_shape
=
tensor_shape
,
config
=
config
,
overlap_p2p_comm
=
True
,
)
)
output_tensor_grads
[
num_model_chunks
-
1
].
append
(
output_tensor_grad
)
...
...
@@ -717,6 +938,7 @@ def forward_backward_pipelining_with_interleaving(
else
:
checkpoint_activations_microbatch
=
None
current_microbatch
=
get_microbatch_id_in_model_chunk
(
forward_k
,
forward
=
True
)
if
config
.
overlap_p2p_comm
:
if
fwd_wait_handles
is
not
None
:
for
req
in
fwd_wait_handles
:
...
...
@@ -724,7 +946,9 @@ def forward_backward_pipelining_with_interleaving(
deallocate_output_tensor
(
output_tensor
,
config
.
deallocate_pipeline_outputs
)
output_tensor
=
forward_step_helper
(
forward_k
,
checkpoint_activations_microbatch
)
output_tensor
=
forward_step_helper
(
forward_k
,
current_microbatch
,
checkpoint_activations_microbatch
)
# Determine if current stage has anything to send in either direction,
# otherwise set tensor to None.
...
...
@@ -802,7 +1026,9 @@ def forward_backward_pipelining_with_interleaving(
)
else
:
# no p2p overlap
output_tensor
=
forward_step_helper
(
forward_k
,
checkpoint_activations_microbatch
)
output_tensor
=
forward_step_helper
(
forward_k
,
current_microbatch
,
checkpoint_activations_microbatch
)
# Backward pass.
backward_k
=
k
...
...
@@ -855,16 +1081,15 @@ def forward_backward_pipelining_with_interleaving(
recv_prev
=
False
# Communicate tensors.
(
input_tensor
,
output_tensor_grad
,
)
=
p2p_communication
.
send_forward_backward_recv_forward_backward
(
output_tensor
,
input_tensor_grad
,
recv_prev
=
recv_prev
,
recv_next
=
recv_next
,
tensor_shape
=
tensor_shape
,
config
=
config
,
(
input_tensor
,
output_tensor_grad
)
=
(
p2p_communication
.
send_forward_backward_recv_forward_backward
(
output_tensor
,
input_tensor_grad
,
recv_prev
=
recv_prev
,
recv_next
=
recv_next
,
tensor_shape
=
tensor_shape
,
config
=
config
,
)
)
deallocate_output_tensor
(
output_tensor
,
config
.
deallocate_pipeline_outputs
)
...
...
@@ -902,16 +1127,33 @@ def forward_backward_pipelining_with_interleaving(
)
)
# Launch any remaining grad reductions
enable_grad_sync
()
if
config
.
grad_sync_func
is
not
None
:
params
=
[]
for
model_chunk_id
in
range
(
num_model_chunks
):
if
model_chunk_id
not
in
synchronized_model_chunks
:
params
.
extend
(
model
[
model_chunk_id
].
parameters
())
synchronized_model_chunks
.
add
(
model_chunk_id
)
if
params
:
config
.
grad_sync_func
(
params
)
# Launch any remaining grad reductions.
enable_grad_sync
()
if
config
.
grad_sync_func
is
not
None
:
for
model_chunk_id
in
range
(
num_model_chunks
):
if
model_chunk_id
not
in
synchronized_model_chunks
:
config
.
grad_sync_func
[
model_chunk_id
](
model
[
model_chunk_id
].
parameters
())
synchronized_model_chunks
.
add
(
model_chunk_id
)
if
config
.
finalize_model_grads_func
is
not
None
and
not
forward_only
:
# If defer_embedding_wgrad_compute is enabled we need to do the
# weight gradient GEMM's here.
finish_embedding_wgrad_compute
(
config
,
embedding_module
)
# Finalize model grads (perform full grad all-reduce / reduce-scatter for
# data parallelism, layernorm all-reduce for sequence parallelism, and
# embedding all-reduce for pipeline parallelism).
config
.
finalize_model_grads_func
(
model
,
total_num_tokens
if
config
.
calculate_per_token_loss
else
None
)
# Restore config.grad_sync_func and config.param_sync_func.
if
forward_only
:
config
.
grad_sync_func
,
config
.
param_sync_func
=
grad_sync_func
,
param_sync_func
if
config
.
timers
is
not
None
:
config
.
timers
(
'forward-backward'
).
stop
()
return
forward_data_store
...
...
@@ -924,17 +1166,23 @@ def get_tensor_shapes(
micro_batch_size
:
int
,
decoder_seq_length
:
int
,
config
,
encoder_decoder_xattn
:
bool
,
):
# Determine right tensor sizes (based on position of rank with respect to split
# rank) and model size.
# Send two tensors if model is T5 and rank is in decoder stage:
# first tensor is decoder (pre-transpose),
# second tensor is encoder (post-transpose).
# If model is T5 and rank is at the boundary:
# send one tensor (post-transpose from encoder).
# Otherwise, send one tensor (pre-transpose).
# Determine right tensor sizes (based on position of rank with
# respect to split rank) and model size.
# Send two tensors if model decoder requires the encoder's output
# (via cross-attention) and rank is in decoder stage.
# first tensor is decoder.
# second tensor is encoder.
# If model has an encoder & decoder and rank is at the boundary:
# send one tensor.
# Otherwise, send one tensor.
tensor_shapes
=
[]
seq_length
=
seq_length
//
parallel_state
.
get_context_parallel_world_size
()
if
model_type
==
ModelType
.
encoder_and_decoder
:
decoder_seq_length
=
decoder_seq_length
//
parallel_state
.
get_context_parallel_world_size
()
if
config
.
sequence_parallel
:
seq_length
=
seq_length
//
parallel_state
.
get_tensor_model_parallel_world_size
()
if
model_type
==
ModelType
.
encoder_and_decoder
:
...
...
@@ -943,12 +1191,14 @@ def get_tensor_shapes(
)
if
model_type
==
ModelType
.
encoder_and_decoder
:
if
parallel_state
.
is_
pipeline_stage_before_split
(
rank
):
if
parallel_state
.
is_
inside_encoder
(
rank
):
tensor_shapes
.
append
((
seq_length
,
micro_batch_size
,
config
.
hidden_size
))
el
se
:
el
if
encoder_decoder_xattn
:
tensor_shapes
.
append
((
decoder_seq_length
,
micro_batch_size
,
config
.
hidden_size
))
tensor_shapes
.
append
((
seq_length
,
micro_batch_size
,
config
.
hidden_size
))
else
:
else
:
tensor_shapes
.
append
((
decoder_seq_length
,
micro_batch_size
,
config
.
hidden_size
))
else
:
# model_type == ModelType.encoder_or_decoder
tensor_shapes
.
append
((
seq_length
,
micro_batch_size
,
config
.
hidden_size
))
return
tensor_shapes
...
...
@@ -976,7 +1226,7 @@ def recv_backward(tensor_shapes, config):
def
send_forward
(
output_tensors
,
tensor_shapes
,
config
):
if
not
isinstance
(
output_tensors
,
list
):
output_tensors
=
[
output_tensors
]
for
(
output_tensor
,
tensor_shape
)
in
zip
(
output_tensors
,
tensor_shapes
):
for
output_tensor
,
tensor_shape
in
zip
(
output_tensors
,
tensor_shapes
):
if
tensor_shape
is
None
:
continue
p2p_communication
.
send_forward
(
output_tensor
,
config
)
...
...
@@ -985,7 +1235,7 @@ def send_forward(output_tensors, tensor_shapes, config):
def
send_backward
(
input_tensor_grads
,
tensor_shapes
,
config
):
if
not
isinstance
(
input_tensor_grads
,
list
):
input_tensor_grads
=
[
input_tensor_grads
]
for
(
input_tensor_grad
,
tensor_shape
)
in
zip
(
input_tensor_grads
,
tensor_shapes
):
for
input_tensor_grad
,
tensor_shape
in
zip
(
input_tensor_grads
,
tensor_shapes
):
if
tensor_shape
is
None
:
continue
p2p_communication
.
send_backward
(
input_tensor_grad
,
config
)
...
...
@@ -995,7 +1245,7 @@ def send_forward_recv_backward(output_tensors, tensor_shapes, config):
if
not
isinstance
(
output_tensors
,
list
):
output_tensors
=
[
output_tensors
]
output_tensor_grads
=
[]
for
(
output_tensor
,
tensor_shape
)
in
zip
(
output_tensors
,
tensor_shapes
):
for
output_tensor
,
tensor_shape
in
zip
(
output_tensors
,
tensor_shapes
):
if
tensor_shape
is
None
:
output_tensor_grads
.
append
(
None
)
continue
...
...
@@ -1010,7 +1260,7 @@ def send_backward_recv_forward(input_tensor_grads, tensor_shapes, config):
if
not
isinstance
(
input_tensor_grads
,
list
):
input_tensor_grads
=
[
input_tensor_grads
]
input_tensors
=
[]
for
(
input_tensor_grad
,
tensor_shape
)
in
zip
(
input_tensor_grads
,
tensor_shapes
):
for
input_tensor_grad
,
tensor_shape
in
zip
(
input_tensor_grads
,
tensor_shapes
):
if
tensor_shape
is
None
:
input_tensors
.
append
(
None
)
continue
...
...
@@ -1032,11 +1282,10 @@ def forward_backward_pipelining_without_interleaving(
decoder_seq_length
:
int
=
None
,
forward_only
:
bool
=
False
,
collect_non_loss_data
:
bool
=
False
,
first_val_step
:
bool
=
None
,
):
"""Run non-interleaved 1F1B schedule, with communication between pipeline
stages.
Returns dictionary with losses if the last stage, empty dict otherwise."""
stages. Returns dictionary with losses if the last stage, empty dict otherwise."""
if
isinstance
(
model
,
list
):
assert
(
...
...
@@ -1055,10 +1304,15 @@ def forward_backward_pipelining_without_interleaving(
"Non-interleaved pipeline parallelism does not support overlapping p2p communication"
)
# Needed only when gradients are finalized in M-Core
if
config
.
finalize_model_grads_func
is
not
None
and
not
forward_only
:
embedding_module
=
clear_embedding_activation_buffer
(
config
,
model
)
if
config
.
timers
is
not
None
:
config
.
timers
(
'forward-backward'
,
log_level
=
1
).
start
(
barrier
=
config
.
barrier_with_L1_time
)
# Disable async grad reductions
no_sync_func
=
config
.
no_sync_func
if
no_sync_func
is
None
and
isinstance
(
model
,
torchDDP
):
no_sync_func
=
model
.
no_sync
if
no_sync_func
is
None
:
no_sync_func
=
contextlib
.
nullcontext
no_sync_context
=
None
...
...
@@ -1101,6 +1355,7 @@ def forward_backward_pipelining_without_interleaving(
max_outstanding_backprops
=
num_warmup_microbatches
+
1
model_type
=
get_model_type
(
model
)
encoder_decoder_xattn
=
get_model_xattn
(
model
)
rank
=
parallel_state
.
get_pipeline_model_parallel_rank
()
recv_tensor_shapes
=
get_tensor_shapes
(
...
...
@@ -1110,6 +1365,7 @@ def forward_backward_pipelining_without_interleaving(
micro_batch_size
=
micro_batch_size
,
decoder_seq_length
=
decoder_seq_length
,
config
=
config
,
encoder_decoder_xattn
=
encoder_decoder_xattn
,
)
send_tensor_shapes
=
get_tensor_shapes
(
rank
=
rank
,
...
...
@@ -1118,11 +1374,14 @@ def forward_backward_pipelining_without_interleaving(
micro_batch_size
=
micro_batch_size
,
decoder_seq_length
=
decoder_seq_length
,
config
=
config
,
encoder_decoder_xattn
=
encoder_decoder_xattn
,
)
# Input, output tensors only need to be saved when doing backward passes
input_tensors
=
None
output_tensors
=
None
total_num_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
int
).
cuda
()
if
not
forward_only
:
input_tensors
=
[]
output_tensors
=
[]
...
...
@@ -1140,7 +1399,7 @@ def forward_backward_pipelining_without_interleaving(
checkpoint_activations_microbatch
=
None
input_tensor
=
recv_forward
(
recv_tensor_shapes
,
config
)
output_tensor
=
forward_step
(
output_tensor
,
num_tokens
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
...
...
@@ -1150,8 +1409,12 @@ def forward_backward_pipelining_without_interleaving(
config
,
collect_non_loss_data
,
checkpoint_activations_microbatch
,
check_first_val_step
(
first_val_step
,
forward_only
,
i
==
0
),
current_microbatch
=
i
,
encoder_decoder_xattn
=
encoder_decoder_xattn
,
)
send_forward
(
output_tensor
,
send_tensor_shapes
,
config
)
total_num_tokens
+=
num_tokens
.
item
()
if
not
forward_only
:
input_tensors
.
append
(
input_tensor
)
...
...
@@ -1176,7 +1439,7 @@ def forward_backward_pipelining_without_interleaving(
else
:
checkpoint_activations_microbatch
=
None
output_tensor
=
forward_step
(
output_tensor
,
num_tokens
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
...
...
@@ -1186,7 +1449,13 @@ def forward_backward_pipelining_without_interleaving(
config
,
collect_non_loss_data
,
checkpoint_activations_microbatch
,
check_first_val_step
(
first_val_step
,
forward_only
,
(
i
==
0
)
and
(
num_warmup_microbatches
==
0
)
),
current_microbatch
=
i
+
num_warmup_microbatches
,
encoder_decoder_xattn
=
encoder_decoder_xattn
,
)
total_num_tokens
+=
num_tokens
.
item
()
if
forward_only
:
send_forward
(
output_tensor
,
send_tensor_shapes
,
config
)
...
...
@@ -1209,6 +1478,12 @@ def forward_backward_pipelining_without_interleaving(
input_tensor
=
input_tensors
.
pop
(
0
)
output_tensor
=
output_tensors
.
pop
(
0
)
# Enable grad sync for the last microbatch in the batch if the full
# backward pass completes in the 1F1B stage.
if
num_warmup_microbatches
==
0
and
last_iteration
:
if
config
.
grad_sync_func
is
None
or
rank
==
0
:
enable_grad_sync
()
input_tensor_grad
=
backward_step
(
input_tensor
,
output_tensor
,
output_tensor_grad
,
model_type
,
config
)
...
...
@@ -1245,10 +1520,26 @@ def forward_backward_pipelining_without_interleaving(
send_backward
(
input_tensor_grad
,
recv_tensor_shapes
,
config
)
# Launch any remaining grad reductions
if
no_sync_context
is
not
None
:
enable_grad_sync
()
if
config
.
grad_sync_func
is
not
None
:
config
.
grad_sync_func
(
model
.
parameters
())
# Launch any remaining grad reductions.
if
no_sync_context
is
not
None
:
enable_grad_sync
()
if
config
.
grad_sync_func
is
not
None
:
config
.
grad_sync_func
(
model
.
parameters
())
if
config
.
finalize_model_grads_func
is
not
None
and
not
forward_only
:
# If defer_embedding_wgrad_compute is enabled we need to do the
# weight gradient GEMM's here.
finish_embedding_wgrad_compute
(
config
,
embedding_module
)
# Finalize model grads (perform full grad all-reduce / reduce-scatter for
# data parallelism, layernorm all-reduce for sequence parallelism, and
# embedding all-reduce for pipeline parallelism).
config
.
finalize_model_grads_func
(
[
model
],
total_num_tokens
if
config
.
calculate_per_token_loss
else
None
)
if
config
.
timers
is
not
None
:
config
.
timers
(
'forward-backward'
).
stop
()
return
forward_data_store
megatron/core/requirements.txt
View file @
4b097dee
torch
\ No newline at end of file
torch
packaging
megatron/core/ssm/__init__.py
0 → 100644
View file @
4b097dee
megatron/core/ssm/mamba_block.py
0 → 100644
View file @
4b097dee
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Some of this code was adopted from https://github.com/state-spaces/mamba/
# This source code is licensed under the Apache license found in the
# LICENSE file in the root directory of this source tree.
import
math
from
dataclasses
import
dataclass
from
functools
import
partial
from
typing
import
Union
from
torch
import
Tensor
,
nn
from
megatron.core
import
parallel_state
from
megatron.core.dist_checkpointing.mapping
import
ShardedStateDict
from
megatron.core.dist_checkpointing.utils
import
replace_prefix_for_sharding
from
megatron.core.extensions.transformer_engine
import
TENorm
from
megatron.core.ssm.mamba_hybrid_layer_allocation
import
Symbols
as
LayerSymbols
from
megatron.core.ssm.mamba_hybrid_layer_allocation
import
allocate_layers
from
megatron.core.tensor_parallel
import
get_cuda_rng_tracker
from
megatron.core.transformer.identity_op
import
IdentityOp
from
megatron.core.transformer.module
import
MegatronModule
from
megatron.core.transformer.spec_utils
import
ModuleSpec
,
build_module
from
megatron.core.transformer.transformer_config
import
TransformerConfig
from
megatron.core.transformer.utils
import
sharded_state_dict_default
from
megatron.core.utils
import
make_viewless_tensor
# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
def
_init_weights
(
module
,
n_layer
,
initializer_range
=
0.02
,
# Now only used for embedding layer.
rescale_prenorm_residual
=
True
,
n_residuals_per_layer
=
1
,
# Change to 2 if we have MLP
):
with
get_cuda_rng_tracker
().
fork
():
if
isinstance
(
module
,
nn
.
Linear
):
if
not
getattr
(
module
.
weight
,
"_no_reinit"
,
False
):
nn
.
init
.
normal_
(
module
.
weight
,
std
=
initializer_range
)
if
module
.
bias
is
not
None
:
if
not
getattr
(
module
.
bias
,
"_no_reinit"
,
False
):
nn
.
init
.
zeros_
(
module
.
bias
)
elif
isinstance
(
module
,
nn
.
Embedding
):
nn
.
init
.
normal_
(
module
.
weight
,
std
=
initializer_range
)
for
name
,
p
in
module
.
named_parameters
():
if
name
in
[
"in_proj.weight"
,
"x_proj.weight"
,
"conv1d.weight"
,
"out_proj.weight"
]:
nn
.
init
.
kaiming_uniform_
(
p
,
a
=
math
.
sqrt
(
5
))
if
rescale_prenorm_residual
:
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
# > A modified initialization which accounts for the accumulation on the
# > residual path with model depth. Scale
# > the weights of residual layers at initialization by a factor of
# > 1/√N where N is the # of residual layers.
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
#
# Reference (Megatron-LM):
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
for
name
,
p
in
module
.
named_parameters
():
if
name
in
[
"out_proj.weight"
,
"fc2.weight"
]:
# Special Scaled Initialization
nn
.
init
.
normal_
(
p
,
mean
=
0.0
,
std
=
initializer_range
/
math
.
sqrt
(
n_residuals_per_layer
*
n_layer
),
)
@
dataclass
class
MambaStackSubmodules
:
"""
A class for the module specs for the MambaStack.
"""
mamba_layer
:
Union
[
ModuleSpec
,
type
]
=
IdentityOp
attention_layer
:
Union
[
ModuleSpec
,
type
]
=
IdentityOp
mlp_layer
:
Union
[
ModuleSpec
,
type
]
=
IdentityOp
class
MambaStack
(
MegatronModule
):
"""
Constructor for the MambaStack class.
Args:
config (TransformerConfig): the transformer configuration
submodules (MambaStackSubmodules): the submodules for the stack
mamba_ssm_ngroups (int, optional): the number of groups for the
MAMBA SSM. Defaults to 8.
residual_in_fp32 (bool, optional): whether to do residual connections
in fp32. Defaults to False.
pre_process (bool, optional): whether to include an embedding layer.
Defaults to True.
hybrid_attention_ratio (float, optional): the target ratio of attention layers to
total layers. Defaults to 0.0.
hybrid_mlp_ratio (float, optional): the target ratio of mlp layers to total
layers. Defaults to 0.0.
hybrid_override_pattern (str, optional): the hybrid layer pattern to override
with. Defaults to None.
post_layer_norm (bool, optional): whether to include a final layer norm.
Defaults to True.
post_process (bool, optional): whether to include an output layer.
Defaults to True.
device (optional): the device to use. Defaults to None.
dtype (optional): the data type to use. Defaults to None.
"""
def
__init__
(
self
,
config
:
TransformerConfig
,
submodules
:
MambaStackSubmodules
,
mamba_ssm_ngroups
:
int
=
8
,
residual_in_fp32
=
False
,
pre_process
:
bool
=
True
,
hybrid_attention_ratio
:
float
=
0.0
,
hybrid_mlp_ratio
:
float
=
0.0
,
hybrid_override_pattern
:
str
=
None
,
post_layer_norm
:
bool
=
True
,
post_process
:
bool
=
True
,
device
=
None
,
dtype
=
None
,
)
->
None
:
super
().
__init__
(
config
=
config
)
self
.
residual_in_fp32
=
residual_in_fp32
self
.
pre_process
=
pre_process
self
.
post_layer_norm
=
post_layer_norm
self
.
post_process
=
post_process
# Required for pipeline parallel schedules
self
.
input_tensor
=
None
self
.
hybrid_attention_ratio
=
hybrid_attention_ratio
self
.
hybrid_mlp_ratio
=
hybrid_mlp_ratio
self
.
hybrid_override_pattern
=
hybrid_override_pattern
layer_type_list
=
allocate_layers
(
self
.
config
.
num_layers
,
self
.
hybrid_attention_ratio
,
self
.
hybrid_mlp_ratio
,
self
.
hybrid_override_pattern
,
)
pp_layer_offset
=
0
if
parallel_state
.
get_pipeline_model_parallel_world_size
()
>
1
:
pp_layer_offset
,
layer_type_list
=
self
.
_select_layers_for_pipeline_parallel
(
layer_type_list
)
self
.
layers
=
nn
.
ModuleList
()
for
i
,
layer_type
in
enumerate
(
layer_type_list
):
if
layer_type
==
LayerSymbols
.
MAMBA
:
layer
=
build_module
(
submodules
.
mamba_layer
,
config
=
self
.
config
,
mamba_ssm_ngroups
=
mamba_ssm_ngroups
,
residual_in_fp32
=
residual_in_fp32
,
layer_number
=
i
+
1
+
pp_layer_offset
,
)
elif
layer_type
==
LayerSymbols
.
ATTENTION
:
# Transformer layers apply their own pp_layer_offset
layer
=
build_module
(
submodules
.
attention_layer
,
config
=
self
.
config
,
layer_number
=
i
+
1
)
elif
layer_type
==
LayerSymbols
.
MLP
:
# Transformer layers apply their own pp_layer_offset
layer
=
build_module
(
submodules
.
mlp_layer
,
config
=
self
.
config
,
layer_number
=
i
+
1
)
else
:
assert
True
,
"unexpected layer_type"
self
.
layers
.
append
(
layer
)
# Required for activation recomputation
self
.
num_layers_per_pipeline_rank
=
len
(
self
.
layers
)
if
self
.
post_process
and
self
.
post_layer_norm
:
# Final layer norm before output.
self
.
final_norm
=
TENorm
(
config
=
self
.
config
,
hidden_size
=
self
.
config
.
hidden_size
,
eps
=
self
.
config
.
layernorm_epsilon
,
)
self
.
apply
(
partial
(
_init_weights
,
n_layer
=
self
.
config
.
num_layers
))
def
_select_layers_for_pipeline_parallel
(
self
,
layer_type_list
):
pipeline_rank
=
parallel_state
.
get_pipeline_model_parallel_rank
()
num_layers_per_pipeline_rank
=
(
self
.
config
.
num_layers
//
parallel_state
.
get_pipeline_model_parallel_world_size
()
)
assert
parallel_state
.
get_virtual_pipeline_model_parallel_world_size
()
is
None
,
(
"The Mamba hybrid model does not currently support "
"virtual/interleaved pipeline parallelism"
)
offset
=
pipeline_rank
*
num_layers_per_pipeline_rank
selected_list
=
layer_type_list
[
offset
:
offset
+
num_layers_per_pipeline_rank
]
return
offset
,
selected_list
def
allocate_inference_cache
(
self
,
batch_size
,
max_seqlen
,
dtype
=
None
):
"""
Allocate inference cache for each layer.
Args:
batch_size (int): The batch size to use for inference.
max_seqlen (int): The maximum sequence length to use
for inference.
dtype (optional): The data type to use for allocation.
Defaults to the data type of the model.
"""
return
{
i
:
layer
.
allocate_inference_cache
(
batch_size
,
max_seqlen
,
dtype
=
dtype
)
for
i
,
layer
in
enumerate
(
self
.
layers
)
}
def
set_input_tensor
(
self
,
input_tensor
:
Tensor
):
"""Set input tensor to be used instead of forward()'s input.
When doing pipeline parallelism the input from the previous
stage comes from communication, not from the input, so the
model's forward_step_func won't have it. This function is thus
used by internal code to bypass the input provided by the
forward_step_func"""
self
.
input_tensor
=
input_tensor
def
forward
(
self
,
hidden_states
:
Tensor
,
attention_mask
:
Tensor
,
inference_params
=
None
,
rotary_pos_emb
:
Tensor
=
None
,
):
"""
Forward function of the MambaStack class.
It either returns the Loss values if labels are given or the
final hidden units
Args:
hidden_states (Tensor): the input tensor.
attention_mask (Tensor): the attention mask.
inference_params (InferenceParams): the inference parameters.
rotary_pos_emb (Tensor, optional): the rotary positional embeddings.
Defaults to None.
Returns:
Tensor: the output tensor.
"""
if
not
self
.
pre_process
:
# See set_input_tensor()
hidden_states
=
self
.
input_tensor
if
inference_params
:
# NOTE(bnorick): match InferenceParams attributes for
# mamba_ssm.utils.generation.InferenceParams,
# this hack supports eval
inference_params
.
max_seqlen
=
inference_params
.
max_sequence_length
inference_params
.
seqlen_offset
=
inference_params
.
sequence_len_offset
for
layer
in
self
.
layers
:
hidden_states
=
layer
(
hidden_states
,
attention_mask
,
inference_params
=
inference_params
,
rotary_pos_emb
=
rotary_pos_emb
,
)
# The attention layer (currently a simplified transformer layer)
# outputs a tuple of (hidden_states, context). Context is intended
# for cross-attention, and is not needed in our model.
if
isinstance
(
hidden_states
,
tuple
):
hidden_states
=
hidden_states
[
0
]
# Final layer norm.
if
self
.
post_process
and
self
.
post_layer_norm
:
hidden_states
=
self
.
final_norm
(
hidden_states
)
# Ensure that the tensor passed between pipeline parallel stages is
# viewless. See related notes in TransformerBlock and TransformerLayer
output
=
make_viewless_tensor
(
inp
=
hidden_states
,
requires_grad
=
hidden_states
.
requires_grad
,
keep_graph
=
True
)
return
hidden_states
def
sharded_state_dict
(
self
,
prefix
:
str
=
''
,
sharded_offsets
:
tuple
=
(),
metadata
:
dict
=
None
)
->
ShardedStateDict
:
"""
Returns a sharded state dictionary for the current object.
This function constructs a sharded state dictionary by iterating over the layers
in the current object, computing the sharded state dictionary for each layer,
and combining the results into a single dictionary.
Parameters:
prefix (str): The prefix to use for the state dictionary keys.
sharded_offsets (tuple): The sharded offsets to use for the state dictionary.
metadata (dict): Additional metadata to use when computing the sharded state dictionary.
Returns:
dict: The sharded state dictionary for the current object.
"""
sharded_state_dict
=
{}
layer_prefix
=
f
'
{
prefix
}
layers.'
for
local_layer_idx
,
layer
in
enumerate
(
self
.
layers
):
global_layer_offset
=
layer
.
layer_number
-
1
# self.layer_number starts at 1
state_dict_prefix
=
(
f
'
{
layer_prefix
}{
local_layer_idx
}
.'
# module list index in MambaBlock
)
sharded_prefix
=
f
'
{
layer_prefix
}{
global_layer_offset
}
.'
sharded_pp_offset
=
[]
layer_sharded_state_dict
=
layer
.
sharded_state_dict
(
state_dict_prefix
,
sharded_pp_offset
,
metadata
)
replace_prefix_for_sharding
(
layer_sharded_state_dict
,
state_dict_prefix
,
sharded_prefix
)
sharded_state_dict
.
update
(
layer_sharded_state_dict
)
# Add modules other than self.layers
for
name
,
module
in
self
.
named_children
():
if
not
module
is
self
.
layers
:
sharded_state_dict
.
update
(
sharded_state_dict_default
(
module
,
f
'
{
prefix
}{
name
}
.'
,
sharded_offsets
,
metadata
)
)
return
sharded_state_dict
megatron/core/ssm/mamba_hybrid_layer_allocation.py
0 → 100644
View file @
4b097dee
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import
logging
if
__name__
!=
"__main__"
:
from
megatron.core.utils
import
log_single_rank
else
:
from
typing
import
Any
def
log_single_rank
(
logger
:
logging
.
Logger
,
*
args
:
Any
,
rank
:
int
=
0
,
**
kwargs
:
Any
):
print
(
*
args
[
1
:],
**
kwargs
)
logger
=
logging
.
getLogger
(
__name__
)
class
Symbols
:
MAMBA
=
'M'
ATTENTION
=
'*'
MLP
=
'-'
VALID
=
{
MAMBA
,
ATTENTION
,
MLP
}
def
_allocate_auto
(
total_layers_count
:
int
,
target_attention_ratio
:
float
,
target_mlp_ratio
:
float
)
->
list
:
# First, allocate attention (evenly spaced, starting and ending with mamba)
attention_layers_count
:
int
=
round
(
total_layers_count
*
target_attention_ratio
)
mamba_layers_count
:
int
=
total_layers_count
-
attention_layers_count
mamba_sections_count
:
int
=
attention_layers_count
+
1
mamba_section_length
:
float
=
mamba_layers_count
/
mamba_sections_count
layer_type_list
=
[
Symbols
.
MAMBA
]
*
total_layers_count
x
:
float
=
mamba_section_length
for
l
in
range
(
total_layers_count
):
if
x
<
0.5
:
layer_type_list
[
l
]
=
Symbols
.
ATTENTION
x
+=
mamba_section_length
else
:
x
-=
1
# Next, allocate mlp
# (evenly distributed, but right-justified, not replacing attention)
mlp_layers_count
:
int
=
round
(
total_layers_count
*
target_mlp_ratio
)
if
mlp_layers_count
>
0
:
mamba_layers_count
-=
mlp_layers_count
mamba_to_mlp_ratio
:
float
=
mamba_layers_count
/
mlp_layers_count
x
:
float
=
mamba_to_mlp_ratio
for
l
in
range
(
total_layers_count
):
if
layer_type_list
[
l
]
==
Symbols
.
MAMBA
:
if
x
<
0.5
:
layer_type_list
[
l
]
=
Symbols
.
MLP
x
+=
mamba_to_mlp_ratio
else
:
x
-=
1
return
layer_type_list
def
_allocate_override
(
total_layers_count
:
int
,
override_pattern
:
str
)
->
list
:
layer_type_list
=
list
(
override_pattern
)
override_pattern_length
=
len
(
layer_type_list
)
if
override_pattern_length
!=
total_layers_count
:
raise
ValueError
(
"The hybrid override pattern is the wrong "
f
"length: got
{
override_pattern_length
}
, expected "
f
"
{
total_layers_count
}
"
)
for
l
in
layer_type_list
:
if
l
not
in
Symbols
.
VALID
:
raise
ValueError
(
f
"In hybrid override pattern, '
{
l
}
' is not "
f
"one of
{
Symbols
.
VALID
}
"
)
return
layer_type_list
def
_layer_counts_match
(
a
:
list
,
b
:
list
)
->
bool
:
for
s
in
Symbols
.
VALID
:
if
a
.
count
(
s
)
!=
b
.
count
(
s
):
return
False
return
True
def
allocate_layers
(
total_layers_count
:
int
,
target_attention_ratio
:
float
,
target_mlp_ratio
:
float
,
override_pattern
:
str
=
None
,
)
->
list
:
assert
total_layers_count
>
0
assert
target_attention_ratio
>=
0.0
and
target_attention_ratio
<=
1.0
assert
target_mlp_ratio
>=
0.0
and
target_mlp_ratio
<=
1.0
assert
target_attention_ratio
+
target_mlp_ratio
<=
1.0
# Note: target_mamba_ratio = 1.0 - target_attention_ratio - target_mlp_ratio
layer_type_list
=
_allocate_auto
(
total_layers_count
,
target_attention_ratio
,
target_mlp_ratio
)
if
override_pattern
is
not
None
:
layer_type_list_override
=
_allocate_override
(
total_layers_count
,
override_pattern
)
log_single_rank
(
logger
,
logging
.
INFO
,
"Using hybrid override pattern"
)
if
(
target_attention_ratio
>
0.0
or
target_mlp_ratio
>
0.0
)
and
not
_layer_counts_match
(
layer_type_list_override
,
layer_type_list
):
raise
ValueError
(
"The number of each type of layer in the override "
"pattern must match the number in the overridden "
"pattern."
)
if
layer_type_list_override
==
layer_type_list
:
log_single_rank
(
logger
,
logging
.
INFO
,
"The override pattern matches the overridden pattern"
)
else
:
log_single_rank
(
logger
,
logging
.
INFO
,
"Warning: overriding pattern A with pattern B"
)
log_single_rank
(
logger
,
logging
.
INFO
,
f
"A:
{
''
.
join
(
layer_type_list
)
}
"
)
log_single_rank
(
logger
,
logging
.
INFO
,
f
"B:
{
''
.
join
(
layer_type_list_override
)
}
"
)
layer_type_list
=
layer_type_list_override
if
target_attention_ratio
>
0.0
or
target_mlp_ratio
>
0.0
or
override_pattern
is
not
None
:
actual_attention_layers_count
=
layer_type_list
.
count
(
Symbols
.
ATTENTION
)
actual_attention_ratio
=
actual_attention_layers_count
/
total_layers_count
actual_mlp_layers_count
=
layer_type_list
.
count
(
Symbols
.
MLP
)
actual_mlp_ratio
=
actual_mlp_layers_count
/
total_layers_count
allocation_string
=
''
.
join
(
layer_type_list
)
log_single_rank
(
logger
,
logging
.
INFO
,
f
"Hybrid allocation (
{
Symbols
.
MAMBA
}
is mamba, "
f
"
{
Symbols
.
ATTENTION
}
is attention, "
f
"
{
Symbols
.
MLP
}
is mlp):"
,
)
log_single_rank
(
logger
,
logging
.
INFO
,
allocation_string
)
log_single_rank
(
logger
,
logging
.
INFO
,
f
"
{
actual_attention_layers_count
}
attention layers in "
f
"
{
total_layers_count
}
total layers."
,
)
log_single_rank
(
logger
,
logging
.
INFO
,
f
"Target attention ratio:
{
target_attention_ratio
:.
2
f
}
. "
f
"Actual attention ratio:
{
actual_attention_ratio
:.
2
f
}
."
,
)
log_single_rank
(
logger
,
logging
.
INFO
,
f
"
{
actual_mlp_layers_count
}
mlp layers in "
f
"
{
total_layers_count
}
total layers."
,
)
log_single_rank
(
logger
,
logging
.
INFO
,
f
"Target mlp ratio:
{
target_mlp_ratio
:.
2
f
}
. "
f
"Actual mlp ratio:
{
actual_mlp_ratio
:.
2
f
}
."
,
)
return
layer_type_list
if
__name__
==
"__main__"
:
test_cases
=
[
# (10, 0.2, 0.0),
# (48, 0.0, 0.0), # will not print anything
# (48, 0.1, 0.0),
# 48, 0.3, 0.0),
# (48, 0.5, 0.0),
# (48, 0.6, 0.0),
# (48, 0.7, 0.0),
# (10, 0.0, 0.1),
# (10, 0.0, 0.3),
# (10, 0.0, 0.5),
# (10, 0.1, 0.1),
# (10, 0.2, 0.2),
# (10, 0.3, 0.3),
# (10, 0.5, 0.5),
# (48, 0.2, 0.3),
# (48, 0.5, 0.2),
# (48, 0.5, 0.2, "MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-"),
# (48, 0.25, 0.25, "MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-"),
# (48, 0.25, 0.25, "MM-*MM-*MM*-MM*-MM*-MM*-M*M-M*M-M*M-M*M-*MM-*MM-"),
# (48, 0.0, 0.2, "MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-"),
# (48, 0.2, 0.0, "MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-"),
# (48, 0.0, 0.0, "MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-"),
# (48, 0.5, 0.5),
# (10, 0.3, 0.2, "MMM*-*M*M-"),
# (10, 0.3, 0.2, "MM*M-*M*M-"),
(
9
,
0.0
,
0.0
,
"M*-M*-M*-"
),
(
9
,
0.0
,
0.0
,
"MMMMMMMMM"
),
]
for
t
in
test_cases
:
print
(
""
)
allocate_layers
(
*
t
)
megatron/core/ssm/mamba_layer.py
0 → 100644
View file @
4b097dee
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Some of this code was adopted from https://github.com/state-spaces/mamba/
# This source code is licensed under the Apache license found in the
# LICENSE file in the root directory of this source tree.
from
dataclasses
import
dataclass
from
typing
import
Union
import
torch
from
torch
import
Tensor
from
megatron.core.transformer.identity_op
import
IdentityOp
from
megatron.core.transformer.module
import
MegatronModule
from
megatron.core.transformer.spec_utils
import
ModuleSpec
,
build_module
from
megatron.core.transformer.transformer_config
import
TransformerConfig
@
dataclass
class
MambaLayerSubmodules
:
norm
:
Union
[
ModuleSpec
,
type
]
=
IdentityOp
mixer
:
Union
[
ModuleSpec
,
type
]
=
IdentityOp
mamba_bda
:
Union
[
ModuleSpec
,
type
]
=
IdentityOp
class
MambaLayer
(
MegatronModule
):
def
__init__
(
self
,
config
:
TransformerConfig
,
submodules
:
MambaLayerSubmodules
,
mamba_ssm_ngroups
=
8
,
layer_number
:
int
=
1
,
residual_in_fp32
=
False
,
):
"""
Top level Mamba Layer
"""
super
().
__init__
(
config
)
self
.
config
=
config
self
.
layer_number
=
layer_number
self
.
residual_in_fp32
=
residual_in_fp32
self
.
hidden_dropout
=
config
.
hidden_dropout
self
.
mixer
=
build_module
(
submodules
.
mixer
,
self
.
config
,
d_model
=
self
.
config
.
hidden_size
,
ngroups
=
mamba_ssm_ngroups
,
layer_number
=
layer_number
,
)
self
.
norm
=
build_module
(
submodules
.
norm
,
self
.
config
,
self
.
config
.
hidden_size
)
self
.
mamba_bda
=
build_module
(
submodules
.
mamba_bda
)
self
.
bias_dropout_add_exec_handler
=
torch
.
enable_grad
def
forward
(
self
,
hidden_states
:
Tensor
,
attention_mask
:
Tensor
,
# Not used in MambaLayer
inference_params
=
None
,
rotary_pos_emb
:
Tensor
=
None
,
# Not used in MambaLayer
):
residual
=
hidden_states
if
self
.
residual_in_fp32
:
residual
=
residual
.
to
(
torch
.
float32
)
hidden_states
=
hidden_states
.
to
(
dtype
=
self
.
config
.
params_dtype
)
hidden_states
=
self
.
norm
(
hidden_states
)
mixer_out_with_bias
=
self
.
mixer
(
hidden_states
,
inference_params
=
inference_params
)
with
self
.
bias_dropout_add_exec_handler
():
hidden_states
=
self
.
mamba_bda
(
self
.
training
,
self
.
config
.
bias_dropout_fusion
)(
mixer_out_with_bias
,
residual
,
self
.
hidden_dropout
)
return
hidden_states
def
allocate_inference_cache
(
self
,
batch_size
,
max_seqlen
,
dtype
=
None
):
return
self
.
mixer
.
allocate_inference_cache
(
batch_size
,
max_seqlen
,
dtype
=
dtype
)
Prev
1
…
12
13
14
15
16
17
18
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment