Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
3bec6514
Commit
3bec6514
authored
Dec 09, 2024
by
xingjinliang
Browse files
升级megatron_v0.10[B
parent
3aca1415
Changes
371
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
4657 additions
and
29 deletions
+4657
-29
megatron/core/models/vision/vit_layer_specs.py
megatron/core/models/vision/vit_layer_specs.py
+95
-0
megatron/core/num_microbatches_calculator.py
megatron/core/num_microbatches_calculator.py
+508
-0
megatron/core/optimizer/__init__.py
megatron/core/optimizer/__init__.py
+459
-0
megatron/core/optimizer/clip_grads.py
megatron/core/optimizer/clip_grads.py
+220
-0
megatron/core/optimizer/distrib_optimizer.py
megatron/core/optimizer/distrib_optimizer.py
+1822
-0
megatron/core/optimizer/grad_scaler.py
megatron/core/optimizer/grad_scaler.py
+49
-27
megatron/core/optimizer/optimizer.py
megatron/core/optimizer/optimizer.py
+1069
-0
megatron/core/optimizer/optimizer_config.py
megatron/core/optimizer/optimizer_config.py
+116
-0
megatron/core/optimizer_param_scheduler.py
megatron/core/optimizer_param_scheduler.py
+297
-0
megatron/core/package_info.py
megatron/core/package_info.py
+2
-2
megatron/core/packed_seq_params.py
megatron/core/packed_seq_params.py
+20
-0
No files found.
Too many changes to show.
To preserve performance only
371 of 371+
files are displayed.
Plain diff
Email patch
megatron/core/models/vision/vit_layer_specs.py
0 → 100755
View file @
3bec6514
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from
megatron.core.extensions.transformer_engine
import
(
TEDotProductAttention
,
TELayerNormColumnParallelLinear
,
TERowParallelLinear
,
)
from
megatron.core.fusions.fused_bias_dropout
import
get_bias_dropout_add
from
megatron.core.tensor_parallel.layers
import
ColumnParallelLinear
,
RowParallelLinear
from
megatron.core.transformer.attention
import
SelfAttention
,
SelfAttentionSubmodules
from
megatron.core.transformer.dot_product_attention
import
DotProductAttention
from
megatron.core.transformer.enums
import
AttnMaskType
from
megatron.core.transformer.identity_op
import
IdentityOp
from
megatron.core.transformer.mlp
import
MLP
,
MLPSubmodules
from
megatron.core.transformer.spec_utils
import
ModuleSpec
from
megatron.core.transformer.transformer_layer
import
TransformerLayer
,
TransformerLayerSubmodules
try
:
import
apex
# pylint: disable=unused-import
from
megatron.core.fusions.fused_layer_norm
import
FusedLayerNorm
HAVE_APEX
=
True
LNImpl
=
FusedLayerNorm
except
ImportError
:
import
warnings
from
megatron.core.transformer.torch_norm
import
WrappedTorchNorm
warnings
.
warn
(
f
'Apex is not installed. Falling back to Torch Norm'
)
LNImpl
=
WrappedTorchNorm
# Use this spec to use lower level Transformer Engine modules (required for fp8 training)
def
get_vit_layer_with_transformer_engine_spec
()
->
ModuleSpec
:
'''
Returns ViT layer spec with Transformer Engine layers
'''
mlp
=
_get_mlp_module_spec
(
use_te
=
True
)
return
ModuleSpec
(
module
=
TransformerLayer
,
submodules
=
TransformerLayerSubmodules
(
self_attention
=
ModuleSpec
(
module
=
SelfAttention
,
params
=
{
"attn_mask_type"
:
AttnMaskType
.
no_mask
},
submodules
=
SelfAttentionSubmodules
(
linear_qkv
=
TELayerNormColumnParallelLinear
,
core_attention
=
TEDotProductAttention
,
linear_proj
=
TERowParallelLinear
,
),
),
self_attn_bda
=
get_bias_dropout_add
,
pre_mlp_layernorm
=
IdentityOp
,
mlp
=
mlp
,
mlp_bda
=
get_bias_dropout_add
,
),
)
def
get_vit_layer_with_local_spec
()
->
ModuleSpec
:
'''
Returns ViT layer spec with Mcore local layers
'''
mlp
=
_get_mlp_module_spec
(
use_te
=
False
)
return
ModuleSpec
(
module
=
TransformerLayer
,
submodules
=
TransformerLayerSubmodules
(
input_layernorm
=
LNImpl
,
self_attention
=
ModuleSpec
(
module
=
SelfAttention
,
params
=
{
"attn_mask_type"
:
AttnMaskType
.
causal
},
submodules
=
SelfAttentionSubmodules
(
linear_qkv
=
ColumnParallelLinear
,
core_attention
=
DotProductAttention
,
linear_proj
=
RowParallelLinear
,
),
),
self_attn_bda
=
get_bias_dropout_add
,
pre_mlp_layernorm
=
LNImpl
,
mlp
=
mlp
,
mlp_bda
=
get_bias_dropout_add
,
),
)
# Helper function to get module spec for MLP/MoE
def
_get_mlp_module_spec
(
use_te
:
bool
=
True
)
->
ModuleSpec
:
# Dense MLP w/ or w/o TE modules.
return
ModuleSpec
(
module
=
MLP
,
submodules
=
MLPSubmodules
(
linear_fc1
=
TELayerNormColumnParallelLinear
if
use_te
else
ColumnParallelLinear
,
linear_fc2
=
TERowParallelLinear
if
use_te
else
RowParallelLinear
,
),
)
megatron/core/num_microbatches_calculator.py
0 → 100755
View file @
3bec6514
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Megatron Core number of microbatches calculators."""
import
logging
from
abc
import
ABC
,
abstractmethod
from
typing
import
List
,
Optional
,
Union
logger
=
logging
.
getLogger
(
__name__
)
# TODO: global_var merge into mcore?
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
:
Union
[
'ConstantNumMicroBatchesCalculator'
,
'RampupBatchsizeNumMicroBatchesCalculator'
]
=
None
def
get_num_microbatches
()
->
int
:
"""Get number of microbatches."""
return
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
.
get
()
def
get_current_global_batch_size
()
->
int
:
"""Get current global batch size."""
return
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
.
get_current_global_batch_size
()
def
get_micro_batch_size
()
->
int
:
"""Get micro batch size."""
return
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
.
get_micro_batch_size
()
def
get_current_running_global_batch_size
()
->
int
:
"""Get current running global batch size, taking into account number of DP replicas might be
incompatible with true global batch size if `decrease_batch_size_if_needed` is True."""
return
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
.
get_current_running_global_batch_size
()
def
update_num_microbatches
(
consumed_samples
:
int
,
consistency_check
:
bool
=
True
,
verbose
:
bool
=
False
)
->
None
:
"""Update number of microbatches.
Args:
consumed_samples (int):
Number of samples consumed.
consistency_check (bool, optional):
Option to check current schedule's consistency. Defaults to True.
verbose (bool, optional):
Option to control logging. Defaults to False.
"""
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
.
update
(
consumed_samples
,
consistency_check
,
verbose
)
def
unset_num_microbatches_calculator
():
"""Unset microbatches calculator.
Useful for multiple runs. See `tests/unit_tests/ckpt_converter/test_ckpt_converter.py`
for an example.
"""
global
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
=
None
def
init_num_microbatches_calculator
(
rank
:
int
,
rampup_batch_size
:
Optional
[
List
[
int
]],
global_batch_size
:
int
,
micro_batch_size
:
int
,
data_parallel_size
:
int
,
decrease_batch_size_if_needed
:
bool
=
False
,
)
->
None
:
"""Initialize number of microbatches calculator. Supporting backward compatibility.
Args:
rank (int):
Rank of the GPU, only rank 0 will log the information.
rampup_batch_size (Optional[List[int]]):
Rampup batch size, should be in format of [start_global_batch_size,
batch_size_increment, ramup_samples].
global_batch_size (int):
Global batch size for the model.
micro_batch_size (int):
Micro batch size at initialization.
data_parallel_size (int):
Data parallel size.
decrease_batch_size_if_needed (bool, optional):
If true, scale down batch size to ensure divisibility by DP size * microbatch size.
Defaults to False.
"""
_configure_global_num_microbatches_calculator
(
rank
,
rampup_batch_size
,
global_batch_size
,
micro_batch_size
,
data_parallel_size
,
decrease_batch_size_if_needed
,
init
=
True
,
)
def
destroy_num_microbatches_calculator
():
"""Destroy number of microbatches calculator."""
global
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
=
None
def
reconfigure_num_microbatches_calculator
(
rank
:
int
,
rampup_batch_size
:
Optional
[
List
[
int
]],
global_batch_size
:
int
,
micro_batch_size
:
int
,
data_parallel_size
:
int
,
decrease_batch_size_if_needed
:
bool
=
False
,
)
->
None
:
"""Reconfigure number of microbatches calculator. Supporting backward compatibility.
Args:
rank (int):
Rank of the GPU, only rank 0 will log the information.
rampup_batch_size (Optional[List[int]]):
Rampup batch size, should be in format of
[start_global_batch_size, batch_size_increment, ramup_samples].
global_batch_size (int):
Global batch size for the model.
micro_batch_size (int):
Micro batch size at initialization.
data_parallel_size (int):
Data parallel size.
decrease_batch_size_if_needed (bool, optional):
If true, scale down batch size to ensure divisibility by DP size * microbatch size.
Defaults to False.
"""
_configure_global_num_microbatches_calculator
(
rank
,
rampup_batch_size
,
global_batch_size
,
micro_batch_size
,
data_parallel_size
,
decrease_batch_size_if_needed
,
init
=
False
,
)
def
_configure_global_num_microbatches_calculator
(
rank
:
int
,
rampup_batch_size
:
Optional
[
List
[
int
]],
global_batch_size
:
int
,
micro_batch_size
:
int
,
data_parallel_size
:
int
,
decrease_batch_size_if_needed
:
bool
=
False
,
init
:
bool
=
False
,
)
->
None
:
"""Configure number of microbatches calculator. Can be used for initialization and
reconfiguration.
Args:
rank (int):
Rank of the GPU, only rank 0 will log the information.
rampup_batch_size (Optional[List[int]]):
Rampup batch size, should be in format of
[start_global_batch_size, batch_size_increment, ramup_samples].
global_batch_size (int):
Global batch size for the model.
micro_batch_size (int):
Micro batch size at initialization.
data_parallel_size (int):
Data parallel size.
decrease_batch_size_if_needed (bool, optional):
If true, scale down batch size to ensure divisibility by DP size * microbatch size.
Defaults to False.
init (bool, optional):
If true, initialize the calculator. Defaults to False.
"""
global
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
if
init
:
assert
(
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
is
None
),
'num microbatches calculator is already initialized.'
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
=
_build_num_microbatches_calculator
(
rank
,
rampup_batch_size
,
global_batch_size
,
micro_batch_size
,
data_parallel_size
,
decrease_batch_size_if_needed
,
)
def
_build_num_microbatches_calculator
(
rank
:
int
,
rampup_batch_size
:
Optional
[
List
[
int
]],
global_batch_size
:
int
,
micro_batch_size
:
int
,
data_parallel_size
:
int
,
decrease_batch_size_if_needed
:
bool
,
)
->
Union
[
'ConstantNumMicroBatchesCalculator'
,
'RampupBatchsizeNumMicroBatchesCalculator'
]:
"""Build number of microbatches calculator. Internal helper method.
Args:
rank (int):
Rank of the GPU, only rank 0 will log the information.
rampup_batch_size (Optional[List[int]]):
Rampup batch size, should be in format of
[start_global_batch_size, batch_size_increment, ramup_samples].
global_batch_size (int):
Global batch size for the model.
micro_batch_size (int):
Micro batch size at initialization.
data_parallel_size (int):
Data parallel size.
decrease_batch_size_if_needed (bool):
If true, scale down batch size to ensure divisibility by DP size * microbatch size.
"""
# Constant batch size.
if
rampup_batch_size
is
None
:
num_microbatches_calculator
=
ConstantNumMicroBatchesCalculator
(
global_batch_size
,
micro_batch_size
,
data_parallel_size
,
decrease_batch_size_if_needed
,
rank
,
)
if
rank
==
0
:
logger
.
info
(
f
'setting number of microbatches to constant
{
num_microbatches_calculator
.
get
()
}
'
)
# Batch size ramp up.
else
:
assert
len
(
rampup_batch_size
)
==
3
,
(
'expected the following '
'format: --rampup-batch-size <start batch size> '
'<batch size incerement> <ramp-up samples>'
)
start_global_batch_size
=
int
(
rampup_batch_size
[
0
])
batch_size_increment
=
int
(
rampup_batch_size
[
1
])
ramup_samples
=
int
(
rampup_batch_size
[
2
])
if
rank
==
0
:
logger
.
info
(
f
'will use batch size rampup starting from global batch size '
f
'
{
start_global_batch_size
}
to global batch size
{
global_batch_size
}
with batch'
f
'size increments
{
batch_size_increment
}
over
{
ramup_samples
}
samples.'
)
num_microbatches_calculator
=
RampupBatchsizeNumMicroBatchesCalculator
(
global_batch_size
,
micro_batch_size
,
data_parallel_size
,
decrease_batch_size_if_needed
,
rank
,
start_global_batch_size
,
batch_size_increment
,
ramup_samples
,
)
return
num_microbatches_calculator
def
_round
(
batch_size
:
int
,
divisor
:
int
)
->
int
:
"""Round `batch_size` down to nearest batch size divisible by `divisor`."""
return
(
batch_size
//
divisor
)
*
divisor
class
NumMicroBatchesCalculator
(
ABC
):
"""Base class for number of microbatches calculator."""
def
__init__
(
self
)
->
None
:
self
.
num_micro_batches
=
None
self
.
current_global_batch_size
=
None
self
.
micro_batch_size
=
None
self
.
current_running_global_batch_size
=
None
def
get
(
self
)
->
int
:
"""Get number of microbatches."""
return
self
.
num_micro_batches
def
get_current_global_batch_size
(
self
)
->
int
:
"""Get current global batch size."""
return
self
.
current_global_batch_size
def
get_micro_batch_size
(
self
)
->
int
:
"""Get current global batch size."""
return
self
.
micro_batch_size
def
get_current_running_global_batch_size
(
self
)
->
int
:
"""Get current running global batch size. If decrease_batch_size_if_needed is False,
this just equals global batch size."""
return
self
.
current_running_global_batch_size
@
abstractmethod
def
update
(
self
,
consumed_samples
,
consistency_check
,
verbose
=
False
)
->
None
:
"""Update number of microbatches depending on batch size rampup."""
pass
class
ConstantNumMicroBatchesCalculator
(
NumMicroBatchesCalculator
):
"""Calculator of number of microbatches with constant global batch size.
Args:
global_batch_size (int):
Global batch size.
micro_batch_size (int):
Micro batch size.
data_parallel_size (int):
Data parallel size.
decrease_batch_size_if_needed (bool):
If true, decrease batch size to ensure divisibility by DP size * microbatch size
(if needed).
rank (int):
Rank (to determine whether logging should be performed).
"""
def
__init__
(
self
,
global_batch_size
:
int
,
micro_batch_size
:
int
,
data_parallel_size
:
int
,
decrease_batch_size_if_needed
:
bool
,
rank
:
int
,
)
->
None
:
micro_batch_times_data_parallel_size
=
micro_batch_size
*
data_parallel_size
if
decrease_batch_size_if_needed
:
running_global_batch_size
=
_round
(
global_batch_size
,
micro_batch_times_data_parallel_size
)
assert
running_global_batch_size
%
micro_batch_times_data_parallel_size
==
0
if
rank
==
0
:
logger
.
info
(
f
'decreasing batch size from
{
global_batch_size
}
to
{
running_global_batch_size
}
'
f
'to keep divisiblity by micro_batch_size=
{
micro_batch_size
}
* '
f
'data_parallel_size=
{
data_parallel_size
}
'
)
self
.
num_micro_batches
=
(
running_global_batch_size
//
micro_batch_times_data_parallel_size
)
else
:
assert
global_batch_size
%
micro_batch_times_data_parallel_size
==
0
,
(
'global batch size ({}) is not divisible by micro batch size ({})'
' times data parallel size ({})'
.
format
(
global_batch_size
,
micro_batch_size
,
data_parallel_size
)
)
running_global_batch_size
=
global_batch_size
self
.
num_micro_batches
=
global_batch_size
//
micro_batch_times_data_parallel_size
assert
(
self
.
num_micro_batches
>=
1
),
'number of microbatches should be at least 1, got {}.'
.
format
(
self
.
num_micro_batches
)
self
.
current_global_batch_size
=
global_batch_size
self
.
current_running_global_batch_size
=
running_global_batch_size
self
.
micro_batch_size
=
micro_batch_size
def
update
(
self
,
consumed_samples
,
consistency_check
,
verbose
=
False
)
->
None
:
pass
class
RampupBatchsizeNumMicroBatchesCalculator
(
NumMicroBatchesCalculator
):
"""Calculator of number of microbatches with batch size rampup.
Over `steps = (global-batch-size - start-batch-size) / batch_size_increment` increment batch
size from start-batch-size to global-batch-size using rampup-samples / steps
samples.
Args:
global_batch_size (int):
Global batch size post rampup.
micro_batch_size (int):
Micro batch size.
data_parallel_size (int):
Data parallel size.
decrease_batch_size_if_needed (bool):
If true, decrease batch size to ensure divisibility by DP size * microbatch size
(if needed).
rank (int):
Rank (to determine whether logging should be performed).
start_global_batch_size (int):
Global batch size to start with.
batch_size_increment (int):
Global batch size increments.
ramup_samples (int):
Number of samples to use ramp up global
batch size from `start_global_batch_size` to `global_batch_size`.
"""
def
__init__
(
self
,
global_batch_size
:
int
,
micro_batch_size
:
int
,
data_parallel_size
:
int
,
decrease_batch_size_if_needed
:
bool
,
rank
:
int
,
start_global_batch_size
:
int
,
batch_size_increment
:
int
,
ramup_samples
:
int
,
)
->
None
:
assert
global_batch_size
>
0
,
'global batch size should be positive, got {}.'
.
format
(
global_batch_size
)
assert
start_global_batch_size
>
0
,
'start batch size should be positive, got {}.'
.
format
(
start_global_batch_size
)
assert
batch_size_increment
>
0
,
'batch size increment should be positive, got {}.'
.
format
(
batch_size_increment
)
assert
ramup_samples
>=
0
,
'ramp-up samples should be non-negative, got {}.'
.
format
(
ramup_samples
)
self
.
global_batch_size
=
global_batch_size
self
.
micro_batch_size
=
micro_batch_size
self
.
data_parallel_size
=
data_parallel_size
self
.
decrease_batch_size_if_needed
=
decrease_batch_size_if_needed
self
.
rank
=
rank
self
.
start_global_batch_size
=
start_global_batch_size
self
.
batch_size_increment
=
batch_size_increment
self
.
ramup_samples
=
ramup_samples
self
.
micro_batch_times_data_parallel_size
=
self
.
micro_batch_size
*
self
.
data_parallel_size
assert
self
.
micro_batch_times_data_parallel_size
>
0
self
.
current_global_batch_size
=
None
diff_batch_size
=
self
.
global_batch_size
-
self
.
start_global_batch_size
assert
diff_batch_size
>=
0
,
(
'expected global batch size to be greater than or equal to start batch size, '
f
'got
{
self
.
global_batch_size
}
and
{
self
.
start_global_batch_size
}
'
)
assert
diff_batch_size
%
batch_size_increment
==
0
,
(
'expected '
f
'global batch size interval (
{
diff_batch_size
}
) to be divisible by global batch '
f
'size increment (
{
batch_size_increment
}
)'
)
num_increments
=
diff_batch_size
//
self
.
batch_size_increment
self
.
rampup_samples_per_increment
=
self
.
ramup_samples
/
num_increments
# Initialize number of microbatches.
self
.
update
(
0
,
consistency_check
=
False
,
verbose
=
True
)
def
update
(
self
,
consumed_samples
:
int
,
consistency_check
:
bool
,
verbose
:
bool
=
False
)
->
None
:
"""Update number of microbatches.
Args:
consumed_samples (int): Number of samples consumed.
consistency_check (bool): Option to check current schedule's consistency.
verbose (bool, optional): Option to control logging. Defaults to False.
"""
# Update current global batch size.
global_batch_size_changed
=
False
old_current_global_batch_size
=
self
.
current_global_batch_size
if
consumed_samples
>
self
.
ramup_samples
:
self
.
current_global_batch_size
=
self
.
global_batch_size
else
:
steps
=
int
(
consumed_samples
/
self
.
rampup_samples_per_increment
)
self
.
current_global_batch_size
=
(
self
.
start_global_batch_size
+
steps
*
self
.
batch_size_increment
)
assert
self
.
current_global_batch_size
<=
self
.
global_batch_size
if
old_current_global_batch_size
!=
self
.
current_global_batch_size
:
global_batch_size_changed
=
True
if
self
.
rank
==
0
and
global_batch_size_changed
and
verbose
:
if
old_current_global_batch_size
is
None
:
logger
.
info
(
f
'setting initial batch size to
{
self
.
current_global_batch_size
}
'
)
else
:
logger
.
info
(
f
'ramping up batch size from
{
old_current_global_batch_size
}
to '
f
'
{
self
.
current_global_batch_size
}
'
)
# Check consistency of the current global batch size.
if
consistency_check
and
not
self
.
decrease_batch_size_if_needed
:
assert
(
self
.
current_global_batch_size
%
self
.
micro_batch_times_data_parallel_size
==
0
),
(
'current global '
'batch size ({}) is not divisible by micro-batch-size ({}) times'
'data parallel size ({})'
.
format
(
self
.
current_global_batch_size
,
self
.
micro_batch_size
,
self
.
data_parallel_size
)
)
if
(
self
.
decrease_batch_size_if_needed
and
self
.
current_global_batch_size
%
self
.
micro_batch_times_data_parallel_size
!=
0
):
self
.
current_running_global_batch_size
=
_round
(
self
.
current_global_batch_size
,
self
.
micro_batch_times_data_parallel_size
)
if
self
.
rank
==
0
and
global_batch_size_changed
and
verbose
:
logger
.
info
(
f
'decreasing batch size from
{
self
.
current_global_batch_size
}
to '
f
'
{
self
.
current_running_global_batch_size
}
to keep divisiblity by '
f
'micro_batch_size=
{
self
.
micro_batch_size
}
* '
f
'data_parallel_size=
{
self
.
data_parallel_size
}
'
)
assert
(
self
.
current_running_global_batch_size
%
self
.
micro_batch_times_data_parallel_size
==
0
)
else
:
self
.
current_running_global_batch_size
=
self
.
current_global_batch_size
self
.
num_micro_batches
=
(
self
.
current_running_global_batch_size
//
self
.
micro_batch_times_data_parallel_size
)
megatron/core/optimizer/__init__.py
0 → 100755
View file @
3bec6514
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import
logging
from
typing
import
Callable
,
Dict
,
List
,
Optional
,
Tuple
import
torch
try
:
from
transformer_engine.pytorch.optimizers
import
FusedAdam
as
Adam
from
transformer_engine.pytorch.optimizers
import
FusedSGD
as
SGD
except
ImportError
:
try
:
from
apex.optimizers
import
FusedAdam
as
Adam
from
apex.optimizers
import
FusedSGD
as
SGD
except
ImportError
:
import
warnings
warnings
.
warn
(
f
'Transformer Engine and Apex are not installed. Falling back to Torch optimizers.'
)
# Apex's FusedAdam is a drop-in replacement for torch's AdamW.
# pylint: disable-next=line-too-long.
# See https://github.com/NVIDIA/apex/blob/7b73b12361068a10b0f44844534613f252a5ea75/apex/optimizers/fused_adam.py#L16.
from
torch.optim
import
AdamW
as
Adam
,
SGD
from
megatron.core
import
mpu
from
..distributed.param_and_grad_buffer
import
_ParamAndGradBuffer
from
..transformer.module
import
MegatronModule
from
..utils
import
log_single_rank
from
.distrib_optimizer
import
DistributedOptimizer
from
.grad_scaler
import
ConstantGradScaler
,
DynamicGradScaler
from
.optimizer
import
(
ChainedOptimizer
,
Float16OptimizerWithFloat16Params
,
FP32Optimizer
,
MegatronOptimizer
,
)
from
.optimizer_config
import
OptimizerConfig
logger
=
logging
.
getLogger
(
__name__
)
def
_get_param_groups
(
model_chunks
:
List
[
MegatronModule
],
no_weight_decay_cond
:
Optional
[
Callable
],
scale_lr_cond
:
Optional
[
Callable
],
lr_mult
:
float
,
lr
:
float
,
min_lr
:
float
,
decoupled_lr
:
Optional
[
float
],
decoupled_min_lr
:
Optional
[
float
],
)
->
List
[
Dict
]:
"""Create parameter groups for optimizer.
Creates parameter groups based on weight decay condition (regularized vs
non regularized), learning rate scale condition (lr vs lr_mult * lr),
and whether it is expert parameters. scale_lr_cond is used during finetuning
where head of the network requires a scaled version of the base learning rate.
Args:
model_chunks (List[MegatronModule]): model chunks to create parameter
groups for.
no_weight_decay_cond (func, optional): function to determine whether a
parameter should not perform weight decay.
scale_lr_cond (func, optional): function to determine whether a parameter
should have a scaled learning rate.
lr_mult (float): learning rate multiplier for parameters that
satisfy scale_lr_cond.
lr (float): learning rate.
min_lr (float): minimum learning rate.
decoupled_lr (Optional[float]): optional decoupled learning rate.
decoupled_min_lr (Optional[float]): optional decoupled minimum learning rate.
Returns:
List of parameter groups.
"""
use_decoupled_learning_rate
=
decoupled_lr
is
not
None
# Map (wd_mult, lr_mult, is_expert_parallel, is_decoupled_lr) to params.
params_map
=
{}
for
model_chunk
in
model_chunks
:
for
name
,
param
in
model_chunk
.
named_parameters
():
if
not
param
.
requires_grad
:
continue
is_expert_parallel
=
not
getattr
(
param
,
'allreduce'
,
True
)
if
no_weight_decay_cond
is
not
None
:
no_wd
=
no_weight_decay_cond
(
name
,
param
)
else
:
# Do not regularize biases and norm parameters.
no_wd
=
name
.
endswith
(
".bias"
)
or
len
(
param
.
shape
)
==
1
if
scale_lr_cond
is
not
None
:
scale_lr
=
scale_lr_cond
(
name
,
param
)
else
:
scale_lr
=
False
if
not
no_wd
and
not
scale_lr
:
wd_mult
,
_lr_mult
=
1.0
,
1.0
elif
not
no_wd
and
scale_lr
:
wd_mult
,
_lr_mult
=
1.0
,
lr_mult
elif
no_wd
and
not
scale_lr
:
wd_mult
,
_lr_mult
=
0.0
,
1.0
else
:
wd_mult
,
_lr_mult
=
0.0
,
lr_mult
is_decoupled_lr
=
False
# For input/embedding and output layer: embedding.word_embeddings.weight /
# output_layer.weight.
if
use_decoupled_learning_rate
and
getattr
(
param
,
'is_embedding_or_output_parameter'
,
False
):
is_decoupled_lr
=
True
key
=
(
wd_mult
,
_lr_mult
,
is_expert_parallel
,
is_decoupled_lr
)
if
key
not
in
params_map
:
params_map
[
key
]
=
[]
params_map
[
key
].
append
(
param
)
param_groups
=
[]
for
(
wd_mult
,
_lr_mult
,
is_expert_parallel
,
is_decoupled_lr
),
params
in
params_map
.
items
():
assert
len
(
params
)
>
0
param_group
=
{
'params'
:
params
,
'wd_mult'
:
wd_mult
,
'lr_mult'
:
_lr_mult
,
'is_expert_parallel'
:
is_expert_parallel
,
'is_decoupled_lr'
:
is_decoupled_lr
,
}
param_groups
.
append
(
param_group
)
param_groups
=
_update_min_and_max_lr_in_param_groups
(
param_groups
,
lr
=
lr
,
min_lr
=
min_lr
,
decoupled_lr
=
decoupled_lr
,
decoupled_min_lr
=
decoupled_min_lr
,
)
return
param_groups
def
_update_min_and_max_lr_in_param_groups
(
param_groups
:
List
[
Dict
],
lr
:
float
,
min_lr
:
float
,
decoupled_lr
:
Optional
[
float
],
decoupled_min_lr
:
Optional
[
float
],
)
->
List
[
Dict
]:
"""
Updates `max_lr` and `min_lr` values in each parameter group, and returns new list.
By default, each group will use `lr` / `min_lr` as `max_lr` / `min_lr`.
If `decoupled_lr` is provided, then `decoupled_lr` / `decoupled_min_lr` will be used
as `max_lr` / `min_lr` for the input and output layer.
Args:
param_groups (List): parameter groups whose 'max_lr' and `min_lr` fields need to
be adjusted.
lr (float): learning rate.
min_lr (float): minimum learning rate.
decoupled_lr (Optional[float]): optional decoupled learning rate.
decoupled_min_lr (Optional[float]): optional decoupled minimum learning rate.
Returns:
List of adjusted parameter groups.
"""
if
decoupled_min_lr
is
None
:
decoupled_min_lr
=
min_lr
for
param_group
in
param_groups
:
if
param_group
[
'is_decoupled_lr'
]:
assert
decoupled_lr
is
not
None
param_group
[
'max_lr'
]
=
decoupled_lr
param_group
[
'min_lr'
]
=
decoupled_min_lr
else
:
param_group
[
'max_lr'
]
=
lr
param_group
[
'min_lr'
]
=
min_lr
return
param_groups
def
_get_param_groups_and_buffers
(
model_chunks
:
List
[
MegatronModule
],
model_chunk_offset
:
int
,
config
:
OptimizerConfig
,
no_weight_decay_cond
:
Optional
[
Callable
],
scale_lr_cond
:
Optional
[
Callable
],
lr_mult
:
float
,
filter_fn
:
Callable
,
buffer_name
:
str
,
)
->
Tuple
[
List
[
Dict
],
Dict
[
int
,
List
[
_ParamAndGradBuffer
]]]:
"""Returns parameter groups and buffer for optimizer.
Args:
model_chunks (List[MegatronModule]): model chunks to create parameter
groups for.
model_chunk_offset (int): offset of model_chunks in global model_chunks list.
config (OptimizerConfig): optimizer configuration object.
no_weight_decay_cond (func, optional): function to determine whether a
parameter should not perform weight decay.
scale_lr_cond (func, optional): function to determine whether a parameter
should have a scaled learning rate.
lr_mult (float): learning rate multiplier for parameters that
satisfy scale_lr_cond.
lr (float): learning rate.
min_lr (float): minimum learning rate.
filter_fn (callable): filtering function for param_groups.
buffer_name (str): name of buffer.
Returns:
List of parameter groups and dictionary of model chunk IDs to buffers.
"""
param_groups
=
_get_param_groups
(
model_chunks
,
no_weight_decay_cond
,
scale_lr_cond
,
lr_mult
,
lr
=
config
.
lr
,
min_lr
=
config
.
min_lr
,
decoupled_lr
=
config
.
decoupled_lr
,
decoupled_min_lr
=
config
.
decoupled_min_lr
,
)
param_groups
=
list
(
filter
(
filter_fn
,
param_groups
))
buffers
=
{}
for
model_chunk_idx
,
model_chunk
in
enumerate
(
model_chunks
):
if
hasattr
(
model_chunk
,
buffer_name
):
buffers
[
model_chunk_idx
+
model_chunk_offset
]
=
getattr
(
model_chunk
,
buffer_name
)
return
param_groups
,
buffers
def
_get_megatron_optimizer_based_on_param_groups
(
config
:
OptimizerConfig
,
model_chunks
:
List
[
MegatronModule
],
param_groups
:
List
,
per_model_buffers
:
Optional
[
Dict
[
int
,
List
[
_ParamAndGradBuffer
]]]
=
None
,
model_parallel_group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
data_parallel_group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
data_parallel_group_gloo
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
data_parallel_group_idx
:
Optional
[
int
]
=
None
,
distributed_optimizer_instance_id
:
Optional
[
int
]
=
0
,
)
->
MegatronOptimizer
:
"""Get Megatron optimizer based on parameter groups.
Args:
config (OptimizerConfig): optimizer configuration object.
model_chunks (list): list of model chunks.
param_groups (list): list of parameter groups.
per_model_buffers (dict, optional): buffers for distributed optimizer. Defaults to None.
data_parallel_group (torch.distributed.ProcessGroup, optional): data-parallel group for
distributed optimizer. Defaults to None.
data_parallel_group_gloo (torch.distributed.ProcessGroup, optional): gloo data-parallel
group for distributed optimizer. Defaults to None.
data_parallel_group_idx (int, optional): data-parallel group index for distributed
optimizer. Defaults to None.
distributed_optimizer_instance_id (int, optional): Distributed optimizer instance. Defaults
0.
Returns:
Instance of MegatronOptimizer.
"""
if
config
.
optimizer
==
'adam'
:
optimizer
=
Adam
(
param_groups
,
lr
=
config
.
lr
,
weight_decay
=
config
.
weight_decay
,
betas
=
(
config
.
adam_beta1
,
config
.
adam_beta2
),
eps
=
config
.
adam_eps
,
)
def
init_state_fn
(
opt
):
for
group
in
opt
.
param_groups
:
for
p
in
group
[
'params'
]:
if
len
(
opt
.
state
[
p
])
==
0
:
opt
.
state
[
p
][
'exp_avg'
]
=
torch
.
zeros_like
(
p
.
data
)
opt
.
state
[
p
][
'exp_avg_sq'
]
=
torch
.
zeros_like
(
p
.
data
)
elif
config
.
optimizer
==
'sgd'
:
optimizer
=
SGD
(
param_groups
,
lr
=
config
.
lr
,
weight_decay
=
config
.
weight_decay
,
momentum
=
config
.
sgd_momentum
,
)
init_state_fn
=
None
else
:
raise
Exception
(
'{} optimizer is not supported.'
.
format
(
config
.
optimizer
))
# Mixed precision optimizer.
# - Note: both the Float16Optimizer and the DistributedOptimizer inherit
# from the MixedPrecisionOptimizer, which manages any optimizer where
# the model params and main params are distinct.
if
config
.
fp16
or
config
.
bf16
or
config
.
use_distributed_optimizer
:
# Grad scaler:
# if loss-scale is provided, instantiate the constant scaler.
# if we are using fp16 and loss-scale is not present, use a
# dynamic scaler.
# otherwise we are running in bf16 with no loss-scale so
# leave it as None.
grad_scaler
=
None
# Constant loss scale.
if
config
.
loss_scale
:
grad_scaler
=
ConstantGradScaler
(
config
.
loss_scale
)
# Dynamic loss scale.
else
:
if
config
.
fp16
:
grad_scaler
=
DynamicGradScaler
(
initial_scale
=
config
.
initial_loss_scale
,
min_scale
=
config
.
min_loss_scale
,
growth_factor
=
2.0
,
backoff_factor
=
0.5
,
growth_interval
=
config
.
loss_scale_window
,
hysteresis
=
config
.
hysteresis
,
)
optimizer_args
=
[
optimizer
,
config
,
grad_scaler
,
init_state_fn
]
if
config
.
use_distributed_optimizer
:
optimizer
=
DistributedOptimizer
(
*
optimizer_args
,
model_chunks
=
model_chunks
,
per_model_buffers
=
per_model_buffers
,
data_parallel_group
=
data_parallel_group
,
data_parallel_group_gloo
=
data_parallel_group_gloo
,
data_parallel_group_idx
=
data_parallel_group_idx
,
distributed_optimizer_instance_id
=
distributed_optimizer_instance_id
,
)
else
:
optimizer
=
Float16OptimizerWithFloat16Params
(
*
optimizer_args
)
setattr
(
optimizer
,
'grad_stats_parallel_group'
,
model_parallel_group
)
else
:
# FP32 optimizer.
optimizer
=
FP32Optimizer
(
optimizer
,
config
,
init_state_fn
)
setattr
(
optimizer
,
'grad_stats_parallel_group'
,
model_parallel_group
)
return
optimizer
def
get_megatron_optimizer
(
config
:
OptimizerConfig
,
model_chunks
:
List
[
MegatronModule
],
no_weight_decay_cond
:
Optional
[
Callable
]
=
None
,
scale_lr_cond
:
Optional
[
Callable
]
=
None
,
lr_mult
:
float
=
1.0
,
)
->
MegatronOptimizer
:
"""Retrieve the Megatron optimizer for model chunks.
We use separate optimizers for expert parameters and non-expert parameters.
Args:
config (OptimizerConfig): optimizer configuration object.
model_chunks (List[MegatronModule]): model chunks to get optimizer for.
no_weight_decay_cond (func, optional): function to determine whether a parameter
should not perform weight decay. Defaults to None.
scale_lr_cond (func, optional): function to determine whether a parameter
should have a scaled learning rate. Defaults to None.
lr_mult (float, optional): learning rate multiplier for parameters that
satisfy scale_lr_cond. Defaults to 1.0.
Returns:
Instance of MegatronOptimizer.
"""
log_single_rank
(
logger
,
logging
.
INFO
,
f
'Setting up optimizer with config
{
config
}
'
)
# Separate out first model chunk if overlapping param AG with optimizer step.
if
config
.
overlap_param_gather_with_optimizer_step
:
all_dense_model_chunks
=
[[
model_chunks
[
0
]],
model_chunks
[
1
:]]
overlap_param_gather_with_optimizer_step_flags
=
[
True
,
False
]
else
:
all_dense_model_chunks
=
[
model_chunks
]
overlap_param_gather_with_optimizer_step_flags
=
[
False
]
model_parallel_rank
=
torch
.
distributed
.
get_rank
(
mpu
.
get_model_parallel_group
())
if
torch
.
distributed
.
get_world_size
(
mpu
.
get_data_parallel_group
(
with_context_parallel
=
True
,
partial_data_parallel
=
False
)
)
>
torch
.
distributed
.
get_world_size
(
mpu
.
get_data_parallel_group
(
with_context_parallel
=
True
,
partial_data_parallel
=
True
)
):
distributed_optimizer_instance_id
=
torch
.
distributed
.
get_rank
(
mpu
.
get_inter_partial_data_parallel_group
()
)
else
:
distributed_optimizer_instance_id
=
0
optimizers
=
[]
model_chunk_offset
=
0
for
dense_model_chunks
,
overlap_param_gather_with_optimizer_step
in
zip
(
all_dense_model_chunks
,
overlap_param_gather_with_optimizer_step_flags
):
param_groups
,
buffers
=
_get_param_groups_and_buffers
(
dense_model_chunks
,
model_chunk_offset
=
model_chunk_offset
,
config
=
config
,
no_weight_decay_cond
=
no_weight_decay_cond
,
scale_lr_cond
=
scale_lr_cond
,
lr_mult
=
lr_mult
,
filter_fn
=
lambda
g
:
not
g
[
'is_expert_parallel'
],
buffer_name
=
'buffers'
,
)
for
model_chunk
in
dense_model_chunks
:
model_chunk
.
overlap_param_gather_with_optimizer_step
=
(
overlap_param_gather_with_optimizer_step
)
optimizers
.
append
(
_get_megatron_optimizer_based_on_param_groups
(
config
,
model_chunks
=
dense_model_chunks
,
param_groups
=
param_groups
,
per_model_buffers
=
buffers
,
model_parallel_group
=
mpu
.
get_model_parallel_group
(),
data_parallel_group
=
mpu
.
get_data_parallel_group
(
with_context_parallel
=
True
,
partial_data_parallel
=
True
),
data_parallel_group_gloo
=
mpu
.
get_data_parallel_group_gloo
(
with_context_parallel
=
True
,
partial_data_parallel
=
True
),
data_parallel_group_idx
=
model_parallel_rank
,
distributed_optimizer_instance_id
=
distributed_optimizer_instance_id
,
)
)
model_chunk_offset
+=
1
moe_param_groups
,
moe_buffers
=
_get_param_groups_and_buffers
(
model_chunks
,
model_chunk_offset
=
0
,
config
=
config
,
no_weight_decay_cond
=
no_weight_decay_cond
,
scale_lr_cond
=
scale_lr_cond
,
lr_mult
=
lr_mult
,
filter_fn
=
lambda
g
:
g
[
'is_expert_parallel'
],
buffer_name
=
'expert_parallel_buffers'
,
)
if
len
(
moe_param_groups
)
>
0
:
model_parallel_rank
=
torch
.
distributed
.
get_rank
(
mpu
.
get_expert_tensor_model_pipeline_parallel_group
()
)
optimizers
.
append
(
_get_megatron_optimizer_based_on_param_groups
(
config
,
model_chunks
=
model_chunks
,
param_groups
=
moe_param_groups
,
per_model_buffers
=
moe_buffers
,
model_parallel_group
=
mpu
.
get_expert_tensor_model_pipeline_parallel_group
(),
data_parallel_group
=
mpu
.
get_expert_data_parallel_group
(),
data_parallel_group_gloo
=
mpu
.
get_expert_data_parallel_group_gloo
(),
data_parallel_group_idx
=
model_parallel_rank
,
)
)
if
len
(
optimizers
)
==
1
:
return
optimizers
[
0
]
return
ChainedOptimizer
(
optimizers
)
megatron/optimizer/clip_grads.py
→
megatron/
core/
optimizer/clip_grads.py
100644 → 100755
View file @
3bec6514
# Copyright (c) 202
2
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 202
4
, NVIDIA CORPORATION. All rights reserved.
"""Gradient clipping."""
from
typing
import
List
,
Optional
,
Union
import
torch
from
torch
import
inf
from
apex.multi_tensor_apply
import
multi_tensor_applier
import
amp_C
from
megatron.model.module
import
param_is_not_shared
from
megatron.core.tensor_parallel
import
param_is_not_tensor_parallel_duplicate
def
clip_grad_norm_fp32
(
parameters
,
grads_for_norm
,
max_norm
,
norm_type
=
2
,
model_parallel_group
=
None
):
"""Clips gradient norm of an iterable of parameters whose gradients
are in fp32.
try
:
from
transformer_engine.pytorch.optimizers
import
(
multi_tensor_applier
,
multi_tensor_l2norm
,
multi_tensor_scale
,
)
l2_norm_impl
=
multi_tensor_l2norm
multi_tensor_scale_impl
=
multi_tensor_scale
except
ImportError
:
try
:
import
amp_C
from
apex.multi_tensor_apply
import
multi_tensor_applier
l2_norm_impl
=
amp_C
.
multi_tensor_l2norm
multi_tensor_scale_impl
=
amp_C
.
multi_tensor_scale
except
ImportError
:
import
warnings
warnings
.
warn
(
f
'Transformer Engine and Apex are not installed. '
'Falling back to local implementations of multi_tensor_applier, '
'multi_tensor_l2norm, and multi_tensor_scale'
)
from
megatron.core.utils
import
(
local_multi_tensor_applier
,
local_multi_tensor_l2_norm
,
local_multi_tensor_scale
,
)
multi_tensor_applier
=
local_multi_tensor_applier
l2_norm_impl
=
local_multi_tensor_l2_norm
multi_tensor_scale_impl
=
local_multi_tensor_scale
from
..tensor_parallel
import
param_is_not_tensor_parallel_duplicate
from
..transformer.module
import
param_is_not_shared
from
..utils
import
get_data_parallel_group_if_dtensor
,
to_local_if_dtensor
def
get_grad_norm_fp32
(
grads_for_norm
:
Union
[
List
[
torch
.
Tensor
],
torch
.
Tensor
],
norm_type
:
Union
[
int
,
float
]
=
2
,
grad_stats_parallel_group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
)
->
float
:
"""Calculate the norm of gradients in fp32.
This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
added functionality to handle model parallel parameters. Note that
the gradients are modified in place.
added functionality to handle model parallel parameters.
Arguments:
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
single Tensor that will have gradients normalized
grads_for_norm (Iterable[Tensor]): an iterable of Tensors or a single
grads_for_norm (Iterable[Tensor] or Tensor): an iterable of Tensors or a single
Tensor that will be used for calculating the grad norm.
max_norm (float or int): max norm of the gradients
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
infinity norm.
model_parallel_group (group): given the nature of the distributed
optimizer, this is passed as an argument.
grad_stats_parallel_group (group): Process group for reducing the grad norms. This is
generally the model-parallel group for non-distributed optimizers, and the entire
world for the distributed optimizer.
Returns:
Total norm of the parameters (viewed as a single vector).
"""
if
isinstance
(
parameters
,
torch
.
Tensor
):
parameters
=
[
parameters
]
if
isinstance
(
grads_for_norm
,
torch
.
Tensor
):
grads_for_norm
=
[
grads_for_norm
]
# Grads.
grads
=
[]
for
param
in
parameters
:
if
param
.
grad
is
not
None
:
assert
param
.
grad
.
type
()
==
'torch.cuda.FloatTensor'
grads
.
append
(
param
.
grad
.
detach
())
data_parallel_group
=
None
for
grad
in
grads_for_norm
:
data_parallel_group
=
get_data_parallel_group_if_dtensor
(
grad
,
data_parallel_group
)
grads_for_norm
=
[
to_local_if_dtensor
(
grad
)
for
grad
in
grads_for_norm
]
# Norm parameters.
max_norm
=
float
(
max_norm
)
norm_type
=
float
(
norm_type
)
total_norm
=
0.0
# Calculate norm.
if
norm_type
==
inf
:
total_norm
=
max
(
grad
.
abs
().
max
()
for
grad
in
grads_for_norm
)
total_norm_cuda
=
torch
.
cuda
.
FloatTensor
([
float
(
total_norm
)])
# Take max across all model-parallel GPUs.
torch
.
distributed
.
all_reduce
(
total_norm_cuda
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
group
=
model_parallel_group
)
total_norm_cuda
=
torch
.
tensor
([
float
(
total_norm
)],
dtype
=
torch
.
float
,
device
=
'cuda'
)
# Take max across all data-parallel GPUs if using FSDP and then all model-parallel GPUs.
if
data_parallel_group
:
torch
.
distributed
.
all_reduce
(
total_norm_cuda
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
group
=
data_parallel_group
)
torch
.
distributed
.
all_reduce
(
total_norm_cuda
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
group
=
grad_stats_parallel_group
)
total_norm
=
total_norm_cuda
[
0
].
item
()
else
:
if
norm_type
==
2.0
:
dummy_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
]
)
dummy_overflow_buf
=
torch
.
tensor
([
0
],
dtype
=
torch
.
int
,
device
=
'cuda'
)
# Use apex's multi-tensor applier for efficiency reasons.
# Multi-tensor applier takes a function and a list of list
# and performs the operation on that list all in one kernel.
if
grads_for_norm
:
grad_norm
,
_
=
multi_tensor_applier
(
amp_C
.
multi_tensor_l2norm
,
l2_norm_impl
,
dummy_overflow_buf
,
[
grads_for_norm
],
False
# no per-parameter norm
False
,
# no per-parameter norm
)
else
:
grad_norm
=
torch
.
cuda
.
FloatTensor
([
0
]
)
grad_norm
=
torch
.
tensor
([
0
],
dtype
=
torch
.
float
,
device
=
'cuda'
)
# Since we will be summing across data parallel groups,
# we need the pow(norm-type).
total_norm
=
grad_norm
**
norm_type
total_norm
=
grad_norm
**
norm_type
else
:
for
grad
in
grads_for_norm
:
grad_norm
=
torch
.
norm
(
grad
,
norm_type
)
total_norm
+=
grad_norm
**
norm_type
# Sum across all model-parallel GPUs.
torch
.
distributed
.
all_reduce
(
total_norm
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
model_parallel_group
)
total_norm
+=
grad_norm
**
norm_type
# Sum across all data-parallel GPUs if using FSDP and then all model-parallel GPUs.
if
data_parallel_group
:
torch
.
distributed
.
all_reduce
(
total_norm
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
data_parallel_group
)
torch
.
distributed
.
all_reduce
(
total_norm
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
grad_stats_parallel_group
)
total_norm
=
total_norm
.
item
()
**
(
1.0
/
norm_type
)
return
total_norm
def
clip_grad_by_total_norm_fp32
(
parameters
:
Union
[
List
[
torch
.
Tensor
],
torch
.
Tensor
],
max_norm
:
Union
[
int
,
float
],
total_norm
:
float
,
):
"""Clips gradient of an iterable of parameters in fp32 by total norm.
Note that the gradients are modified in place.
Args:
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
single Tensor that will have gradients normalized.
max_norm (float or int): max norm of the gradients.
total_norm (float): total norm of the gradients.
"""
# Grads.
params
=
[]
grads
=
[]
for
param
in
parameters
:
if
param
.
grad
is
not
None
:
assert
param
.
grad
.
type
()
==
'torch.cuda.FloatTensor'
params
.
append
(
param
)
grads
.
append
(
to_local_if_dtensor
(
param
.
grad
).
detach
())
# Scale.
clip_coeff
=
max_norm
/
(
total_norm
+
1.0e-6
)
if
clip_coeff
<
1.0
:
dummy_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
multi_tensor_applier
(
amp_C
.
multi_tensor_scale
,
dummy_overflow_buf
,
[
grads
,
grads
],
clip_coeff
)
dummy_overflow_buf
=
torch
.
tensor
([
0
],
dtype
=
torch
.
int
,
device
=
'cuda'
)
multi_tensor_applier
(
multi_tensor_scale_impl
,
dummy_overflow_buf
,
[
grads
,
grads
],
clip_coeff
)
return
total_norm
def
count_zeros_fp32
(
parameters
:
Union
[
List
[
torch
.
Tensor
],
torch
.
Tensor
],
grad_stats_parallel_group
:
torch
.
distributed
.
ProcessGroup
,
)
->
float
:
"""Counts the number of zeros in gradients associated with the passed-in list of
parameters.
def
count_zeros_fp32
(
parameters
,
model_parallel_group
):
Args:
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
single Tensor that will have the number of zeros in its corresponding
gradient counted.
grad_stats_parallel_group (group): Process group for reducing the num_zeros count. This is
generally the model-parallel group for non-distributed optimizers, and the entire
world for the distributed optimizer.
"""
if
isinstance
(
parameters
,
torch
.
Tensor
):
parameters
=
[
parameters
]
...
...
@@ -115,20 +191,29 @@ def count_zeros_fp32(parameters, model_parallel_group):
# - grad should not be none
# - parameter should not be shared
# - should not be a replica due to tensor model parallelism
total_num_zeros
=
torch
.
cuda
.
FloatTensor
([
0.0
])
total_num_zeros
=
torch
.
tensor
([
0.0
],
dtype
=
torch
.
float
,
device
=
'cuda'
)
data_parallel_group
=
None
for
param
in
parameters
:
grad_not_none
=
param
.
grad
is
not
None
is_not_shared
=
param_is_not_shared
(
param
)
is_not_tp_duplicate
=
param_is_not_tensor_parallel_duplicate
(
param
)
if
grad_not_none
and
is_not_shared
and
is_not_tp_duplicate
:
grad
=
param
.
grad
.
detach
()
data_parallel_group
=
get_data_parallel_group_if_dtensor
(
param
.
grad
,
data_parallel_group
)
grad
=
to_local_if_dtensor
(
param
.
grad
).
detach
()
num_zeros
=
grad
.
numel
()
-
torch
.
count_nonzero
(
grad
)
total_num_zeros
=
num_zeros
+
total_num_zeros
# Sum across all data-parallel GPUs if using FSDP.
if
data_parallel_group
:
torch
.
distributed
.
all_reduce
(
total_num_zeros
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
data_parallel_group
)
# Sum across all model-parallel GPUs.
torch
.
distributed
.
all_reduce
(
total_num_zeros
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
model_parallel_group
)
torch
.
distributed
.
all_reduce
(
total_num_zeros
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
grad_stats_parallel_group
)
total_num_zeros
=
total_num_zeros
.
item
()
...
...
megatron/core/optimizer/distrib_optimizer.py
0 → 100755
View file @
3bec6514
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Megatron distributed optimizer."""
import
itertools
from
dataclasses
import
replace
from
logging
import
getLogger
from
typing
import
Callable
,
Dict
,
List
,
Optional
,
Tuple
import
torch
HAVE_APEX_OR_TE
=
True
try
:
from
transformer_engine.pytorch.optimizers
import
FusedAdam
as
Adam
except
ImportError
:
try
:
from
apex.optimizers
import
FusedAdam
as
Adam
except
ImportError
:
from
torch.optim
import
AdamW
as
Adam
HAVE_APEX_OR_TE
=
False
from
..
import
tensor_parallel
from
..config_logger
import
has_config_logger_enabled
,
log_config_to_disk
from
..dist_checkpointing
import
ShardedTensor
from
..dist_checkpointing.dict_utils
import
nested_values
from
..dist_checkpointing.mapping
import
(
LocalNonpersistentObject
,
ShardedObject
,
ShardedStateDict
,
ShardedTensorFactory
,
)
from
..dist_checkpointing.utils
import
extract_sharded_tensors_and_factories
from
..distributed.param_and_grad_buffer
import
_ParamAndGradBuffer
,
partition_buckets
from
..transformer.module
import
MegatronModule
from
..utils
import
is_float8tensor
from
.grad_scaler
import
MegatronGradScaler
from
.optimizer
import
(
MixedPrecisionOptimizer
,
_multi_tensor_copy_this_to_that
,
_zero_grad_group_helper
,
)
from
.optimizer_config
import
OptimizerConfig
try
:
# This will be used when "--fp8-param-gather" is enabled.
# When BF16/FP16 parameters don't exist, we need to cast the FP32 main parameters to
# FP8 directly in the optimizer.
from
transformer_engine.pytorch.cpp_extensions
import
cast_to_fp8
except
:
pass
logger
=
getLogger
(
__name__
)
class
Range
:
"""
A range represents a start and end points for indexing a shard
from a full tensor.
Args:
start (int): Start index.
end (int): End index.
"""
def
__init__
(
self
,
start
:
int
,
end
:
int
):
self
.
start
=
start
self
.
end
=
end
self
.
size
=
end
-
start
def
normalize
(
self
,
start
:
int
=
0
):
"""Shift start/end indexes to start at new start index.
Both start and end indexes will be shifted by [new start] - [old start].
Args:
start (int): New start index.
"""
return
Range
(
start
,
start
+
self
.
size
)
def
__str__
(
self
):
return
"%d,%d [%d]"
%
(
self
.
start
,
self
.
end
,
self
.
size
)
def
__len__
(
self
):
return
self
.
end
-
self
.
start
class
DistributedOptimizer
(
MixedPrecisionOptimizer
):
"""Distributed optimizer, for all data types (fp16, bf16, and fp32).
See __init__() below for argument details.
"""
@
classmethod
def
_build_model_gbuf_param_range_map
(
cls
,
param_world_index_map
:
Dict
[
torch
.
nn
.
Parameter
,
Tuple
],
gbuf_world_range
:
Range
,
bucket_offset
:
int
,
):
"""
Build mapping from param reference to grad buffer shard ranges.
This method builds a mapping from parameter references to grad
buffer shard ranges, specific to each data-parallel (DP) rank's
set of 'owned' parameters. Each grad buffer (padded to be an even
multiple of DP-world-size) is conceptually divided into DP-world-size
contiguous regions, where each DP rank 'owns' a contiguous region.
Ownership in this sense means DP rank is responsible for reducing
the relevant subset of grads, and updating the relevant subset of
params.
This conceptual partitioning of the grad buffer does NOT respect
parameter boundaries, and as such it is assumed that each created
range references a shard (or subset) of the full parameter. It is
easiest to think of each DP rank as operating (i.e., reducing,
gathering) purely on views into the grad buffer, for all model-to-
main & main-to-model operations.
This method creates four ranges:
- The param's range within the entire grad buffer (i.e., world index).
- The param's range within the relevant grad bucket's buffer.
- The param's range within the DP rank's local view of the grad buffer.
- The param's range within itself (i.e., its shard).
"""
# Param range map.
param_range_map
=
{}
for
param
,
param_world_indexes
in
param_world_index_map
.
items
():
# Param range.
param_world_start
,
param_world_end
,
_
=
param_world_indexes
param_local_start
=
max
(
0
,
param_world_start
-
gbuf_world_range
.
start
)
param_local_end
=
min
(
gbuf_world_range
.
size
,
param_world_end
-
gbuf_world_range
.
start
)
# Add param, if within local gbuf range.
if
param_local_end
>
param_local_start
:
param_local_range
=
Range
(
param_local_start
,
param_local_end
)
param_world_range
=
param_local_range
.
normalize
(
param_local_start
+
gbuf_world_range
.
start
)
param_world_range_in_bucket
=
Range
(
param_world_range
.
start
-
bucket_offset
,
param_world_range
.
end
-
bucket_offset
)
sub_param_start
=
max
(
0
,
gbuf_world_range
.
start
-
param_world_start
)
sub_param_range
=
param_local_range
.
normalize
(
sub_param_start
)
param_range_map
[
param
]
=
{
"gbuf_world"
:
param_world_range
,
"gbuf_world_in_bucket"
:
param_world_range_in_bucket
,
"gbuf_local"
:
param_local_range
,
"param"
:
sub_param_range
,
}
return
param_range_map
@
classmethod
def
_build_model_gbuf_range
(
cls
,
param_and_grad_buffer
:
_ParamAndGradBuffer
,
bucket_index
:
int
):
"""
Build mapping between params and their grad buffers.
This method does the initial setup for the method above. This setup
includes determining the shard ranges into the param_and_grad_buffer
for each data-parallel (DP) rank. Each DP rank keeps range info for
all other DP ranks, for the purpose of creating args for
reduce-scatter and all-gather.
"""
data_parallel_rank
=
torch
.
distributed
.
get_rank
(
param_and_grad_buffer
.
data_parallel_group
)
data_parallel_world_size
=
param_and_grad_buffer
.
data_parallel_group
.
size
()
bucket
=
param_and_grad_buffer
.
buckets
[
bucket_index
]
gbuf_size
=
bucket
.
grad_data
.
numel
()
assert
(
gbuf_size
%
data_parallel_world_size
==
0
),
f
"Each bucket's buffer size should be divisible by
{
data_parallel_world_size
}
"
max_gbuf_range_size
=
gbuf_size
//
data_parallel_world_size
# All world ranges (i.e., across all data parallel ranks).
gbuf_world_all_ranges
=
[]
for
r
in
range
(
data_parallel_world_size
):
# Compute start of chunk in this bucket.
gbuf_world_start
=
r
*
max_gbuf_range_size
gbuf_world_end
=
min
(
gbuf_size
,
gbuf_world_start
+
max_gbuf_range_size
)
# Add bucket's offset in grad buffer.
gbuf_world_range
=
Range
(
gbuf_world_start
+
bucket
.
offset
,
gbuf_world_end
+
bucket
.
offset
)
gbuf_world_all_ranges
.
append
(
gbuf_world_range
)
# Local DP's ranges.
gbuf_world_range
=
gbuf_world_all_ranges
[
data_parallel_rank
]
# Get each param's ranges.
param_range_map
=
cls
.
_build_model_gbuf_param_range_map
(
param_and_grad_buffer
.
param_index_map
,
gbuf_world_range
,
bucket
.
offset
)
# Group into dict.
data
=
{
"param_map"
:
param_range_map
}
return
data
@
classmethod
def
_build_gbuf_range_map
(
cls
,
param_and_grad_buffer
:
_ParamAndGradBuffer
):
"""
Build mapping between params and their grad buffers. These mappings are
partitioned according to data type.
Iterate through all buckets of grad buffer to construct param ranges
that this rank "owns" (the dp_rank'th shard of each bucket, where each
shard is 1/dp_world_size of the bucket).
Args:
param_and_grad_buffer (_ParamAndGradBuffer): buffer to build mapping for.
"""
return
{
(
param_and_grad_buffer
.
param_dtype
,
param_and_grad_buffer
.
grad_dtype
):
[
cls
.
_build_model_gbuf_range
(
param_and_grad_buffer
,
bucket_index
)
for
bucket_index
in
range
(
len
(
param_and_grad_buffer
.
buckets
))
]
}
@
classmethod
def
_build_model_param_gbuf_map
(
cls
,
gbuf_ranges
:
List
[
Dict
]
)
->
Dict
[
torch
.
nn
.
Parameter
,
Tuple
]:
"""
Create a reverse of the gbuf_ranges, for referencing in opposite direction.
"""
param_gbuf_map
=
{}
for
gbuf_index
,
gbuf_range_map
in
enumerate
(
gbuf_ranges
):
for
dtype
,
gbuf_range_map_for_all_buckets
in
gbuf_range_map
.
items
():
for
bucket_index
,
gbuf_range_map
in
enumerate
(
gbuf_range_map_for_all_buckets
):
for
param
,
_
in
gbuf_range_map
[
"param_map"
].
items
():
assert
param
not
in
param_gbuf_map
,
(
"Param should not be in param_gbuf_map; each param only belongs "
"to a single bucket."
)
param_gbuf_map
[
param
]
=
(
gbuf_index
,
dtype
,
bucket_index
)
return
param_gbuf_map
@
classmethod
def
_build_optimizer_group_ranges
(
cls
,
param_groups
:
List
[
Dict
],
gbuf_ranges
:
List
[
Dict
]):
"""
Create optimizer groups.
Given the set of parameter shard ranges that are owned by the current
data-parallel (DP) rank, gather the set of parameters that will be
used (in the method below) to create the current DP's optimizer
groups.
"""
# Param group map.
# World param group map.
# - Store a mapping of <model_parameter:group_index> for all parameters
# across all DP ranks. This is necessary because it is our first
# cross reference between the DDP mappings and the optimizer group
# parameters. This mapping only for use in the next step of building
# the local mapping over this DP rank's parameters.
world_param_group_map
=
{}
for
group_index
,
group
in
enumerate
(
param_groups
):
for
param
in
group
[
"params"
]:
assert
param
.
requires_grad
world_param_group_map
[
param
]
=
group_index
# Optimizer group ranges & param-group mapping.
# - Build a mapping from groups to their contained parameters, and also
# from parameters to their containing group index and order within
# the group. The group index and order are particularly important for
# saving and loading checkpoints.
local_param_group_map
=
{}
group_ranges
=
[{
"params"
:
[]}
for
_
in
param_groups
]
for
gbuf_range_map
in
gbuf_ranges
:
for
dtype
,
gbuf_range_map_for_all_buckets
in
gbuf_range_map
.
items
():
for
gbuf_range_map
in
gbuf_range_map_for_all_buckets
:
for
param
in
gbuf_range_map
[
"param_map"
]:
group_index
=
world_param_group_map
[
param
]
group_range
=
group_ranges
[
group_index
]
group_range
[
"params"
].
append
(
param
)
local_param_group_map
[
param
]
=
(
group_index
,
len
(
group_range
[
"params"
])
-
1
)
# Squeeze zero-size group ranges.
for
group_index
,
group_range
in
enumerate
(
group_ranges
):
group_range
[
"orig_group"
]
=
param_groups
[
group_index
]
group_range
[
"orig_group_idx"
]
=
param_groups
[
group_index
]
return
local_param_group_map
,
group_ranges
@
classmethod
def
_build_model_and_main_param_groups
(
cls
,
gbuf_ranges
:
List
[
Dict
],
param_gbuf_map
:
Dict
[
torch
.
nn
.
Parameter
,
Tuple
],
opt_group_ranges
:
List
,
):
"""
Create main parameter groups needed for the optimizer step.
These groups encompass both: 1) groups used by this class, for
reducing/gather, and 2) groups used by the inner optimizer for the
parameter update. Given that the conceptual grad buffer partitioning
(created in earlier method) doesn't respect parameter boundaries,
the optimizer operates on shards of the model parameters, rather than
the full parameters.
"""
# Parameter groups:
# model_float16_groups: original float16 parameters
# model_fp32_groups: original fp32 parameters
# shard_float16_groups: shards of original float16 parameters
# shard_fp32_groups: shards of original fp32 parameters
# shard_fp32_from_float16_groups: fp32 copy of float16 parameters
model_float16_groups
=
[]
model_fp32_groups
=
[]
shard_float16_groups
=
[]
shard_fp32_groups
=
[]
shard_fp32_from_float16_groups
=
[]
# Allocate (or slice) each group's param shard.
for
group_range
in
opt_group_ranges
:
# Params of this group.
model_float16_params_this_group
=
[]
model_fp32_params_this_group
=
[]
shard_float16_params_this_group
=
[]
shard_fp32_params_this_group
=
[]
shard_fp32_from_float16_params_this_group
=
[]
model_float16_groups
.
append
(
model_float16_params_this_group
)
model_fp32_groups
.
append
(
model_fp32_params_this_group
)
shard_float16_groups
.
append
(
shard_float16_params_this_group
)
shard_fp32_groups
.
append
(
shard_fp32_params_this_group
)
shard_fp32_from_float16_groups
.
append
(
shard_fp32_from_float16_params_this_group
)
for
model_param
in
group_range
[
"params"
]:
assert
model_param
.
requires_grad
gbuf_index
,
dtype
,
bucket_index
=
param_gbuf_map
[
model_param
]
gbuf_range
=
gbuf_ranges
[
gbuf_index
][
dtype
][
bucket_index
]
param_range
=
gbuf_range
[
"param_map"
][
model_param
][
"param"
]
# fp16, bf16 params.
if
model_param
.
type
()
in
[
'torch.cuda.HalfTensor'
,
'torch.cuda.BFloat16Tensor'
]:
# Clone model -> main.
shard_model_param
=
model_param
.
detach
().
view
(
-
1
)[
param_range
.
start
:
param_range
.
end
]
# If we use FP8 params to initialize FP32 main params (compared to using the
# bf16/fp16 params to initialize the main params), there will be a loss of
# precision at the beginning of training (this problem will not occur if the
# training is long enough or if the main params are loaded from a checkpoint).
if
is_float8tensor
(
model_param
)
and
hasattr
(
model_param
,
'get_high_precision_init_val'
):
shard_main_param
=
(
model_param
.
get_high_precision_init_val
()
.
view
(
-
1
)[
param_range
.
start
:
param_range
.
end
]
.
clone
()
.
to
(
shard_model_param
.
device
)
.
float
()
)
model_param
.
clear_high_precision_init_val
()
else
:
shard_main_param
=
shard_model_param
.
clone
().
float
()
tensor_parallel
.
copy_tensor_model_parallel_attributes
(
shard_model_param
,
model_param
)
tensor_parallel
.
copy_tensor_model_parallel_attributes
(
shard_main_param
,
model_param
)
if
hasattr
(
model_param
,
'shared'
):
shard_model_param
.
shared
=
model_param
.
shared
shard_main_param
.
shared
=
model_param
.
shared
# Add to group.
model_float16_params_this_group
.
append
(
model_param
)
shard_float16_params_this_group
.
append
(
shard_model_param
)
shard_fp32_from_float16_params_this_group
.
append
(
shard_main_param
)
# fp32 params.
elif
model_param
.
type
()
==
'torch.cuda.FloatTensor'
:
shard_model_param
=
model_param
.
view
(
-
1
)[
param_range
.
start
:
param_range
.
end
]
model_fp32_params_this_group
.
append
(
model_param
)
shard_fp32_params_this_group
.
append
(
shard_model_param
)
tensor_parallel
.
copy_tensor_model_parallel_attributes
(
shard_model_param
,
model_param
)
if
hasattr
(
model_param
,
'shared'
):
shard_model_param
.
shared
=
model_param
.
shared
else
:
raise
TypeError
(
'Wrapped parameters must be one of '
'torch.cuda.FloatTensor, '
'torch.cuda.HalfTensor, or '
'torch.cuda.BFloat16Tensor. '
'Received {}'
.
format
(
model_param
.
type
())
)
# Update optimizer's params.
group_range
[
"orig_group"
][
"params"
]
=
[
*
shard_fp32_params_this_group
,
*
shard_fp32_from_float16_params_this_group
,
]
return
(
model_float16_groups
,
model_fp32_groups
,
shard_float16_groups
,
shard_fp32_groups
,
shard_fp32_from_float16_groups
,
)
def
__init__
(
self
,
optimizer
:
torch
.
optim
.
Optimizer
,
config
:
OptimizerConfig
,
grad_scaler
:
MegatronGradScaler
,
init_state_fn
:
Optional
[
Callable
],
model_chunks
:
List
[
MegatronModule
],
per_model_buffers
:
Dict
[
int
,
List
[
_ParamAndGradBuffer
]],
data_parallel_group
:
torch
.
distributed
.
ProcessGroup
,
data_parallel_group_gloo
:
torch
.
distributed
.
ProcessGroup
,
data_parallel_group_idx
:
int
,
distributed_optimizer_instance_id
:
int
,
):
"""
Distributed optimizer, for all data types (fp16, bf16, and fp32).
The steps in this method create the core mapping between param and grad buffers,
parameters, and parameter shard ranges, that is needed for converting between model
param indexes and main parameter shard indexes. This method also updates the optimizer
parameter groups with the newly created shards.
Args:
optimizer (torch.optim.Optimizer): base optimizer such as Adam or SGD.
config (OptimizerConfig): configuration object for optimizer.
grad_scaler (MegatronGradScaler): used for scaling gradients. Note that
this can be None. This case happens when `bf16 = True` and we don't
use any loss scale. Note that for `bf16 = True`, we can have
a constant gradient scaler. Also for `bf16 = False`, we
always require a grad scaler.
init_state_fn (Callable, optional): function to initialize state in the optimizer.
model_chunks (List[MegatronModule]): list of model chunks.
per_model_buffers (Dict[int, List[_ParamAndGradBuffer]]): the implementation of the
distributed optimizer is centered on using a contiguous buffer for
communicating grads & params between the model state and the optimizer state.
You can find a more detailed description in
https://github.com/NVIDIA/Megatron-LM/blob/main/docs/source/distrib_optimizer.md.
data_parallel_group (torch.distributed.ProcessGroup): data-parallel group to use to
all-gather params after optimizer.step().
data_parallel_group_gloo (torch.distributed.ProcessGroup): gloo data-parallel group
(used in checkpoint loading and saving).
data_parallel_group_idx (int): index in data-parallel group (used by
distributed checkpointing logic).
distributed_optimizer_instance_id (int): index of the Distributed Optimizer instance.
"""
if
has_config_logger_enabled
(
config
):
log_config_to_disk
(
config
,
locals
(),
prefix
=
type
(
self
).
__name__
)
super
().
__init__
(
optimizer
,
config
,
grad_scaler
,
init_state_fn
)
self
.
model_chunks
=
model_chunks
self
.
ddp_config
=
self
.
model_chunks
[
0
].
ddp_config
for
model_chunk
in
self
.
model_chunks
:
assert
self
.
ddp_config
==
model_chunk
.
ddp_config
assert
isinstance
(
optimizer
,
Adam
),
"Only Adam currently supported, due to checkpointing requirements."
# Model grad buffer ranges.
assert
per_model_buffers
is
not
None
,
"per_model_buffers must be provided"
self
.
buffers
=
list
(
itertools
.
chain
(
*
per_model_buffers
.
values
()))
self
.
per_model_buffers
=
per_model_buffers
self
.
data_parallel_group
=
data_parallel_group
self
.
data_parallel_group_gloo
=
data_parallel_group_gloo
self
.
data_parallel_group_idx
=
data_parallel_group_idx
self
.
distributed_optimizer_instance_id
=
distributed_optimizer_instance_id
self
.
gbuf_idx_to_model_idx_map
=
{}
gbuf_idx
=
0
for
model_idx
,
buffers
in
self
.
per_model_buffers
.
items
():
for
_
in
buffers
:
self
.
gbuf_idx_to_model_idx_map
[
gbuf_idx
]
=
model_idx
gbuf_idx
+=
1
self
.
per_model_bucket_groups
=
{}
for
model_idx
,
buffers
in
self
.
per_model_buffers
.
items
():
self
.
per_model_bucket_groups
[
model_idx
]
=
partition_buckets
(
buffers
)
self
.
gbuf_ranges
=
[]
self
.
per_bucket_numel
=
[]
self
.
per_bucket_numel_unpadded
=
[]
for
buffer
in
self
.
buffers
:
self
.
per_bucket_numel
.
append
(
{
(
buffer
.
param_dtype
,
buffer
.
grad_dtype
):
[
bucket
.
grad_data
.
numel
()
for
bucket
in
buffer
.
buckets
]
}
)
self
.
per_bucket_numel_unpadded
.
append
(
{
(
buffer
.
param_dtype
,
buffer
.
grad_dtype
):
[
bucket
.
numel_unpadded
for
bucket
in
buffer
.
buckets
]
}
)
self
.
gbuf_ranges
.
append
(
self
.
_build_gbuf_range_map
(
buffer
))
self
.
model_param_gbuf_map
=
self
.
_build_model_param_gbuf_map
(
self
.
gbuf_ranges
)
# Optimizer ranges.
(
self
.
model_param_group_index_map
,
self
.
opt_group_ranges
)
=
(
self
.
_build_optimizer_group_ranges
(
self
.
optimizer
.
param_groups
,
self
.
gbuf_ranges
)
)
# Allocate main param shards.
(
self
.
model_float16_groups
,
self
.
model_fp32_groups
,
self
.
shard_float16_groups
,
self
.
shard_fp32_groups
,
self
.
shard_fp32_from_float16_groups
,
)
=
self
.
_build_model_and_main_param_groups
(
self
.
gbuf_ranges
,
self
.
model_param_gbuf_map
,
self
.
opt_group_ranges
)
# Update optimizer groups.
# - Also, leverage state_dict() and load_state_dict() to
# recast preexisting per-param state tensors.
self
.
optimizer
.
param_groups
=
[
g
[
"orig_group"
]
for
g
in
self
.
opt_group_ranges
]
self
.
optimizer
.
load_state_dict
(
self
.
optimizer
.
state_dict
())
def
_get_model_param_range_map
(
self
,
param
:
torch
.
nn
.
Parameter
):
"""
Given a model param, get the index sub-range of the param that this
data-parallel rank owns.
"""
gbuf_index
,
dtype
,
bucket_index
=
self
.
model_param_gbuf_map
[
param
]
gbuf_range_map
=
self
.
gbuf_ranges
[
gbuf_index
][
dtype
][
bucket_index
]
param_range_map
=
gbuf_range_map
[
"param_map"
][
param
]
return
param_range_map
def
get_grad_stats_parallel_group
(
self
)
->
torch
.
distributed
.
ProcessGroup
:
"""
With the distributed optimizer, gradient statistics (num_zeros & norm) are reduced over
all ranks (versus only the model-parallel ranks with the non-distributed optimizer).
"""
return
None
def
state_dict
(
self
):
"""
The state dict contains all non-DP-rank-dependent (i.e., non-parameter-
related) optimizer variables. The returned state dict can be stored in
the standard model/RNG checkpoint file. The parameter and dependent
optimizer state (e.g., exp_avg, exp_avg_sq) are stored in a separate
checkpoint file by calling 'save_parameter_state()'.
"""
inner_state_dict
=
self
.
optimizer
.
state_dict
()
state_dict
=
{}
# Extract 'step', for non-Apex/TE support.
if
not
HAVE_APEX_OR_TE
:
steps
=
list
(
set
([
s
[
"step"
].
item
()
for
s
in
inner_state_dict
[
"state"
].
values
()]))
assert
len
(
steps
)
==
1
step
=
steps
[
0
]
# Optimizer state (do not store parameter state here).
state_dict
[
'optimizer'
]
=
{
k
:
v
for
k
,
v
in
inner_state_dict
.
items
()
if
k
!=
"state"
}
for
param_group
in
state_dict
[
"optimizer"
][
"param_groups"
]:
del
param_group
[
"params"
]
if
not
HAVE_APEX_OR_TE
:
# Native PyTorch param group requires step (i.e., iteration).
param_group
[
"step"
]
=
step
# Grad scaler state.
if
self
.
grad_scaler
:
state_dict
[
'grad_scaler'
]
=
self
.
grad_scaler
.
state_dict
()
return
state_dict
def
load_state_dict
(
self
,
state_dict
):
"""Load the state dict.
As detailed in state_dict(), the state dict contains all non-
parameter-related variables. This method is notably longer than
state_dict(), because the Torch optimizers state has yet to be
allocated at this point, and so we must do a cross referencing between
the optimizers state (and the ordering it expects for parameter state)
and this DP rank's shards. The optimizer at this point does not contain
any tensor dimension information, so we must get these dimensions from
the DP shards mapped during DistributedOptimizer.__init__().
The tensor parameter state is loaded via load_parameter_state(), and
so this method also must populate the loaded state dict with dummy
tensor data (i.e., via torch.empty() below). This will be overwritten
during load_parameter_state().
** Note: Torch optimizer's state structure. **
The Torch optimizer stores its state in two levels. The top level is a
list of groups, where each group contains a list of integer indexes
(corresponding to parameters) that index into a master parameter list
that is shared by all groups. As such, three values are necessary for
maintaining this ordering:
- group_index : The group to which a parameter belongs.
- group_order : The index of a parameter within its group.
- state_order : The index of a parameter within the shared parameter
list.
"""
# Get the Torch optimizer's state dict.
# - This 'inner' optimizer at this point is unallocated, and only
# contains an integer ordering of parameters within each group, and
# the ordering of parameters within its flattened parameter state
# list.
inner_state_dict
=
self
.
optimizer
.
state_dict
()
state_dict_param_groups
=
[
{
**
group
,
"params"
:
list
(
inner_state_dict
[
"param_groups"
][
idx
][
"params"
])}
for
idx
,
group
in
enumerate
(
state_dict
[
"optimizer"
][
"param_groups"
])
]
# Allocate or retrieve optimizer state (i.e., tensors).
if
len
(
self
.
optimizer
.
state
)
==
0
:
# Allocate empty optimizer state if not previously initialized.
# - If len(self.optimizer.state) == 0, this means that the optimizer
# state has not been previously initialized. Once it has been
# initialized, we skip this code block to avoid reallocating
# empty tensors (i.e., torch.empty), which in turn reduces memory
# fragmentation.
# - Real data is overwritten during load_parameter_state().
state_dict_state
=
[]
for
gbuf_range_maps
in
self
.
gbuf_ranges
:
for
gbuf_range_map_for_all_buckets
in
gbuf_range_maps
.
values
():
for
gbuf_range_map
in
gbuf_range_map_for_all_buckets
:
for
model_param
,
param_range_map
in
gbuf_range_map
[
"param_map"
].
items
():
# Get parameter ordering information (see method docstring
# for details).
group_index
,
group_order
=
self
.
model_param_group_index_map
[
model_param
]
state_order
=
inner_state_dict
[
"param_groups"
][
group_index
][
"params"
][
group_order
]
# Allocate dummy tensors.
numel
=
len
(
param_range_map
[
"gbuf_world"
])
init_shard
=
lambda
:
torch
.
empty
(
(
numel
,),
dtype
=
torch
.
float32
,
device
=
torch
.
cuda
.
current_device
()
)
state_dict_state
.
append
(
(
state_order
,
{
"exp_avg"
:
init_shard
(),
"exp_avg_sq"
:
init_shard
()})
)
# Sort by state order (see method docstring for details).
state_dict_state
.
sort
(
key
=
lambda
s
:
s
[
0
])
state_dict_state
=
{
s
[
0
]:
s
[
1
]
for
s
in
state_dict_state
}
else
:
# Retrieve existing optimizer state.
state_dict_state
=
inner_state_dict
[
"state"
]
# Extract 'step', for non-Apex/TE support.
if
not
HAVE_APEX_OR_TE
:
steps
=
list
(
set
([
g
[
"step"
]
for
g
in
state_dict
[
"optimizer"
][
"param_groups"
]]))
assert
len
(
steps
)
==
1
step
=
torch
.
tensor
(
steps
[
0
],
dtype
=
torch
.
float
)
for
s
in
state_dict_state
.
values
():
# Native PyTorch state dict requires step (i.e., iteration).
s
[
"step"
]
=
step
# Optimizer.
self
.
optimizer
.
load_state_dict
(
{
"state"
:
state_dict_state
,
"param_groups"
:
state_dict_param_groups
}
)
# Grad scaler.
if
'grad_scaler'
not
in
state_dict
:
if
self
.
config
.
fp16
:
logger
.
info
(
'***WARNING*** found an old checkpoint, will not '
'load grad scaler ...'
)
else
:
if
self
.
grad_scaler
:
self
.
grad_scaler
.
load_state_dict
(
state_dict
[
'grad_scaler'
])
else
:
logger
.
info
(
'***WARNING*** fould the grad scaler in the '
'checkpoint but it is None in the class. '
'Skipping loading grad scaler ...'
)
if
'param_state'
in
state_dict
:
assert
'param_state_sharding_type'
in
state_dict
,
state_dict
.
keys
()
param_state
=
state_dict
[
'param_state'
]
sharding_type
=
state_dict
[
'param_state_sharding_type'
]
logger
.
info
(
f
'Loading distributed optimizer sharded state of type
{
sharding_type
}
'
)
if
sharding_type
==
'dp_zero_gather_scatter'
:
self
.
load_parameter_state_from_dp_zero
(
param_state
)
elif
sharding_type
==
'fully_sharded_bucket_space'
:
self
.
load_parameter_state_from_fs_bucket_space
(
param_state
)
elif
sharding_type
==
'fully_sharded_model_space'
:
self
.
load_parameter_state_from_fs_model_space
(
param_state
)
else
:
raise
NotImplementedError
(
f
'Unknown sharding_type:
{
sharding_type
}
'
)
def
get_parameter_state_fs_bucket_space
(
self
):
"""Get internal representation of parameter state without any copies and modifications.
This is referred to as "fully sharded bucket space" because the optimizer state is
fully sharded (e.g. no gather involved) and bucket-centric (the state
follows the internal structure of the Distributed Optimizer buckets)
as opposed to model-centric (typical structure of PyT optimizers)
"""
state
=
{
"per_bucket_numel"
:
self
.
per_bucket_numel
,
"per_bucket_numel_unpadded"
:
self
.
per_bucket_numel_unpadded
,
}
for
gbuf_idx
,
gbuf_range_maps
in
enumerate
(
self
.
gbuf_ranges
):
# Iterate grad buffers (by data type).
dtype_state
=
{}
assert
len
(
gbuf_range_maps
)
==
1
,
"single dtype supported, for now."
for
dtype
,
gbuf_range_map_for_all_buckets
in
gbuf_range_maps
.
items
():
buckets_state
=
[]
for
bucket_idx
,
gbuf_range_map
in
enumerate
(
gbuf_range_map_for_all_buckets
):
bucket_state
=
[]
for
model_param
,
param_range_map
in
gbuf_range_map
[
"param_map"
].
items
():
# Main param & optimizer states.
group_index
,
group_order
=
self
.
model_param_group_index_map
[
model_param
]
main_param
=
self
.
optimizer
.
param_groups
[
group_index
][
"params"
][
group_order
]
optim_state
=
self
.
optimizer
.
state
[
main_param
]
tensors
=
{
"param"
:
main_param
,
**
optim_state
,
"gbuf_local_start"
:
param_range_map
[
"gbuf_local"
].
start
,
"gbuf_local_end"
:
param_range_map
[
"gbuf_local"
].
end
,
}
bucket_state
.
append
(
tensors
)
buckets_state
.
append
(
bucket_state
)
dtype_state
[
dtype
]
=
buckets_state
state
[
gbuf_idx
]
=
dtype_state
return
state
def
get_parameter_state_dp_zero
(
self
):
"""Get parameter state (i.e., parameter & optimizer tensors).
This method performs two steps:
- For each DP rank, copy param & optimizer shards to contiguous CPU
buffers (e.g., one buffer each for main_param, exp_avg, and
exp_avg_sq).
- Gather contiguous buffers on DP rank 0 and concatenate to world
buffers.
"""
# Data parallelism variables.
data_parallel_world_size
=
self
.
data_parallel_group_gloo
.
size
()
data_parallel_rank
=
torch
.
distributed
.
get_rank
(
self
.
data_parallel_group_gloo
)
data_parallel_group_gloo
=
self
.
data_parallel_group_gloo
data_parallel_global_ranks
=
torch
.
distributed
.
get_process_group_ranks
(
self
.
data_parallel_group_gloo
)
# Collect param states.
state
=
{
"buckets_coalesced"
:
True
}
for
gbuf_idx
,
gbuf_range_maps
in
enumerate
(
self
.
gbuf_ranges
):
# Iterate grad buffers (by data type).
dtype_state
=
{}
assert
len
(
gbuf_range_maps
)
==
1
,
"single dtype supported, for now."
for
dtype
,
gbuf_range_map_for_all_buckets
in
gbuf_range_maps
.
items
():
buffer_numel_unpadded
=
self
.
buffers
[
gbuf_idx
].
numel_unpadded
# Create coalesced tensors for all state related to parameters in this buffer.
world_tensors
=
{}
if
data_parallel_rank
==
0
:
world_tensors
=
{
key
:
torch
.
zeros
(
(
buffer_numel_unpadded
,),
dtype
=
torch
.
float32
,
device
=
"cpu"
)
for
key
in
(
"param"
,
"exp_avg"
,
"exp_avg_sq"
)
}
world_tensors
[
"numel_unpadded"
]
=
buffer_numel_unpadded
offset_in_world_tensors
=
0
for
bucket_idx
,
gbuf_range_map
in
enumerate
(
gbuf_range_map_for_all_buckets
):
# Compute local DP contiguous shard's size.
gbuf_world_numel
=
self
.
buffers
[
gbuf_idx
].
buckets
[
bucket_idx
].
grad_data
.
numel
()
assert
gbuf_world_numel
%
data_parallel_world_size
==
0
gbuf_local_numel
=
gbuf_world_numel
//
data_parallel_world_size
gbuf_world_numel_unpadded
=
(
self
.
buffers
[
gbuf_idx
].
buckets
[
bucket_idx
].
numel_unpadded
)
assert
gbuf_world_numel_unpadded
<=
gbuf_world_numel
local_shards
=
{
key
:
torch
.
zeros
((
gbuf_local_numel
,),
dtype
=
torch
.
float32
,
device
=
"cpu"
)
for
key
in
(
"param"
,
"exp_avg"
,
"exp_avg_sq"
)
}
# Build contiguous DP rank shards (for param + optim states).
for
model_param
,
param_range_map
in
gbuf_range_map
[
"param_map"
].
items
():
# Main param & optimizer states.
group_index
,
group_order
=
self
.
model_param_group_index_map
[
model_param
]
main_param
=
self
.
optimizer
.
param_groups
[
group_index
][
"params"
][
group_order
]
optim_state
=
self
.
optimizer
.
state
[
main_param
]
tensors
=
{
"param"
:
main_param
,
**
optim_state
}
# Copy states into contiguous shard.
gbuf_local_start
=
param_range_map
[
"gbuf_local"
].
start
gbuf_local_end
=
param_range_map
[
"gbuf_local"
].
end
for
key
in
local_shards
:
local_shards
[
key
][
gbuf_local_start
:
gbuf_local_end
].
data
.
copy_
(
tensors
[
key
].
detach
().
cpu
()
)
# Gather contiguous shards on DP rank 0.
for
key
,
send_tensor
in
local_shards
.
items
():
# Gather tensor list.
if
data_parallel_rank
==
0
:
recv_tensors
=
[
torch
.
zeros
((
gbuf_local_numel
,),
dtype
=
torch
.
float32
,
device
=
"cpu"
)
for
_
in
range
(
data_parallel_world_size
)
]
else
:
recv_tensors
=
None
# Gather.
torch
.
distributed
.
gather
(
send_tensor
,
recv_tensors
,
data_parallel_global_ranks
[
0
],
data_parallel_group_gloo
,
)
# Concatenate.
if
data_parallel_rank
==
0
:
recv_tensors_concatenated
=
torch
.
cat
(
recv_tensors
)
# Copy this bucket's collected all-gather tensors into the right place
# in the tensor for the buffer. The tensor for the buffer gets rid of
# the padding between buckets.
start
=
offset_in_world_tensors
end
=
offset_in_world_tensors
+
gbuf_world_numel_unpadded
world_tensors
[
key
][
start
:
end
].
copy_
(
recv_tensors_concatenated
[:
gbuf_world_numel_unpadded
]
)
offset_in_world_tensors
+=
gbuf_world_numel_unpadded
# Collect world state.
dtype_state
[
dtype
]
=
world_tensors
state
[
gbuf_idx
]
=
dtype_state
return
state
def
save_parameter_state
(
self
,
filename
:
str
):
"""Save the distributed parameter state on DP rank 0.
Args:
filename (str): path to save parameter state to.
"""
state_dict
=
self
.
get_parameter_state_dp_zero
()
if
torch
.
distributed
.
get_rank
(
self
.
data_parallel_group
)
==
0
:
torch
.
save
(
state_dict
,
filename
)
def
sharded_state_dict
(
self
,
model_sharded_state_dict
:
ShardedStateDict
,
is_loading
:
bool
=
False
,
sharding_type
:
str
=
'fully_sharded_model_space'
,
):
"""
Chooses between 3 param state sharding implementations as requested by `sharding_type`.
Regular state dict parameters are saved on DP rank 0 and loaded on all ranks.
"""
if
not
is_loading
and
sharding_type
==
'fully_sharded_bucket_space'
:
logger
.
warning
(
'`fully_sharded_bucket_space` sharding for DistributedOptimizer'
' checkpoint is deprecated and will be removed in the future.'
' Please switch to `full_sharded_model_space`.'
)
state_dict
=
self
.
state_dict
()
if
sharding_type
!=
'fully_sharded_model_space'
:
# State dict differs between different model parallel groups
state_dict
=
{
k
:
ShardedObject
(
f
'optimizer.distributed.dp_group_idx_
{
self
.
data_parallel_group_idx
}
.
{
k
}
'
,
v
,
(
1
,),
(
0
,),
replica_id
=
torch
.
distributed
.
get_rank
(
self
.
data_parallel_group
),
)
for
k
,
v
in
state_dict
.
items
()
}
if
is_loading
:
# Call the distributed optimizer's specialized load_state_dict(),
# which conditionally skips re-allocating the optimizer's state if
# already initialized, which in turn reduces memory fragmentation.
self
.
load_state_dict
(
self
.
state_dict
())
if
sharding_type
==
'fully_sharded_bucket_space'
:
param_state
=
self
.
sharded_param_state_fs_bucket_space
(
model_sharded_state_dict
,
is_loading
)
elif
sharding_type
==
'dp_zero_gather_scatter'
:
param_state
=
self
.
sharded_param_state_dp_zero
(
model_sharded_state_dict
,
is_loading
)
elif
sharding_type
==
'fully_sharded_model_space'
:
param_state
=
self
.
sharded_param_state_fs_model_space
(
model_sharded_state_dict
,
is_loading
)
else
:
raise
NotImplementedError
(
f
'Unknown sharding_type:
{
sharding_type
}
'
)
state_dict
[
'param_state'
]
=
param_state
state_dict
[
'param_state_sharding_type'
]
=
sharding_type
return
state_dict
def
sharded_param_state_dp_zero
(
self
,
model_sharded_state_dict
:
ShardedStateDict
,
is_loading
:
bool
=
False
):
"""Naive implementation which reuses gather/scatter from the legacy ckpt format.
During saving, gathers the parameters state on DP rank 0 and saves a ShardedObject
with fixed TPxPP structure. During loading, loads the saved data on DP rank 0
(None on other ranks). Relies on the parameters scatter done in load_state_dict.
"""
if
is_loading
:
param_state_data
=
None
else
:
if
self
.
distributed_optimizer_instance_id
==
0
:
# Gather on rank 0
param_state_data
=
self
.
get_parameter_state_dp_zero
()
if
(
torch
.
distributed
.
get_rank
(
self
.
data_parallel_group
)
==
0
and
self
.
distributed_optimizer_instance_id
==
0
):
# Fixed TPxPP. Save on DP rank 0 only
param_state
=
ShardedObject
(
f
'optimizer.distributed.dp_group_idx_
{
self
.
data_parallel_group_idx
}
.param_state'
,
param_state_data
,
(
1
,),
(
0
,),
)
else
:
# DP ranks > 0 don't save. During loading, the param_state needs to be None.
param_state
=
LocalNonpersistentObject
(
None
)
return
param_state
def
sharded_param_state_fs_bucket_space
(
self
,
model_sharded_state_dict
:
ShardedStateDict
,
is_loading
:
bool
=
False
):
"""Sharded state dict where each noncontiguous buffer is a separate ShardedTensor.
Results in fully parallel save and load without any inter-process
communication or intermediate buffers/copies.
"""
data_parallel_rank
=
torch
.
distributed
.
get_rank
(
self
.
data_parallel_group
)
data_parallel_world_size
=
torch
.
distributed
.
get_world_size
(
self
.
data_parallel_group
)
state
=
self
.
get_parameter_state_fs_bucket_space
()
# per_bucket_numel metadata is saved separately for each TPxPP domain.
for
per_bucket_key
in
(
'per_bucket_numel'
,
'per_bucket_numel_unpadded'
):
key
=
(
f
'optimizer.distributed.dp_group_idx_
{
self
.
data_parallel_group_idx
}
'
f
'.
{
per_bucket_key
}
'
)
state
[
per_bucket_key
]
=
ShardedObject
(
key
,
state
[
per_bucket_key
],
(
1
,),
(
0
,),
replica_id
=
data_parallel_rank
)
for
gbuf_idx
,
gbuf_range_maps
in
enumerate
(
self
.
gbuf_ranges
):
for
dtype
,
gbuf_range_map_for_all_buckets
in
state
[
gbuf_idx
].
items
():
for
bucket_idx
,
bucket_state
in
enumerate
(
gbuf_range_map_for_all_buckets
):
# Compute local DP contiguous shard's size.
gbuf_world_numel
=
self
.
buffers
[
gbuf_idx
].
buckets
[
bucket_idx
].
grad_data
.
numel
()
assert
gbuf_world_numel
%
data_parallel_world_size
==
0
gbuf_local_numel
=
gbuf_world_numel
//
data_parallel_world_size
sharded_bucket_key
=
(
f
'optimizer.distributed.dp_group_idx_
{
self
.
data_parallel_group_idx
}
'
f
'.gbuf_idx_
{
gbuf_idx
}
.dtype_
{
dtype
}
.bucket_idx_
{
bucket_idx
}
'
)
# The global ckpt tensors must be fully covered.
# We add extra empty padding if necessary
assert
bucket_state
,
'empty bucket encountered'
# Insert padding between parameter tensors to ensure full coverage as needed.
all_pad_tensors
=
{}
for
i
in
range
(
len
(
bucket_state
)
-
1
):
next_param_start
=
bucket_state
[
i
+
1
][
'gbuf_local_start'
]
cur_param_end
=
bucket_state
[
i
][
'gbuf_local_end'
]
if
next_param_start
!=
cur_param_end
:
pad_tensors
=
{
k
:
torch
.
empty
(
next_param_start
-
cur_param_end
,
dtype
=
v
.
dtype
,
device
=
v
.
device
)
for
k
,
v
in
bucket_state
[
i
].
items
()
if
isinstance
(
v
,
torch
.
Tensor
)
}
all_pad_tensors
[
i
+
1
]
=
{
**
pad_tensors
,
'gbuf_local_start'
:
cur_param_end
,
'gbuf_local_end'
:
next_param_start
,
'padding'
:
True
,
}
# Insert from end so that insertion positions are still correct.
indices_to_insert
=
sorted
(
list
(
all_pad_tensors
.
keys
()))
for
index_to_insert
in
reversed
(
indices_to_insert
):
bucket_state
.
insert
(
index_to_insert
,
all_pad_tensors
[
index_to_insert
])
if
bucket_state
[
-
1
][
'gbuf_local_end'
]
!=
gbuf_local_numel
:
pad_tensors
=
{
k
:
torch
.
empty
(
gbuf_local_numel
-
bucket_state
[
-
1
][
'gbuf_local_end'
],
dtype
=
v
.
dtype
,
device
=
v
.
device
,
)
for
k
,
v
in
bucket_state
[
-
1
].
items
()
if
isinstance
(
v
,
torch
.
Tensor
)
}
bucket_state
.
append
(
{
**
pad_tensors
,
'gbuf_local_start'
:
bucket_state
[
-
1
][
'gbuf_local_end'
],
'gbuf_local_end'
:
gbuf_local_numel
,
'padding'
:
True
,
}
)
# Each tensor is mapped to a slice (`flattened_range`)
# of a DP-local shard of size `gbuf_local_numel`.
for
bucket_params_idx
in
range
(
len
(
bucket_state
)):
tensors
=
bucket_state
[
bucket_params_idx
]
gbuf_local_start
=
tensors
.
pop
(
'gbuf_local_start'
)
gbuf_local_end
=
tensors
.
pop
(
'gbuf_local_end'
)
if
'padding'
not
in
tensors
:
tensors
[
'padding'
]
=
False
for
key
in
tensors
:
if
key
==
'padding'
:
tensors
[
key
]
=
LocalNonpersistentObject
(
tensors
[
key
])
continue
assert
tensors
[
key
].
shape
==
(
gbuf_local_end
-
gbuf_local_start
,),
(
tensors
[
key
].
shape
,
gbuf_local_start
,
gbuf_local_end
,
)
tensors
[
key
]
=
ShardedTensor
(
f
'
{
sharded_bucket_key
}
.
{
key
}
'
,
tensors
[
key
],
tensors
[
key
].
dtype
,
(
gbuf_local_numel
,),
(
data_parallel_world_size
*
gbuf_local_numel
,),
(
data_parallel_rank
*
gbuf_local_numel
,),
axis_fragmentations
=
(
data_parallel_world_size
,),
flattened_range
=
slice
(
gbuf_local_start
,
gbuf_local_end
),
allow_shape_mismatch
=
True
,
)
return
state
def
sharded_param_state_fs_model_space
(
self
,
model_sharded_state_dict
:
ShardedStateDict
,
is_loading
:
bool
=
False
):
"""Sharded state dict where each buffer is mapped to corresponding model param.
In this approach the optimizer state tensors are directly related to model parameters
by linking them with metadata from `model_sharded_state_dict`.
This will allow changing TP and PP while using DistOpt (as with other optimizers).
"""
param_to_sharded_metadata
=
{}
model_sharded_state_dict
,
_
=
extract_sharded_tensors_and_factories
(
model_sharded_state_dict
)
for
sh_base
in
nested_values
(
model_sharded_state_dict
):
param_to_sharded_metadata
[
sh_base
.
data
]
=
sh_base
prefix
=
'optimizer.state'
state
=
{}
# Not stored in the checkpoint, used only to identify params in
# `sharded_param_state_fs_model_space`.
param_idx
=
0
for
gbuf_range_maps
in
self
.
gbuf_ranges
:
for
gbuf_range_map_for_all_buckets
in
gbuf_range_maps
.
values
():
for
gbuf_range_map
in
gbuf_range_map_for_all_buckets
:
for
model_param
,
param_range_map
in
gbuf_range_map
[
"param_map"
].
items
():
group_index
,
group_order
=
self
.
model_param_group_index_map
[
model_param
]
param_range
=
param_range_map
[
'param'
]
main_param
=
self
.
optimizer
.
param_groups
[
group_index
][
"params"
][
group_order
]
optim_state
=
self
.
optimizer
.
state
[
main_param
]
tensors
=
{
"fp32_param"
:
main_param
,
**
optim_state
}
# Match optimizer parameter with model ShardedTensor (or
# ShardedTensorFactory).
try
:
sharded_metadata
=
param_to_sharded_metadata
[
model_param
]
except
KeyError
as
e
:
raise
ValueError
(
f
'Model param
{
model_param
}
not in model_sharded_state_dict'
)
from
e
# Set DP corresponding replica_id coordinate to 0.
assert
(
len
(
sharded_metadata
.
replica_id
)
==
3
),
f
'Expected replica_id format (PP, TP, DP), got:
{
sharded_metadata
}
'
replica_id
=
(
*
sharded_metadata
.
replica_id
[:
2
],
self
.
distributed_optimizer_instance_id
,
)
# Instantiate ShardedTensor (or ShardedTensorFactory) for optimizer
# params.
for
state_key
,
state_ten
in
tensors
.
items
():
replace_kwargs
=
dict
(
key
=
f
'
{
prefix
}
.
{
state_key
}
.
{
sharded_metadata
.
key
}
'
,
data
=
state_ten
,
dtype
=
state_ten
.
dtype
,
flattened_range
=
slice
(
param_range
.
start
,
param_range
.
end
),
replica_id
=
replica_id
,
)
if
isinstance
(
sharded_metadata
,
ShardedTensorFactory
):
replace_kwargs
.
pop
(
'dtype'
)
tensors
[
state_key
]
=
replace
(
sharded_metadata
,
**
replace_kwargs
)
tensors
[
state_key
].
validate_metadata_integrity
()
state
[
param_idx
]
=
tensors
param_idx
+=
1
return
state
def
load_parameter_state_from_fs_bucket_space
(
self
,
state_dict
):
"""Loads the parameter state from an internal representation.
Inverse of the `get_parameter_state_fs_bucket_space` method.
"""
logger
.
warning
(
'`fully_sharded_bucket_space` sharding for DistributedOptimizer'
'checkpoint is deprecated. Please switch to `full_sharded_model_space`'
)
if
state_dict
is
not
None
and
"per_bucket_numel_unpadded"
in
state_dict
:
per_bucket_numel_unpadded_in_checkpoint
=
state_dict
[
"per_bucket_numel_unpadded"
]
assert
self
.
per_bucket_numel_unpadded
==
per_bucket_numel_unpadded_in_checkpoint
,
(
f
"Number of unpadded elements in each bucket need to be the same in current run "
f
"(
{
self
.
per_bucket_numel_unpadded
}
) and checkpoint "
f
"(
{
per_bucket_numel_unpadded_in_checkpoint
}
)"
)
for
gbuf_idx
,
gbuf_range_maps
in
enumerate
(
self
.
gbuf_ranges
):
assert
len
(
gbuf_range_maps
)
==
1
,
"single dtype supported, for now."
for
dtype
,
gbuf_range_map_for_all_buckets
in
gbuf_range_maps
.
items
():
for
bucket_idx
,
gbuf_range_map
in
enumerate
(
gbuf_range_map_for_all_buckets
):
bucket_state
=
state_dict
[
gbuf_idx
][
dtype
][
bucket_idx
]
bucket_state
=
[
bucket_state_elem
for
bucket_state_elem
in
bucket_state
if
not
bucket_state_elem
[
'padding'
]
]
assert
len
(
bucket_state
)
==
len
(
gbuf_range_map
[
"param_map"
]),
(
len
(
bucket_state
),
len
(
gbuf_range_map
[
"param_map"
]),
)
for
src_tensors
,
(
model_param
,
param_range_map
)
in
zip
(
bucket_state
,
gbuf_range_map
[
"param_map"
].
items
()
):
# Main param & optimizer states.
group_index
,
group_order
=
self
.
model_param_group_index_map
[
model_param
]
main_param
=
self
.
optimizer
.
param_groups
[
group_index
][
"params"
][
group_order
]
optim_state
=
self
.
optimizer
.
state
[
main_param
]
dst_tensors
=
{
"param"
:
main_param
,
**
optim_state
}
for
key
in
dst_tensors
:
dst_tensors
[
key
].
copy_
(
src_tensors
[
key
])
@
torch
.
no_grad
()
def
load_parameter_state_from_fs_model_space
(
self
,
state_dict
):
"""Loads the parameter state from a "model space" representation.
Inverse of the `sharded_param_state_fs_model_space` method.
"""
param_idx
=
0
# matching order with `sharded_param_state_fs_model_space`
for
gbuf_range_maps
in
self
.
gbuf_ranges
:
for
gbuf_range_map_for_all_buckets
in
gbuf_range_maps
.
values
():
for
gbuf_range_map
in
gbuf_range_map_for_all_buckets
:
for
model_param
,
param_range_map
in
gbuf_range_map
[
"param_map"
].
items
():
group_index
,
group_order
=
self
.
model_param_group_index_map
[
model_param
]
main_param
=
self
.
optimizer
.
param_groups
[
group_index
][
"params"
][
group_order
]
optim_state
=
self
.
optimizer
.
state
[
main_param
]
src_tensors
=
state_dict
[
param_idx
]
dst_tensors
=
{
"fp32_param"
:
main_param
,
**
optim_state
}
for
key
in
dst_tensors
:
dst_tensors
[
key
].
copy_
(
src_tensors
[
key
])
param_idx
+=
1
@
classmethod
def
_update_legacy_world_tensors
(
cls
,
old_tensors
,
new_numels
):
'''Reshard buckets (where each bucket is a tensor) to new target
numels, where the total numel remains the same.'''
old_total
=
sum
([
t
.
numel
()
for
t
in
old_tensors
])
new_total
=
sum
(
new_numels
)
assert
old_total
==
new_total
unified_tensor
=
torch
.
cat
(
old_tensors
,
dim
=
0
)
new_tensors
=
[]
start_idx
=
0
for
new_numel
in
new_numels
:
new_tensors
.
append
(
unified_tensor
[
start_idx
:
(
start_idx
+
new_numel
)])
start_idx
+=
new_numel
return
new_tensors
def
load_parameter_state_from_dp_zero_legacy
(
self
,
state_dict
):
"""Load parameter state (i.e., parameter & optimizer tensors) from DP 0 rank,
using the legacy checkpoint format as described below.
The difference between this method and `load_parameter_state_from_dp_zero_modern()`
is that this method is used for updating the format of checkpoints that
were saved using code from before Feb 13, 2024. Starting on this date, a
new format was used (i.e., different format for the parameter mapping and
bucket sharding).
Use arg `--ckpt-convert-update-legacy-dist-opt-format` to call this
method, along with `--ckpt-convert-format` and `--ckpt-convert-save` to
update a legacy-format checkpoint to the modern format.
"""
# Data parallelism variables.
data_parallel_world_size
=
self
.
data_parallel_group_gloo
.
size
()
data_parallel_rank
=
torch
.
distributed
.
get_rank
(
self
.
data_parallel_group_gloo
)
data_parallel_group_gloo
=
self
.
data_parallel_group_gloo
data_parallel_global_ranks
=
torch
.
distributed
.
get_process_group_ranks
(
self
.
data_parallel_group_gloo
)
# Scatter tensors to all DP ranks.
for
gbuf_idx
,
gbuf_range_maps
in
enumerate
(
self
.
gbuf_ranges
):
for
dtype
,
gbuf_range_map_for_all_buckets
in
gbuf_range_maps
.
items
():
if
data_parallel_rank
==
0
:
buffer_numel_unpadded
=
self
.
buffers
[
gbuf_idx
].
numel_unpadded
model_numels
=
[
b
.
numel_unpadded
for
b
in
self
.
buffers
[
gbuf_idx
].
buckets
]
checkpoint_numels
=
[
t
.
numel
()
for
t
in
state_dict
[
gbuf_idx
][
torch
.
float32
][
"param"
]
]
assert
sum
(
model_numels
)
==
sum
(
checkpoint_numels
)
for
key
in
(
"param"
,
"exp_avg"
,
"exp_avg_sq"
):
legacy_world_tensors
=
self
.
_update_legacy_world_tensors
(
state_dict
[
gbuf_idx
][
torch
.
float32
][
key
],
[
self
.
buffers
[
gbuf_idx
].
buckets
[
bi
].
numel_unpadded
for
bi
in
range
(
len
(
gbuf_range_map_for_all_buckets
))
],
)
offset_in_world_tensors
=
0
for
bucket_idx
,
gbuf_range_map
in
enumerate
(
gbuf_range_map_for_all_buckets
):
# Compute local DP contiguous shard's size.
gbuf_world_numel
=
(
self
.
buffers
[
gbuf_idx
].
buckets
[
bucket_idx
].
grad_data
.
numel
()
)
assert
gbuf_world_numel
%
data_parallel_world_size
==
0
gbuf_local_numel
=
gbuf_world_numel
//
data_parallel_world_size
gbuf_world_numel_unpadded
=
(
self
.
buffers
[
gbuf_idx
].
buckets
[
bucket_idx
].
numel_unpadded
)
assert
gbuf_world_numel_unpadded
<=
gbuf_world_numel
# Contiguous local shards (received from DP rank 0).
recv_tensor
=
torch
.
empty
(
(
gbuf_local_numel
,),
dtype
=
torch
.
float32
,
device
=
"cpu"
)
# Scatter tensor list.
if
data_parallel_rank
==
0
:
start
=
offset_in_world_tensors
end
=
offset_in_world_tensors
+
gbuf_world_numel_unpadded
world_tensor
=
legacy_world_tensors
[
bucket_idx
]
assert
(
world_tensor
.
numel
()
==
gbuf_world_numel_unpadded
),
"%d vs. %d."
%
(
world_tensor
.
numel
(),
gbuf_world_numel_unpadded
)
offset_in_world_tensors
+=
gbuf_world_numel_unpadded
# Pad world_tensor to gbuf_world_numel. Don't pad at the front,
# pad at the back.
world_tensor
=
torch
.
nn
.
functional
.
pad
(
world_tensor
,
(
0
,
gbuf_world_numel
-
gbuf_world_numel_unpadded
)
)
assert
world_tensor
.
numel
()
==
gbuf_world_numel
gbuf_start_idxs
=
list
(
range
(
0
,
gbuf_world_numel
,
gbuf_local_numel
))
send_tensors
=
[
world_tensor
[
i
:
(
i
+
gbuf_local_numel
)]
for
i
in
gbuf_start_idxs
]
else
:
send_tensors
=
None
# Scatter.
torch
.
distributed
.
scatter
(
recv_tensor
,
send_tensors
,
data_parallel_global_ranks
[
0
],
data_parallel_group_gloo
,
)
# Copy local contiguous shards to param/optim shards.
for
model_param
,
param_range_map
in
gbuf_range_map
[
"param_map"
].
items
():
# Main param & optimizer states.
group_index
,
group_order
=
self
.
model_param_group_index_map
[
model_param
]
main_param
=
self
.
optimizer
.
param_groups
[
group_index
][
"params"
][
group_order
]
if
key
==
"param"
:
tensor_to_copy_into
=
main_param
else
:
optim_state
=
self
.
optimizer
.
state
[
main_param
]
tensor_to_copy_into
=
optim_state
[
key
]
# Copy states into contiguous shard.
gbuf_local_start
=
param_range_map
[
"gbuf_local"
].
start
gbuf_local_end
=
param_range_map
[
"gbuf_local"
].
end
tensor_to_copy_into
.
data
.
copy_
(
recv_tensor
[
gbuf_local_start
:
gbuf_local_end
]
)
def
load_parameter_state_from_dp_zero
(
self
,
state_dict
,
*
,
update_legacy_format
=
False
):
"""Load parameter state (i.e., parameter & optimizer tensors) from DP 0 rank,
using the new checkpoint format with coalesced state across buckets.
This method performs the reverse of get_parameter_state_dp_zero():
- Scatter contiguous buffers from DP rank 0 to each DP rank (each DP
rank receives its relevant subset of the world buffers).
- For each DP rank, copy param & optimizer shards from contiguous CPU
buffers. (e.g., one buffer each for main_param, exp_avg, and
exp_avg_sq).
"""
# Selectively load from a legacy checkpoint. The legacy format was used
# prior to Feb 13, 2024.
if
update_legacy_format
:
return
self
.
load_parameter_state_from_dp_zero_legacy
(
state_dict
)
# Data parallelism variables.
data_parallel_world_size
=
self
.
data_parallel_group_gloo
.
size
()
data_parallel_rank
=
torch
.
distributed
.
get_rank
(
self
.
data_parallel_group_gloo
)
data_parallel_group_gloo
=
self
.
data_parallel_group_gloo
data_parallel_global_ranks
=
torch
.
distributed
.
get_process_group_ranks
(
self
.
data_parallel_group_gloo
)
if
data_parallel_rank
==
0
:
# Do nothing if "--fp8-param-gather" is not used.
self
.
split_state_dict_if_needed
(
state_dict
)
# Scatter tensors to all DP ranks.
for
gbuf_idx
,
gbuf_range_maps
in
enumerate
(
self
.
gbuf_ranges
):
for
dtype
,
gbuf_range_map_for_all_buckets
in
gbuf_range_maps
.
items
():
if
data_parallel_rank
==
0
:
buffer_numel_unpadded
=
self
.
buffers
[
gbuf_idx
].
numel_unpadded
checkpoint_numel_unpadded
=
state_dict
[
gbuf_idx
][
dtype
][
"numel_unpadded"
]
assert
buffer_numel_unpadded
==
checkpoint_numel_unpadded
,
(
f
"Number of unpadded elements must be same in current run "
f
"(
{
buffer_numel_unpadded
}
) and checkpoint (
{
checkpoint_numel_unpadded
}
)"
)
for
key
in
(
"param"
,
"exp_avg"
,
"exp_avg_sq"
):
offset_in_world_tensors
=
0
for
bucket_idx
,
gbuf_range_map
in
enumerate
(
gbuf_range_map_for_all_buckets
):
# Compute local DP contiguous shard's size.
gbuf_world_numel
=
(
self
.
buffers
[
gbuf_idx
].
buckets
[
bucket_idx
].
grad_data
.
numel
()
)
assert
gbuf_world_numel
%
data_parallel_world_size
==
0
gbuf_local_numel
=
gbuf_world_numel
//
data_parallel_world_size
gbuf_world_numel_unpadded
=
(
self
.
buffers
[
gbuf_idx
].
buckets
[
bucket_idx
].
numel_unpadded
)
assert
gbuf_world_numel_unpadded
<=
gbuf_world_numel
# Contiguous local shards (received from DP rank 0).
recv_tensor
=
torch
.
zeros
(
(
gbuf_local_numel
,),
dtype
=
torch
.
float32
,
device
=
"cpu"
)
# Scatter tensor list.
if
data_parallel_rank
==
0
:
world_tensors
=
state_dict
[
gbuf_idx
][
dtype
][
key
]
start
=
offset_in_world_tensors
end
=
offset_in_world_tensors
+
gbuf_world_numel_unpadded
assert
0
<=
start
<
end
<=
world_tensors
.
numel
()
world_tensor
=
world_tensors
[
start
:
end
]
offset_in_world_tensors
+=
gbuf_world_numel_unpadded
# Pad world_tensor to gbuf_world_numel. Don't pad at the front,
# pad at the back.
world_tensor
=
torch
.
nn
.
functional
.
pad
(
world_tensor
,
(
0
,
gbuf_world_numel
-
gbuf_world_numel_unpadded
)
)
assert
world_tensor
.
numel
()
==
gbuf_world_numel
gbuf_start_idxs
=
list
(
range
(
0
,
gbuf_world_numel
,
gbuf_local_numel
))
send_tensors
=
[
world_tensor
[
i
:
(
i
+
gbuf_local_numel
)]
for
i
in
gbuf_start_idxs
]
else
:
send_tensors
=
None
# Scatter.
torch
.
distributed
.
scatter
(
recv_tensor
,
send_tensors
,
data_parallel_global_ranks
[
0
],
data_parallel_group_gloo
,
)
# Copy local contiguous shards to param/optim shards.
for
model_param
,
param_range_map
in
gbuf_range_map
[
"param_map"
].
items
():
# Main param & optimizer states.
group_index
,
group_order
=
self
.
model_param_group_index_map
[
model_param
]
main_param
=
self
.
optimizer
.
param_groups
[
group_index
][
"params"
][
group_order
]
if
key
==
"param"
:
tensor_to_copy_into
=
main_param
else
:
optim_state
=
self
.
optimizer
.
state
[
main_param
]
tensor_to_copy_into
=
optim_state
[
key
]
# Copy states into contiguous shard.
gbuf_local_start
=
param_range_map
[
"gbuf_local"
].
start
gbuf_local_end
=
param_range_map
[
"gbuf_local"
].
end
tensor_to_copy_into
.
data
.
copy_
(
recv_tensor
[
gbuf_local_start
:
gbuf_local_end
]
)
def
split_state_dict_if_needed
(
self
,
state_dict
):
"""
When "--fp8-param-gather" is disabled, weights and biases are stored in the same
`_ParamAndGradBuffer`. So, when saving a checkpoint, the optimizer's main parameters are
saved in a single continuous tensor (this also applies to "exp_avg" and "exp_avg_sq").
However, when "--fp8-param-gather" is enabled, weights(in fp8 dtype) and biases(in bf16/fp16
dtype) are stored in separate `_ParamAndGradBuffer`. Therefore, when we enabled
"--fp8-param-gather", and want to load a checkpoint saved without "--fp8-param-gather", we
need to split the weights(fp8) and biases(bf16/fp16) in the static_dict into two separate
tensors.
"""
# Skip if there is no fp8 buffers.
fp8_gbuf_indices
=
[]
for
gbuf_idx
,
gbuf_range_maps
in
enumerate
(
self
.
gbuf_ranges
):
for
dtype
,
_
in
gbuf_range_maps
.
items
():
if
is_float8tensor
(
self
.
buffers
[
gbuf_idx
].
params
[
0
]):
fp8_gbuf_indices
.
append
(
gbuf_idx
)
if
len
(
fp8_gbuf_indices
)
==
0
:
return
dtype_to_gbuf_idx
=
{}
for
key
in
state_dict
.
keys
():
if
key
!=
'buckets_coalesced'
:
for
dtype
in
state_dict
[
key
].
keys
():
assert
dtype
not
in
dtype_to_gbuf_idx
if
dtype
[
0
]
==
torch
.
uint8
:
# If the `state_dict`` already contains a torch.uint8 buffer, we assumed
# that the fp8 weights and fp16/bf16 biases in the checkpoint are already
# separated. In this case, no action is required, so we can return directly.
return
dtype_to_gbuf_idx
[
dtype
]
=
key
# 1. Replace the gbuf_idx in the checkpoint with the new gbuf_idx.
# 2. Copy the non-tensor data (i.e., the "buckets_coalesced") to `new_state_dict`.
new_state_dict
=
{
'buckets_coalesced'
:
state_dict
[
'buckets_coalesced'
]}
for
gbuf_idx
,
gbuf_range_maps
in
enumerate
(
self
.
gbuf_ranges
):
for
dtype
,
_
in
gbuf_range_maps
.
items
():
if
not
is_float8tensor
(
self
.
buffers
[
gbuf_idx
].
params
[
0
]):
new_state_dict
[
gbuf_idx
]
=
state_dict
[
dtype_to_gbuf_idx
[
dtype
]]
for
fp8_gbuf_idx
in
fp8_gbuf_indices
:
# Note that `self.buffers[fp8_gbuf_idx].params[0].dtype` is the dummy dtype of
# `Float8Tensor`, not torch.uint8.
non_fp8_param_and_grad_dtype
=
(
self
.
buffers
[
fp8_gbuf_idx
].
params
[
0
].
dtype
,
self
.
buffers
[
fp8_gbuf_idx
].
grad_dtype
,
)
# Iterate through all buffers to find the one that needs to be split.
non_fp8_gbuf_idx
=
None
for
gbuf_idx
,
gbuf_range_maps
in
enumerate
(
self
.
gbuf_ranges
):
for
dtype
,
_
in
gbuf_range_maps
.
items
():
if
dtype
==
non_fp8_param_and_grad_dtype
:
non_fp8_gbuf_idx
=
gbuf_idx
assert
non_fp8_gbuf_idx
is
not
None
# We need the fp8_flags to determine the order of weight (fp8) and bias (fp16/bf16) in
# the buffer.
index_to_fp8_map
=
{}
for
index
in
self
.
buffers
[
fp8_gbuf_idx
].
param_indices
:
assert
index
not
in
index_to_fp8_map
index_to_fp8_map
[
index
]
=
True
for
index
in
self
.
buffers
[
non_fp8_gbuf_idx
].
param_indices
:
assert
index
not
in
index_to_fp8_map
index_to_fp8_map
[
index
]
=
False
param_indices
=
(
self
.
buffers
[
fp8_gbuf_idx
].
param_indices
+
self
.
buffers
[
non_fp8_gbuf_idx
].
param_indices
)
assert
min
(
param_indices
)
==
0
assert
max
(
param_indices
)
==
len
(
param_indices
)
-
1
fp8_flags
=
[]
for
i
in
range
(
len
(
param_indices
)):
fp8_flag
.
append
(
index_to_fp8_map
[
i
])
fp8_buffer
=
self
.
buffers
[
fp8_gbuf_idx
]
non_fp8_buffer
=
self
.
buffers
[
non_fp8_gbuf_idx
]
fp8_idx
=
len
(
fp8_buffer
.
params
)
-
1
non_fp8_idx
=
len
(
non_fp8_buffer
.
params
)
-
1
offsets
,
fp8_offsets
,
non_fp8_offsets
=
[
0
],
[
0
],
[
0
]
# Because the parameters in `_ParamAndGradBuffer` are traversed in reverse order, the
# flag here also needs to be traversed in reverse order.
for
fp8_flag
in
fp8_flags
[::
-
1
]:
if
fp8_flag
:
numel
=
fp8_buffer
.
params
[
fp8_idx
].
nelement
()
fp8_idx
-=
1
offsets
.
append
(
offsets
[
-
1
]
+
numel
)
fp8_offsets
.
append
(
fp8_offsets
[
-
1
]
+
numel
)
else
:
numel
=
non_fp8_buffer
.
params
[
non_fp8_idx
].
nelement
()
non_fp8_idx
-=
1
offsets
.
append
(
offsets
[
-
1
]
+
numel
)
non_fp8_offsets
.
append
(
non_fp8_offsets
[
-
1
]
+
numel
)
# Split the target buffer into two separate buffers.
fp8_state_dict
,
non_fp8_state_dict
=
{},
{}
for
key
in
[
'param'
,
'exp_avg'
,
'exp_avg_sq'
]:
tensor
=
state_dict
[
non_fp8_gbuf_idx
][
non_fp8_param_and_grad_dtype
][
key
]
fp8_tensor
=
torch
.
empty
([
fp8_offsets
[
-
1
]],
dtype
=
tensor
.
dtype
)
non_fp8_tensor
=
torch
.
empty
([
non_fp8_offsets
[
-
1
]],
dtype
=
tensor
.
dtype
)
fp8_idx
,
non_fp8_idx
=
0
,
0
for
i
in
range
(
len
(
offsets
)
-
1
):
if
fp8_flags
[
-
(
i
+
1
)]:
fp8_tensor
[
fp8_offsets
[
fp8_idx
]
:
fp8_offsets
[
fp8_idx
+
1
]].
copy_
(
tensor
[
offsets
[
i
]
:
offsets
[
i
+
1
]]
)
fp8_idx
+=
1
else
:
non_fp8_tensor
[
non_fp8_offsets
[
non_fp8_idx
]
:
non_fp8_offsets
[
non_fp8_idx
+
1
]
].
copy_
(
tensor
[
offsets
[
i
]
:
offsets
[
i
+
1
]])
non_fp8_idx
+=
1
fp8_state_dict
[
key
]
=
fp8_tensor
non_fp8_state_dict
[
key
]
=
non_fp8_tensor
fp8_state_dict
[
'numel_unpadded'
]
=
fp8_offsets
[
-
1
]
non_fp8_state_dict
[
'numel_unpadded'
]
=
non_fp8_offsets
[
-
1
]
# Add the two separate buffers into `new_state_dict`.
new_state_dict
[
fp8_gbuf_idx
]
=
{}
new_state_dict
[
fp8_gbuf_idx
][(
torch
.
uint8
,
fp8_buffer
.
grad_dtype
)]
=
fp8_state_dict
new_state_dict
[
non_fp8_gbuf_idx
][
non_fp8_param_and_grad_dtype
]
=
non_fp8_state_dict
# Inplace update state_dict
state_dict
.
clear
()
for
key
,
value
in
new_state_dict
.
items
():
state_dict
[
key
]
=
value
def
load_parameter_state
(
self
,
filename
:
str
,
*
,
update_legacy_format
=
False
):
"""Load the distributed parameter state from disk.
Args:
filename (str): path to load parameter state from.
"""
state_dict
=
None
if
torch
.
distributed
.
get_rank
(
self
.
data_parallel_group
)
==
0
:
state_dict
=
torch
.
load
(
filename
)
self
.
load_parameter_state_from_dp_zero
(
state_dict
,
update_legacy_format
=
update_legacy_format
)
def
zero_grad
(
self
,
set_to_none
:
bool
=
True
):
"""
Zeroes grads for the model related parameters, i.e., model_float16_groups
and model_fp32_groups. We additionally zero the remaining groups as a
memory optimization to reduce fragmentation; in the case of
set_to_none==True, the space used by this field can be safely deallocated.
Args:
set_to_none (bool): if true, set grads to None.
"""
for
groups
in
(
self
.
model_float16_groups
,
self
.
model_fp32_groups
,
self
.
shard_float16_groups
,
# grad empty/unused here?
self
.
shard_fp32_groups
,
# throws grad-access warning
self
.
shard_fp32_from_float16_groups
,
):
for
group
in
groups
:
_zero_grad_group_helper
(
group
,
set_to_none
)
def
_collect_main_grad_data_for_unscaling
(
self
):
"""
Note: this should be equivalent to the float-16 optimizer's method,
but written differently, so the two should be combined.
"""
return
[
param
.
grad
.
data
for
group
in
self
.
optimizer
.
param_groups
for
param
in
group
[
"params"
]
]
def
_get_model_and_main_params_data_float16
(
self
):
"""
Get aligned list of model and main params.
"""
model_data
=
[]
main_data
=
[]
for
model_group
,
main_group
in
zip
(
self
.
shard_float16_groups
,
self
.
shard_fp32_from_float16_groups
):
for
model_param
,
main_param
in
zip
(
model_group
,
main_group
):
model_data
.
append
(
model_param
.
data
)
main_data
.
append
(
main_param
.
data
)
return
model_data
,
main_data
def
_copy_model_grads_to_main_grads
(
self
):
"""
Copy model grads to main grads.
Since this step follows a reduce-scatter through the DDP's grad
buffer, this method is responsible for copying the updated grads
from the grad buffer to the main shard's grad field.
"""
# Utility method for copying group grads.
def
copy_group_grads
(
model_groups
,
shard_main_groups
):
for
model_group
,
shard_main_group
in
zip
(
model_groups
,
shard_main_groups
):
for
model_param
,
shard_main_param
in
zip
(
model_group
,
shard_main_group
):
param_range_map
=
self
.
_get_model_param_range_map
(
model_param
)
param_range
=
param_range_map
[
"param"
]
assert
param_range
.
size
==
shard_main_param
.
nelement
()
model_grad
=
model_param
.
main_grad
shard_model_grad
=
model_grad
.
view
(
-
1
)[
param_range
.
start
:
param_range
.
end
]
shard_main_param
.
grad
=
shard_model_grad
.
float
()
# Copy model groups to shard groups.
copy_group_grads
(
self
.
model_float16_groups
,
self
.
shard_fp32_from_float16_groups
)
copy_group_grads
(
self
.
model_fp32_groups
,
self
.
shard_fp32_groups
)
def
_copy_main_params_to_model_params
(
self
):
"""
Copy main params to model params.
Since this step is followed by an all-gather through the DDP's grad
buffer, this method is responsible for copying the updated params
from the main shards into the correct position in the grad buffer.
"""
# Utility method for copying group params.
def
copy_group_params
(
shard_main_groups
,
model_groups
):
for
shard_main_group
,
model_group
in
zip
(
shard_main_groups
,
model_groups
):
for
shard_main_param
,
model_param
in
zip
(
shard_main_group
,
model_group
):
param_range_map
=
self
.
_get_model_param_range_map
(
model_param
)
world_range
=
param_range_map
[
"gbuf_world_in_bucket"
]
assert
world_range
.
size
==
shard_main_param
.
nelement
()
gbuf_index
,
_
,
bucket_id
=
self
.
model_param_gbuf_map
[
model_param
]
model_param_buffer
=
self
.
buffers
[
gbuf_index
].
buckets
[
bucket_id
].
param_data
shard_model_param
=
model_param_buffer
.
view
(
-
1
)[
world_range
.
start
:
world_range
.
end
]
if
is_float8tensor
(
model_param
):
# 1. When "--fp8-param-gather" is disabled, the main param is first cast to
# BF16/FP16, and then cast to FP8, so the amax_history is calculated
# using BF16/FP16 param.
# 2. When "--fp8-param-gather" is enabled, we can cast the FP32 main param
# to FP8 directly, which results in slightly different results with
# higher speed. In theory, this does not affect convergence.
# TODO: The following code maintains the logic of the point-1 above. It can
# be deleted if it is not necessary.
shard_main_param
=
shard_main_param
.
to
(
model_param
.
dtype
)
cast_to_fp8
(
shard_main_param
.
view
(
1
,
-
1
),
model_param
.
_fp8_meta
[
'scaling_fwd'
],
model_param
.
_fp8_meta_index
,
model_param
.
_fp8_dtype
,
out
=
shard_model_param
.
view
(
1
,
-
1
),
)
else
:
shard_model_param
.
data
.
copy_
(
shard_main_param
)
# Copy shard groups to model groups.
copy_group_params
(
self
.
shard_fp32_from_float16_groups
,
self
.
model_float16_groups
)
copy_group_params
(
self
.
shard_fp32_groups
,
self
.
model_fp32_groups
)
def
_copy_model_params_to_main_params
(
self
):
"""
Copy model params to main params.
During finetuning, this method is used to reload the main params from
the model params. This copy does not make use of the grad buffer as
an intermediary.
"""
# Utility method for copying group params.
def
copy_group_params
(
model_groups
,
shard_main_groups
):
for
model_group
,
shard_main_group
in
zip
(
model_groups
,
shard_main_groups
):
for
model_param
,
shard_main_param
in
zip
(
model_group
,
shard_main_group
):
param_range_map
=
self
.
_get_model_param_range_map
(
model_param
)
param_range
=
param_range_map
[
"param"
]
assert
param_range
.
size
==
shard_main_param
.
nelement
()
shard_model_param
=
model_param
.
view
(
-
1
)[
param_range
.
start
:
param_range
.
end
]
shard_main_param
.
data
.
copy_
(
shard_model_param
)
# Copy model groups to shard groups.
copy_group_params
(
self
.
model_float16_groups
,
self
.
shard_fp32_from_float16_groups
)
copy_group_params
(
self
.
model_fp32_groups
,
self
.
shard_fp32_groups
)
def
_update_fp8_scale_inv_and_amax
(
self
):
"""
If detect FP8 parameters, update their `_scale_inv` and do reduce-max for their
`amax_history`.
"""
amaxes
=
[]
scales
=
[]
scale_invs
=
[]
# Iterate over all parameters inside this optimizer to find FP8 parameters.
for
buffer
in
self
.
buffers
:
for
bucket
in
buffer
.
buckets
:
for
param
in
bucket
.
params_list
:
if
is_float8tensor
(
param
):
fp8_meta
=
param
.
_fp8_meta
[
'scaling_fwd'
]
fp8_meta_index
=
param
.
_fp8_meta_index
amaxes
.
append
(
fp8_meta
.
amax_history
[
0
][
fp8_meta_index
].
view
(
1
))
scales
.
append
(
fp8_meta
.
scale
[
fp8_meta_index
].
view
(
1
))
scale_invs
.
append
(
param
.
_scale_inv
.
view
(
1
))
# Reset transpose cache
param
.
_reset_caches
()
# If there is no FP8 parameters, skip all operations.
if
len
(
scales
)
>
0
:
dummy_overflow_buf
=
torch
.
tensor
([
0
],
dtype
=
torch
.
int
,
device
=
'cuda'
)
# Update scaling factors.
packed_scales
=
torch
.
empty
(
len
(
scales
),
dtype
=
torch
.
float32
,
device
=
scales
[
0
].
device
)
packed_scale_views
=
[
packed_scales
[
i
].
view
(
1
)
for
i
in
range
(
len
(
scales
))]
_multi_tensor_copy_this_to_that
(
scales
,
packed_scale_views
,
dummy_overflow_buf
)
torch
.
reciprocal
(
packed_scales
,
out
=
packed_scales
)
_multi_tensor_copy_this_to_that
(
packed_scale_views
,
scale_invs
,
dummy_overflow_buf
)
# Reduce amaxes.
# Note: Assume each param has a separate amax.
packed_amaxes
=
torch
.
empty
(
len
(
amaxes
),
dtype
=
torch
.
float32
,
device
=
amaxes
[
0
].
device
)
packed_amax_views
=
[
packed_amaxes
[
i
].
view
(
1
)
for
i
in
range
(
len
(
amaxes
))]
_multi_tensor_copy_this_to_that
(
amaxes
,
packed_amax_views
,
dummy_overflow_buf
)
torch
.
distributed
.
all_reduce
(
packed_amaxes
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
group
=
self
.
data_parallel_group
)
_multi_tensor_copy_this_to_that
(
packed_amax_views
,
amaxes
,
dummy_overflow_buf
)
@
torch
.
no_grad
()
def
step_with_ready_grads
(
self
)
->
bool
:
"""Step the optimizer with ready gradients, return successful.
Under the hood, either launch synchronous param all-gathers or get ready to launch
asynchorous all-gathers that get overlapped with the next forward pass.
"""
update_successful
=
super
().
step_with_ready_grads
()
# If there is no FP8 parameters, this will do nothing.
self
.
_update_fp8_scale_inv_and_amax
()
timers
=
self
.
config
.
timers
if
timers
is
not
None
:
timers
(
'params-all-gather'
,
log_level
=
1
).
start
(
barrier
=
self
.
config
.
barrier_with_L1_time
)
# If not overlapping all-gather for parameters, launch synchronous all-gather
# communication calls here. If overlapping all-gather for parameters, the following
# the first all-gather is launched asynchronously in the next optimizer.zero_grad()
# call and subsequent all-gathers are launched in the forward pre-hook.
if
not
self
.
ddp_config
.
overlap_param_gather
:
for
model_chunk
in
self
.
model_chunks
:
model_chunk
.
start_param_sync
()
if
timers
is
not
None
:
timers
(
'params-all-gather'
).
stop
()
return
update_successful
megatron/optimizer/grad_scaler.py
→
megatron/
core/
optimizer/grad_scaler.py
100644 → 100755
View file @
3bec6514
# Copyright (c) 202
2
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 202
4
, NVIDIA CORPORATION. All rights reserved.
"""Megatron grad scaler."""
from
abc
import
ABC
from
abc
import
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
typing
import
Dict
import
torch
class
MegatronGradScaler
(
ABC
):
def
__init__
(
self
,
initial_scale
):
def
__init__
(
self
,
initial_scale
:
float
):
"""Initialize scale value with the input initial scale."""
assert
initial_scale
>
0.0
self
.
_scale
=
torch
.
cuda
.
FloatT
ensor
([
initial_scale
])
self
.
_scale
=
torch
.
t
ensor
([
initial_scale
]
,
dtype
=
torch
.
float
,
device
=
'cuda'
)
@
property
def
scale
(
self
):
...
...
@@ -24,7 +23,7 @@ class MegatronGradScaler(ABC):
return
self
.
_scale
.
double
().
reciprocal
().
float
()
@
abstractmethod
def
update
(
self
,
found_inf
):
def
update
(
self
,
found_inf
:
bool
):
pass
@
abstractmethod
...
...
@@ -32,14 +31,16 @@ class MegatronGradScaler(ABC):
pass
@
abstractmethod
def
load_state_dict
(
self
,
state_dict
):
def
load_state_dict
(
self
,
state_dict
:
Dict
):
pass
class
ConstantGradScaler
(
MegatronGradScaler
):
"""
Constant grad scaler (loss scale is never adjusted regardless of NaNs seen in gradients).
"""
def
update
(
self
,
found_inf
):
def
update
(
self
,
found_inf
:
bool
):
pass
def
state_dict
(
self
):
...
...
@@ -49,26 +50,48 @@ class ConstantGradScaler(MegatronGradScaler):
pass
class
DynamicGradScaler
(
MegatronGradScaler
):
def
__init__
(
self
,
initial_scale
,
min_scale
,
growth_factor
,
backoff_factor
,
growth_interval
,
hysteresis
):
""""Grad scaler with dynamic scale that gets adjusted
during training."""
"""
Grad scaler with dynamic scale that gets adjusted during training.
Reduces loss scale by `backoff_factor` if `hysteresis` number of NaNs are seen in a row. Increases
loss scale by `growth_factor` if NaNs are not seen for `growth_interval` iterations.
"""
def
__init__
(
self
,
initial_scale
:
float
,
min_scale
:
float
,
growth_factor
:
float
,
backoff_factor
:
float
,
growth_interval
:
int
,
hysteresis
:
int
,
):
"""
Grad scaler with dynamic scale that gets adjusted during training.
Args:
initial_scale (float): Initial loss scale value.
min_scale (float): Minimum loss scale value.
growth_factor (float): Factor to grow loss scale by if NaNs are not seen in `growth_interval`
training iterations. Must be greater than 1.
backoff_factor (float): Factor to decrease loss scale by if NaNs are seen in `hysteresis`
consecutive training iterations. Must be between 0 and 1.
growth_interval (int): Number of training iterations of no NaNs before loss scale is increased.
hysteresis (int): Number of training iterations of consecutive NaNs before loss scale is decreased.
"""
super
(
DynamicGradScaler
,
self
).
__init__
(
initial_scale
)
# Lower bound on the scale.
assert
min_scale
>
0.0
assert
min_scale
<=
initial_scale
self
.
min_scale
=
torch
.
cuda
.
FloatT
ensor
([
min_scale
])
self
.
min_scale
=
torch
.
t
ensor
([
min_scale
]
,
dtype
=
torch
.
float
,
device
=
'cuda'
)
# Growth and backoff factors for the scale.
assert
growth_factor
>
1.0
self
.
growth_factor
=
torch
.
cuda
.
FloatT
ensor
([
growth_factor
])
self
.
growth_factor
=
torch
.
t
ensor
([
growth_factor
]
,
dtype
=
torch
.
float
,
device
=
'cuda'
)
assert
backoff_factor
<
1.0
assert
backoff_factor
>
0.0
self
.
backoff_factor
=
torch
.
cuda
.
FloatT
ensor
([
backoff_factor
])
self
.
backoff_factor
=
torch
.
t
ensor
([
backoff_factor
]
,
dtype
=
torch
.
float
,
device
=
'cuda'
)
# Interval over which if we don't see any inf/nan,
# we will scale the grad scale by the growth factor.
assert
growth_interval
>
0
...
...
@@ -82,8 +105,10 @@ class DynamicGradScaler(MegatronGradScaler):
self
.
_growth_tracker
=
0
self
.
_hysteresis_tracker
=
self
.
hysteresis
def
update
(
self
,
found_inf
):
def
update
(
self
,
found_inf
:
bool
):
"""
Updates internal state in grad scaler based on whether NaNs are seen in grads or not.
"""
# If we have an inf/nan, growth tracker is set to 0
# and hysterisis tracker is reduced by 1.
...
...
@@ -92,8 +117,7 @@ class DynamicGradScaler(MegatronGradScaler):
self
.
_hysteresis_tracker
-=
1
# Now if we are out of hysteresis count, scale down the loss.
if
self
.
_hysteresis_tracker
<=
0
:
self
.
_scale
=
torch
.
max
(
self
.
_scale
*
self
.
backoff_factor
,
self
.
min_scale
)
self
.
_scale
=
torch
.
max
(
self
.
_scale
*
self
.
backoff_factor
,
self
.
min_scale
)
else
:
# If there is no nan/inf, increment the growth tracker.
self
.
_growth_tracker
+=
1
...
...
@@ -105,7 +129,6 @@ class DynamicGradScaler(MegatronGradScaler):
# and scale up the loss scale.
self
.
_scale
=
self
.
_scale
*
self
.
growth_factor
def
state_dict
(
self
):
state_dict
=
{}
state_dict
[
'scale'
]
=
self
.
_scale
...
...
@@ -113,8 +136,7 @@ class DynamicGradScaler(MegatronGradScaler):
state_dict
[
'hysteresis_tracker'
]
=
self
.
_hysteresis_tracker
return
state_dict
def
load_state_dict
(
self
,
state_dict
):
def
load_state_dict
(
self
,
state_dict
:
Dict
):
self
.
_scale
=
state_dict
[
'scale'
].
cuda
(
torch
.
cuda
.
current_device
())
self
.
_growth_tracker
=
state_dict
[
'growth_tracker'
]
self
.
_hysteresis_tracker
=
state_dict
[
'hysteresis_tracker'
]
megatron/optimizer/optimizer.py
→
megatron/
core/
optimizer/optimizer.py
100644 → 100755
View file @
3bec6514
# Copyright (c) 202
3
, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 202
4
, NVIDIA CORPORATION. All rights reserved.
"""Megatron optimizer."""
from
abc
import
ABC
from
abc
import
abstractmethod
from
apex.multi_tensor_apply
import
multi_tensor_applier
import
amp_C
import
copy
import
math
from
abc
import
ABC
,
abstractmethod
from
itertools
import
chain
from
logging
import
getLogger
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
torch
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
torch._utils
import
_flatten_dense_tensors
,
_unflatten_dense_tensors
from
megatron
import
get_timers
from
megatron
import
print_rank_0
from
megatron.core
import
mpu
,
tensor_parallel
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.model
import
Float16Module
from
megatron.model.module
import
param_is_not_shared
from
megatron.utils
import
unwrap_model
try
:
from
transformer_engine.pytorch.optimizers
import
multi_tensor_applier
,
multi_tensor_scale
multi_tensor_scale_impl
=
multi_tensor_scale
except
ImportError
:
try
:
import
amp_C
from
apex.multi_tensor_apply
import
multi_tensor_applier
multi_tensor_scale_impl
=
amp_C
.
multi_tensor_scale
except
ImportError
:
import
warnings
warnings
.
warn
(
'Transformer Engine and Apex are not installed. '
'Falling back to local implementations of '
'multi_tensor_applier and multi_tensor_scale'
)
from
megatron.core.utils
import
local_multi_tensor_applier
,
local_multi_tensor_scale
multi_tensor_applier
=
local_multi_tensor_applier
multi_tensor_scale_impl
=
local_multi_tensor_scale
from
.clip_grads
import
clip_grad_norm_fp32
,
count_zeros_fp32
from
..
import
parallel_state
,
tensor_parallel
from
..config_logger
import
has_config_logger_enabled
,
log_config_to_disk
from
..dist_checkpointing.mapping
import
ShardedStateDict
from
..dist_checkpointing.optimizer
import
(
get_param_id_to_sharded_param_map
,
make_sharded_optimizer_tensor
,
optim_state_to_sharding_state
,
)
from
..dist_checkpointing.utils
import
add_prefix_for_sharding
from
..transformer.module
import
param_is_not_shared
from
.clip_grads
import
clip_grad_by_total_norm_fp32
,
count_zeros_fp32
,
get_grad_norm_fp32
from
.grad_scaler
import
MegatronGradScaler
from
.optimizer_config
import
OptimizerConfig
logger
=
getLogger
(
__name__
)
def
_zero_grad_group_helper
(
group
,
set_to_none
):
"""Zero out the gradient for a group of parameters.
Note: copied from torch.optim.optimizer."""
def
_zero_grad_group_helper
(
group
:
List
[
torch
.
nn
.
Parameter
],
set_to_none
:
bool
):
"""
Zero out the gradient for a group of parameters.
Note: copied from torch.optim.optimizer.
"""
for
param
in
group
:
if
param
.
grad
is
not
None
:
if
set_to_none
:
...
...
@@ -36,65 +69,65 @@ def _zero_grad_group_helper(group, set_to_none):
param
.
grad
.
zero_
()
def
_multi_tensor_copy_this_to_that
(
this
,
that
,
overflow_buf
=
None
):
"""Use multi-tensor-applier to copy values from one list to another.
We don't have a blfoat16 implementation so for now if the overflow_buf
def
_multi_tensor_copy_this_to_that
(
this
:
List
[
torch
.
Tensor
],
that
:
List
[
torch
.
Tensor
],
overflow_buf
:
Optional
[
torch
.
Tensor
]
=
None
):
"""
Use multi-tensor-applier to copy values from one list to another.
We don't have a bfloat16 implementation so for now if the overflow_buf
is not provided, we default back to simple loop copy to be compatible
with bfloat16."""
if
overflow_buf
:
with bfloat16.
"""
if
overflow_buf
is
not
None
:
overflow_buf
.
fill_
(
0
)
# Scaling with factor `1.0` is equivalent to copy.
multi_tensor_applier
(
amp_C
.
multi_tensor_scale
,
overflow_buf
,
[
this
,
that
],
1.0
)
multi_tensor_applier
(
multi_tensor_scale_impl
,
overflow_buf
,
[
this
,
that
],
1.0
)
else
:
for
this_
,
that_
in
zip
(
this
,
that
):
that_
.
copy_
(
this_
)
class
MegatronOptimizer
(
ABC
):
"""
Base class for all Megatron optimizers.
Args:
optimizer (torch.optim.Optimizer): base optimizer such as Adam or SGD.
config (OptimizerConfig): configuration object for optimizer.
init_state_fn (Callable, optional): function to initialize state in the optimizer.
"""
def
__init__
(
self
,
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
params_have_main_grad
,
use_contiguous_buffers_in_local_ddp
,
models
):
"""Input optimizer is the base optimizer
for example
Adam."""
def
__init__
(
self
,
optimizer
:
torch
.
optim
.
Optimizer
,
config
:
OptimizerConfig
,
init_state_fn
:
Callable
=
lambda
x
:
None
,
):
"""Input optimizer is the base optimizer
(e.g.,
Adam
)
."""
self
.
optimizer
=
optimizer
assert
self
.
optimizer
,
'no optimizer is provided.'
# Set gradient clipping and logging params.
self
.
clip_grad
=
clip_grad
self
.
log_num_zeros_in_grad
=
log_num_zeros_in_grad
self
.
params_have_main_grad
=
params_have_main_grad
self
.
use_contiguous_buffers_in_local_ddp
=
use_contiguous_buffers_in_local_ddp
# 'models' are retained for access to the contiguous grad buffers.
# (see distributed optimizer)
self
.
models
=
models
self
.
config
=
config
self
.
init_state_fn
=
init_state_fn
if
self
.
use_contiguous_buffers_in_local_ddp
:
assert
self
.
params_have_main_grad
,
\
"use of contiguous buffer requires that params have main grad"
def
get_parameters
(
self
):
def
get_parameters
(
self
)
->
List
[
torch
.
nn
.
Parameter
]:
"""
Get list of parameters wrapped in optimizer.
"""
params
=
[]
for
param_group
in
self
.
optimizer
.
param_groups
:
for
param
in
param_group
[
'params'
]:
params
.
append
(
param
)
return
params
def
get_main_grads_for_grad_norm
(
self
):
# Filter parameters based on:
# - grad should not be none
# - parameter should not be shared
# - should not be a replica due to tensor model parallelism
def
get_main_grads_for_grad_norm
(
self
)
->
List
[
torch
.
Tensor
]:
"""
Get main_grads that should be taken into account to compute the grad norm.
Filter parameters based on:
- grad should not be None.
- parameter should not be shared (i.e., grads shouldn't be double counted while
computing norms).
- should not be a replica due to tensor model parallelism.
"""
params
=
self
.
get_parameters
()
grads_for_norm
=
[]
for
param
in
params
:
...
...
@@ -107,41 +140,85 @@ class MegatronOptimizer(ABC):
return
grads_for_norm
def
get_grad_stats_parallel_group
(
self
)
->
torch
.
distributed
.
ProcessGroup
:
"""Process group for reducing gradient statistics (num_zeros & norm).
def
get_model_parallel_group
(
self
):
"""Default returned here, but the distributed optimizer overrides this."""
return
mpu
.
get_model_parallel_group
()
The two most common cases are:
- Non-distributed optimizer (default): Return the model-parallel group.
- Distributed optimizer (overridden in distrib_optimizer.py): Return the entire world.
"""
if
hasattr
(
self
,
'model_parallel_group'
):
warnings
.
warn
(
"WARNING: `optimizer.model_parallel_group` deprecated and renamed to "
"`optimizer.grad_stats_parallel_group`. The previous name will be "
"removed in a future release."
)
self
.
grad_stats_parallel_group
=
self
.
model_parallel_group
delattr
(
self
,
"model_parallel_group"
)
return
self
.
grad_stats_parallel_group
if
hasattr
(
self
,
'grad_stats_parallel_group'
):
return
self
.
grad_stats_parallel_group
return
parallel_state
.
get_model_parallel_group
()
@
abstractmethod
def
prepare_grads
(
self
)
->
bool
:
"""Pre-processing gradients before the optimizer step, returns whether inf/nan is found."""
return
False
def
clip_grad_norm
(
self
,
clip_grad
):
params
=
self
.
get_parameters
()
grads_for_norm
=
self
.
get_main_grads_for_grad_norm
()
return
clip_grad_norm_fp32
(
params
,
grads_for_norm
,
clip_grad
,
model_parallel_group
=
self
.
get_model_parallel_group
())
@
abstractmethod
def
step_with_ready_grads
(
self
)
->
bool
:
"""Step the optimizer with ready gradients, return successful."""
return
True
@
torch
.
no_grad
()
def
get_grad_norm
(
self
):
"""Compute and return grad norm."""
grads_for_norm
=
self
.
get_main_grads_for_grad_norm
()
total_norm
=
get_grad_norm_fp32
(
grads_for_norm
,
grad_stats_parallel_group
=
self
.
get_grad_stats_parallel_group
()
)
return
total_norm
def
count_zeros
(
self
):
def
clip_grad_norm
(
self
,
clip_grad
:
float
)
->
float
:
"""Compute and return grad norm, also clip grads."""
params
=
self
.
get_parameters
()
return
count_zeros_fp32
(
params
,
model_parallel_group
=
self
.
get_model_parallel_group
())
grads_for_norm
=
self
.
get_main_grads_for_grad_norm
()
grad_norm
=
get_grad_norm_fp32
(
grads_for_norm
,
grad_stats_parallel_group
=
self
.
get_grad_stats_parallel_group
()
)
clip_grad_by_total_norm_fp32
(
params
,
clip_grad
,
grad_norm
)
return
grad_norm
def
count_zeros
(
self
)
->
float
:
"""Count number of zeros in model's gradients."""
params
=
self
.
get_parameters
()
return
count_zeros_fp32
(
params
,
grad_stats_parallel_group
=
self
.
get_grad_stats_parallel_group
()
)
@
abstractmethod
def
zero_grad
(
self
,
set_to_none
=
True
):
def
zero_grad
(
self
,
set_to_none
:
bool
=
True
):
"""Zero gradients and prepare for next forward pass."""
pass
@
abstractmethod
def
get_loss_scale
(
self
):
"""The output should be a cuda tensor of size 1."""
def
get_loss_scale
(
self
)
->
torch
.
Tensor
:
"""
Get current loss scale factor.
NOTE: The output should be a CUDA tensor of size 1.
"""
pass
def
scale_loss
(
self
,
loss
):
def
scale_loss
(
self
,
loss
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Simple scaling."""
return
self
.
get_loss_scale
()
*
loss
def
start_param_sync
(
self
,
model_index
:
int
,
*
unused
):
"""
Start parameter synchronization for all optimizers.
This is a no-op for all non-distributed optimizers.
"""
pass
@
abstractmethod
def
reload_model_params
(
self
):
...
...
@@ -152,17 +229,16 @@ class MegatronOptimizer(ABC):
with main parameters, the main parameters need to also be updated."""
pass
@
abstractmethod
def
state_dict
(
self
):
"""Return state_dict."""
pass
@
abstractmethod
def
load_state_dict
(
self
,
state_dict
):
"""Load pass-in `state_dict`."""
pass
# Promote state so it can be retrieved or set via
# "optimizer_instance.state"
def
_get_state
(
self
):
...
...
@@ -173,7 +249,6 @@ class MegatronOptimizer(ABC):
state
=
property
(
_get_state
,
_set_state
)
# Promote param_groups so it can be retrieved or set via
# "optimizer_instance.param_groups"
# (for example, to adjust the learning rate)
...
...
@@ -185,200 +260,104 @@ class MegatronOptimizer(ABC):
param_groups
=
property
(
_get_param_groups
,
_set_param_groups
)
@
abstractmethod
def
step
(
self
,
args
,
timers
):
pass
def
gather_model_params
(
self
,
args
,
timers
):
"""
For the case of a non-distributed-optimizer, there is nothing to
do here.
"""
def
step
(
self
):
"""Step the optimizer."""
pass
@
abstractmethod
def
sharded_state_dict
(
self
,
model_sharded_state_dict
:
ShardedStateDict
,
is_loading
:
bool
=
False
)
->
ShardedStateDict
:
"""Builds sharded state dict for the optimizer, based on model's sharded state dict.
def
allreduce_word_embedding_grads
(
self
,
args
):
"""
All-reduce word embedding grads.
Args:
model_sharded_state_dict (ShardedStateDict): sharded state dict of the model
is_loading (bool, optional): flag indicating whether the state dict will be
used to save or load the optimizer state. Defaults to False.
Reduce grads across first and last stages to ensure that word_embeddings
parameters stay in sync. This should only run for models that support
pipelined model parallelism (BERT and GPT-2).
Returns: optimizer sharded state dict
"""
if
mpu
.
is_rank_in_embedding_group
(
ignore_virtual
=
True
)
and
\
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
if
mpu
.
is_pipeline_first_stage
(
ignore_virtual
=
True
):
unwrapped_model
=
self
.
models
[
0
]
elif
mpu
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
unwrapped_model
=
self
.
models
[
-
1
]
else
:
# We do not support the interleaved schedule for T5 yet.
unwrapped_model
=
self
.
models
[
0
]
unwrapped_model
=
unwrap_model
(
unwrapped_model
,
(
torchDDP
,
LocalDDP
,
Float16Module
))
if
unwrapped_model
.
share_embeddings_and_output_weights
:
weight
=
unwrapped_model
.
shared_embedding_or_output_weight
()
if
args
.
DDP_impl
==
'local'
:
grad
=
weight
.
main_grad
else
:
grad
=
weight
.
grad
torch
.
distributed
.
all_reduce
(
grad
,
group
=
mpu
.
get_embedding_group
())
def
allreduce_position_embedding_grads
(
self
,
args
):
"""
All-reduce position_embeddings grad across first (encoder) and
split (decoder) stages to ensure that position embeddings parameters
stay in sync. This should only run for T5 models with pipeline
parallelism.
"""
if
mpu
.
is_rank_in_position_embedding_group
()
and
\
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
and
\
args
.
pipeline_model_parallel_split_rank
is
not
None
:
unwrapped_model
=
self
.
models
[
0
]
unwrapped_model
=
unwrap_model
(
unwrapped_model
,
(
torchDDP
,
LocalDDP
,
Float16Module
))
assert
args
.
DDP_impl
==
'local'
,
\
'T5 model is only supported with local DDP mode'
grad
=
unwrapped_model
.
language_model
.
embedding
.
position_embeddings
.
weight
.
main_grad
torch
.
distributed
.
all_reduce
(
grad
,
group
=
mpu
.
get_position_embedding_group
())
def
allreduce_embedding_grads
(
self
,
args
):
"""All-reduce both word and position embeddings."""
self
.
allreduce_word_embedding_grads
(
args
)
self
.
allreduce_position_embedding_grads
(
args
)
def
allreduce_layernorm_grads
(
self
,
args
):
"""All-reduce layernorm grads (for sequence parallelism)."""
# All-reduce layernorm parameters across model parallel nodes
# when sequence parallelism is used
if
mpu
.
get_tensor_model_parallel_world_size
()
>
1
and
\
args
.
sequence_parallel
:
grads
=
[]
for
model_module
in
self
.
models
:
unwrapped_model
=
unwrap_model
(
model_module
,
(
torchDDP
,
LocalDDP
,
Float16Module
))
for
param
in
unwrapped_model
.
parameters
():
if
getattr
(
param
,
'sequence_parallel'
,
False
):
grad
=
param
.
main_grad
if
args
.
DDP_impl
==
'local'
else
param
.
grad
grads
.
append
(
grad
.
data
)
coalesced
=
_flatten_dense_tensors
(
grads
)
torch
.
distributed
.
all_reduce
(
coalesced
,
group
=
mpu
.
get_tensor_model_parallel_group
())
for
buf
,
synced
in
zip
(
grads
,
_unflatten_dense_tensors
(
coalesced
,
grads
)):
buf
.
copy_
(
synced
)
def
reduce_model_grads
(
self
,
args
,
timers
):
"""All-reduce all grads, and all-reduce embeddings."""
# All-reduce layer-norm grads (for sequence parallelism).
timers
(
'layernorm-grads-all-reduce'
,
log_level
=
1
).
start
(
barrier
=
args
.
barrier_with_L1_time
)
self
.
allreduce_layernorm_grads
(
args
)
timers
(
'layernorm-grads-all-reduce'
).
stop
()
# All-reduce if needed.
if
args
.
DDP_impl
==
'local'
:
timers
(
'grads-all-reduce'
,
log_level
=
1
).
start
(
barrier
=
args
.
barrier_with_L1_time
)
for
model
in
self
.
models
:
model
.
allreduce_gradients
()
timers
(
'grads-all-reduce'
).
stop
()
# All-reduce embedding grads.
timers
(
'embedding-grads-all-reduce'
,
log_level
=
1
).
start
(
barrier
=
args
.
barrier_with_L1_time
)
self
.
allreduce_embedding_grads
(
args
)
timers
(
'embedding-grads-all-reduce'
).
stop
()
@
staticmethod
def
_extract_common_per_param_step
(
state_dict
)
->
Union
[
int
,
torch
.
Tensor
]:
common_step
=
None
for
param_idx
,
param_state
in
state_dict
[
'state'
].
items
():
param_step
=
param_state
.
get
(
'step'
,
None
)
if
param_step
is
not
None
:
if
common_step
is
None
:
common_step
=
param_step
elif
common_step
!=
param_step
:
raise
ValueError
(
"The optimizer step differs per parameter. Mcore only supports "
"optimizers whose step is shared across all parameters."
)
return
common_step
@
staticmethod
def
_restore_common_per_param_step
(
state_dict
:
Dict
,
step
:
Union
[
int
,
torch
.
Tensor
]):
for
param_idx
,
param_state
in
state_dict
[
'state'
].
items
():
param_state
[
'step'
]
=
copy
.
deepcopy
(
step
)
class
MixedPrecisionOptimizer
(
MegatronOptimizer
):
"""Base class for both the float-16 and the distributed optimizer.
Arguments:
optimizer: base optimizer such as Adam or SGD
clip_grad: clip gradeints with this global L2 norm. Note
that clipping is ignored if clip_grad == 0
log_num_zeros_in_grad: return number of zeros in the gradients.
params_have_main_grad: flag indicating if parameters have
a `main_grad` field. If this is set, we are assuming
that the model parameters are store in the `main_grad`
field instead of the typical `grad` field. This happens
for the DDP cases where there is a continuous buffer
holding the gradients. For example for bfloat16, we want
to do gradient accumulation and all-reduces in float32
and as a result we store those gradients in the main_grad.
Note that main grad is not necessarily in float32.
use_contiguous_buffers_in_local_ddp: if true, the local DDP model
is using a contiguous buffer to hold the model grads.
fp16: if true, the model is running in fp16.
bf16: if true, the model is running in bfloat16.
params_dtype: used by distributed optimizer.
grad_scaler: used for scaling gradients. Note that this can be
None. This case happens when `bf16 = True` and we don't
Args:
optimizer (torch.optim.Optimizer): base optimizer such as Adam or SGD.
config (OptimizerConfig): configuration object for optimizer.
grad_scaler (MegatronGradScaler): used for scaling gradients. Note that
this can be None. This case happens when `bf16 = True` and we don't
use any loss scale. Note that for `bf16 = True`, we can have
a const
n
at gradient scaler. Also for `bf16 = False`, we
a consta
n
t gradient scaler. Also for `bf16 = False`, we
always require a grad scaler.
models: list of models (i.e., the virtual pipelining models). This
is used by the distributed optimizer for mapping parameters.
init_state_fn (Callable, optional): function to initialize state in the optimizer.
"""
def
__init__
(
self
,
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
params_have_main_grad
,
use_contiguous_buffers_in_local_ddp
,
fp16
,
bf16
,
params_dtype
,
grad_scaler
,
models
):
super
().
__init__
(
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
params_have_main_grad
,
use_contiguous_buffers_in_local_ddp
,
models
)
self
.
fp16
=
fp16
self
.
bf16
=
bf16
self
.
params_dtype
=
params_dtype
def
__init__
(
self
,
optimizer
:
torch
.
optim
.
Optimizer
,
config
:
OptimizerConfig
,
grad_scaler
:
Optional
[
MegatronGradScaler
],
init_state_fn
:
Callable
,
):
if
has_config_logger_enabled
(
config
):
log_config_to_disk
(
config
,
locals
(),
prefix
=
type
(
self
).
__name__
)
super
().
__init__
(
optimizer
,
config
,
init_state_fn
)
self
.
grad_scaler
=
grad_scaler
# None grad scaler is only supported for bf16.
if
self
.
grad_scaler
is
None
:
assert
not
self
.
fp16
,
'fp16 expects a grad scaler.'
assert
not
self
.
config
.
fp16
,
'fp16 expects a grad scaler.'
# Tensor used to determine if a nan/if has happend.
# Any non-zero value indicates inf/nan.
# Note that we keep this for the cases that grad scaler is none.
# We still record nan/inf if we have a bfloat16 with a grad scaler.
if
self
.
grad_scaler
:
self
.
found_inf
=
torch
.
cuda
.
FloatTensor
([
0.0
]
)
self
.
found_inf
=
torch
.
tensor
([
0.0
],
dtype
=
torch
.
float
,
device
=
'cuda'
)
# Dummy tensor needed for apex multi-apply tensor.
# For bfloat, we don't have multi-tensor apply and for now
# we set it to none so the multi-tensor apply gets ignored.
if
bf16
:
if
self
.
config
.
bf16
:
self
.
_dummy_overflow_buf
=
None
else
:
self
.
_dummy_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
]
)
self
.
_dummy_overflow_buf
=
torch
.
tensor
([
0
],
dtype
=
torch
.
int
,
device
=
'cuda'
)
# In case grad scaler is not passed, define the unity scale.
if
self
.
grad_scaler
is
None
:
self
.
_scale_one
=
torch
.
cuda
.
FloatTensor
([
1.0
])
self
.
_scale_one
=
torch
.
tensor
([
1.0
],
dtype
=
torch
.
float
,
device
=
'cuda'
)
def
get_loss_scale
(
self
):
if
self
.
grad_scaler
is
None
:
return
self
.
_scale_one
return
self
.
grad_scaler
.
scale
def
reload_model_params
(
self
):
self
.
_copy_model_params_to_main_params
()
def
_unscale_main_grads_and_check_for_nan
(
self
):
# Collect main grads.
...
...
@@ -389,119 +368,139 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
# Unscale and set found inf/nan
torch
.
_amp_foreach_non_finite_check_and_unscale_
(
main_grads
,
self
.
found_inf
,
self
.
grad_scaler
.
inv_scale
)
main_grads
,
self
.
found_inf
,
self
.
grad_scaler
.
inv_scale
)
# Update across all model parallel instances.
torch
.
distributed
.
all_reduce
(
self
.
found_inf
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
group
=
self
.
get_model_parallel_group
())
torch
.
distributed
.
all_reduce
(
self
.
found_inf
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
group
=
self
.
get_grad_stats_parallel_group
(),
)
# Check for nan.
found_inf_flag
=
(
self
.
found_inf
.
item
()
>
0
)
found_inf_flag
=
self
.
found_inf
.
item
()
>
0
return
found_inf_flag
@
torch
.
no_grad
()
def
step
(
self
,
args
,
timers
):
def
prepare_grads
(
self
)
->
bool
:
"""Pre-processing gradients before the optimizer step, returns whether inf/nan is found."""
timers
=
self
.
config
.
timers
# Copy gradients from model params to main params.
timers
(
'optimizer-copy-to-main-grad'
,
log_level
=
1
).
start
(
barrier
=
args
.
barrier_with_L1_time
)
if
timers
is
not
None
:
timers
(
'optimizer-copy-to-main-grad'
,
log_level
=
1
).
start
(
barrier
=
self
.
config
.
barrier_with_L1_time
)
self
.
_copy_model_grads_to_main_grads
()
timers
(
'optimizer-copy-to-main-grad'
).
stop
()
if
timers
is
not
None
:
timers
(
'optimizer-copy-to-main-grad'
).
stop
()
# Do unscale, check for inf, and update grad scaler only for
# the case that grad scaler is provided.
if
self
.
grad_scaler
:
# Unscale and check for inf/nan.
timers
(
'optimizer-unscale-and-check-inf'
,
log_level
=
1
).
start
(
barrier
=
args
.
barrier_with_L1_time
)
if
timers
is
not
None
:
timers
(
'optimizer-unscale-and-check-inf'
,
log_level
=
1
).
start
(
barrier
=
self
.
config
.
barrier_with_L1_time
)
found_inf_flag
=
self
.
_unscale_main_grads_and_check_for_nan
()
timers
(
'optimizer-unscale-and-check-inf'
).
stop
()
if
timers
is
not
None
:
timers
(
'optimizer-unscale-and-check-inf'
).
stop
()
# We are done with scaling gradients
# so we can update the loss scale.
self
.
grad_scaler
.
update
(
found_inf_flag
)
# If we found inf/nan, skip the update.
if
found_inf_flag
:
return
False
,
None
,
None
return
found_inf_flag
# Clip the main gradients.
timers
(
'optimizer-clip-main-grad'
,
log_level
=
1
).
start
(
barrier
=
args
.
barrier_with_L1_time
)
grad_norm
=
None
if
self
.
clip_grad
>
0.0
:
grad_norm
=
self
.
clip_grad_norm
(
self
.
clip_grad
)
timers
(
'optimizer-clip-main-grad'
).
stop
()
# Count the zeros in the grads.
timers
(
'optimizer-count-zeros'
,
log_level
=
1
).
start
(
barrier
=
args
.
barrier_with_L1_time
)
num_zeros_in_grad
=
self
.
count_zeros
()
if
\
self
.
log_num_zeros_in_grad
else
None
timers
(
'optimizer-count-zeros'
).
stop
()
return
False
@
torch
.
no_grad
()
def
step_with_ready_grads
(
self
)
->
bool
:
"""Step the optimizer with ready gradients, return successful."""
timers
=
self
.
config
.
timers
# Step the optimizer.
timers
(
'optimizer-inner-step'
,
log_level
=
1
).
start
(
barrier
=
args
.
barrier_with_L1_time
)
if
timers
is
not
None
:
timers
(
'optimizer-inner-step'
,
log_level
=
1
).
start
(
barrier
=
self
.
config
.
barrier_with_L1_time
)
self
.
optimizer
.
step
()
timers
(
'optimizer-inner-step'
).
stop
()
if
timers
is
not
None
:
timers
(
'optimizer-inner-step'
).
stop
()
# Update params from main params.
timers
(
'optimizer-copy-main-to-model-params'
,
log_level
=
1
).
start
(
barrier
=
args
.
barrier_with_L1_time
)
if
timers
is
not
None
:
timers
(
'optimizer-copy-main-to-model-params'
,
log_level
=
1
).
start
(
barrier
=
self
.
config
.
barrier_with_L1_time
)
self
.
_copy_main_params_to_model_params
()
timers
(
'optimizer-copy-main-to-model-params'
).
stop
()
if
timers
is
not
None
:
timers
(
'optimizer-copy-main-to-model-params'
).
stop
()
return
True
@
torch
.
no_grad
()
def
step
(
self
):
timers
=
self
.
config
.
timers
found_inf_flag
=
self
.
prepare_grads
()
if
found_inf_flag
:
return
False
,
None
,
None
# Clip the main gradients.
if
timers
is
not
None
:
timers
(
'optimizer-clip-main-grad'
,
log_level
=
1
).
start
(
barrier
=
self
.
config
.
barrier_with_L1_time
)
grad_norm
=
None
if
self
.
config
.
clip_grad
>
0.0
:
grad_norm
=
self
.
clip_grad_norm
(
self
.
config
.
clip_grad
)
if
timers
is
not
None
:
timers
(
'optimizer-clip-main-grad'
).
stop
()
# Count the zeros in the grads.
if
timers
is
not
None
:
timers
(
'optimizer-count-zeros'
,
log_level
=
1
).
start
(
barrier
=
self
.
config
.
barrier_with_L1_time
)
num_zeros_in_grad
=
self
.
count_zeros
()
if
self
.
config
.
log_num_zeros_in_grad
else
None
if
timers
is
not
None
:
timers
(
'optimizer-count-zeros'
).
stop
()
success
=
self
.
step_with_ready_grads
()
# Successful update.
return
True
,
grad_norm
,
num_zeros_in_grad
return
success
,
grad_norm
,
num_zeros_in_grad
class
Float16OptimizerWithFloat16Params
(
MixedPrecisionOptimizer
):
"""Float16 optimizer for fp16 and bf16 data types.
Arguments:
optimizer: base optimizer such as Adam or SGD
clip_grad: clip gradeints with this global L2 norm. Note
that clipping is ignored if clip_grad == 0
log_num_zeros_in_grad: return number of zeros in the gradients.
params_have_main_grad: flag indicating if parameters have
a `main_grad` field. If this is set, we are assuming
that the model parameters are store in the `main_grad`
field instead of the typical `grad` field. This happens
for the DDP cases where there is a continuous buffer
holding the gradients. For example for bfloat16, we want
to do gradient accumulation and all-reduces in float32
and as a result we store those gradients in the main_grad.
Note that main grad is not necessarily in float32.
use_contiguous_buffers_in_local_ddp: if true, the local DDP model
is using a contiguous buffer to hold the model grads.
fp16: if true, the model is running in fp16.
bf16: if true, the model is running in bfloat16.
grad_scaler: used for scaling gradients. Note that this can be
None. This case happens when `bf16 = True` and we don't
Args:
optimizer (torch.optim.Optimizer): base optimizer such as Adam or SGD.
config (OptimizerConfig): configuration object for optimizer.
grad_scaler (MegatronGradScaler): used for scaling gradients. Note that
this can be None. This case happens when `bf16 = True` and we don't
use any loss scale. Note that for `bf16 = True`, we can have
a const
n
at gradient scaler. Also for `bf16 = False`, we
a consta
n
t gradient scaler. Also for `bf16 = False`, we
always require a grad scaler.
models: list of models (i.e., the virtual pipelining models). This
is used by the distributed optimizer for mapping parameters.
init_state_fn (Callable, optional): function to initialize state in the optimizer.
"""
def
__init__
(
self
,
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
params_have_main_grad
,
use_contiguous_buffers_in_local_ddp
,
fp16
,
bf16
,
params_dtype
,
grad_scaler
,
models
):
def
__init__
(
self
,
optimizer
:
torch
.
optim
.
Optimizer
,
config
:
OptimizerConfig
,
grad_scaler
:
MegatronGradScaler
,
init_state_fn
:
Callable
,
):
super
().
__init__
(
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
params_have_main_grad
,
use_contiguous_buffers_in_local_ddp
,
fp16
,
bf16
,
params_dtype
,
grad_scaler
,
models
)
super
().
__init__
(
optimizer
,
config
,
grad_scaler
,
init_state_fn
)
# ======================
# main parameter stuff
# ======================
# Handle main parameters.
# Three groups of parameters:
# float16_groups: original float16 parameters
...
...
@@ -521,14 +520,12 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
if
param
.
requires_grad
:
# float16 params:
if
param
.
type
()
in
[
'torch.cuda.HalfTensor'
,
'torch.cuda.BFloat16Tensor'
]:
if
param
.
type
()
in
[
'torch.cuda.HalfTensor'
,
'torch.cuda.BFloat16Tensor'
]:
float16_params_this_group
.
append
(
param
)
# Create a copy
main_param
=
param
.
detach
().
clone
().
float
()
# Copy tensor model parallel attributes.
tensor_parallel
.
copy_tensor_model_parallel_attributes
(
main_param
,
param
)
tensor_parallel
.
copy_tensor_model_parallel_attributes
(
main_param
,
param
)
if
hasattr
(
param
,
'shared'
):
main_param
.
shared
=
param
.
shared
# Replace the optimizer params with the new fp32 copy.
...
...
@@ -537,26 +534,25 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
fp32_from_float16_params_this_group
.
append
(
main_param
)
# Reset existing state dict key to the new main param.
if
param
in
self
.
optimizer
.
state
:
self
.
optimizer
.
state
[
main_param
]
\
=
self
.
optimizer
.
state
.
pop
(
param
)
self
.
optimizer
.
state
[
main_param
]
=
self
.
optimizer
.
state
.
pop
(
param
)
# fp32 params.
elif
param
.
type
()
==
'torch.cuda.FloatTensor'
:
fp32_params_this_group
.
append
(
param
)
param_group
[
'params'
][
i
]
=
param
else
:
raise
TypeError
(
'Wrapped parameters must be one of '
'torch.cuda.FloatTensor, '
'torch.cuda.HalfTensor, or '
'torch.cuda.BFloat16Tensor. '
'Received {}'
.
format
(
param
.
type
()))
raise
TypeError
(
'Wrapped parameters must be one of '
'torch.cuda.FloatTensor, '
'torch.cuda.HalfTensor, or '
'torch.cuda.BFloat16Tensor. '
'Received {}'
.
format
(
param
.
type
())
)
self
.
float16_groups
.
append
(
float16_params_this_group
)
self
.
fp32_from_float16_groups
.
append
(
fp32_from_float16_params_this_group
)
self
.
fp32_from_float16_groups
.
append
(
fp32_from_float16_params_this_group
)
self
.
fp32_from_fp32_groups
.
append
(
fp32_params_this_group
)
def
zero_grad
(
self
,
set_to_none
=
True
):
"""We only need to zero the model related parameters, i.e.,
float16_groups & fp32_from_fp32_groups. We additionally zero
...
...
@@ -570,7 +566,6 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
for
group
in
self
.
fp32_from_fp32_groups
:
_zero_grad_group_helper
(
group
,
set_to_none
)
def
_collect_main_grad_data_for_unscaling
(
self
):
main_grads
=
[]
...
...
@@ -586,27 +581,23 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
for
main_param
in
main_group
:
if
main_param
.
grad
is
not
None
:
main_grads
.
append
(
main_param
.
grad
.
data
)
return
main_grads
return
main_grads
def
_get_model_and_main_params_data_float16
(
self
):
model_data
=
[]
main_data
=
[]
for
model_group
,
main_group
in
zip
(
self
.
float16_groups
,
self
.
fp32_from_float16_groups
):
for
model_group
,
main_group
in
zip
(
self
.
float16_groups
,
self
.
fp32_from_float16_groups
):
for
model_param
,
main_param
in
zip
(
model_group
,
main_group
):
model_data
.
append
(
model_param
.
data
)
main_data
.
append
(
main_param
.
data
)
return
model_data
,
main_data
def
_copy_model_grads_to_main_grads
(
self
):
# This only needs to be done for the float16 group.
for
model_group
,
main_group
in
zip
(
self
.
float16_groups
,
self
.
fp32_from_float16_groups
):
for
model_group
,
main_group
in
zip
(
self
.
float16_groups
,
self
.
fp32_from_float16_groups
):
for
model_param
,
main_param
in
zip
(
model_group
,
main_group
):
if
self
.
params_have_main_grad
and
hasattr
(
model_param
,
'main_grad'
):
if
hasattr
(
model_param
,
'main_grad'
):
main_param
.
grad
=
model_param
.
main_grad
.
float
()
else
:
if
model_param
.
grad
is
not
None
:
...
...
@@ -616,36 +607,25 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
# (If using contiguous buffers, main_grad's memory should
# persist and therefore should not be deallocated.)
model_param
.
grad
=
None
if
self
.
params_have_main_grad
and
\
not
self
.
use_contiguous_buffers_in_local_ddp
:
model_param
.
main_grad
=
None
# For fp32 grads, we need to reset the grads to main grad.
if
self
.
params_have_main_grad
:
for
model_group
in
self
.
fp32_from_fp32_groups
:
for
model_param
in
model_group
:
model_param
.
grad
=
model_param
.
main_grad
# Safe to de-reference model's main_grad after copying.
# (If using contiguous buffers, main_grad's memory should
# persist and therefore should not be deallocated.)
if
not
self
.
use_contiguous_buffers_in_local_ddp
:
model_param
.
main_grad
=
None
for
model_group
in
self
.
fp32_from_fp32_groups
:
for
model_param
in
model_group
:
model_param
.
grad
=
model_param
.
main_grad
def
_copy_main_params_to_model_params
(
self
):
# Only needed for the float16 params.
model_data
,
main_data
=
self
.
_get_model_and_main_params_data_float16
()
_multi_tensor_copy_this_to_that
(
this
=
main_data
,
that
=
model_data
,
overflow_buf
=
self
.
_dummy_overflow_buf
)
_multi_tensor_copy_this_to_that
(
this
=
main_data
,
that
=
model_data
,
overflow_buf
=
self
.
_dummy_overflow_buf
)
def
_copy_model_params_to_main_params
(
self
):
# Only needed for the float16 params.
model_data
,
main_data
=
self
.
_get_model_and_main_params_data_float16
()
_multi_tensor_copy_this_to_that
(
this
=
model_data
,
that
=
main_data
,
overflow_buf
=
self
.
_dummy_overflow_buf
)
_multi_tensor_copy_this_to_that
(
this
=
model_data
,
that
=
main_data
,
overflow_buf
=
self
.
_dummy_overflow_buf
)
def
state_dict
(
self
):
state_dict
=
{}
...
...
@@ -655,119 +635,435 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
state_dict
[
'fp32_from_fp16_params'
]
=
self
.
fp32_from_float16_groups
return
state_dict
def
sharded_state_dict
(
self
,
model_sharded_state_dict
:
ShardedStateDict
,
is_loading
:
bool
=
False
):
if
is_loading
:
self
.
init_state_fn
(
self
.
optimizer
)
state_dict
=
self
.
state_dict
()
id_to_sharded_param_map
=
get_param_id_to_sharded_param_map
(
model_sharded_state_dict
,
chain
.
from_iterable
(
g
for
g
in
self
.
float16_groups
)
)
# Convert fp32_from_fp16_params
assert
len
(
state_dict
[
'fp32_from_fp16_params'
])
==
len
(
state_dict
[
'optimizer'
][
'param_groups'
]
)
state_dict
[
'fp32_from_fp16_params'
]
=
[
[
make_sharded_optimizer_tensor
(
id_to_sharded_param_map
[
param_id
],
fp32_param
,
prefix
=
f
'optimizer.state.fp32_param'
,
)
for
param_id
,
fp32_param
in
zip
(
state_group
[
'params'
],
fp32_group
)
]
for
fp32_group
,
state_group
in
zip
(
state_dict
[
'fp32_from_fp16_params'
],
state_dict
[
'optimizer'
][
'param_groups'
]
)
]
step
=
self
.
_extract_common_per_param_step
(
state_dict
[
'optimizer'
])
# Convert regular optimizer state
# all optimizer parameters passed to optim_state_to_sharding_state are
# expected to have the same shape as the model parameters,
# so we save the step separately and ignore it here
optim_state_to_sharding_state
(
state_dict
[
'optimizer'
],
id_to_sharded_param_map
,
exclude_keys
=
"step"
)
# save step as a shared step among all parameters. Separate per-parameter
# steps are not supported
state_dict
[
'optimizer'
][
'state'
][
'common_step'
]
=
step
return
state_dict
def
load_state_dict
(
self
,
state_dict
):
pipeline_parallel_size
=
parallel_state
.
get_pipeline_model_parallel_world_size
()
# Optimizer.
optimizer_key
=
'optimizer'
if
optimizer_key
not
in
state_dict
:
optimizer_key
=
'optimizer_state_dict'
print_rank_0
(
'***WARNING*** loading optimizer from '
'an old checkpoint ...'
)
logger
.
info
(
'***WARNING*** loading optimizer from an old checkpoint ...'
)
if
'common_step'
in
state_dict
[
optimizer_key
][
'state'
]:
common_step
=
state_dict
[
optimizer_key
][
'state'
].
pop
(
'common_step'
)
self
.
_restore_common_per_param_step
(
state_dict
[
optimizer_key
],
common_step
)
self
.
optimizer
.
load_state_dict
(
state_dict
[
optimizer_key
])
# Grad scaler.
if
'grad_scaler'
not
in
state_dict
:
if
self
.
fp16
:
print_rank_0
(
'***WARNING*** found an old checkpoint, will not '
'load grad scaler ...'
)
if
self
.
config
.
fp16
:
logger
.
info
(
'***WARNING*** found an old checkpoint, will not load grad scaler ...'
)
else
:
if
self
.
grad_scaler
:
self
.
grad_scaler
.
load_state_dict
(
state_dict
[
'grad_scaler'
])
else
:
print_rank_0
(
'***WARNING*** fould the grad scaler in the '
'checkpoint but it is None in the class. '
'Skipping loading grad scaler ...'
)
logger
.
info
(
'***WARNING*** fould the grad scaler in the '
'checkpoint but it is None in the class. '
'Skipping loading grad scaler ...'
)
# Copy data for the main params.
fp32_from_float16_params_key
=
'fp32_from_fp16_params'
if
fp32_from_float16_params_key
not
in
state_dict
:
fp32_from_float16_params_key
=
'fp32_from_fp16'
for
current_group
,
saved_group
in
zip
(
self
.
fp32_from_float16_groups
,
state_dict
[
fp32_from_float16_params_key
]
):
self
.
fp32_from_float16_groups
,
state_dict
[
fp32_from_float16_params_key
]
):
for
current_param
,
saved_param
in
zip
(
current_group
,
saved_group
):
current_param
.
data
.
copy_
(
saved_param
.
data
)
class
FP32Optimizer
(
MegatronOptimizer
):
"""Float32 optimizer.
def
__init__
(
self
,
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
params_have_main_grad
,
use_contiguous_buffers_in_local_ddp
,
models
):
Args:
optimizer (torch.optim.Optimizer): base optimizer such as Adam or SGD.
config (OptimizerConfig): configuration object for optimizer.
init_state_fn (Callable, optional): function to initialize state in the optimizer.
"""
super
(
FP32Optimizer
,
self
).
__init__
(
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
params_have_main_grad
,
use_contiguous_buffers_in_local_ddp
,
models
)
def
__init__
(
self
,
optimizer
:
torch
.
optim
.
Optimizer
,
config
:
OptimizerConfig
,
init_state_fn
:
Callable
):
if
has_config_logger_enabled
(
config
):
log_config_to_disk
(
config
,
locals
(),
prefix
=
type
(
self
).
__name__
)
self
.
_
scale
=
torch
.
cuda
.
FloatTensor
([
1.0
]
)
super
(
FP32Optimizer
,
self
)
.
_
_init__
(
optimizer
,
config
,
init_state_fn
)
self
.
_scale
=
torch
.
tensor
([
1.0
],
dtype
=
torch
.
float
,
device
=
'cuda'
)
def
zero_grad
(
self
,
set_to_none
=
True
):
"""Copied from torch.optim.optimizer"""
for
group
in
self
.
optimizer
.
param_groups
:
_zero_grad_group_helper
(
group
[
'params'
],
set_to_none
)
def
get_loss_scale
(
self
):
"""FP32 optimizer does not do any scaling."""
return
self
.
_scale
@
torch
.
no_grad
()
def
step
(
self
,
args
,
timers
)
:
"""
Clip gradients (if needed) and step the base optimizer.
Always return successful since there is no overflow."""
def
prepare_grads
(
self
)
->
bool
:
"""
Pre-processing gradients before the optimizer step, returns whether inf/nan is found."""
timers
=
self
.
config
.
timers
# Copy main_grads to grads.
timers
(
'optimizer-copy-to-main-grad'
,
log_level
=
1
).
start
(
barrier
=
args
.
barrier_with_L1_time
)
if
self
.
params_have_main_grad
:
for
param_group
in
self
.
optimizer
.
param_groups
:
for
param
in
param_group
[
'params'
]:
if
timers
is
not
None
:
timers
(
'optimizer-copy-to-main-grad'
,
log_level
=
1
).
start
(
barrier
=
self
.
config
.
barrier_with_L1_time
)
for
param_group
in
self
.
optimizer
.
param_groups
:
for
param
in
param_group
[
'params'
]:
if
hasattr
(
param
,
'main_grad'
):
param
.
grad
=
param
.
main_grad
if
timers
is
not
None
:
timers
(
'optimizer-copy-to-main-grad'
).
stop
()
return
False
@
torch
.
no_grad
()
def
step_with_ready_grads
(
self
)
->
bool
:
"""Step the optimizer with ready gradients, return successful."""
timers
=
self
.
config
.
timers
# Update parameters.
if
timers
is
not
None
:
timers
(
'optimizer-inner-step'
,
log_level
=
1
).
start
(
barrier
=
self
.
config
.
barrier_with_L1_time
)
self
.
optimizer
.
step
()
if
timers
is
not
None
:
timers
(
'optimizer-inner-step'
).
stop
()
return
True
# Safe to de-reference model's main_grad after copying.
# (If using contiguous buffers, main_grad's memory should
# persist and therefore should not be deallocated.)
if
not
self
.
use_contiguous_buffers_in_local_ddp
:
param
.
main_grad
=
None
timers
(
'optimizer-copy-to-main-grad'
).
stop
()
@
torch
.
no_grad
()
def
step
(
self
):
"""Clip gradients (if needed) and step the base optimizer.
Always return successful since there is no overflow."""
timers
=
self
.
config
.
timers
found_inf_flag
=
self
.
prepare_grads
()
if
found_inf_flag
:
return
False
,
None
,
None
# Clip gradients.
timers
(
'optimizer-clip-main-grad'
,
log_level
=
1
).
start
(
barrier
=
args
.
barrier_with_L1_time
)
if
timers
is
not
None
:
timers
(
'optimizer-clip-main-grad'
,
log_level
=
1
).
start
(
barrier
=
self
.
config
.
barrier_with_L1_time
)
grad_norm
=
None
if
self
.
clip_grad
>
0.0
:
grad_norm
=
self
.
clip_grad_norm
(
self
.
clip_grad
)
timers
(
'optimizer-clip-main-grad'
).
stop
()
if
self
.
config
.
clip_grad
>
0.0
:
grad_norm
=
self
.
clip_grad_norm
(
self
.
config
.
clip_grad
)
if
timers
is
not
None
:
timers
(
'optimizer-clip-main-grad'
).
stop
()
# count the zeros in the grads
timers
(
'optimizer-count-zeros'
,
log_level
=
1
).
start
(
barrier
=
args
.
barrier_with_L1_time
)
num_zeros_in_grad
=
self
.
count_zeros
()
if
\
self
.
log_num_zeros_in_grad
else
None
timers
(
'optimizer-count-zeros'
).
stop
()
# Count the zeros in the grads.
if
timers
is
not
None
:
timers
(
'optimizer-count-zeros'
,
log_level
=
1
).
start
(
barrier
=
self
.
config
.
barrier_with_L1_time
)
num_zeros_in_grad
=
self
.
count_zeros
()
if
self
.
config
.
log_num_zeros_in_grad
else
None
if
timers
is
not
None
:
timers
(
'optimizer-count-zeros'
).
stop
()
# Update parameters.
timers
(
'optimizer-inner-step'
,
log_level
=
1
).
start
(
barrier
=
args
.
barrier_with_L1_time
)
self
.
optimizer
.
step
()
timers
(
'optimizer-inner-step'
).
stop
()
success
=
self
.
step_with_ready_grads
()
# No overflow for FP32 optimizer.
return
True
,
grad_norm
,
num_zeros_in_grad
return
success
,
grad_norm
,
num_zeros_in_grad
def
reload_model_params
(
self
):
pass
def
state_dict
(
self
):
return
self
.
optimizer
.
state_dict
()
def
load_state_dict
(
self
,
state_dict
):
pipeline_parallel_size
=
parallel_state
.
get_pipeline_model_parallel_world_size
()
if
'common_step'
in
state_dict
[
'state'
]:
common_step
=
state_dict
[
'state'
].
pop
(
'common_step'
)
self
.
_restore_common_per_param_step
(
state_dict
,
common_step
)
self
.
optimizer
.
load_state_dict
(
state_dict
)
def
sharded_state_dict
(
self
,
model_sharded_state_dict
:
ShardedStateDict
,
is_loading
:
bool
=
False
):
if
is_loading
:
self
.
init_state_fn
(
self
.
optimizer
)
state_dict
=
self
.
state_dict
()
id_to_sharded_param_map
=
get_param_id_to_sharded_param_map
(
model_sharded_state_dict
,
self
.
get_parameters
()
)
step
=
self
.
_extract_common_per_param_step
(
state_dict
)
# all optimizer parameters passed to optim_state_to_sharding_state are
# expected to have the same shape as the model parameters,
# so we save the step separately and ignore it here
optim_state_to_sharding_state
(
state_dict
,
id_to_sharded_param_map
,
exclude_keys
=
"step"
)
# save step as a shared step among all parameters. Separate per-parameter
# steps are not supported
state_dict
[
'state'
][
'common_step'
]
=
step
return
state_dict
class
ProxyDict
:
"""
A dictionary-like object that proxies to a list of dictionaries.
e.g., ProxyDict([{'a': 1}, {'b': 2}]) behaves like:
{
(0, 'a'): 1,
(1, 'b'): 2,
}
We use tuples as keys to avoid ambiguity with the keys of the inner dicts.
"""
def
__init__
(
self
,
inner_dicts
:
List
[
dict
]):
self
.
_inner_dicts
=
inner_dicts
def
__getitem__
(
self
,
key
:
Tuple
[
int
,
str
]):
idx
,
inner_key
=
key
return
self
.
_inner_dicts
[
idx
].
get
(
inner_key
)
def
__setitem__
(
self
,
key
:
Tuple
[
int
,
str
],
value
:
Any
):
idx
,
inner_key
=
key
self
.
_inner_dicts
[
idx
][
inner_key
]
=
value
def
__len__
(
self
)
->
int
:
return
sum
([
len
(
inner_dict
)
for
inner_dict
in
self
.
_inner_dicts
])
def
__iter__
(
self
):
for
idx
,
inner_dict
in
enumerate
(
self
.
_inner_dicts
):
for
inner_key
in
inner_dict
:
yield
(
idx
,
inner_key
)
def
items
(
self
):
"""Return generator over underlying items."""
for
idx
,
inner_dict
in
enumerate
(
self
.
_inner_dicts
):
for
inner_key
,
value
in
inner_dict
.
items
():
yield
(
idx
,
inner_key
),
value
class
ChainedOptimizer
(
MegatronOptimizer
):
"""ChainedOptimizer is designed for a collection of optimizers.
These optimizers are responsible for different parts of multiple models for
a training task and will be executed one-by-one when the model is updated.
Args:
chained_optimizers: a list of optimizers.
"""
def
__init__
(
self
,
chained_optimizers
:
List
[
MegatronOptimizer
]):
self
.
model_chunks
=
[]
self
.
config
=
getattr
(
chained_optimizers
[
0
],
'config'
,
None
)
for
optimizer
in
chained_optimizers
:
if
hasattr
(
optimizer
,
'model_chunks'
):
for
model_chunk
in
optimizer
.
model_chunks
:
if
model_chunk
not
in
self
.
model_chunks
:
self
.
model_chunks
.
append
(
model_chunk
)
assert
self
.
config
==
getattr
(
optimizer
,
'config'
,
None
)
self
.
chained_optimizers
=
chained_optimizers
@
property
def
param_groups
(
self
)
->
List
[
dict
]:
"""Get param_groups aggregated over underlying optimizers."""
param_groups
=
[]
for
optimizer
in
self
.
chained_optimizers
:
param_groups
+=
optimizer
.
param_groups
return
param_groups
@
property
def
state
(
self
)
->
ProxyDict
:
"""
Return optimizer state with tuple keys, where the first element is the
index of the optimizer in the list of chained optimizers.
"""
return
ProxyDict
([
opt
.
state
for
opt
in
self
.
chained_optimizers
])
def
zero_grad
(
self
,
set_to_none
=
True
):
for
optimizer
in
self
.
chained_optimizers
:
optimizer
.
zero_grad
(
set_to_none
)
def
get_loss_scale
(
self
):
return
self
.
chained_optimizers
[
0
].
get_loss_scale
()
def
reload_model_params
(
self
):
for
optimizer
in
self
.
chained_optimizers
:
optimizer
.
reload_model_params
()
def
state_dict
(
self
):
return
[
optimizer
.
state_dict
()
for
optimizer
in
self
.
chained_optimizers
]
def
sharded_state_dict
(
self
,
model_sharded_state_dict
:
ShardedStateDict
,
is_loading
:
bool
=
False
,
**
kwargs
):
sharded_state_dict
=
{}
for
optimizer_idx
,
optimizer
in
enumerate
(
self
.
chained_optimizers
):
optim_state_dict
=
optimizer
.
sharded_state_dict
(
model_sharded_state_dict
,
is_loading
,
**
kwargs
)
add_prefix_for_sharding
(
optim_state_dict
,
f
'chained_
{
optimizer_idx
}
.'
)
sharded_state_dict
[
optimizer_idx
]
=
optim_state_dict
return
sharded_state_dict
def
load_state_dict
(
self
,
state_dict
):
if
len
(
self
.
chained_optimizers
)
!=
len
(
state_dict
):
raise
RuntimeError
(
f
'Expected
{
len
(
self
.
chained_optimizers
)
}
entries'
f
' in state dict, but got
{
len
(
state_dict
)
}
.'
)
if
isinstance
(
state_dict
,
dict
):
state_dict
=
(
v
for
k
,
v
in
sorted
(
state_dict
.
items
()))
for
optimizer
,
state
in
zip
(
self
.
chained_optimizers
,
state_dict
):
optimizer
.
load_state_dict
(
state
)
@
torch
.
no_grad
()
def
prepare_grads
(
self
)
->
bool
:
"""Pre-processing gradients before the optimizer step, returns whether inf/nan is found."""
found_inf_flag
=
False
for
optimizer
in
self
.
chained_optimizers
:
found_inf_flag
|=
optimizer
.
prepare_grads
()
return
found_inf_flag
@
torch
.
no_grad
()
def
step_with_ready_grads
(
self
)
->
bool
:
"""Step the optimizer with ready gradients, return successful."""
success
=
True
for
optimizer_idx
,
optimizer
in
enumerate
(
self
.
chained_optimizers
):
success
&=
optimizer
.
step_with_ready_grads
()
if
self
.
config
.
overlap_param_gather_with_optimizer_step
and
optimizer_idx
==
0
:
assert
success
assert
len
(
optimizer
.
model_chunks
)
==
1
optimizer
.
model_chunks
[
0
].
start_param_sync
(
force_dispatch
=
True
)
return
success
@
torch
.
no_grad
()
def
step
(
self
):
"""ChainedOptimizer will step all optimizers one by one."""
found_inf_flag
=
self
.
prepare_grads
()
if
found_inf_flag
:
return
False
,
None
,
None
# Get grad norm.
grad_norms
=
[]
for
optimizer
in
self
.
chained_optimizers
:
_grad_norm
=
optimizer
.
get_grad_norm
()
grad_norms
+=
[
_grad_norm
if
_grad_norm
else
0.0
]
grad_norm
=
math
.
sqrt
(
sum
([
x
**
2
for
x
in
grad_norms
]))
# Clip gradients.
for
optimizer
in
self
.
chained_optimizers
:
if
optimizer
.
config
.
clip_grad
>
0.0
:
clip_grad_by_total_norm_fp32
(
optimizer
.
get_parameters
(),
max_norm
=
optimizer
.
config
.
clip_grad
,
total_norm
=
grad_norm
,
)
# Count the zeros in the grads.
num_zeros_in_grad
=
0
for
optimizer
in
self
.
chained_optimizers
:
num_zeros_in_grad
+=
(
optimizer
.
count_zeros
()
if
optimizer
.
config
.
log_num_zeros_in_grad
else
0
)
update_successful
=
self
.
step_with_ready_grads
()
return
update_successful
,
grad_norm
,
num_zeros_in_grad
def
save_parameter_state
(
self
,
filename
:
str
):
"""Save the distributed parameter states of all optimizers to a file.
Args:
filename (str): path to save parameter state to.
"""
save_states
=
False
states
=
[]
for
optimizer
in
self
.
chained_optimizers
:
if
hasattr
(
optimizer
,
'get_parameter_state_dp_zero'
):
state_dict
=
optimizer
.
get_parameter_state_dp_zero
()
# Save checkpoint economically, only when DP rank = 0, state dict
# needs to be saved.
if
torch
.
distributed
.
get_rank
(
optimizer
.
data_parallel_group
)
==
0
:
states
.
append
(
state_dict
)
save_states
=
True
else
:
states
.
append
(
None
)
else
:
states
.
append
(
None
)
if
save_states
:
torch
.
save
(
states
,
filename
)
def
load_parameter_state
(
self
,
filename
:
str
,
*
,
update_legacy_format
:
bool
=
False
):
"""Load the distributed parameter states of all optimizers from a file.
Args:
filename (str): path to load parameter state from.
"""
states
=
None
for
idx
,
optimizer
in
enumerate
(
self
.
chained_optimizers
):
if
not
hasattr
(
optimizer
,
'load_parameter_state_from_dp_zero'
):
continue
# Lazy loading checkpoint, state dict is needed only when DP rank = 0.
if
torch
.
distributed
.
get_rank
(
optimizer
.
data_parallel_group
)
==
0
and
states
is
None
:
states
=
torch
.
load
(
filename
)
state_dict
=
states
[
idx
]
if
states
else
None
optimizer
.
load_parameter_state_from_dp_zero
(
state_dict
,
update_legacy_format
=
update_legacy_format
)
def
start_param_sync
(
self
,
model_index
:
int
,
*
unused
):
"""Start parameter synchronization for all optimizers."""
for
optimizer
in
self
.
chained_optimizers
:
optimizer
.
start_param_sync
(
model_index
,
*
unused
)
megatron/core/optimizer/optimizer_config.py
0 → 100755
View file @
3bec6514
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from
dataclasses
import
dataclass
from
typing
import
Callable
,
Optional
import
torch
@
dataclass
class
OptimizerConfig
:
"""Configuration for optimizer."""
##############
# General
##############
optimizer
:
str
=
'adam'
"""Optimizer to use (one of Adam or SGD)."""
lr
:
Optional
[
float
]
=
None
"""Initial learning rate. Depending on decay style and initial warmup, the learning rate at each
iteration would be different.
"""
min_lr
:
Optional
[
float
]
=
None
"""Minumum value for learning rate. The scheduler clip values below this threshold."""
decoupled_lr
:
Optional
[
float
]
=
None
"""Separate learning rate for the input and output layer."""
decoupled_min_lr
:
Optional
[
float
]
=
None
"""Minimum value for learning rate for the input and output layer. The scheduler clip values
below this threshold.
"""
weight_decay
:
float
=
0.01
"""Weight decay coefficient for L2 regularization."""
##############
# Precision
##############
fp16
:
bool
=
False
"""If true, train with fp16 mixed precision training. Defaults to False."""
bf16
:
bool
=
False
"""If true, train with bf16 mixed precision training. Defaults to False."""
params_dtype
:
torch
.
dtype
=
torch
.
float32
"""dtype used when intializing the weights. Defaults to torch.float32."""
###############
# Loss scaling
###############
loss_scale
:
Optional
[
float
]
=
None
"""Static loss scaling, positive power of 2 values can improve fp16 convergence. If None,
dynamic loss scaling is used.
"""
initial_loss_scale
:
float
=
2
**
32
"""Initial loss-scale for dynamic loss scaling."""
min_loss_scale
:
float
=
1.0
"""Minimum loss scale for dynamic loss scaling."""
loss_scale_window
:
float
=
1000
"""Window over which to raise/lower dynamic scale."""
hysteresis
:
int
=
2
"""Hysteresis for dynamic loss scaling."""
##############
# Optimizer
##############
# Adam
adam_beta1
:
float
=
0.9
"""First coefficient for computing running averages of gradient and its square in Adam
optimizer.
"""
adam_beta2
:
float
=
0.999
"""Second coefficient for computing running averages of gradient and its square in Adam
optimizer.
"""
adam_eps
:
float
=
1e-08
"""Term added to the denominator to improve numerical stability in Adam optimizer."""
# SGD.
sgd_momentum
:
float
=
0.9
"""Momentum factor for SGD optimizer."""
#######################
# Distributed optimizer
#######################
use_distributed_optimizer
:
bool
=
False
"""Distribute optimizer state over data-parallel replicas."""
overlap_param_gather_with_optimizer_step
:
bool
=
False
"""If true, overlap param all-gather of first bucket with optimizer step."""
################
# Miscellaneous
################
clip_grad
:
float
=
1.0
"""Gradient clipping based on global L2 norm."""
log_num_zeros_in_grad
:
bool
=
False
"""If true, calculate and log the number of zeros in gradient."""
barrier_with_L1_time
:
bool
=
False
"""If true, use barrier with level 1 time measurements."""
timers
:
Callable
=
None
"""Function to get timers."""
config_logger_dir
:
str
=
""
"""When non-empty, dumps entry-point configs to config_logger_dir"""
megatron/optimizer_param_scheduler.py
→
megatron/
core/
optimizer_param_scheduler.py
100644 → 100755
View file @
3bec6514
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Learning rate decay and weight decay incr functions."""
import
logging
import
math
from
megatron
import
print_rank_0
class
OptimizerParamScheduler
(
object
):
"""Anneals learning rate and weight decay"""
def
__init__
(
self
,
optimizer
,
init_lr
,
max_lr
,
min_lr
,
lr_warmup_steps
,
lr_decay_steps
,
lr_decay_style
,
start_wd
,
end_wd
,
wd_incr_steps
,
wd_incr_style
,
use_checkpoint_opt_param_scheduler
=
True
,
override_opt_param_scheduler
=
False
):
from
typing
import
Optional
from
megatron.core.optimizer
import
MegatronOptimizer
from
megatron.core.utils
import
log_single_rank
logger
=
logging
.
getLogger
(
__name__
)
class
OptimizerParamScheduler
:
"""Anneals learning rate and weight decay
Args:
optimizer (MegatronOptimizer): the optimizer to be used
init_lr (float): initial learning rate
max_lr (float): maximum learning rate
min_lr (float): minimum learning rate
lr_warmup_steps (int): number of warmup steps
lr_decay_steps (int): number of decay steps
lr_decay_style (str): decay style for learning rate
start_wd (float): initial weight decay
end_wd (float): final weight decay
wd_incr_steps (int): number of weight decay increment steps
wd_incr_style (str): weight decay increment style
use_checkpoint_opt_param_scheduler (bool, optional): whether to use the checkpoint values
for the optimizer param scheduler
override_opt_param_scheduler (bool, optional): whether to override the optimizer param
scheduler values with the class values
wsd_decay_steps (int, optional): number of weight decay decay steps
lr_wsd_decay_style (str, optional): decay style for learning rate during weight decay decay
steps
"""
def
__init__
(
self
,
optimizer
:
MegatronOptimizer
,
init_lr
:
float
,
max_lr
:
float
,
min_lr
:
float
,
lr_warmup_steps
:
int
,
lr_decay_steps
:
int
,
lr_decay_style
:
str
,
start_wd
:
float
,
end_wd
:
float
,
wd_incr_steps
:
int
,
wd_incr_style
:
str
,
use_checkpoint_opt_param_scheduler
:
Optional
[
bool
]
=
True
,
override_opt_param_scheduler
:
Optional
[
bool
]
=
False
,
wsd_decay_steps
:
Optional
[
int
]
=
None
,
lr_wsd_decay_style
:
Optional
[
str
]
=
None
,
)
->
None
:
# Class values.
self
.
optimizer
=
optimizer
...
...
@@ -28,10 +68,14 @@ class OptimizerParamScheduler(object):
self
.
lr_warmup_steps
=
lr_warmup_steps
self
.
num_steps
=
0
self
.
lr_decay_steps
=
lr_decay_steps
self
.
wsd_decay_steps
=
wsd_decay_steps
self
.
lr_wsd_decay_style
=
lr_wsd_decay_style
assert
self
.
lr_decay_steps
>
0
assert
self
.
lr_warmup_steps
<
self
.
lr_decay_steps
self
.
lr_decay_style
=
lr_decay_style
if
self
.
lr_decay_style
==
"WSD"
:
assert
self
.
wsd_decay_steps
is
not
None
self
.
start_wd
=
start_wd
self
.
end_wd
=
end_wd
...
...
@@ -43,16 +87,16 @@ class OptimizerParamScheduler(object):
self
.
override_opt_param_scheduler
=
override_opt_param_scheduler
self
.
use_checkpoint_opt_param_scheduler
=
use_checkpoint_opt_param_scheduler
if
self
.
override_opt_param_scheduler
:
assert
not
self
.
use_checkpoint_opt_param_scheduler
,
'both override and '
\
'use-checkpoint are set.'
assert
not
self
.
use_checkpoint_opt_param_scheduler
,
(
'both override and '
'use-checkpoint are set.'
)
# Set the learning rate
self
.
step
(
0
)
print_rank_0
(
'> learning rate decay style: {}'
.
format
(
self
.
lr_decay_style
))
log_single_rank
(
logger
,
logging
.
INFO
,
f
"> learning rate decay style:
{
self
.
lr_decay_style
}
"
)
def
get_wd
(
self
):
"""
Weight decay incr functions"""
def
get_wd
(
self
)
->
float
:
"""Weight decay incr functions"""
if
self
.
num_steps
>
self
.
wd_incr_steps
:
return
self
.
end_wd
...
...
@@ -70,71 +114,86 @@ class OptimizerParamScheduler(object):
elif
self
.
wd_incr_style
==
'cosine'
:
coeff
=
0.5
*
(
math
.
cos
(
math
.
pi
*
(
1
-
incr_ratio
))
+
1.0
)
else
:
raise
Exception
(
'{} weight decay increment style is not supported.'
.
format
(
self
.
wd_incr_style
))
raise
Exception
(
f
'
{
self
.
wd_incr_style
}
weight decay increment style is not supported.'
)
return
self
.
start_wd
+
coeff
*
delta_wd
def
get_lr
(
self
):
def
get_lr
(
self
,
param_group
:
dict
)
->
float
:
"""Learning rate decay functions from:
https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""
https://openreview.net/pdf?id=BJYwwY9ll pg. 4
Args:
param_group (dict): parameter group from the optimizer.
"""
max_lr
=
param_group
.
get
(
'max_lr'
,
self
.
max_lr
)
min_lr
=
param_group
.
get
(
'min_lr'
,
self
.
min_lr
)
# Use linear warmup for the initial part.
if
self
.
lr_warmup_steps
>
0
and
self
.
num_steps
<=
self
.
lr_warmup_steps
:
return
(
self
.
init_lr
+
(
(
self
.
max_lr
-
self
.
init_lr
)
*
float
(
self
.
num_steps
)
/
float
(
self
.
lr_warmup_steps
)
)
return
self
.
init_lr
+
(
(
max_lr
-
self
.
init_lr
)
*
float
(
self
.
num_steps
)
/
float
(
self
.
lr_warmup_steps
)
)
# If the learning rate is constant, just return the initial value.
if
self
.
lr_decay_style
==
'constant'
:
return
self
.
max_lr
return
max_lr
# For any steps larger than `self.lr_decay_steps`, use `
self.
min_lr`.
# For any steps larger than `self.lr_decay_steps`, use `min_lr`.
if
self
.
num_steps
>
self
.
lr_decay_steps
:
return
self
.
min_lr
return
min_lr
# If we are done with the warmup period, use the decay style.
if
self
.
lr_decay_style
==
'inverse-square-root'
:
warmup_steps
=
max
(
self
.
lr_warmup_steps
,
1
)
num_steps
=
max
(
self
.
num_steps
,
1
)
lr
=
self
.
max_lr
*
warmup_steps
**
0.5
/
(
num_steps
**
0.5
)
return
max
(
self
.
min_lr
,
lr
)
lr
=
max_lr
*
warmup_steps
**
0.5
/
(
num_steps
**
0.5
)
return
max
(
min_lr
,
lr
)
num_steps_
=
self
.
num_steps
-
self
.
lr_warmup_steps
decay_steps_
=
self
.
lr_decay_steps
-
self
.
lr_warmup_steps
decay_ratio
=
float
(
num_steps_
)
/
float
(
decay_steps_
)
assert
decay_ratio
>=
0.0
assert
decay_ratio
<=
1.0
delta_lr
=
self
.
max_lr
-
self
.
min_lr
delta_lr
=
max_lr
-
min_lr
if
self
.
lr_decay_style
==
'linear'
:
coeff
=
(
1.0
-
decay_ratio
)
coeff
=
1.0
-
decay_ratio
elif
self
.
lr_decay_style
==
'cosine'
:
coeff
=
0.5
*
(
math
.
cos
(
math
.
pi
*
decay_ratio
)
+
1.0
)
elif
self
.
lr_decay_style
==
'WSD'
:
wsd_anneal_start_
=
self
.
lr_decay_steps
-
self
.
wsd_decay_steps
if
self
.
num_steps
<=
wsd_anneal_start_
:
coeff
=
1.0
else
:
wsd_steps
=
self
.
num_steps
-
wsd_anneal_start_
wsd_decay_ratio
=
float
(
wsd_steps
)
/
float
(
self
.
wsd_decay_steps
)
if
self
.
lr_wsd_decay_style
==
"linear"
:
coeff
=
1.0
-
wsd_decay_ratio
elif
self
.
lr_wsd_decay_style
==
"cosine"
:
coeff
=
0.5
*
(
math
.
cos
(
math
.
pi
*
wsd_decay_ratio
)
+
1.0
)
elif
self
.
lr_wsd_decay_style
==
"exponential"
:
coeff
=
(
2.0
*
math
.
pow
(
0.5
,
wsd_decay_ratio
))
-
1.0
else
:
raise
Exception
(
'{} decay style is not supported.'
.
format
(
self
.
lr_decay_style
))
raise
Exception
(
f
'
{
self
.
lr_decay_style
}
decay style is not supported.'
)
return
self
.
min_lr
+
coeff
*
delta_lr
return
min_lr
+
coeff
*
delta_lr
def
step
(
self
,
increment
:
int
)
->
None
:
"""Set lr for all parameters groups.
def
step
(
self
,
increment
):
"""Set lr for all parameters groups."""
Args:
increment (int): number of steps to increment
"""
self
.
num_steps
+=
increment
new_lr
=
self
.
get_lr
()
new_wd
=
self
.
get_wd
()
for
group
in
self
.
optimizer
.
param_groups
:
group
[
'lr'
]
=
new_lr
*
group
.
get
(
'lr_mult'
,
1.0
)
group
[
'weight_decay'
]
=
new_wd
*
group
.
get
(
'wd_mult'
,
1.0
)
for
param_group
in
self
.
optimizer
.
param_groups
:
new_lr
=
self
.
get_lr
(
param_group
)
param_group
[
'lr'
]
=
new_lr
*
param_group
.
get
(
'lr_mult'
,
1.0
)
param_group
[
'weight_decay'
]
=
new_wd
*
param_group
.
get
(
'wd_mult'
,
1.0
)
def
state
_
dict
(
self
):
def
state_dict
(
self
)
->
dict
:
"""Return the
state
dict
."""
state_dict
=
{
'max_lr'
:
self
.
max_lr
,
'lr_warmup_steps'
:
self
.
lr_warmup_steps
,
...
...
@@ -145,91 +204,94 @@ class OptimizerParamScheduler(object):
'start_wd'
:
self
.
start_wd
,
'end_wd'
:
self
.
end_wd
,
'wd_incr_style'
:
self
.
wd_incr_style
,
'wd_incr_steps'
:
self
.
wd_incr_steps
'wd_incr_steps'
:
self
.
wd_incr_steps
,
}
return
state_dict
def
_check_and_set
(
self
,
cls_value
,
sd_value
,
name
):
def
_check_and_set
(
self
,
cls_value
:
float
,
sd_value
:
float
,
name
:
str
)
->
float
:
"""Auxiliary function for checking the values in the checkpoint and
setting them."""
setting them.
Args:
cls_value (float): class value
sd_value (float): checkpoint value
name (str): name of the parameter
"""
if
self
.
override_opt_param_scheduler
:
print_rank_0
(
'
> overriding {} value to {
}'
.
format
(
name
,
cls_value
)
)
log_single_rank
(
logger
,
logging
.
INFO
,
f
"
> overriding
{
name
}
value to
{
cls_value
}
"
)
return
cls_value
if
not
self
.
use_checkpoint_opt_param_scheduler
:
assert
cls_value
==
sd_value
,
\
f
'OptimizerParamScheduler: class input value
{
cls_value
}
and checkpoint'
\
assert
cls_value
==
sd_value
,
(
f
'OptimizerParamScheduler: class input value
{
cls_value
}
and checkpoint'
f
'value
{
sd_value
}
for
{
name
}
do not match'
print_rank_0
(
' > using checkpoint value {} for {}'
.
format
(
sd_value
,
name
))
)
log_single_rank
(
logger
,
logging
.
INFO
,
f
" > using checkpoint value
{
sd_value
}
for
{
name
}
"
)
return
sd_value
def
load_state_dict
(
self
,
state_dict
:
dict
)
->
None
:
"""Load the state dict.
def
load_state_dict
(
self
,
sd
):
Args:
state_dict (dict): state dict to be load
"""
if
'start_lr'
in
s
d
:
max_lr_
=
s
d
[
'start_lr'
]
if
'start_lr'
in
s
tate_dict
:
max_lr_
=
s
tate_dict
[
'start_lr'
]
else
:
max_lr_
=
s
d
[
'max_lr'
]
self
.
max_lr
=
self
.
_check_and_set
(
self
.
max_lr
,
max_lr_
,
'learning rate'
)
self
.
min_lr
=
self
.
_check_and_set
(
self
.
min_lr
,
sd
[
'min_lr'
],
'minimum learning rate'
)
if
'warmup_iter'
in
s
d
:
lr_warmup_steps_
=
s
d
[
'warmup_iter'
]
elif
'warmup_steps'
in
s
d
:
lr_warmup_steps_
=
s
d
[
'warmup_steps'
]
max_lr_
=
s
tate_dict
[
'max_lr'
]
self
.
max_lr
=
self
.
_check_and_set
(
self
.
max_lr
,
max_lr_
,
'learning rate'
)
self
.
min_lr
=
self
.
_check_and_set
(
self
.
min_lr
,
state_dict
[
'min_lr'
],
'minimum learning rate'
)
if
'warmup_iter'
in
s
tate_dict
:
lr_warmup_steps_
=
s
tate_dict
[
'warmup_iter'
]
elif
'warmup_steps'
in
s
tate_dict
:
lr_warmup_steps_
=
s
tate_dict
[
'warmup_steps'
]
else
:
lr_warmup_steps_
=
s
d
[
'lr_warmup_steps'
]
self
.
lr_warmup_steps
=
self
.
_check_and_set
(
self
.
lr_warmup_steps
,
lr_warmup_steps_
,
'warmup iterations'
)
if
'end_iter'
in
s
d
:
lr_decay_steps_
=
s
d
[
'end_iter'
]
elif
'decay_steps'
in
s
d
:
lr_decay_steps_
=
s
d
[
'decay_steps'
]
lr_warmup_steps_
=
s
tate_dict
[
'lr_warmup_steps'
]
self
.
lr_warmup_steps
=
self
.
_check_and_set
(
self
.
lr_warmup_steps
,
lr_warmup_steps_
,
'warmup iterations'
)
if
'end_iter'
in
s
tate_dict
:
lr_decay_steps_
=
s
tate_dict
[
'end_iter'
]
elif
'decay_steps'
in
s
tate_dict
:
lr_decay_steps_
=
s
tate_dict
[
'decay_steps'
]
else
:
lr_decay_steps_
=
sd
[
'lr_decay_steps'
]
self
.
lr_decay_steps
=
self
.
_check_and_set
(
self
.
lr_decay_steps
,
lr_decay_steps_
,
'total number of iterations'
)
lr_decay_steps_
=
state_dict
[
'lr_decay_steps'
]
self
.
lr_decay_steps
=
self
.
_check_and_set
(
self
.
lr_decay_steps
,
lr_decay_steps_
,
'total number of iterations'
)
if
'decay_style'
in
s
d
:
lr_decay_style_
=
s
d
[
'decay_style'
]
if
'decay_style'
in
s
tate_dict
:
lr_decay_style_
=
s
tate_dict
[
'decay_style'
]
else
:
lr_decay_style_
=
s
d
[
'lr_decay_style'
]
self
.
lr_decay_style
=
self
.
_check_and_set
(
self
.
lr_decay_style
,
lr_decay_style_
,
'learning rate decay style'
)
lr_decay_style_
=
s
tate_dict
[
'lr_decay_style'
]
self
.
lr_decay_style
=
self
.
_check_and_set
(
self
.
lr_decay_style
,
lr_decay_style_
,
'learning rate decay style'
)
if
'num_iters'
in
s
d
:
num_steps
=
s
d
[
'num_iters'
]
if
'num_iters'
in
s
tate_dict
:
num_steps
=
s
tate_dict
[
'num_iters'
]
else
:
num_steps
=
s
d
[
'num_steps'
]
num_steps
=
s
tate_dict
[
'num_steps'
]
self
.
step
(
increment
=
num_steps
)
if
'start_wd'
in
sd
:
self
.
start_wd
=
self
.
_check_and_set
(
self
.
start_wd
,
sd
[
'start_wd'
],
"start weight decay"
)
self
.
end_wd
=
self
.
_check_and_set
(
self
.
end_wd
,
sd
[
'end_wd'
],
"end weight decay"
)
self
.
wd_incr_steps
=
self
.
_check_and_set
(
self
.
wd_incr_steps
,
sd
[
'wd_incr_steps'
],
"total number of weight decay iterations"
)
self
.
wd_incr_style
=
self
.
_check_and_set
(
self
.
wd_incr_style
,
sd
[
'wd_incr_style'
],
"weight decay incr style"
)
if
'start_wd'
in
state_dict
:
self
.
start_wd
=
self
.
_check_and_set
(
self
.
start_wd
,
state_dict
[
'start_wd'
],
"start weight decay"
)
self
.
end_wd
=
self
.
_check_and_set
(
self
.
end_wd
,
state_dict
[
'end_wd'
],
"end weight decay"
)
self
.
wd_incr_steps
=
self
.
_check_and_set
(
self
.
wd_incr_steps
,
state_dict
[
'wd_incr_steps'
],
"total number of weight decay iterations"
,
)
self
.
wd_incr_style
=
self
.
_check_and_set
(
self
.
wd_incr_style
,
state_dict
[
'wd_incr_style'
],
"weight decay incr style"
)
megatron/core/package_info.py
100644 → 100755
View file @
3bec6514
...
...
@@ -2,9 +2,9 @@
MAJOR
=
0
MINOR
=
3
MINOR
=
10
PATCH
=
0
PRE_RELEASE
=
''
PRE_RELEASE
=
'
rc0
'
# Use the following formatting: (major, minor, patch, pre-release)
VERSION
=
(
MAJOR
,
MINOR
,
PATCH
,
PRE_RELEASE
)
...
...
megatron/core/packed_seq_params.py
0 → 100755
View file @
3bec6514
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from
dataclasses
import
dataclass
from
torch
import
Tensor
@
dataclass
class
PackedSeqParams
:
'''
parameters to TEDotProductAttention and fused rope kernels for the
`thd` (packed) sequence format
'''
qkv_format
:
str
=
None
cu_seqlens_q
:
Tensor
=
None
cu_seqlens_kv
:
Tensor
=
None
cu_seqlens_q_padded
:
Tensor
=
None
cu_seqlens_kv_padded
:
Tensor
=
None
max_seqlen_q
:
Tensor
=
None
max_seqlen_kv
:
Tensor
=
None
Prev
1
…
15
16
17
18
19
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment