Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wuxk1
Megatron-LM
Commits
4b097dee
Commit
4b097dee
authored
Oct 29, 2024
by
liangjing
Browse files
update to core_v0.9
parent
3aca1415
Changes
341
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3730 additions
and
86 deletions
+3730
-86
megatron/core/dist_checkpointing/strategies/zarr.py
megatron/core/dist_checkpointing/strategies/zarr.py
+131
-40
megatron/core/dist_checkpointing/utils.py
megatron/core/dist_checkpointing/utils.py
+180
-5
megatron/core/dist_checkpointing/validation.py
megatron/core/dist_checkpointing/validation.py
+525
-0
megatron/core/distributed/__init__.py
megatron/core/distributed/__init__.py
+10
-0
megatron/core/distributed/distributed_data_parallel.py
megatron/core/distributed/distributed_data_parallel.py
+485
-0
megatron/core/distributed/distributed_data_parallel_config.py
...tron/core/distributed/distributed_data_parallel_config.py
+44
-0
megatron/core/distributed/finalize_model_grads.py
megatron/core/distributed/finalize_model_grads.py
+162
-0
megatron/core/distributed/param_and_grad_buffer.py
megatron/core/distributed/param_and_grad_buffer.py
+769
-0
megatron/core/extensions/__init__.py
megatron/core/extensions/__init__.py
+0
-0
megatron/core/extensions/transformer_engine.py
megatron/core/extensions/transformer_engine.py
+969
-0
megatron/core/fusions/fused_bias_dropout.py
megatron/core/fusions/fused_bias_dropout.py
+32
-19
megatron/core/fusions/fused_bias_geglu.py
megatron/core/fusions/fused_bias_geglu.py
+85
-0
megatron/core/fusions/fused_bias_gelu.py
megatron/core/fusions/fused_bias_gelu.py
+10
-3
megatron/core/fusions/fused_bias_swiglu.py
megatron/core/fusions/fused_bias_swiglu.py
+89
-0
megatron/core/fusions/fused_cross_entropy.py
megatron/core/fusions/fused_cross_entropy.py
+143
-0
megatron/core/fusions/fused_layer_norm.py
megatron/core/fusions/fused_layer_norm.py
+67
-17
megatron/core/fusions/fused_softmax.py
megatron/core/fusions/fused_softmax.py
+18
-2
megatron/core/inference/__init__.py
megatron/core/inference/__init__.py
+1
-0
megatron/core/inference/ammo_support/__init__.py
megatron/core/inference/ammo_support/__init__.py
+8
-0
megatron/core/inference/ammo_support/gpt/model_specs.py
megatron/core/inference/ammo_support/gpt/model_specs.py
+2
-0
No files found.
Too many changes to show.
To preserve performance only
341 of 341+
files are displayed.
Plain diff
Email patch
megatron/core/dist_checkpointing/strategies/zarr.py
View file @
4b097dee
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Strategies using Zarr as an underlying format. """
import
logging
import
os
from
functools
import
partial
from
logging
import
getLogger
from
pathlib
import
Path
from
typing
import
List
from
typing
import
Callable
,
List
,
Optional
,
Tuple
import
numpy
as
np
import
torch
import
zarr
from
..core
import
CheckpointingException
from
..dict_utils
import
dict_list_map_inplace
from
..dict_utils
import
dict_list_map_inplace
,
nested_values
from
..mapping
import
ShardedStateDict
,
ShardedTensor
,
is_main_replica
from
.base
import
LoadShardedStrategy
,
SaveShardedStrategy
,
StrategyAction
,
default_strategies
from
.base
import
(
LoadShardedStrategy
,
SaveShardedStrategy
,
StrategyAction
,
register_default_strategy
,
)
logger
=
logging
.
getLogger
(
__name__
)
numpy_to_torch_dtype_dict
=
{
np
.
bool
_
:
torch
.
bool
,
np
.
uint8
:
torch
.
uint8
,
np
.
int8
:
torch
.
int8
,
np
.
int16
:
torch
.
int16
,
np
.
int32
:
torch
.
int32
,
np
.
int64
:
torch
.
int64
,
np
.
float16
:
torch
.
float16
,
np
.
float32
:
torch
.
float32
,
np
.
float64
:
torch
.
float64
,
np
.
complex64
:
torch
.
complex64
,
np
.
complex128
:
torch
.
complex128
,
np
.
dtype
(
'
bool
'
)
:
torch
.
bool
,
np
.
dtype
(
'
uint8
'
)
:
torch
.
uint8
,
np
.
dtype
(
'
int8
'
)
:
torch
.
int8
,
np
.
dtype
(
'
int16
'
)
:
torch
.
int16
,
np
.
dtype
(
'
int32
'
)
:
torch
.
int32
,
np
.
dtype
(
'
int64
'
)
:
torch
.
int64
,
np
.
dtype
(
'
float16
'
)
:
torch
.
float16
,
np
.
dtype
(
'
float32
'
)
:
torch
.
float32
,
np
.
dtype
(
'
float64
'
)
:
torch
.
float64
,
np
.
dtype
(
'
complex64
'
)
:
torch
.
complex64
,
np
.
dtype
(
'
complex128
'
)
:
torch
.
complex128
,
}
torch_to_numpy_dtype_dict
=
{
v
:
k
for
k
,
v
in
numpy_to_torch_dtype_dict
.
items
()}
try
:
import
tensorstore
# Register a bfloat16 type with this import
import
tensorstore
# pylint: disable=unused-import
HAS_BFLOAT16
=
True
numpy_to_torch_dtype_dict
[
np
.
dtype
(
'bfloat16'
)]
=
torch
.
bfloat16
...
...
@@ -41,11 +51,28 @@ try:
except
ImportError
:
HAS_BFLOAT16
=
False
_import_trigger
=
None
logger
=
getLogger
(
__name__
)
def
register_default_zarr_strategies
():
"""Register default strategies related to Zarr backend."""
register_default_strategy
(
StrategyAction
.
SAVE_SHARDED
,
'zarr'
,
1
,
ZarrSaveShardedStrategy
(
'zarr'
,
1
)
)
class
ZarrSaveShardedStrategy
(
SaveShardedStrategy
):
def
save
(
self
,
sharded_tensors
:
List
[
ShardedTensor
],
checkpoint_dir
:
Path
):
"""Save strategy for Zarr backend."""
def
__init__
(
self
,
backend
:
str
,
version
:
int
):
super
().
__init__
(
backend
,
version
)
logger
.
warning
(
f
'`zarr` distributed checkpoint backend is deprecated.'
' Please switch to PyTorch Distributed format (`torch_dist`).'
)
def
save
(
self
,
sharded_state_dict
:
ShardedStateDict
,
checkpoint_dir
:
Path
):
sharded_tensors
=
list
(
nested_values
(
sharded_state_dict
))
arrays
=
_create_or_open_zarr_arrays
(
sharded_tensors
,
checkpoint_dir
)
for
ten
,
arr
in
zip
(
sharded_tensors
,
arrays
):
_save_to_existing_array
(
ten
,
arr
)
...
...
@@ -54,24 +81,41 @@ class ZarrSaveShardedStrategy(SaveShardedStrategy):
def
_create_or_open_zarr_arrays
(
sharded_tensors
:
List
[
ShardedTensor
],
checkpoint_dir
:
Path
)
->
List
[
zarr
.
Array
]:
)
->
List
[
Optional
[
zarr
.
Array
]]:
"""Returns list of zarr arrays corresponding to given tensors.
For a sharded tensors that:
a) is main replica and represents the first chunk (all offsets 0), creates the Zarr array
b) is main replica but not the first chunk,
opens the arrays created in (a) (possibly by other process)
c) otherwise, sets the corresponding array to None since it won't be used
Args:
sharded_tensors (List[ShardedTensor]): sharded tensors from a given rank
that will be saved to checkpoint
checkpoint_dir (Path): checkpoint in which the arrays will be created
"""
arrays
=
[]
for
ten
in
sharded_tensors
:
if
_should_create_array
(
ten
):
_create_zarr_array
(
ten
,
checkpoint_dir
)
# TODO: maybe reuse the opened arrays
arr
=
_create_zarr_array
(
ten
,
checkpoint_dir
)
if
_should_create_array
(
ten
)
else
None
arrays
.
append
(
arr
)
torch
.
distributed
.
barrier
()
for
ten
in
sharded_tensors
:
# if is_main_replica(ten.replica_id) and set(ten.global_offset) == {0}:
# continue
# Open arrays created above by other processes
for
arr_idx
,
ten
in
enumerate
(
sharded_tensors
):
if
arrays
[
arr_idx
]
is
not
None
:
# array created by this process
assert
_should_create_array
(
ten
),
ten
continue
if
not
is_main_replica
(
ten
.
replica_id
):
# this array won't be needed for saving and can stay None
continue
open_kwargs
=
{}
if
ten
.
flattened_range
is
not
None
:
open_kwargs
[
'synchronizer'
]
=
zarr
.
ProcessSynchronizer
(
str
(
checkpoint_dir
/
f
'
{
ten
.
key
}
.sync'
)
)
arr
=
zarr
.
open
(
checkpoint_dir
/
ten
.
key
,
'r+'
,
**
open_kwargs
)
arrays
.
append
(
arr
)
arrays
[
arr_idx
]
=
_open_zarr_array_verbose
(
checkpoint_dir
/
ten
.
key
,
'r+'
,
**
open_kwargs
)
return
arrays
...
...
@@ -83,9 +127,10 @@ def _should_create_array(ten: ShardedTensor):
)
def
_save_to_existing_array
(
sharded_tensor
:
ShardedTensor
,
arr
:
zarr
.
Array
):
def
_save_to_existing_array
(
sharded_tensor
:
ShardedTensor
,
arr
:
Optional
[
zarr
.
Array
]
):
if
not
is_main_replica
(
sharded_tensor
.
replica_id
):
return
assert
arr
is
not
None
x
=
sharded_tensor
.
data
x
=
x
.
detach
().
cpu
()
torch
.
cuda
.
synchronize
()
...
...
@@ -114,6 +159,7 @@ def _create_zarr_array(sharded_tensor: ShardedTensor, checkpoint_dir: Path):
fill_value
=
None
,
write_empty_chunks
=
True
,
)
logger
.
debug
(
f
'Created a new Zarr array at
{
checkpoint_dir
/
sharded_tensor
.
key
}
'
)
except
zarr
.
errors
.
ContainsArrayError
as
e
:
raise
CheckpointingException
(
f
'Array
{
checkpoint_dir
/
sharded_tensor
.
key
}
already exists'
...
...
@@ -127,12 +173,21 @@ def _create_zarr_array(sharded_tensor: ShardedTensor, checkpoint_dir: Path):
class
ZarrLoadShardedStrategy
(
LoadShardedStrategy
):
"""Load strategy for the Zarr backend."""
def
load
(
self
,
sharded_state_dict
:
ShardedStateDict
,
checkpoint_dir
:
Path
):
dict_list_map_inplace
(
partial
(
_load_from_array
,
checkpoint_dir
=
checkpoint_dir
),
sharded_state_dict
)
return
sharded_state_dict
def
load_tensors_metadata
(
self
,
checkpoint_dir
:
Path
):
def
get_zarr_shape_dtype
(
path
):
arr
=
zarr
.
open
(
path
,
'r'
)
return
arr
.
shape
,
arr
.
dtype
return
load_zarr_based_sharded_metadata
(
checkpoint_dir
,
get_zarr_shape_dtype
)
def
check_backend_compatibility
(
self
,
loaded_version
):
pass
# TODO
...
...
@@ -142,12 +197,7 @@ class ZarrLoadShardedStrategy(LoadShardedStrategy):
def
_load_from_array
(
sharded_tensor
:
ShardedTensor
,
checkpoint_dir
:
Path
):
assert
isinstance
(
sharded_tensor
,
ShardedTensor
),
type
(
sharded_tensor
)
try
:
arr
=
zarr
.
open
(
checkpoint_dir
/
sharded_tensor
.
key
,
'r'
)
except
zarr
.
errors
.
PathNotFoundError
as
e
:
raise
CheckpointingException
(
f
'Array
{
checkpoint_dir
/
sharded_tensor
.
key
}
not found'
)
from
e
arr
=
_open_zarr_array_verbose
(
checkpoint_dir
/
sharded_tensor
.
key
,
'r'
)
if
not
sharded_tensor
.
allow_shape_mismatch
and
sharded_tensor
.
global_shape
!=
arr
.
shape
:
_msg
=
(
...
...
@@ -161,7 +211,22 @@ def _load_from_array(sharded_tensor: ShardedTensor, checkpoint_dir: Path):
return
postprocess_numpy_array
(
x
,
sharded_tensor
)
def
_open_zarr_array_verbose
(
path
:
Path
,
mode
:
str
,
**
open_kwargs
):
try
:
return
zarr
.
open
(
str
(
path
),
mode
,
**
open_kwargs
)
except
zarr
.
errors
.
PathNotFoundError
as
e
:
ckpt_dir
=
path
.
parent
err_msg
=
f
'Array
{
path
}
not found'
if
ckpt_dir
.
exists
():
ckpt_files
=
[
f
.
name
for
f
in
ckpt_dir
.
iterdir
()]
logger
.
debug
(
f
'
{
err_msg
}
. Checkpoint directory
{
ckpt_dir
}
content:
{
ckpt_files
}
'
)
else
:
err_msg
+=
f
'. Checkpoint directory
{
ckpt_dir
}
does not exist.'
raise
CheckpointingException
(
err_msg
)
from
e
def
postprocess_numpy_array
(
loaded_array
,
sharded_tensor
,
apply_flattened_range
=
True
):
"""Turn numpy array to torch tensor."""
x
=
loaded_array
if
HAS_BFLOAT16
and
x
.
dtype
==
np
.
dtype
(
'bfloat16'
):
x
=
x
.
astype
(
np
.
dtype
(
'float32'
))
...
...
@@ -189,10 +254,12 @@ def postprocess_numpy_array(loaded_array, sharded_tensor, apply_flattened_range=
def
flatten_range
(
sharded_tensor
,
x
):
"""Apply flattened range to a tensor."""
return
x
.
flatten
()[
sharded_tensor
.
flattened_range
]
def
pad_to_expected_shape
(
x
:
torch
.
Tensor
,
expected_sharded_ten
:
ShardedTensor
):
"""Pad tensor to the expected shape."""
pad_args
=
[]
assert
len
(
x
.
shape
)
==
len
(
expected_sharded_ten
.
local_shape
)
# Reversed iteration order because F.pad expects so
...
...
@@ -204,9 +271,10 @@ def pad_to_expected_shape(x: torch.Tensor, expected_sharded_ten: ShardedTensor):
if
x_sh
==
exp_sh
:
pad_args
.
extend
((
0
,
0
))
elif
x_sh
>
exp_sh
:
assert
(
False
),
f
'Expected shape (
{
exp_sh
}
) smaller than actual (
{
x_sh
}
) for
{
repr
(
expected_sharded_ten
)
}
'
assert
False
,
(
f
'Expected shape (
{
exp_sh
}
) smaller than actual (
{
x_sh
}
)'
f
' for
{
repr
(
expected_sharded_ten
)
}
'
)
else
:
pad_args
.
extend
((
0
,
exp_sh
-
x_sh
))
# TODO: behavior control with envvar is for testing purposes only, remove it
...
...
@@ -224,7 +292,30 @@ def pad_to_expected_shape(x: torch.Tensor, expected_sharded_ten: ShardedTensor):
return
torch
.
nn
.
functional
.
pad
(
x
.
unsqueeze
(
0
),
pad_args
,
mode
=
'replicate'
).
squeeze
(
0
)
# default_strategies[StrategyAction.LOAD_SHARDED.value][('zarr', 1)] = ZarrLoadShardedStrategy()
default_strategies
[
StrategyAction
.
SAVE_SHARDED
.
value
][(
'zarr'
,
1
)]
=
ZarrSaveShardedStrategy
(
'zarr'
,
1
)
def
load_zarr_based_sharded_metadata
(
checkpoint_dir
:
Path
,
get_shape_dtype_fn
:
Callable
[[
str
],
Tuple
[
Tuple
[
int
],
np
.
dtype
]]
)
->
ShardedStateDict
:
"""Load metadata of Zarr arrays.
Args:
checkpoint_dir (str): checkpoint root directory
get_shape_dtype_fn (str -> ((int, ...), np.dtype)): a function returning
an array shape and dtype for a given Zarr array path
"""
sharded_state_dict
=
{}
for
subdir
in
checkpoint_dir
.
iterdir
():
if
not
subdir
.
is_dir
()
or
not
(
subdir
/
'.zarray'
).
exists
():
continue
key
=
subdir
.
name
arr_shape
,
arr_dtype
=
get_shape_dtype_fn
(
str
(
subdir
))
sharded_state_dict
[
key
]
=
ShardedTensor
(
key
,
None
,
numpy_to_torch_dtype_dict
[
arr_dtype
],
arr_shape
,
arr_shape
,
tuple
(
0
for
_
in
arr_shape
),
tuple
(
1
for
_
in
arr_shape
),
)
return
sharded_state_dict
megatron/core/dist_checkpointing/utils.py
View file @
4b097dee
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
from
typing
import
Tuple
""" Helpers for manipulating sharded tensors and sharded state dicts. """
from
typing
import
Dict
,
Optional
,
Tuple
from
.dict_utils
import
dict_list_map_inplace
,
extract_matching_values
from
.mapping
import
(
LocalNonpersitentObject
,
LocalNonpersistentObject
,
ShardedBase
,
ShardedObject
,
ShardedStateDict
,
ShardedTensor
,
ShardedTensorFactory
,
StateDict
,
)
# _ShardId uniquely identifies a ShardedTensor. This is a subset of ShardedTensor
# attributes: key (str), global_offset (tuple) and flattened_range (optional tuple)
_ShardId
=
Tuple
[
str
,
tuple
,
Optional
[
tuple
]]
def
_sharded_tensor_shard_id
(
sharded_tensor
:
ShardedTensor
)
->
_ShardId
:
"""Unique id of the sharded tensor data.
Should yield the same value for same data replicated on different ranks.
Args:
sharded_tensor (ShardedTensor): sharded tensor representing the data shard
Returns (tuple): unique id of a data shard
"""
f_range
=
sharded_tensor
.
flattened_range
return
(
sharded_tensor
.
key
,
sharded_tensor
.
global_offset
,
None
if
f_range
is
None
else
(
f_range
.
start
,
f_range
.
stop
),
)
def
_sharded_object_id
(
sharded_object
:
ShardedObject
)
->
_ShardId
:
"""Unique id of the sharded object data.
Should yield the same value for same data replicated on different ranks.
Args:
sharded_object (ShardedObject): sharded object representing the data shard
Returns (tuple): unique id of a data shard
"""
return
(
sharded_object
.
key
,
sharded_object
.
global_offset
,
sharded_object
.
global_shape
)
def
extract_sharded_tensors
(
sharded_state_dict
:
ShardedStateDict
,
)
->
Tuple
[
ShardedStateDict
,
StateDict
]:
"""Extract a dict consisting of only ShardedTensor objects
from a given state dict with any objects.
Args:
sharded_state_dict: state dict possibly containing ShardedTensor objects
Returns:
Tuple[ShardedStateDict, StateDict]: tuple of:
- state dict with all ShardedTensor (keeping the original state dict structure)
- state dict with all objects other than ShardedTensor
(keeping the original state dict structure)
"""
return
extract_matching_values
(
sharded_state_dict
,
lambda
v
:
isinstance
(
v
,
ShardedTensor
))
def
extract_sharded_tensors_and_factories
(
sharded_state_dict
:
ShardedStateDict
,
)
->
Tuple
[
ShardedStateDict
,
StateDict
]:
"""Extract a dict consisting of only ShardedTensor and ShardedTensorFactory objects
from a given state dict with any objects.
Args:
sharded_state_dict:
state dict possibly containing ShardedTensor and ShardedTensorFactory objects
Returns:
Tuple[ShardedStateDict, StateDict]: tuple of:
- state dict with all ShardedTensor and ShardedTensorFactory
(keeping the original state dict structure)
- state dict with all other objects (keeping the original state dict structure)
"""
return
extract_matching_values
(
sharded_state_dict
,
lambda
v
:
isinstance
(
v
,
(
ShardedTensor
,
ShardedTensorFactory
))
)
...
...
@@ -29,16 +93,127 @@ def extract_sharded_tensors_and_factories(
def
extract_sharded_tensors_or_nonpersistent
(
sharded_state_dict
:
ShardedStateDict
,
)
->
Tuple
[
ShardedStateDict
,
StateDict
]:
"""Extract a dict consisting of only ShardedTensor, ShardedTensorFactory
and LocalNonpersistentObject objects from a given state dict with any objects.
Args:
sharded_state_dict: state dict possibly containing ShardedTensor, ShardedTensorFactory
and LocalNonpersistentObject objects
Returns:
Tuple[ShardedStateDict, StateDict]: tuple of:
- state dict with all ShardedTensor, ShardedTensorFactory and LocalNonpersistentObject
(keeping the original state dict structure)
- state dict with all other objects (keeping the original state dict structure)
"""
return
extract_matching_values
(
sharded_state_dict
,
lambda
v
:
isinstance
(
v
,
(
ShardedTensor
,
LocalNonpersitentObject
,
ShardedTensorFactory
)),
lambda
v
:
isinstance
(
v
,
(
ShardedTensor
,
LocalNonpersistentObject
,
ShardedTensorFactory
)),
)
def
extract_sharded_base
(
sharded_state_dict
:
ShardedStateDict
,
)
->
Tuple
[
ShardedStateDict
,
StateDict
]:
"""Extract a dict consisting of only ShardedBase from a given state dict with any objects.
Args:
sharded_state_dict: state dict possibly containing ShardedBase objects
Returns:
Tuple[ShardedStateDict, StateDict]: tuple of:
- state dict with all ShardedBase objects (keeping the original state dict structure)
- state dict with all other objects (keeping the original state dict structure)
"""
return
extract_matching_values
(
sharded_state_dict
,
lambda
v
:
isinstance
(
v
,
ShardedBase
))
def
extract_nonpersistent
(
sharded_state_dict
:
ShardedStateDict
,
)
->
Tuple
[
ShardedStateDict
,
StateDict
]:
"""Extract a dict consisting of only LocalNonpersistentObjects from a given state dict.
Args:
sharded_state_dict: state dict possibly containing LocalNonpersistentObjects
Returns:
Tuple[ShardedStateDict, StateDict]: tuple of:
- state dict with all LocalNonpersistentObjects
(keeping the original state dict structure)
- state dict with all other objects (keeping the original state dict structure)
"""
return
extract_matching_values
(
sharded_state_dict
,
lambda
v
:
isinstance
(
v
,
LocalNonpersistentObject
)
)
def
add_prefix_for_sharding
(
sharded_state_dict
:
ShardedStateDict
,
prefix
:
str
):
"""Prepend a given prefix to all ShardedBase objects in a given state dict *in-place*.
Args:
sharded_state_dict (ShardedStateDict): sharded state dict
prefix (str): prefix to be prepended
Returns:
None: state dict is modified in-place
"""
def
add_prefix
(
t
):
if
isinstance
(
t
,
Sharded
Tensor
):
t
.
key
=
f
'
{
prefix
}
.
{
t
.
key
}
'
if
isinstance
(
t
,
Sharded
Base
):
t
.
key
=
f
'
{
prefix
}{
t
.
key
}
'
return
t
dict_list_map_inplace
(
add_prefix
,
sharded_state_dict
)
def
replace_prefix_for_sharding
(
sharded_state_dict
:
ShardedStateDict
,
old_prefix
:
str
,
new_prefix
:
str
):
"""Replaces the given prefix in *all* sharded keys in a given state dict.
Errors out if some key does not begin with a given prefix.
Args:
sharded_state_dict (ShardedStateDict): sharded state dict to replace keys in
old_prefix (str): prefix to be replaced in each key
new_prefix (str): new prefix
Returns:
None: state dict is modified in place
"""
def
_replace_prefix
(
x
):
if
isinstance
(
x
,
(
ShardedTensor
,
ShardedTensorFactory
,
ShardedObject
)):
if
not
x
.
key
.
startswith
(
old_prefix
):
raise
ValueError
(
f
'Expected
{
x
.
key
}
to begin with prefix
{
old_prefix
}
'
)
x
.
key
=
f
'
{
new_prefix
}{
x
.
key
[
len
(
old_prefix
):]
}
'
# str.removeprefix in Python >= 3.9
return
x
dict_list_map_inplace
(
_replace_prefix
,
sharded_state_dict
)
def
apply_prefix_mapping
(
sharded_state_dict
:
ShardedStateDict
,
prefix_map
:
Dict
[
str
,
str
]):
"""Replaces prefixes *only in keys matching* with one of prefixes in the map.
Args:
sharded_state_dict (ShardedStateDict): sharded state dict to replace keys in
prefix_map (Dict[str, str]):
map of old->new prefixes. The first matching prefix for each key is used
Returns:
None: state dict is modified in place
"""
def
_replace_prefixes
(
x
):
if
not
isinstance
(
x
,
(
ShardedTensor
,
ShardedTensorFactory
,
ShardedObject
)):
return
x
for
old_prefix
,
new_prefix
in
prefix_map
.
items
():
if
x
.
key
.
startswith
(
old_prefix
):
x
.
key
=
(
f
'
{
new_prefix
}{
x
.
key
[
len
(
old_prefix
):]
}
'
# str.removeprefix in Python >= 3.9
)
break
return
x
dict_list_map_inplace
(
_replace_prefixes
,
sharded_state_dict
)
megatron/core/dist_checkpointing/validation.py
0 → 100644
View file @
4b097dee
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import
logging
from
collections
import
Counter
,
defaultdict
from
enum
import
Enum
from
pathlib
import
Path
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Set
,
Tuple
,
Union
import
numpy
as
np
import
torch
from
megatron.core.dist_checkpointing
import
ShardedTensor
from
megatron.core.dist_checkpointing.core
import
CheckpointingException
,
maybe_load_config
from
megatron.core.dist_checkpointing.dict_utils
import
(
extract_matching_values
,
map_reduce
,
nested_values
,
)
from
megatron.core.dist_checkpointing.mapping
import
(
ShardedBase
,
ShardedObject
,
ShardedStateDict
,
is_main_replica
,
)
from
megatron.core.dist_checkpointing.strategies.base
import
(
LoadCommonStrategy
,
LoadShardedStrategy
,
SaveCommonStrategy
,
SaveShardedStrategy
,
StrategyAction
,
get_default_strategy
,
)
if
TYPE_CHECKING
:
from
megatron.core.dist_checkpointing.serialization
import
CkptShardedMetadata
logger
=
logging
.
getLogger
(
__name__
)
# list of local saved/loaded ShardedBase objects
_LocalMetadata
=
List
[
Union
[
ShardedTensor
,
ShardedObject
]]
# list of lists of global saved/loaded ShardedBase objects (each list element corresponds to global rank)
_GlobalMetadata
=
List
[
_LocalMetadata
]
class
StrictHandling
(
Enum
):
"""Determines handling of load mismatch (non-empty "unexpected" or "missing" keys).
Different flags carry different implications on performance and behaviour and
are divided into two groups:
- *_UNEXPECTED
- *_ALL
The first group ignores missing keys (present in the checkpoint but missing
in the sharded state dict) which is created in order to avoid inter-rank
metadata exchange. Note that the metadata exchange will happen anyway
with `load(..., validate_access_integrity=True)` flag in which case using the
`*_ALL` option is recommended as it provides a more thorough check with no
performance penalty wrt. `*_UNEXPECTED` group.
All options except for the first one (`ASSUME_OK_UNEXPECTED`) require
extra disk access before the load in order to remove unexpected keys
from the sharded state dict requested to load.
"""
# Relies on the underlying strategy to raise error on unexpected keys
ASSUME_OK_UNEXPECTED
=
'assume_ok_unexpected'
# Logs (with WARNING level) "unexpected" keys. Missing keys are ignored.
# This is treated as a reasonable default for a "non-strict" load
LOG_UNEXPECTED
=
'log_unexpected'
# Logs (with WARNING level) all mismatched keys.
LOG_ALL
=
'log_all'
# Raise error on unexpected keys before load attempt.
# Gives cleaner error message than `ASSUME_OK_UNEXPECTED` but requires
# extra disk access.
RAISE_UNEXPECTED
=
'raise_unexpected'
# Raise error on any mismatch. Similar to `RAISE_UNEXPECTED` but requires
# metadata exchange.
RAISE_ALL
=
'raise_all'
# "Unexpected" mismatches are not reported, but returned by the `load`
# function along with the loaded state dict. Missing keys are ignored.
RETURN_UNEXPECTED
=
'return_unexpected'
# All mismatches are returned along with the loaded state dict.
RETURN_ALL
=
'return_all'
# Simply ignores mismatches (not recommended)
IGNORE_ALL
=
'ignore_all'
@
staticmethod
def
requires_explicit_ckpt_mismatch_check
(
val
:
'StrictHandling'
)
->
bool
:
"""Whether a given strict flag involves mismatch check against the checkpoint."""
return
val
!=
StrictHandling
.
ASSUME_OK_UNEXPECTED
@
staticmethod
def
requires_global_app_metadata
(
val
:
'StrictHandling'
)
->
bool
:
"""Whether a given strict option requires global metadata for validation."""
return
val
in
(
StrictHandling
.
IGNORE_ALL
,
StrictHandling
.
RAISE_ALL
,
StrictHandling
.
RETURN_ALL
,
StrictHandling
.
LOG_ALL
,
)
@
staticmethod
def
requires_returning_mismatch_keys
(
val
:
'StrictHandling'
)
->
bool
:
"""Whether a given strict option results in extra return value from the `load` function."""
return
val
in
(
StrictHandling
.
RETURN_UNEXPECTED
,
StrictHandling
.
RETURN_ALL
)
def
parse_strict_flag
(
strict
:
Union
[
str
,
StrictHandling
])
->
StrictHandling
:
"""Parse user passed strict flag from a string to StrictHandling instance.
Args:
strict (str, StrictHandling): strict flag to parse. If already an instance
of StrictHandling, this function is a noop.
Returns:
StrictHandling: enum instance
"""
if
isinstance
(
strict
,
StrictHandling
):
return
strict
try
:
return
StrictHandling
(
strict
)
except
(
ValueError
,
TypeError
)
as
e
:
raise
ValueError
(
f
'Invalid strict flag:
{
e
}
'
)
from
e
def
validate_integrity_and_strict_load
(
sharded_state_dict
:
ShardedStateDict
,
strict
:
StrictHandling
,
validate_access_integrity
:
bool
,
local_metadata
:
Optional
[
_LocalMetadata
]
=
None
,
global_metadata
:
Optional
[
_GlobalMetadata
]
=
None
,
ckpt_sharded_metadata
:
Optional
[
'CkptShardedMetadata'
]
=
None
,
)
->
Tuple
[
ShardedStateDict
,
Set
[
str
],
Set
[
str
]]:
"""Validates sharding integrity and potential mismatches with the checkpoint.
`validate_access_integrity` controls sharding integrity check (orthogonal
to strictness checking) which verifies `sharded_state_dict` runtime completeness
(in isolation from the actual checkpoint).
`strict` flag controls handling of mismatches between the requested
sharded state dict to load and the actual checkpoint. See `StrictHandling`
docs for details regarding flag behavior and performance implications
(disk interactions or inter-rank communication).
Args:
sharded_state_dict (ShardedStateDict): sharded state dict to verify.
strict (StrictHandling): flag determining how to handle sharded keys mismatch.
validate_access_integrity (bool): whether to perform sharding validation.
local_metadata (_LocalMetadata, optional): local sharded state dict metadata.
Defaults to None, in which case it's determined based on `sharded_state_dict`.
global_metadata (_GlobalMetadata, optional): global sharded state dict metadata
(exchanged between ranks). Defaults to None, in which case "missing"
keys are not determined.
ckpt_sharded_metadata (CkptShardedMetadata, optional): sharded metadata
from the checkpoint. Defaults to None, which only makes sense
for the `StrictHandling.ASSUME_OK_UNEXPECTED` strict value.
Returns:
Tuple[ShardedStateDict, Set[str], Set[str]]: tuple of: sharded state dict
without unexpected keys, missing and unexpected keys. Missing keys are equal
on all ranks, unexpected keys might differ across ranks. Additionally,
missing keys might be erroneously empty (depending on `strict` value).
"""
missing_keys
,
unexpected_keys
=
[],
[]
if
StrictHandling
.
requires_explicit_ckpt_mismatch_check
(
strict
):
if
ckpt_sharded_metadata
is
None
:
raise
CheckpointingException
(
'Cannot verify checkpoint mismatch with ckpt_sharded_metadata=None.'
)
if
local_metadata
is
None
:
local_metadata
=
[
sh_base
.
without_data
()
for
sh_base
in
nested_values
(
sharded_state_dict
)
]
# We don't want to check for missing keys even if we could
_skip_missing_keys
=
strict
in
(
StrictHandling
.
ASSUME_OK_UNEXPECTED
,
StrictHandling
.
LOG_UNEXPECTED
,
StrictHandling
.
RAISE_UNEXPECTED
,
StrictHandling
.
RETURN_UNEXPECTED
,
)
missing_keys
,
unexpected_keys
=
_determine_missing_and_unexpected_keys
(
ckpt_sharded_metadata
,
local_metadata
,
None
if
_skip_missing_keys
else
global_metadata
)
sharded_state_dict
=
adjust_non_strict_load
(
sharded_state_dict
,
unexpected_keys
)
if
strict
==
StrictHandling
.
IGNORE_ALL
:
missing_keys
,
unexpected_keys
=
[],
[]
elif
strict
in
(
StrictHandling
.
RAISE_UNEXPECTED
,
StrictHandling
.
RAISE_ALL
):
maybe_report_missing_and_unexpected_keys
(
missing_keys
,
unexpected_keys
,
True
)
elif
strict
in
(
StrictHandling
.
LOG_UNEXPECTED
,
StrictHandling
.
LOG_ALL
):
maybe_report_missing_and_unexpected_keys
(
missing_keys
,
unexpected_keys
,
False
)
if
validate_access_integrity
:
if
global_metadata
is
None
:
raise
CheckpointingException
(
'Cannot check sharding intergrity without global_metadata (None).'
)
validate_sharding_integrity
(
global_metadata
)
return
sharded_state_dict
,
missing_keys
,
unexpected_keys
def
verify_checkpoint_and_load_strategy
(
checkpoint_dir
:
str
,
sharded_strategy
:
Union
[
LoadShardedStrategy
,
Tuple
[
str
,
int
],
None
]
=
None
,
common_strategy
:
Union
[
LoadCommonStrategy
,
Tuple
[
str
,
int
],
None
]
=
None
,
)
->
Tuple
[
LoadShardedStrategy
,
LoadCommonStrategy
]:
"""Verifies if checkpoint metadata exists and matches given strategies.
If no strategies are passed, they are determined based on the checkpoint metadata.
Args:
checkpoint_dir (str): checkpoint directory
sharded_strategy (LoadShardedStrategy, Tuple[str, int], optional): sharded load strategy to be verified
if compatible with the checkpoint content. If None, the default sharded load strategy
for the checkpoint backend will be returned.
common_strategy (LoadCommonStrategy, Tuple[str, int], optional): common load strategy to be verified
if compatible with the checkpoint content. If None, the default common load strategy
for the checkpoint backend will be returned.
"""
if
not
Path
(
checkpoint_dir
).
exists
():
raise
CheckpointingException
(
f
'Checkpoint directory
{
checkpoint_dir
}
does not exist'
)
saved_config
=
maybe_load_config
(
checkpoint_dir
)
if
saved_config
is
None
:
raise
CheckpointingException
(
f
'
{
checkpoint_dir
}
is not a distributed checkpoint'
)
if
sharded_strategy
is
None
:
sharded_strategy
=
get_default_strategy
(
StrategyAction
.
LOAD_SHARDED
,
saved_config
.
sharded_backend
,
saved_config
.
sharded_backend_version
,
)
elif
isinstance
(
sharded_strategy
,
tuple
):
sharded_strategy
=
get_default_strategy
(
StrategyAction
.
LOAD_SHARDED
,
*
sharded_strategy
)
if
common_strategy
is
None
:
common_strategy
=
get_default_strategy
(
StrategyAction
.
LOAD_COMMON
,
saved_config
.
common_backend
,
saved_config
.
common_backend_version
,
)
elif
isinstance
(
common_strategy
,
tuple
):
sharded_strategy
=
get_default_strategy
(
StrategyAction
.
LOAD_COMMON
,
*
common_strategy
)
sharded_strategy
.
check_backend_compatibility
(
saved_config
.
sharded_backend
)
sharded_strategy
.
check_version_compatibility
(
saved_config
.
sharded_backend_version
)
common_strategy
.
check_backend_compatibility
(
saved_config
.
common_backend
)
common_strategy
.
check_version_compatibility
(
saved_config
.
common_backend_version
)
return
sharded_strategy
,
common_strategy
def
adjust_non_strict_load
(
sharded_state_dict
:
ShardedStateDict
,
sharded_keys_to_remove
:
Set
[
str
]
)
->
ShardedStateDict
:
"""Adjusts sharded state dict removing keys not existing in the checkpoint.
Args:
sharded_state_dict (ShardedStateDict): sharded state dict to modify
sharded_keys_to_remove (Set[str]): keys to remove from the state dict
Returns:
ShardedStateDict: state dict without ShardedBase objects with specified keys
"""
def
is_unexpected_key
(
x
:
ShardedBase
):
assert
isinstance
(
x
,
ShardedBase
),
f
'Unexpected type
{
type
(
x
)
}
'
return
x
.
key
in
sharded_keys_to_remove
_
,
sharded_state_dict
=
extract_matching_values
(
sharded_state_dict
,
is_unexpected_key
)
return
sharded_state_dict
def
_determine_missing_and_unexpected_keys
(
ckpt_sharded_metadata
:
'CkptShardedMetadata'
,
local_metadata
:
_LocalMetadata
,
global_metadata
:
Optional
[
_GlobalMetadata
]
=
None
,
)
->
Tuple
[
Set
[
str
],
Set
[
str
]]:
"""Determines load mismatches based on metadata.
There is an asymmetry between "unexpected" and "missing" keys.
Unexpected keys can be determined based only on local metadata.
Missing keys must be based on global metadata, since other ranks might access
different keys than the current rank.
In consequence, the return value of this function is different on each rank:
"missing_keys" are equal, but "unexpected_keys" might differ across ranks.
Args:
ckpt_sharded_metadata (CkptShardedMetadata): sharded state dict (without data)
constructed based on the checkpoint content
local_metadata (_LocalMetadata): list of local ShardedBase objects
requested to be loaded by this rank
global_metadata (_GlobalMetadata, optional): list of global ShardedBase objects
requested to be loaded by all ranks. Defaults to None, in which case
returned "missing" keys are empty.
Returns:
Tuple[Set[str], Set[str]]: missing and unexpected keys. Missing keys are equal
on all ranks, unexpected keys might differ across ranks. If passed
`global_metadata` is empty, returned missing keys are empty as well.
"""
local_accessed_keys
=
set
(
sh_base
.
key
for
sh_base
in
local_metadata
)
ckpt_keys
=
set
(
sh_base
.
key
for
sh_base
in
ckpt_sharded_metadata
.
values
())
unexpected_keys
=
local_accessed_keys
-
ckpt_keys
if
global_metadata
is
not
None
:
global_accessed_keys
=
set
(
sh_base
.
key
for
rank_metadata
in
global_metadata
for
sh_base
in
rank_metadata
)
missing_keys
=
ckpt_keys
-
global_accessed_keys
else
:
missing_keys
=
set
()
if
missing_keys
:
logger
.
debug
(
f
'Dist ckpt load missing keys:
{
missing_keys
}
'
)
if
unexpected_keys
:
logger
.
debug
(
f
'Dist ckpt load unexpected keys:
{
unexpected_keys
}
'
)
return
missing_keys
,
unexpected_keys
def
maybe_report_missing_and_unexpected_keys
(
missing_keys
:
Set
[
str
],
unexpected_keys
:
Set
[
str
],
raise_error
:
bool
=
True
)
->
None
:
"""Raises or logs an error in case missing or unexpected keys are non-empty.
Args:
missing_keys (Set[str]): missing keys in the state dict
unexpected_keys (Set[str]): unexpected keys in the state dict
raise_error: If True, raises error on mismatch. Otherwise, logs mismatch
with WARNING level.
Returns:
None
Raises:
CheckpointingException: if `raise_error` is True and at least one of
`missing_keys` or `unexpected_keys` are non-empty.
"""
if
not
missing_keys
and
not
unexpected_keys
:
return
missing_title_msg
=
(
f
'Some keys found in the checkpoint are missing in the provided sharded state dict. '
)
missing_body_msg
=
f
'Missing keys (for all ranks):
{
missing_keys
}
. '
unexpected_title_msg
=
f
'Unexpected keys (not found in the checkpoint) encountered in the provided sharded state dict. '
unexpected_body_msg
=
f
'Unexpected keys (for this rank):
{
unexpected_keys
}
. '
error_msg
=
''
if
missing_keys
:
error_msg
+=
missing_title_msg
if
unexpected_keys
:
error_msg
+=
unexpected_title_msg
error_msg
+=
'
\n
'
if
missing_keys
:
error_msg
+=
missing_body_msg
if
unexpected_keys
:
error_msg
+=
unexpected_body_msg
if
raise_error
:
raise
CheckpointingException
(
error_msg
)
else
:
logger
.
warning
(
error_msg
)
def
validate_sharding_integrity
(
global_metadata
:
_GlobalMetadata
)
->
None
:
"""Validate if the ShardedTensors and ShardedObjects from multiple processes define correct sharding.
Local ShardedTensors and ShardedObject metadata is exchanged with `torch.distributed.all_gather_object`
and then process with global rank 0 checks if main replicas of the shards:
- cover the whole global tensors
- don't overlap
Args:
global_metadata (_GlobalMetadata): ShardedTensor and ShardedObject objects from all ranks.
Returns:
None
Raises:
CheckpointingException for invalid access pattern
"""
if
torch
.
distributed
.
get_rank
()
!=
0
:
return
key_shardings
=
defaultdict
(
list
)
for
rank
,
rank_shardings
in
enumerate
(
global_metadata
):
for
sharding
in
rank_shardings
:
key_shardings
[
sharding
.
key
].
append
((
rank
,
sharding
))
for
key
,
shardings
in
key_shardings
.
items
():
if
isinstance
(
shardings
[
0
][
1
],
ShardedObject
):
_validate_objects_for_key
(
shardings
)
else
:
_validate_sharding_for_key
(
shardings
)
def
_validate_sharding_for_key
(
rank_sharding
:
List
[
Tuple
[
int
,
ShardedTensor
]]):
some_rank_shard
=
rank_sharding
[
0
][
1
]
global_shape
=
some_rank_shard
.
global_shape
local_shape
=
some_rank_shard
.
local_shape
dtype
=
some_rank_shard
.
dtype
has_flattened_range
=
some_rank_shard
.
flattened_range
is
not
None
for
rank
,
sharding
in
rank_sharding
:
assert
sharding
.
dtype
==
dtype
,
(
sharding
.
dtype
,
dtype
,
some_rank_shard
)
assert
sharding
.
global_shape
==
global_shape
,
(
sharding
.
global_shape
,
global_shape
,
some_rank_shard
,
)
assert
sharding
.
local_shape
==
local_shape
,
(
sharding
.
local_shape
,
local_shape
,
some_rank_shard
,
)
assert
(
sharding
.
flattened_range
is
not
None
)
==
has_flattened_range
,
(
(
sharding
.
flattened_range
is
not
None
),
has_flattened_range
,
some_rank_shard
,
)
shard_access_cnt
=
_compute_shards_access
(
rank_sharding
)
if
has_flattened_range
:
map_reduce
(
rank_sharding
,
lambda
x
:
x
[
1
].
global_offset
,
lambda
x
:
x
[
1
],
_validate_sharding_for_key_flattened
,
)
else
:
if
not
torch
.
all
(
shard_access_cnt
==
1
):
logger
.
error
(
f
'Invalid access pattern for
{
rank_sharding
[
0
][
1
]
}
:
{
shard_access_cnt
}
'
)
raise
CheckpointingException
(
f
'Invalid access pattern for
{
rank_sharding
[
0
][
1
]
}
'
)
def
_compute_shards_access
(
rank_sharding
):
shard_access_cnt
=
torch
.
zeros
(
rank_sharding
[
0
][
1
].
axis_fragmentations
,
dtype
=
torch
.
int
,
device
=
'cpu'
)
for
rank
,
sharding
in
rank_sharding
:
if
is_main_replica
(
sharding
.
replica_id
):
shard_access_cnt
[
sharding
.
local_chunk_offset_in_global
()]
+=
1
return
shard_access_cnt
def
_validate_sharding_for_key_flattened
(
tensors_by_shard
):
all_slices
=
[]
local_shape
=
tensors_by_shard
[
0
].
local_shape
for
sharding
in
tensors_by_shard
:
assert
sharding
.
local_shape
==
local_shape
sharding
:
ShardedTensor
if
not
is_main_replica
(
sharding
.
replica_id
):
continue
all_slices
.
append
((
sharding
.
flattened_range
.
start
,
sharding
.
flattened_range
.
stop
))
starts
,
stops
=
map
(
np
.
asarray
,
zip
(
*
sorted
(
all_slices
)))
if
(
starts
[
0
]
!=
0
or
stops
[
-
1
]
!=
np
.
product
(
local_shape
)
or
not
np
.
all
(
starts
[
1
:]
==
stops
[:
-
1
])
):
logger
.
error
(
f
'Flattened ranges dont cover the whole shard
{
tensors_by_shard
[
0
]
}
. Ranges:
{
(
starts
,
stops
)
}
'
)
raise
CheckpointingException
(
f
'Flattened ranges dont cover the whole shard
{
tensors_by_shard
[
0
]
}
. Ranges:
{
(
starts
,
stops
)
}
'
)
def
_validate_objects_for_key
(
sharded_objects
:
List
[
ShardedObject
]):
"""Ensure uniqueness of saved objects."""
unique_keys
=
[
sh_obj
.
unique_key
for
_
,
sh_obj
in
sharded_objects
if
is_main_replica
(
sh_obj
.
replica_id
)
]
if
len
(
unique_keys
)
!=
len
(
set
(
unique_keys
)):
duplicates
=
{
k
:
cnt
for
k
,
cnt
in
Counter
(
unique_keys
).
items
()
if
cnt
>
1
}
logger
.
error
(
f
'Duplicate ShardedObject keys and counts:
{
duplicates
}
'
)
raise
CheckpointingException
(
f
'Duplicate ShardedObject keys:
{
list
(
duplicates
.
keys
())
}
'
)
expected_shard_num
=
np
.
prod
(
sharded_objects
[
0
][
1
].
global_shape
)
if
len
(
unique_keys
)
!=
expected_shard_num
:
err_msg
=
f
'Invalid access pattern:
{
expected_shard_num
-
len
(
unique_keys
)
}
ShardedObject are missing.'
logger
.
error
(
f
'
{
err_msg
}
Existing shards:
{
unique_keys
}
'
)
raise
CheckpointingException
(
err_msg
)
def
determine_global_metadata
(
sharded_state_dict
:
ShardedStateDict
,
)
->
Tuple
[
_LocalMetadata
,
_GlobalMetadata
]:
"""Exchanges local metadata with `all_gather_object` to determine global metadata.
Args:
sharded_state_dict (ShardedStateDict): local sharded state dict
Returns:
Tuple[_LocalMetadata, _GlobalMetadata]: local and global ShardedBase objects with stripped data
"""
local_metadata
=
[
ten
.
without_data
()
for
ten
in
nested_values
(
sharded_state_dict
)]
global_metadata
=
[
None
]
*
torch
.
distributed
.
get_world_size
()
torch
.
distributed
.
all_gather_object
(
global_metadata
,
local_metadata
)
return
local_metadata
,
global_metadata
def
validate_sharded_objects_handling
(
sharded_strategy
:
Union
[
SaveShardedStrategy
,
LoadShardedStrategy
],
common_strategy
:
Union
[
SaveCommonStrategy
,
LoadCommonStrategy
],
)
->
None
:
"""Checks if either of the passed strategies can handle sharded objects.
Args:
sharded_strategy (Union[SaveShardedStrategy, LoadShardedStrategy]): sharded strategy used for saving/loading
common_strategy (Union[SaveCommonStrategy, LoadCommonStrategy]): common strategy used for saving/loading
Returns:
None
Raises:
CheckpointingException: if both strategies can't handle ShardedObjects
"""
if
(
not
sharded_strategy
.
can_handle_sharded_objects
and
not
common_strategy
.
can_handle_sharded_objects
):
raise
CheckpointingException
(
f
'Either sharded strategy or common strategy must implement ShardedObjects handling.'
f
' Both
{
sharded_strategy
}
and
{
common_strategy
}
specify can_handle_sharded_objects=False'
)
megatron/core/distributed/__init__.py
0 → 100644
View file @
4b097dee
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from
.distributed_data_parallel
import
DistributedDataParallel
from
.distributed_data_parallel_config
import
DistributedDataParallelConfig
from
.finalize_model_grads
import
finalize_model_grads
# For backwards compatibility. ParamAndGradBuffer will be deprecated in future release.
# ParamAndGradBuffer (which is an alias of _ParamAndGradBuffer) is not intended to be
# consumed directly by external code.
from
.param_and_grad_buffer
import
ParamAndGradBuffer
megatron/core/distributed/distributed_data_parallel.py
0 → 100644
View file @
4b097dee
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import
logging
from
contextlib
import
contextmanager
import
torch
from
..
import
parallel_state
from
..config_logger
import
has_config_logger_enabled
,
log_config_to_disk
from
..transformer.module
import
MegatronModule
from
..transformer.transformer_config
import
TransformerConfig
from
..utils
import
is_float8tensor
,
log_single_rank
from
.distributed_data_parallel_config
import
DistributedDataParallelConfig
from
.param_and_grad_buffer
import
_ParamAndGradBuffer
,
partition_buckets
logger
=
logging
.
getLogger
(
__name__
)
class
DistributedDataParallel
(
MegatronModule
):
"""
DDP wrapper which stores grads in contiguous buffers. Also has option of overlapping
communication with backprop computation by breaking up full model's gradients into smaller
buckets and running all-reduce / reduce-scatter on each bucket asynchronously. This class
also provides the option to do the gradient accumulation in a type other than the param type
(e.g., fp32 for a bf16 model).
Args:
config: Transformer config object.
ddp_config: DistributedDataParallel config object.
module: Underlying model.
disable_bucketing: If true, force assign all parameters to a single bucket. If false,
use standard bucketing policy: assign parameters to smaller buckets and all-reduce
per bucket _if_ overlap_grad_reduce is True and pp_rank is 0.
"""
def
__init__
(
self
,
config
:
TransformerConfig
,
ddp_config
:
DistributedDataParallelConfig
,
module
:
torch
.
nn
.
Module
,
disable_bucketing
:
bool
=
False
,
):
super
().
__init__
(
config
=
config
)
if
has_config_logger_enabled
(
config
):
log_config_to_disk
(
config
,
locals
(),
prefix
=
type
(
self
).
__name__
)
self
.
module
=
module
# If bucket_size is not provided as an input, use sane default.
# If using very large dp_sizes, make buckets larger to ensure that chunks used in NCCL
# ring-reduce implementations are large enough to remain bandwidth-bound rather than
# latency-bound.
if
ddp_config
.
bucket_size
is
None
:
ddp_config
.
bucket_size
=
max
(
40000000
,
1000000
*
parallel_state
.
get_data_parallel_world_size
()
)
# Set bucket_size to infinity if overlap_grad_reduce is False.
if
not
ddp_config
.
overlap_grad_reduce
:
ddp_config
.
bucket_size
=
None
self
.
ddp_config
=
ddp_config
log_single_rank
(
logger
,
logging
.
INFO
,
f
'Setting up DistributedDataParallel with config
{
self
.
ddp_config
}
'
,
)
# Turn off bucketing if we are on a pipeline stage that is not the first (since
# data-parallel communication on these stages is not on the critical path), or if
# disable_bucketing is True (e.g., we might not want to break up model parameters
# into buckets for model chunks after the first in the interleaved schedule).
self
.
bucket_size
=
self
.
ddp_config
.
bucket_size
if
parallel_state
.
get_pipeline_model_parallel_rank
()
>
0
:
self
.
bucket_size
=
None
if
disable_bucketing
:
self
.
bucket_size
=
None
self
.
param_to_bucket_group
=
{}
# Group parameters by their gradient type.
param_to_name
=
{}
dense_params
=
[]
expert_parallel_params
=
[]
self
.
params_with_grad
=
[]
for
name
,
param
in
self
.
module
.
named_parameters
():
if
not
param
.
requires_grad
:
continue
# Track params with grad to enable direct setting
# of param.grad_added_to_main_grad
self
.
params_with_grad
.
append
(
param
)
param
.
grad_added_to_main_grad
=
False
param_to_name
[
param
]
=
name
if
getattr
(
param
,
'allreduce'
,
True
):
dense_params
.
append
(
param
)
else
:
expert_parallel_params
.
append
(
param
)
def
_allocate_buffers_for_parameters
(
input_params
,
data_parallel_group
,
gradient_scaling_factor
):
param_and_grad_dtype_to_params
=
{}
param_and_grad_dtype_to_offsets
=
{}
param_and_grad_dtype_to_indices
=
{}
# Group parameters by their gradient type.
for
param
in
input_params
:
assert
param
.
requires_grad
param_dtype
=
param
.
dtype
if
is_float8tensor
(
param
):
# Currently TE's Float8Tensor is a wrapper of torch.Tensor. It has a "fake"
# dtype (usually a higher precision dtype such as bfloat16), but its actual
# data is stored in the form of a torch uint8 tensor within the Float8Tensor's
# ".data" attribute. Therefore, when creating the param buffer for fp8 params,
# it is necessary to use torch.uint8, not the "fake" dtype got from
# "param.dtype".
param_dtype
=
torch
.
uint8
grad_dtype
=
torch
.
float
if
self
.
ddp_config
.
grad_reduce_in_fp32
else
param
.
dtype
params
=
param_and_grad_dtype_to_params
.
get
((
param_dtype
,
grad_dtype
),
[])
params
.
append
(
param
)
param_and_grad_dtype_to_params
[(
param_dtype
,
grad_dtype
)]
=
params
# Get the index of each param among the params with same dtype, if a param is fp8,
# use its "fake" high precision dtype to find which params have same dtype with it.
# For example:
# Case 1:
# params = [p1(bf16), p2(bf16), p3(bf16), p4(bf16)]
# param_and_grad_dtype_to_indices = {
# (torch.bfloat16, torch.float32): [0, 1, 2, 3],
# }
# Case 2:
# params = [p1(bf16), p2(fp8), p3(fp8), p4(bf16)]
# param_and_grad_dtype_to_indices = {
# (torch.bfloat16, torch.float32): [0, 3],
# (torch.uint8, torch.float32): [1, 2],
# }
# We need these indices to load a non-native-fp8 checkpoint in native-fp8 mode.
offset
=
param_and_grad_dtype_to_offsets
.
get
((
param
.
dtype
,
grad_dtype
),
0
)
param_and_grad_dtype_to_offsets
[(
param
.
dtype
,
grad_dtype
)]
=
offset
+
1
indices
=
param_and_grad_dtype_to_indices
.
get
((
param_dtype
,
grad_dtype
),
[])
indices
.
append
(
offset
)
param_and_grad_dtype_to_indices
[(
param_dtype
,
grad_dtype
)]
=
indices
if
not
config
.
calculate_per_token_loss
:
target_gradient_scaling_factor
=
1.0
/
parallel_state
.
get_data_parallel_world_size
(
with_context_parallel
=
True
)
if
self
.
ddp_config
.
average_in_collective
:
# Collective is averaging gradients in collective with data_parallel_group.
assert
(
gradient_scaling_factor
/
torch
.
distributed
.
get_world_size
(
group
=
data_parallel_group
)
==
target_gradient_scaling_factor
)
else
:
assert
gradient_scaling_factor
==
target_gradient_scaling_factor
# Allocate the grad buffers and map the grads.
buffers
=
[]
for
(
param_dtype
,
grad_dtype
),
params
in
param_and_grad_dtype_to_params
.
items
():
buffers
.
append
(
_ParamAndGradBuffer
(
self
.
ddp_config
,
param_dtype
,
grad_dtype
,
params
,
data_parallel_group
,
self
.
bucket_size
,
param_to_name
,
gradient_scaling_factor
,
param_and_grad_dtype_to_indices
[(
param_dtype
,
grad_dtype
)],
)
)
# In some scenarios, we want to put buckets from different buffers into a group so that
# their communication can be aggregated. For example, when there are both fp8 buffers
# and bf16 buffers in the model and vpp is enabled, each model chunk will have an fp8
# bucket and a bf16 bucket, which doubles the number of communication kernels, and
# because of the use of CUDA_DEVICE_MAX_CONNECTIONS=1, having multiple back-to-back
# communications will prevent the overlap of the communication kernels with computation
# kernels.
# If bucketing is explicitly disabled, then put all buckets in a buffer into a single
# bucket group.
bucket_groups
=
partition_buckets
(
buffers
,
force_single_bucket_group
=
disable_bucketing
)
# Set `next_param_gather_bucket_group` for different bucket groups by iterating through
# buckets in reverse order (since all-gathers happen in reverse order of buckets).
if
self
.
ddp_config
.
use_distributed_optimizer
and
self
.
ddp_config
.
overlap_param_gather
:
num_bucket_groups
=
len
(
bucket_groups
)
for
i
in
range
(
1
,
num_bucket_groups
):
bucket_groups
[
num_bucket_groups
-
i
].
next_param_gather_bucket_group
=
(
bucket_groups
[
num_bucket_groups
-
i
-
1
]
)
# Create map from param to bucket group, used in pre_hook.
for
bucket_group
in
bucket_groups
:
for
bucket
in
bucket_group
.
buckets
:
for
param
in
bucket
.
params_list
:
self
.
param_to_bucket_group
[
param
]
=
bucket_group
return
buffers
,
bucket_groups
if
config
.
calculate_per_token_loss
:
gradient_scaling_factor
=
1.0
expert_gradient_scaling_factor
=
1.0
else
:
if
self
.
ddp_config
.
average_in_collective
:
gradient_scaling_factor
=
1.0
expert_gradient_scaling_factor
=
(
1.0
/
parallel_state
.
get_expert_model_parallel_world_size
()
)
else
:
data_parallel_world_size
=
parallel_state
.
get_data_parallel_world_size
(
with_context_parallel
=
True
)
gradient_scaling_factor
=
1.0
/
data_parallel_world_size
expert_gradient_scaling_factor
=
1.0
/
data_parallel_world_size
# Allocate the param+grad buffers for dense params' grads.
self
.
buffers
,
self
.
bucket_groups
=
_allocate_buffers_for_parameters
(
dense_params
,
parallel_state
.
get_data_parallel_group
(
with_context_parallel
=
True
),
gradient_scaling_factor
=
gradient_scaling_factor
,
)
# Allocate separate param+grad buffers for expert parallel params' grads.
self
.
expert_parallel_buffers
,
self
.
expert_parallel_bucket_groups
=
(
_allocate_buffers_for_parameters
(
expert_parallel_params
,
parallel_state
.
get_data_modulo_expert_parallel_group
(
with_context_parallel
=
True
),
gradient_scaling_factor
=
expert_gradient_scaling_factor
,
)
)
# Delete references to weight_tensor if they exist since we don't want two parameter copies
# if we re-mapped parameters (which happens when we use the distributed optimizer).
# This is a temporary workaround around a TE bug that is fixed with
# https://github.com/NVIDIA/TransformerEngine/pull/719.
if
self
.
ddp_config
.
use_distributed_optimizer
:
@
torch
.
no_grad
()
def
unmap_weight_tensor
(
m
):
if
hasattr
(
m
,
'weight_tensor'
):
m
.
weight_tensor
=
None
self
.
module
.
apply
(
unmap_weight_tensor
)
# Register backward hook.
# Accumulation function for the gradients need to be stored so they
# don't go out of scope.
self
.
grad_accs
=
[]
for
param
in
self
.
module
.
parameters
():
if
param
.
requires_grad
:
# Expand so we get access to grad_fn.
param_tmp
=
param
.
expand_as
(
param
)
# Get the gradient accumulator function.
grad_acc
=
param_tmp
.
grad_fn
.
next_functions
[
0
][
0
]
grad_acc
.
register_hook
(
self
.
_make_backward_post_hook
(
param
))
self
.
grad_accs
.
append
(
grad_acc
)
self
.
use_forward_hook
=
(
self
.
ddp_config
.
use_distributed_optimizer
and
self
.
ddp_config
.
overlap_param_gather
)
self
.
remove_forward_pre_hook_handles
=
{}
if
self
.
use_forward_hook
:
self
.
enable_forward_pre_hook
()
self
.
overlap_param_gather_with_optimizer_step
=
False
def
enable_forward_pre_hook
(
self
):
"""
Enable forward pre-hooks needed for param all-gather overlap with forward compute.
"""
assert
self
.
use_forward_hook
assert
len
(
self
.
remove_forward_pre_hook_handles
)
==
0
# Register forward pre-hook for all sub-modules.
for
module
in
self
.
module
.
modules
():
self
.
remove_forward_pre_hook_handles
[
module
]
=
module
.
register_forward_pre_hook
(
self
.
_make_forward_pre_hook
()
)
def
disable_forward_pre_hook
(
self
):
"""
Disable forward pre-hooks needed for param all-gather overlap with forward compute.
"""
assert
self
.
use_forward_hook
# De-register forward pre-hook for all sub-modules.
for
module
in
self
.
module
.
modules
():
assert
self
.
remove_forward_pre_hook_handles
[
module
]
is
not
None
self
.
remove_forward_pre_hook_handles
[
module
].
remove
()
del
self
.
remove_forward_pre_hook_handles
[
module
]
assert
len
(
self
.
remove_forward_pre_hook_handles
)
==
0
# Force synchronize parameters.
self
.
start_param_sync
(
force_sync
=
True
)
def
forward
(
self
,
*
inputs
,
**
kwargs
):
"""
Calls the wrapped module's forward() method.
"""
return
self
.
module
(
*
inputs
,
**
kwargs
)
def
_make_forward_pre_hook
(
self
):
"""
Create a forward pre-hook to wait on all-gather handles when necessary (i.e.,
when a module uses a parameter in a bucket with a still incomplete all-gather).
"""
def
hook
(
module
,
*
unused
):
assert
(
self
.
use_forward_hook
),
"Should use pre-hook only when overlap_param_gather is True"
# Make sure all parameters in this module have been all-gathered as necessary.
for
param
in
module
.
parameters
(
recurse
=
False
):
# Skip parameters without an associated buffer (such parameters have a
# .requires_grad field equal to False).
if
param
not
in
self
.
param_to_bucket_group
:
continue
assert
param
.
requires_grad
# If aligning param all-gather across pipeline stages, all-gather is dispatched
# by start_param_sync calls in core/pipeline_parallelism/schedules.py.
# If overlapping param all-gather with optimizer step, then all-gather has
# already been dispatched in optimizer step.
skip_next_bucket_dispatch
=
(
self
.
ddp_config
.
align_param_gather
or
self
.
overlap_param_gather_with_optimizer_step
)
self
.
param_to_bucket_group
[
param
].
finish_param_sync
(
skip_next_bucket_dispatch
=
skip_next_bucket_dispatch
)
return
hook
def
_make_backward_post_hook
(
self
,
param
:
torch
.
nn
.
Parameter
):
"""
Creates a backward post-hook to dispatch an all-reduce / reduce-scatter when
ready (i.e., when all grads in a bucket have been computed in all microbatches
in a batch).
"""
def
hook
(
*
unused
):
if
param
in
self
.
param_to_bucket_group
:
assert
param
.
requires_grad
if
self
.
ddp_config
.
overlap_grad_reduce
:
assert
(
param
.
grad
is
not
None
),
'param.grad being None is not safe when overlap_grad_reduce is True'
if
param
.
grad
is
not
None
and
(
not
param
.
grad_added_to_main_grad
or
getattr
(
param
,
'zero_out_wgrad'
,
False
)
):
param
.
main_grad
.
add_
(
param
.
grad
.
data
)
param
.
grad
=
None
if
self
.
ddp_config
.
overlap_grad_reduce
:
self
.
param_to_bucket_group
[
param
].
register_grad_ready
(
param
)
return
hook
@
contextmanager
def
no_sync
(
self
):
"""
Context manager that turns off gradient synchronization.
"""
for
bucket_group
in
self
.
bucket_groups
+
self
.
expert_parallel_bucket_groups
:
bucket_group
.
is_last_microbatch
=
False
try
:
yield
finally
:
for
bucket_group
in
self
.
bucket_groups
+
self
.
expert_parallel_bucket_groups
:
bucket_group
.
is_last_microbatch
=
True
def
start_param_sync
(
self
,
*
unused
,
force_sync
:
bool
=
False
,
force_dispatch
:
bool
=
False
):
"""
Initiates param sync (all-gather) communication operations for all model parameters.
By default, when overlap_param_gather is set to True, dispatches asynchronous communication
calls; when overlap_param_gather is set to False, calls synchronous communication
ops. Can override this default behavior using flags below.
Args:
force_sync (bool, optional): force synchronous collective regardless of
other settings.
force_dispatch (bool, optional): force dispatch regardless of other settings.
"""
if
not
force_sync
:
# If overlapping param AG with optimizer step, AG should not be dispatched again
# in forward_backward_step.
if
self
.
overlap_param_gather_with_optimizer_step
and
not
force_dispatch
:
return
for
bucket_group
in
self
.
bucket_groups
+
self
.
expert_parallel_bucket_groups
:
bucket_group
.
start_param_sync
(
force_sync
=
force_sync
)
def
start_grad_sync
(
self
,
*
unused
):
"""
Initiates grad sync (all-reduce or reduce-scatter) communication operations
for all model gradients.
When overlap_grad_reduce is set to True, dispatches asynchronous communication
calls. When overlap_grad_reduce is set to False, calls synchronous
communication ops.
"""
for
bucket_group
in
self
.
bucket_groups
+
self
.
expert_parallel_bucket_groups
:
bucket_group
.
start_grad_sync
()
def
finish_grad_sync
(
self
):
"""
Finishes grad sync (all-reduce or reduce-scatter) communication operations
for all model gradients.
When overlap_grad_reduce is set to True, waits for asynchronous communication
calls to complete. When overlap_grad_reduce is set to False, calls synchronous
communication ops.
"""
for
bucket_group
in
self
.
bucket_groups
+
self
.
expert_parallel_bucket_groups
:
bucket_group
.
finish_grad_sync
()
def
scale_gradients
(
self
,
scaling_factor
:
float
):
"""Scale all gradients inside the buffers by `scaling_factor`."""
for
buffer
in
self
.
buffers
+
self
.
expert_parallel_buffers
:
buffer
.
scale_gradients
(
scaling_factor
)
def
zero_grad_buffer
(
self
):
"""
Zeros out all grad buffers. Needs to be called at the beginning of each
training iteration.
"""
for
param
in
self
.
params_with_grad
:
param
.
grad_added_to_main_grad
=
False
for
buffer
in
self
.
buffers
+
self
.
expert_parallel_buffers
:
buffer
.
reset
()
for
bucket_group
in
self
.
bucket_groups
+
self
.
expert_parallel_bucket_groups
:
bucket_group
.
reset
()
def
broadcast_params
(
self
):
"""
Syncs parameters across all DP ranks.
"""
for
param
in
self
.
module
.
parameters
():
is_expert_parallel
=
not
getattr
(
param
,
'allreduce'
,
True
)
if
is_expert_parallel
:
data_parallel_group
=
parallel_state
.
get_data_modulo_expert_parallel_group
(
with_context_parallel
=
True
)
else
:
data_parallel_group
=
parallel_state
.
get_data_parallel_group
(
with_context_parallel
=
True
)
torch
.
distributed
.
broadcast
(
param
.
data
,
src
=
torch
.
distributed
.
get_global_rank
(
data_parallel_group
,
0
),
group
=
data_parallel_group
,
)
def
state_dict
(
self
,
prefix
=
''
,
keep_vars
=
False
):
"""
Returns a dictionary containing references to the whole state of the
wrapped module.
Both parameters and persistent buffers (e.g. running averages) are included.
Keys are corresponding parameter and buffer names. Parameters and buffers
set to None are not included.
"""
return
self
.
module
.
state_dict
(
prefix
=
prefix
,
keep_vars
=
keep_vars
)
def
state_dict_for_save_checkpoint
(
self
,
prefix
=
''
,
keep_vars
=
False
):
"""
Returns wrapped module's state_dict for checkpoint saving.
"""
return
self
.
module
.
state_dict_for_save_checkpoint
(
prefix
=
prefix
,
keep_vars
=
keep_vars
)
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
"""
Copies parameters and buffers from state_dict into the wrapped module and its
descendants. If strict is True, then the keys of state_dict must exactly match
the keys returned by this module’s state_dict() function.
"""
self
.
module
.
load_state_dict
(
state_dict
,
strict
=
strict
)
megatron/core/distributed/distributed_data_parallel_config.py
0 → 100644
View file @
4b097dee
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from
dataclasses
import
dataclass
from
typing
import
Optional
@
dataclass
class
DistributedDataParallelConfig
:
"""Configuration for DistributedDataParallel."""
grad_reduce_in_fp32
:
bool
=
False
"""If true, reduce grads in fp32."""
overlap_grad_reduce
:
bool
=
False
"""If true, overlap grad all-reduce / reduce-scatter with backward compute."""
overlap_param_gather
:
bool
=
False
"""If true, overlap param all-gather with forward compute."""
align_param_gather
:
bool
=
False
"""If true, all PP stages will launch param all-gathers simultaneously. Otherwise, each
PP stage will independently launch as needed.
"""
use_distributed_optimizer
:
bool
=
False
"""If true, issue reduce-scatter collectives to aggregate gradients and clean up
originally allocated model parameters, otherwise issue all-reduce collectives.
"""
check_for_nan_in_grad
:
bool
=
False
""" If true, check for NaNs in gradients _before_ communication collective."""
bucket_size
:
Optional
[
int
]
=
None
"""Maximum number of parameters in each bucket. If unspecified, MCore uses a default
value of max(40000000, 1000000 * dp_size) parameters (larger DP sizes need larger
buckets to ensure collectives do not become latency-bound)."""
average_in_collective
:
bool
=
False
"""If true, compute average in collective directly, as opposed to dividing by the
dp_size first and then computing sum in the collective."""
fp8_param_gather
:
bool
=
False
"""If true, keep the compute param in fp8 (do not use any other intermediate dtype) and
perform the param all-gather in fp8."""
megatron/core/distributed/finalize_model_grads.py
0 → 100644
View file @
4b097dee
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from
typing
import
List
,
Optional
import
torch
from
torch._utils
import
_flatten_dense_tensors
,
_unflatten_dense_tensors
from
..
import
parallel_state
from
..transformer.transformer_config
import
TransformerConfig
from
..utils
import
get_attr_wrapped_model
,
get_model_config
def
_allreduce_word_embedding_grads
(
model
:
List
[
torch
.
nn
.
Module
],
config
:
TransformerConfig
):
"""
All-reduce word embedding grads.
Reduce grads across first and last stages to ensure that word_embeddings parameters stay in
sync.
"""
if
(
parallel_state
.
is_rank_in_embedding_group
(
ignore_virtual
=
True
)
and
torch
.
distributed
.
get_world_size
(
parallel_state
.
get_embedding_group
())
>
1
):
if
parallel_state
.
is_pipeline_first_stage
(
ignore_virtual
=
True
):
model_module
=
model
[
0
]
elif
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
model_module
=
model
[
-
1
]
else
:
# We do not support an interleaved schedule for models with encoders yet.
model_module
=
model
[
0
]
model_module
=
get_attr_wrapped_model
(
model_module
,
'pre_process'
,
return_model_obj
=
True
)
if
model_module
.
share_embeddings_and_output_weights
:
weight
=
model_module
.
shared_embedding_or_output_weight
()
grad
=
weight
.
main_grad
torch
.
distributed
.
all_reduce
(
grad
,
group
=
parallel_state
.
get_embedding_group
())
def
_allreduce_position_embedding_grads
(
model
:
List
[
torch
.
nn
.
Module
],
config
:
TransformerConfig
):
"""
All-reduce position_embeddings grad across encoder and decoder stages to ensure that position
embeddings parameters stay in sync.
"""
if
(
parallel_state
.
is_rank_in_position_embedding_group
()
and
torch
.
distributed
.
get_world_size
(
parallel_state
.
get_position_embedding_group
())
>
1
):
if
parallel_state
.
is_pipeline_first_stage
(
ignore_virtual
=
True
):
model_module
=
model
[
0
]
elif
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
model_module
=
model
[
-
1
]
else
:
# We do not support an interleaved schedule for models with encoders yet.
model_module
=
model
[
0
]
model_module
=
get_attr_wrapped_model
(
model_module
,
'pre_process'
,
return_model_obj
=
True
)
assert
hasattr
(
model_module
,
'position_embeddings'
)
grad
=
model_module
.
position_embeddings
.
weight
.
main_grad
torch
.
distributed
.
all_reduce
(
grad
,
group
=
parallel_state
.
get_position_embedding_group
())
def
_allreduce_embedding_grads
(
model
:
List
[
torch
.
nn
.
Module
],
config
:
TransformerConfig
):
"""
All-reduce both word and position embeddings.
"""
_allreduce_word_embedding_grads
(
model
,
config
)
_allreduce_position_embedding_grads
(
model
,
config
)
def
_allreduce_layernorm_grads
(
model
:
List
[
torch
.
nn
.
Module
],
config
:
TransformerConfig
):
"""
All-reduce layernorm grads (for sequence parallelism).
"""
# All-reduce layernorm parameters across model parallel nodes
# when sequence parallelism is used
if
parallel_state
.
get_tensor_model_parallel_world_size
()
>
1
and
(
config
.
sequence_parallel
or
config
.
qk_layernorm
):
grads
=
[]
for
model_chunk
in
model
:
for
name
,
param
in
get_attr_wrapped_model
(
model_chunk
,
'named_parameters'
)():
if
(
param
.
requires_grad
and
getattr
(
param
,
'sequence_parallel'
,
False
)
or
'q_layernorm'
in
name
or
'k_layernorm'
in
name
):
grad
=
param
.
main_grad
grads
.
append
(
grad
.
data
)
if
grads
:
coalesced
=
_flatten_dense_tensors
(
grads
)
torch
.
distributed
.
all_reduce
(
coalesced
,
group
=
parallel_state
.
get_tensor_model_parallel_group
()
)
for
buf
,
synced
in
zip
(
grads
,
_unflatten_dense_tensors
(
coalesced
,
grads
)):
buf
.
copy_
(
synced
)
def
finalize_model_grads
(
model
:
List
[
torch
.
nn
.
Module
],
num_tokens
:
Optional
[
torch
.
Tensor
]
=
None
):
"""
All-reduce all model grads across DP replicas, layernorm grads for sequence parallelism,
embedding grads across first and last pipeline stages (if not tied),
scale gradients by `num_tokens`.
"""
config
=
get_model_config
(
model
[
0
])
# All-reduce / reduce-scatter across DP replicas.
if
config
.
timers
is
not
None
:
config
.
timers
(
'all-grads-sync'
,
log_level
=
1
).
start
(
barrier
=
config
.
barrier_with_L1_time
)
for
model_chunk
in
model
:
model_chunk
.
finish_grad_sync
()
if
config
.
timers
is
not
None
:
config
.
timers
(
'all-grads-sync'
).
stop
()
# All-reduce layer-norm grads (for sequence parallelism).
if
config
.
timers
is
not
None
:
config
.
timers
(
'layernorm-grads-all-reduce'
,
log_level
=
1
).
start
(
barrier
=
config
.
barrier_with_L1_time
)
_allreduce_layernorm_grads
(
model
,
config
)
if
config
.
timers
is
not
None
:
config
.
timers
(
'layernorm-grads-all-reduce'
).
stop
()
# All-reduce embedding grads (for pipeline parallelism).
if
config
.
timers
is
not
None
:
config
.
timers
(
'embedding-grads-all-reduce'
,
log_level
=
1
).
start
(
barrier
=
config
.
barrier_with_L1_time
)
_allreduce_embedding_grads
(
model
,
config
)
if
config
.
timers
is
not
None
:
config
.
timers
(
'embedding-grads-all-reduce'
).
stop
()
# normalize gradients for per-token loss normalization.
# if we are using by the number of tokens, then we use that as a divisor. this number
# will be the total number of non-padded tokens in the global batch.
if
num_tokens
is
not
None
:
# the number of tokens is only present on the last stage, so broadcast it
# to the other ranks in the pipeline parallel group.
last_rank
=
parallel_state
.
get_pipeline_model_parallel_last_rank
()
pp_group
=
parallel_state
.
get_pipeline_model_parallel_group
()
if
not
isinstance
(
last_rank
,
list
):
assert
not
isinstance
(
last_rank
,
list
)
last_rank
=
[
last_rank
]
assert
not
isinstance
(
pp_group
,
list
)
pp_group
=
[
pp_group
]
# need to do a broadcast for every pp group, even though num_tokens should be the same.
num_tokens_list
=
[]
for
lr
,
group
in
zip
(
last_rank
,
pp_group
):
torch
.
distributed
.
broadcast
(
num_tokens
,
src
=
lr
,
group
=
group
)
num_tokens_list
.
append
(
torch
.
clone
(
num_tokens
))
assert
all
(
x
.
item
()
==
num_tokens_list
[
0
]
for
x
in
num_tokens_list
)
# all-reduce across DP ranks.
torch
.
distributed
.
all_reduce
(
num_tokens
,
group
=
parallel_state
.
get_data_parallel_group
())
for
model_chunk
in
model
:
if
num_tokens
>
0
:
scaling
=
1.0
/
num_tokens
model_chunk
.
scale_gradients
(
scaling
)
megatron/core/distributed/param_and_grad_buffer.py
0 → 100644
View file @
4b097dee
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import
logging
import
math
import
os
import
warnings
from
enum
import
Enum
from
typing
import
Dict
,
List
,
Optional
import
torch
from
torch.distributed
import
_coalescing_manager
from
..utils
import
is_float8tensor
,
log_on_each_pipeline_stage
from
.distributed_data_parallel_config
import
DistributedDataParallelConfig
logger
=
logging
.
getLogger
(
__name__
)
class
BufferType
(
Enum
):
"""
Enumeration for buffer type.
"""
PARAM
=
1
GRAD
=
2
def
shard_buffer
(
buffer
:
torch
.
Tensor
,
data_parallel_world_size
:
int
):
"""
Shard buffer into data_parallel_world_size chunks of equal size.
"""
assert
buffer
.
numel
()
%
data_parallel_world_size
==
0
shard_size
=
buffer
.
numel
()
//
data_parallel_world_size
sharded_buffer
=
[
buffer
[(
r
*
shard_size
)
:
((
r
+
1
)
*
shard_size
)]
for
r
in
range
(
data_parallel_world_size
)
]
return
sharded_buffer
class
_ParamAndGradBucket
:
"""
Bucket to keep track of a subset of the model's parameters and gradients.
Args:
params: List of parameters whose gradients are collated in this bucket.
param_data: View in ParamAndGradBuffer.param_data that this bucket is responsible for.
grad_data: View in ParamAndGradBuffer.grad_data that this bucket is responsible for.
offset: Offset of this bucket's view in the larger ParamAndGradBuffer.
numel_unpadded: Number of unpadded elements in bucket.
gradient_scaling_factor: This factor is utilized to scale gradients prior to their
communication. Its application is twofold: it facilitates the averaging of gradients
and the scaling of gradients in the context of the Mixture of Experts (MoE) model.
bucket_id: Index of bucket in buffer.
"""
def
__init__
(
self
,
params
:
List
[
torch
.
nn
.
Parameter
],
param_data
:
Optional
[
torch
.
Tensor
],
grad_data
:
torch
.
Tensor
,
offset
:
int
,
numel_unpadded
:
int
,
gradient_scaling_factor
:
float
,
bucket_id
:
int
,
):
self
.
params_list
=
params
self
.
params
=
set
(
params
)
# Make sure there are no duplicate params.
assert
len
(
self
.
params_list
)
==
len
(
self
.
params
)
self
.
param_data
=
param_data
self
.
grad_data
=
grad_data
# The distributed optimizer needs to keep track of this bucket's offset
# within the full grad_buffer.
self
.
offset
=
offset
self
.
numel_unpadded
=
numel_unpadded
self
.
gradient_scaling_factor
=
gradient_scaling_factor
self
.
bucket_id
=
bucket_id
class
_ParamAndGradBucketGroup
:
"""
Put multiple buckets into a group so that their communications can be aggregated together.
Provides functionality to register when params in the bucket group have grads ready to be
synced; an asynchronous communication call is automatically launched when _all_ params in
the bucket group have grads ready.
Args:
buckets: A list of buckets.
ddp_config: DistributedDataParallel config object.
data_parallel_group: Data-parallel process group.
data_parallel_world_size: World size using the data-parallel group group.
"""
def
__init__
(
self
,
buckets
:
List
[
_ParamAndGradBucket
],
ddp_config
:
DistributedDataParallelConfig
,
data_parallel_group
:
torch
.
distributed
.
ProcessGroup
,
data_parallel_world_size
:
int
,
):
self
.
buckets
=
buckets
self
.
ddp_config
=
ddp_config
self
.
data_parallel_group
=
data_parallel_group
self
.
data_parallel_world_size
=
data_parallel_world_size
self
.
data_parallel_rank
=
torch
.
distributed
.
get_rank
(
group
=
data_parallel_group
)
# State for bookkeeping: params is the set of parameters this bucket group is
# responsible for, params_with_grad is the set of parameters with grads
# available. When overlap_grad_reduce is True, communication (all-reduce
# or reduce-scatter) is issued when params_with_grad equals params.
self
.
param_to_bucket
=
{}
self
.
params
=
set
()
for
bucket
in
self
.
buckets
:
for
param
in
bucket
.
params_list
:
self
.
param_to_bucket
[
param
]
=
bucket
self
.
params
.
add
(
param
)
self
.
next_param_gather_bucket_group
=
None
self
.
reset
()
self
.
param_gather_handle
=
None
self
.
param_gather_dispatched
=
False
self
.
grad_reduce_handle
=
None
def
reset
(
self
):
"""
Reset metadata in bucket group in preparation for the next iteration of training.
"""
self
.
params_with_grad
=
set
()
self
.
is_last_microbatch
=
True
def
check_for_nan_in_grad
(
self
):
"""
Make sure norm of grads in bucket are not NaN prior to data-parallel
all-reduce / reduce-scatter.
"""
global_rank
=
torch
.
distributed
.
get_rank
()
norm_is_nan
=
self
.
buckets
[
0
].
grad_data
.
norm
(
p
=
2
).
isnan
()
for
i
in
range
(
1
,
len
(
self
.
buckets
)):
norm_is_nan
.
logical_or_
(
self
.
buckets
[
i
].
grad_data
.
norm
(
p
=
2
).
isnan
())
assert
not
norm_is_nan
,
(
f
'Rank
{
global_rank
}
: found NaN in local grad norm in '
f
'backward pass before data-parallel communication collective. '
f
'Device:
{
torch
.
cuda
.
current_device
()
}
, node:
{
os
.
uname
()[
1
]
}
'
)
def
start_param_sync
(
self
,
force_sync
:
bool
=
False
):
"""
Initiates all necessary param all-gathers for this bucket.
When ddp_config.overlap_param_gather is set to True, dispatches an asynchronous
communication call (unless force_sync is True). When ddp_config.overlap_param_gather
is set to False, makes synchronous call.
Args:
force_sync (bool, optional): force synchronous collective regardless of
other settings if true.
"""
assert
self
.
ddp_config
.
use_distributed_optimizer
if
force_sync
:
if
self
.
param_gather_handle
is
not
None
:
self
.
param_gather_handle
.
wait
()
self
.
param_gather_handle
=
None
return
else
:
assert
self
.
param_gather_handle
is
None
async_op
=
self
.
ddp_config
.
overlap_param_gather
and
not
force_sync
# Coalesce communication kernels across buckets in the bucket group.
with
_coalescing_manager
(
self
.
data_parallel_group
,
async_ops
=
async_op
)
as
cm
:
for
bucket
in
self
.
buckets
:
local_data_view
=
shard_buffer
(
bucket
.
param_data
,
self
.
data_parallel_world_size
)[
self
.
data_parallel_rank
]
torch
.
distributed
.
_all_gather_base
(
bucket
.
param_data
,
local_data_view
,
group
=
self
.
data_parallel_group
,
async_op
=
async_op
,
)
if
async_op
:
self
.
param_gather_handle
=
cm
else
:
# When using `_coalescing_manager`, even if a synchronous op (async_op=False) is used,
# `cm` is not None, which is different from when `_coalescing_manager` is not used in
# which case the torch.distributed._all_gather_base() will return None. In order to
# maintain consistency with prior code, we need to manually set communication handle to
# None.
self
.
param_gather_handle
=
None
self
.
param_gather_dispatched
=
True
def
finish_param_sync
(
self
,
skip_next_bucket_dispatch
:
bool
=
False
):
"""
Finishes param sync communication operation for this bucket. Dispatches
next bucket's param sync if available, unless skip_next_bucket_dispatch
is True.
When ddp_config.overlap_param_gather is set to True, waits for asynchronous
communication call to complete (and dispatches one if one is not already
outstanding). Throws assertion error if ddp_config.overlap_param_gather is set to
False.
Args:
skip_next_bucket_dispatch (bool, optional): if true, dispatch next
bucket's communication if available.
"""
assert
self
.
ddp_config
.
use_distributed_optimizer
assert
self
.
ddp_config
.
overlap_param_gather
# If current bucket's param AG has not been dispatched, dispatch it now (e.g., first
# AG bucket in first model chunk if ddp_config.align_param_gather is False).
if
not
self
.
param_gather_dispatched
:
self
.
start_param_sync
()
if
self
.
param_gather_handle
is
not
None
:
self
.
param_gather_handle
.
wait
()
self
.
param_gather_handle
=
None
# Dispatch next bucket's asynchronous param AG.
if
self
.
next_param_gather_bucket_group
is
not
None
and
not
skip_next_bucket_dispatch
:
self
.
next_param_gather_bucket_group
.
start_param_sync
()
def
start_grad_sync
(
self
):
"""
Initiates grad sync (all-reduce or reduce-scatter) communication operations
for all buckets in the bucket group.
When ddp_config.overlap_grad_reduce is set to True, dispatches an asynchronous
communication call. When ddp_config.overlap_grad_reduce is set to False, makes
synchronous call.
"""
assert
(
self
.
grad_reduce_handle
is
None
),
'Should not have multiple communication calls outstanding at once'
if
self
.
ddp_config
.
check_for_nan_in_grad
:
self
.
check_for_nan_in_grad
()
# gradient_scaling_factor already takes into account whether we are computing
# an average or sum in the data-parallel collective.
for
bucket
in
self
.
buckets
:
if
bucket
.
gradient_scaling_factor
!=
1.0
:
bucket
.
grad_data
*=
bucket
.
gradient_scaling_factor
# Decide reduce_op.
reduce_op
=
torch
.
distributed
.
ReduceOp
.
SUM
if
self
.
ddp_config
.
average_in_collective
:
reduce_op
=
torch
.
distributed
.
ReduceOp
.
AVG
# Use async communications only when overlap_grad_reduce is True.
async_op
=
self
.
ddp_config
.
overlap_grad_reduce
# Coalesce communication kernels across buckets in the bucket group.
with
_coalescing_manager
(
self
.
data_parallel_group
,
async_ops
=
async_op
)
as
cm
:
for
bucket
in
self
.
buckets
:
if
self
.
ddp_config
.
use_distributed_optimizer
:
local_data_view
=
shard_buffer
(
bucket
.
grad_data
,
self
.
data_parallel_world_size
)[
self
.
data_parallel_rank
]
torch
.
distributed
.
_reduce_scatter_base
(
local_data_view
,
bucket
.
grad_data
,
op
=
reduce_op
,
group
=
self
.
data_parallel_group
,
async_op
=
async_op
,
)
else
:
torch
.
distributed
.
all_reduce
(
bucket
.
grad_data
,
op
=
reduce_op
,
group
=
self
.
data_parallel_group
,
async_op
=
async_op
,
)
if
async_op
:
self
.
grad_reduce_handle
=
cm
else
:
# When using `_coalescing_manager`, even if a synchronous op (async_op=False) is used,
# `cm` is not None, which is different from when `_coalescing_manager` is not used in
# which case the torch.distributed._reduce_scatter_base() will return None. In order to
# maintain consistency with prior code, we need to manually set communication handle to
# None.
self
.
grad_reduce_handle
=
None
def
finish_grad_sync
(
self
):
"""
Finishes grad sync (all-reduce or reduce-scatter) communication operations
for all buckets in the bucket group.
When ddp_config.overlap_grad_reduce is set to True, waits for asynchronous
communication call to complete. When ddp_config.overlap_grad_reduce is set to False,
makes synchronous call.
"""
# If overlap_grad_reduce is False, start (and finish) synchronous communication call here.
self
.
param_gather_dispatched
=
False
if
not
self
.
ddp_config
.
overlap_grad_reduce
:
self
.
start_grad_sync
()
return
assert
self
.
grad_reduce_handle
is
not
None
,
(
f
'Communication call has not been issued for this bucket '
f
'(
{
len
(
self
.
params_with_grad
)
}
/
{
len
(
self
.
params
)
}
params have grad available)'
)
self
.
grad_reduce_handle
.
wait
()
self
.
grad_reduce_handle
=
None
def
register_grad_ready
(
self
,
param
:
torch
.
nn
.
Parameter
):
"""
Registers grads for the passed-in param to be "ready" for grad sync.
When the number of microbatches is greater than 1, we only want to register
grads as ready when processing the last microbatch and ddp_config.overlap_grad_reduce
is True.
"""
assert
(
self
.
ddp_config
.
overlap_grad_reduce
),
'register_grad_ready() should only be called when overlap_grad_reduce is True'
if
self
.
is_last_microbatch
:
assert
param
in
self
.
param_to_bucket
,
'Param is not in the bucket group'
assert
param
not
in
self
.
params_with_grad
,
'Cannot set grad twice'
self
.
params_with_grad
.
add
(
param
)
# If all params in bucket group have grads available, issue communication call.
if
len
(
self
.
params_with_grad
)
==
len
(
self
.
params
):
self
.
start_grad_sync
()
class
_ParamAndGradBuffer
:
"""
Groups parameters and gradients into a contiguous buffer, and then breaks the buffer into
buckets with roughly `bucket_size` parameters each.
Args:
ddp_config: DistributedDataParallel config object.
param_dtype: Type of param tensor.
grad_dtype: Type of grad tensor.
params: List of parameters whose parameters and gradients are collated in the underlying
tensor.
data_parallel_group: Data-parallel process group.
bucket_size: The rough size of each bucket in terms of number of parameters.
param_to_name: Mapping from `torch.nn.Parameter` to name (for logging purposes).
gradient_scaling_factor: This factor is utilized to scale gradients prior to their
communication. Its application is twofold: it facilitates the averaging of gradients
and the scaling of gradients in the context of the Mixture of Experts (MoE) model.
param_indices: The index of each param among the params with same dtype, if a param is fp8,
use its "fake" high precision dtype to determine which params have same dtype with it.
These indices are needed when loading a non-native-fp8 checkpoint in native-fp8 mode.
"""
def
__init__
(
self
,
ddp_config
:
DistributedDataParallelConfig
,
param_dtype
:
torch
.
dtype
,
grad_dtype
:
torch
.
dtype
,
params
:
List
[
torch
.
nn
.
Parameter
],
data_parallel_group
:
torch
.
distributed
.
ProcessGroup
,
bucket_size
:
int
,
param_to_name
:
Dict
[
torch
.
nn
.
Parameter
,
str
],
gradient_scaling_factor
:
float
,
param_indices
:
List
[
int
],
):
self
.
ddp_config
=
ddp_config
self
.
params
=
params
self
.
param_indices
=
param_indices
# Check that params are unique.
unique_params
=
set
()
for
param
in
params
:
assert
param
not
in
unique_params
unique_params
.
add
(
param
)
del
unique_params
# Store attributes that will be needed later.
self
.
param_dtype
=
param_dtype
self
.
grad_dtype
=
grad_dtype
self
.
data_parallel_group
=
data_parallel_group
self
.
data_parallel_world_size
=
torch
.
distributed
.
get_world_size
(
group
=
self
.
data_parallel_group
)
self
.
gradient_scaling_factor
=
gradient_scaling_factor
# Data structures to store underlying buckets and relevant indexing data.
self
.
buckets
=
[]
self
.
param_to_bucket
=
{}
# Param -> bucket mapping.
self
.
param_index_map
=
{}
# Param -> location in buffer mapping (used in dist. optimizer).
def
_pad
(
number_to_be_padded
:
int
,
divisor
:
int
)
->
int
:
return
int
(
math
.
ceil
(
number_to_be_padded
/
divisor
)
*
divisor
)
def
_pad_end_of_bucket_if_needed
(
bucket_end_index
:
int
)
->
int
:
"""
Pads end index of bucket if using distributed optimizer (to ensure uniform sharding).
"""
if
self
.
ddp_config
.
use_distributed_optimizer
:
# Workaround for TE bug causing cuBLAS to pick an incompatible algorithm.
# This also helps cuBLAS pick more efficient algorithms for GEMMs.
# We now ensure that all buckets start at a memory address that is 256-byte
# aligned (128 values since params and grads use >= 16-bit precision).
return
_pad
(
bucket_end_index
,
math
.
lcm
(
self
.
data_parallel_world_size
,
128
))
return
bucket_end_index
def
_pad_start_of_param_if_needed
(
param_start_index
:
int
)
->
int
:
"""
Pads start index of param if using distributed optimizer (to ensure "good" alignment).
"""
if
self
.
ddp_config
.
use_distributed_optimizer
:
# Ensure that params start at 128-byte aligned addresses (64 values
# since params are >= 16-bit precision).
return
_pad
(
param_start_index
,
64
)
return
param_start_index
# First, figure out how many elements should be in the underlying buffer storage.
# Note that if we need to split the buffer into smaller buckets, each of these
# might need to be padded as well (if using the distributed optimizer).
param_start_index
=
0
bucket_start_index
=
param_start_index
bucket_params
=
set
()
self
.
bucket_indices
=
[]
per_bucket_numel_unpadded
=
[]
bucket_id
=
0
def
_update_bucket_metadata
(
param_end_index
:
int
)
->
int
:
"""
Record metadata for the bucket starting at bucket_start_index and ending with the
passed-in param_end_index. Returns the bucket's end_index.
"""
nonlocal
bucket_start_index
,
bucket_params
,
bucket_id
per_bucket_numel_unpadded
.
append
(
param_end_index
-
bucket_start_index
)
bucket_end_index
=
_pad_end_of_bucket_if_needed
(
param_end_index
)
# Record metadata of new bucket.
self
.
bucket_indices
.
append
((
bucket_start_index
,
bucket_end_index
))
bucket_start_index
=
bucket_end_index
# Prepare for next bucket.
bucket_params
=
set
()
bucket_id
+=
1
# Return the potentially padded bucket_end_index.
return
bucket_end_index
def
_does_param_require_new_bucket
(
param
):
"""
Split shared embedding parameters into separate bucket if using distributed
optimizer that makes use of reduce-scatters instead of all-reduces.
This ensures that the first and last pipeline stage partition optimizer state
for the shared embedding parameters the same way across DP replicas, allowing
the DP reduce-scatter to be before the embedding all-reduce.
"""
return
(
getattr
(
param
,
"shared_embedding"
,
False
)
and
self
.
ddp_config
.
use_distributed_optimizer
)
for
param
in
params
[::
-
1
]:
# Iterate through parameters in reverse order to roughly follow backprop order.
this_numel
=
param
.
data
.
nelement
()
param_start_index
=
_pad_start_of_param_if_needed
(
param_start_index
)
# Create bucket with collected parameters if current param needs its own bucket.
if
_does_param_require_new_bucket
(
param
):
# We are creating a bucket for the already accumulated parameters, whose params
# end at the current param_start_index.
if
self
.
ddp_config
.
use_distributed_optimizer
:
# Make sure new bucket is appropriately padded.
if
param_start_index
%
self
.
data_parallel_world_size
!=
0
:
param_start_index
=
_pad_end_of_bucket_if_needed
(
param_start_index
)
if
len
(
bucket_params
)
>
0
:
bucket_end_index
=
_update_bucket_metadata
(
param_start_index
)
param_end_index
=
param_start_index
+
this_numel
self
.
param_index_map
[
param
]
=
(
param_start_index
,
param_end_index
,
bucket_id
)
bucket_params
.
add
(
param
)
# If we have enough elements already or the current param is part of the shared
# embedding layer and needs a separate bucket, form a new bucket.
if
(
bucket_size
is
not
None
and
(
param_end_index
-
bucket_start_index
)
>=
bucket_size
)
or
_does_param_require_new_bucket
(
param
):
bucket_end_index
=
_update_bucket_metadata
(
param_end_index
)
param_start_index
=
bucket_end_index
else
:
param_start_index
=
param_end_index
# Add remaining params to a new bucket.
if
len
(
bucket_params
)
>
0
:
bucket_end_index
=
_update_bucket_metadata
(
param_end_index
)
# Next, create underlying storage for buffer (with numel elements that includes
# padding as necessary).
self
.
numel
=
bucket_end_index
self
.
numel_unpadded
=
sum
(
per_bucket_numel_unpadded
)
assert
self
.
numel_unpadded
<=
self
.
numel
if
self
.
ddp_config
.
use_distributed_optimizer
:
assert
self
.
numel
%
self
.
data_parallel_world_size
==
0
else
:
assert
self
.
numel
==
self
.
numel_unpadded
self
.
param_data
=
None
# Only re-map param tensors if using distributed optimizer.
if
self
.
ddp_config
.
use_distributed_optimizer
:
self
.
param_data
=
torch
.
zeros
(
self
.
numel
,
dtype
=
self
.
param_dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
,
)
self
.
grad_data
=
torch
.
zeros
(
self
.
numel
,
dtype
=
self
.
grad_dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
,
)
# Finally, map param.data and param.main_grad fields to buffers.
bucket_params
=
[]
bucket_start_index
=
0
cur_bucket_id
=
0
for
param
in
params
[::
-
1
]:
param_start_index
,
param_end_index
,
bucket_id
=
self
.
param_index_map
[
param
]
# Assign param.data to appropriate segment of self.param_data.
if
self
.
param_data
is
not
None
:
old_param_data
=
param
.
data
new_param_data
=
self
.
_get
(
param
.
data
.
shape
,
param_start_index
,
buffer_type
=
BufferType
.
PARAM
)
if
is_float8tensor
(
param
):
param
.
_data
=
new_param_data
else
:
param
.
data
=
new_param_data
assert
old_param_data
.
_base
is
None
# Copy tensor values (from initialization or checkpoint).
param
.
data
.
detach
().
copy_
(
old_param_data
)
del
old_param_data
param
.
main_grad
=
self
.
_get
(
param
.
data
.
shape
,
param_start_index
,
buffer_type
=
BufferType
.
GRAD
)
if
bucket_id
!=
cur_bucket_id
:
bucket_end_index
=
_pad_end_of_bucket_if_needed
(
param_start_index
)
self
.
buckets
.
append
(
self
.
_new_bucket
(
bucket_params
=
bucket_params
,
start_index
=
bucket_start_index
,
end_index
=
bucket_end_index
,
numel_unpadded
=
per_bucket_numel_unpadded
[
cur_bucket_id
],
bucket_id
=
cur_bucket_id
,
)
)
bucket_start_index
=
bucket_end_index
bucket_params
=
[]
assert
cur_bucket_id
+
1
==
len
(
self
.
buckets
)
assert
bucket_id
==
cur_bucket_id
+
1
cur_bucket_id
=
bucket_id
bucket_params
.
append
(
param
)
# Add remaining params to a new bucket.
if
len
(
bucket_params
)
>
0
:
bucket_end_index
=
_pad_end_of_bucket_if_needed
(
param_end_index
)
self
.
buckets
.
append
(
self
.
_new_bucket
(
bucket_params
=
bucket_params
,
start_index
=
bucket_start_index
,
end_index
=
bucket_end_index
,
numel_unpadded
=
per_bucket_numel_unpadded
[
cur_bucket_id
],
bucket_id
=
cur_bucket_id
,
)
)
# Log buckets for all PP stages.
log_strs
=
[]
log_strs
.
append
(
f
'Number of buckets for gradient all-reduce / reduce-scatter:
{
len
(
self
.
buckets
)
}
'
)
for
index
,
bucket
in
enumerate
(
self
.
buckets
):
numel
=
0
for
param
in
bucket
.
params
:
numel
+=
param
.
data
.
nelement
()
log_strs
.
append
(
f
'Params for bucket
{
index
+
1
}
(
{
numel
}
elements):'
)
for
param
in
bucket
.
params
:
log_strs
.
append
(
f
'
\t
{
param_to_name
[
param
]
}
'
)
log_on_each_pipeline_stage
(
logger
,
logging
.
INFO
,
'
\n
'
.
join
(
log_strs
))
def
scale_gradients
(
self
,
scaling_factor
:
float
)
->
None
:
"""Scale the gradient data by `scaling_factor`."""
self
.
grad_data
*=
scaling_factor
def
_get
(
self
,
shape
:
torch
.
Size
,
start_index
:
int
,
buffer_type
:
BufferType
)
->
torch
.
Tensor
:
"""
Return a tensor with the input `shape` as a view into the 1-D data starting at
`start_index`.
"""
end_index
=
start_index
+
shape
.
numel
()
assert
end_index
<=
self
.
numel
,
'Requested tensor is out of buffer range'
if
buffer_type
==
BufferType
.
PARAM
:
assert
self
.
param_data
is
not
None
buffer_tensor
=
self
.
param_data
[
start_index
:
end_index
]
elif
buffer_type
==
BufferType
.
GRAD
:
buffer_tensor
=
self
.
grad_data
[
start_index
:
end_index
]
else
:
raise
Exception
(
"Illegal buffer type provided to GradBuffer._get() function"
)
buffer_tensor
=
buffer_tensor
.
view
(
shape
)
return
buffer_tensor
def
_new_bucket
(
self
,
bucket_params
:
List
[
torch
.
nn
.
Parameter
],
start_index
:
int
,
end_index
:
int
,
numel_unpadded
:
int
,
bucket_id
:
int
,
)
->
_ParamAndGradBucket
:
"""
Helper function that creates a new bucket. Also updates param->bucket mapping.
"""
# Assert that indices are correctly padded (if needed), and that bucket
# position is same as originally computed.
if
self
.
ddp_config
.
use_distributed_optimizer
:
assert
start_index
%
self
.
data_parallel_world_size
==
0
assert
end_index
%
self
.
data_parallel_world_size
==
0
assert
(
start_index
,
end_index
)
==
self
.
bucket_indices
[
bucket_id
]
# Get appropriate view into global ParamAndGradBuffer.
bucketed_param_data
=
None
if
self
.
param_data
is
not
None
:
bucketed_param_data
=
self
.
_get
(
torch
.
Size
([
end_index
-
start_index
]),
start_index
,
buffer_type
=
BufferType
.
PARAM
)
bucketed_grad_data
=
self
.
_get
(
torch
.
Size
([
end_index
-
start_index
]),
start_index
,
buffer_type
=
BufferType
.
GRAD
)
bucket
=
_ParamAndGradBucket
(
params
=
bucket_params
,
param_data
=
bucketed_param_data
,
grad_data
=
bucketed_grad_data
,
offset
=
start_index
,
numel_unpadded
=
numel_unpadded
,
gradient_scaling_factor
=
self
.
gradient_scaling_factor
,
bucket_id
=
bucket_id
,
)
for
bucket_param
in
bucket_params
:
assert
bucket_param
not
in
self
.
param_to_bucket
self
.
param_to_bucket
[
bucket_param
]
=
bucket
return
bucket
def
reset
(
self
):
"""
Zero out the underlying grad_buffer.
"""
self
.
grad_data
.
zero_
()
def
partition_buckets
(
buffers
:
List
[
_ParamAndGradBuffer
],
force_single_bucket_group
:
bool
=
False
)
->
List
[
_ParamAndGradBucketGroup
]:
"""
Automatically regroup the buckets of input buffers and return a list of bucket groups.
In some scenarios, we need to put buckets from different buffers into a group so that their
communication can be aggregated.
For example, when there are both fp8 weights and bf16 biases in the model and virtual
pipeline parallelism is enabled, each model chunk will have an fp8 bucket and a bf16 bucket,
which doubles the number of communication kernels, and because of the use of
CUDA_DEVICE_MAX_CONNECTIONS=1, having multiple back-to-back communications will prevent the
overlap of communication kernels with computation kernels.
The grouping strategy is:
1. If force_single_bucket_group is True, put all buckets across all buffers into a single
bucket group.
2. If force_single_bucket_group is False, when there is no fp8 buffer in the input buffers,
let each bucket group have only one bucket.
3. If force_single_bucket_group is False, when using fp8 params, merge all non-fp8 buckets
into the last fp8 bucket group.
- Since the non-fp8 parameters (typically the biases of various layers) are relatively
small, they are likely to be grouped into a single non-fp8 bucket.
- The fp8 buckets start from the end of the model, i.e., the first bucket corresponds to
the end of the model, while the last bucket corresponds to the beginning.
- If we combine the non-fp8 bucket with the first fp8 bucket, we cannot initiate the
reduce-scatter to synchronize gradients after the backward pass at the end of the model
has completed. This is because we need to wait for the non-fp8 params from the beginning
layers to obtain their gradients.
- Combining the non-fp8 bucket with the last fp8 bucket can help avoid this issue.
Args:
buffers (list): list of input buffers.
single_bucket_group_per_buffer (bool, optional): force group all buckets in each buffer
into a single bucket group.
"""
if
len
(
buffers
)
==
0
:
return
[]
dtype_to_buffer_map
=
{}
for
buffer
in
buffers
:
dtype
=
buffer
.
param_dtype
# Make sure that the param_dtype of any two buffers is different.
assert
dtype
not
in
dtype_to_buffer_map
dtype_to_buffer_map
[
dtype
]
=
buffer
# Case 1: Put all buckets into a single bucket group if force_single_bucket_group is True.
if
force_single_bucket_group
:
buckets
=
[]
ddp_config
=
buffers
[
0
].
ddp_config
data_parallel_group
=
buffers
[
0
].
data_parallel_group
data_parallel_world_size
=
buffers
[
0
].
data_parallel_world_size
for
buffer
in
buffers
:
assert
ddp_config
==
buffer
.
ddp_config
assert
data_parallel_group
==
buffer
.
data_parallel_group
assert
data_parallel_world_size
==
buffer
.
data_parallel_world_size
buckets
.
extend
(
buffer
.
buckets
)
bucket_group
=
_ParamAndGradBucketGroup
(
buckets
,
ddp_config
,
data_parallel_group
,
data_parallel_world_size
)
return
[
bucket_group
]
if
torch
.
uint8
not
in
dtype_to_buffer_map
:
# Case 2: When there is no fp8 buffer in the input buffers, let each bucket group have
# only one bucket.
bucket_groups
=
[]
for
buffer
in
buffers
:
for
bucket
in
buffer
.
buckets
:
bucket_groups
.
append
(
_ParamAndGradBucketGroup
(
[
bucket
],
buffer
.
ddp_config
,
buffer
.
data_parallel_group
,
buffer
.
data_parallel_world_size
,
)
)
return
bucket_groups
else
:
# Case 3: When using fp8 params, merge all non-fp8 buckets into the last fp8 bucket group.
non_fp8_buckets
=
[]
for
buffer
in
buffers
:
if
buffer
.
param_dtype
!=
torch
.
uint8
:
for
bucket
in
buffer
.
buckets
:
non_fp8_buckets
.
append
(
bucket
)
bucket_groups
=
[]
fp8_buffer
=
dtype_to_buffer_map
[
torch
.
uint8
]
for
bucket
in
fp8_buffer
.
buckets
:
if
len
(
bucket_groups
)
==
len
(
fp8_buffer
.
buckets
)
-
1
:
# The last bucket group.
group_buckets
=
[
bucket
]
+
non_fp8_buckets
else
:
# The first N-1 bucket groups.
group_buckets
=
[
bucket
]
bucket_groups
.
append
(
_ParamAndGradBucketGroup
(
group_buckets
,
buffer
.
ddp_config
,
buffer
.
data_parallel_group
,
buffer
.
data_parallel_world_size
,
)
)
return
bucket_groups
# For backwards compatibility. ParamAndGradBuffer will be deprecated in future release.
# _ParamAndGradBuffer is not intended to be consumed directly by external code.
class
ParamAndGradBuffer
(
_ParamAndGradBuffer
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
warnings
.
warn
(
"`ParamAndGradBuffer` will be deprecated in a future release, and is not "
"intended to be used by external code."
)
megatron/
fused_kernels/test
s/__init__.py
→
megatron/
core/extension
s/__init__.py
View file @
4b097dee
File moved
megatron/core/extensions/transformer_engine.py
0 → 100644
View file @
4b097dee
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import
dataclasses
import
os
import
warnings
from
typing
import
Callable
import
torch
import
transformer_engine
as
te
from
packaging.version
import
Version
as
PkgVersion
from
torch
import
Tensor
from
megatron.core
import
ModelParallelConfig
,
parallel_state
from
megatron.core.dist_checkpointing.utils
import
replace_prefix_for_sharding
from
megatron.core.packed_seq_params
import
PackedSeqParams
from
megatron.core.parallel_state
import
(
get_context_parallel_global_ranks
,
get_context_parallel_group
,
get_tensor_and_expert_parallel_world_size
,
get_tensor_model_parallel_group
,
)
from
megatron.core.tensor_parallel
import
get_cuda_rng_tracker
,
get_expert_parallel_rng_tracker_name
from
megatron.core.tensor_parallel.utils
import
divide
from
megatron.core.transformer.enums
import
AttnMaskType
from
megatron.core.transformer.transformer_config
import
TransformerConfig
from
megatron.core.transformer.utils
import
make_sharded_tensors_for_checkpoint
from
megatron.core.utils
import
get_te_version
,
is_te_min_version
def
_get_extra_te_kwargs
(
config
:
TransformerConfig
):
extra_transformer_engine_kwargs
=
{
"params_dtype"
:
config
.
params_dtype
}
if
is_te_min_version
(
"0.12.0"
):
if
config
.
use_cpu_initialization
:
extra_transformer_engine_kwargs
[
"device"
]
=
'cpu'
else
:
extra_transformer_engine_kwargs
[
"device"
]
=
torch
.
cuda
.
current_device
()
return
extra_transformer_engine_kwargs
def
condition_init_method
(
config
,
init_method
):
"""Condition TE init_method on config.perform_initialization."""
return
init_method
if
config
.
perform_initialization
else
(
lambda
w
:
None
)
class
TENorm
:
"""
A conditional wrapper to initialize an instance of Transformer-Engine's
`LayerNorm` or `RMSNorm` based on input
"""
# TODO should we ditch normalization config and just use spec to choose LayerNorm vs RMSNorm?
def
__new__
(
cls
,
config
:
TransformerConfig
,
hidden_size
:
int
,
eps
:
float
=
1e-5
):
if
config
.
normalization
==
"LayerNorm"
:
instance
=
te
.
pytorch
.
LayerNorm
(
hidden_size
=
hidden_size
,
eps
=
eps
,
sequence_parallel
=
config
.
sequence_parallel
,
zero_centered_gamma
=
config
.
layernorm_zero_centered_gamma
,
**
_get_extra_te_kwargs
(
config
),
)
elif
config
.
normalization
==
"RMSNorm"
:
assert
hasattr
(
te
.
pytorch
,
"RMSNorm"
),
"Transformer-Engine >= v0.11 required to use this feature"
instance
=
te
.
pytorch
.
RMSNorm
(
hidden_size
=
hidden_size
,
eps
=
eps
,
sequence_parallel
=
config
.
sequence_parallel
,
zero_centered_gamma
=
config
.
layernorm_zero_centered_gamma
,
**
_get_extra_te_kwargs
(
config
),
)
else
:
raise
Exception
(
'Only LayerNorm and RMSNorm are curently supported'
)
return
instance
class
TELinear
(
te
.
pytorch
.
Linear
):
"""
Wrapper for the Transformer-Engine's `Linear` layer.
Note that if Megatron's parallel_state has not been initialized
yet, the tp_group passed to TE will be None and must be set later
via set_tensor_parallel_group().
"""
def
__init__
(
self
,
input_size
:
int
,
output_size
:
int
,
*
,
parallel_mode
:
str
,
config
:
ModelParallelConfig
,
init_method
:
Callable
,
bias
:
bool
,
skip_bias_add
:
bool
,
skip_weight_param_allocation
:
bool
,
tp_comm_buffer_name
:
str
=
None
,
is_expert
:
bool
=
False
,
):
self
.
config
=
config
# TE returns a zero length Tensor when bias=False and
# return_bias=True, but we prefer None. So in that case we
# tell TE to not return the bias, and return None
# ourselves. This way our forward always returns two values
# and we don't have to deal with the zero length Tensor.
self
.
te_return_bias
=
skip_bias_add
and
bias
self
.
is_first_microbatch
=
True
self
.
disable_parameter_transpose_cache
=
self
.
config
.
disable_parameter_transpose_cache
if
skip_weight_param_allocation
:
raise
ValueError
(
'Transformer Engine linear layers do not support skip_weight_param_allocation'
)
extra_kwargs
=
_get_extra_te_kwargs
(
config
)
if
is_te_min_version
(
"0.8.0"
):
if
self
.
config
.
tp_comm_overlap
:
if
is_te_min_version
(
"1.5.0"
):
# Use old overlap flags if they were supplied instead
extra_kwargs
[
"ub_overlap_ag"
]
=
(
self
.
config
.
tp_comm_overlap_ag
if
hasattr
(
self
.
config
,
"tp_comm_overlap_ag"
)
else
self
.
config
.
tp_comm_split_ag
or
self
.
config
.
tp_comm_atomic_ag
)
extra_kwargs
[
"ub_overlap_rs"
]
=
(
self
.
config
.
tp_comm_overlap_rs
if
hasattr
(
self
.
config
,
"tp_comm_overlap_rs"
)
else
self
.
config
.
tp_comm_split_rs
or
self
.
config
.
tp_comm_atomic_rs
)
# Disable ub overlap for experts.
if
is_expert
:
extra_kwargs
[
"ub_overlap_ag"
]
=
False
extra_kwargs
[
"ub_overlap_rs"
]
=
False
else
:
extra_kwargs
[
"ub_split_ag"
]
=
self
.
config
.
tp_comm_split_ag
extra_kwargs
[
"ub_atomic_gemm_ag"
]
=
self
.
config
.
tp_comm_atomic_ag
extra_kwargs
[
"ub_split_rs"
]
=
self
.
config
.
tp_comm_split_rs
extra_kwargs
[
"ub_atomic_gemm_rs"
]
=
self
.
config
.
tp_comm_atomic_rs
# Disable ub overlap for experts.
if
is_expert
:
extra_kwargs
[
"ub_split_ag"
]
=
False
extra_kwargs
[
"ub_atomic_gemm_ag"
]
=
False
extra_kwargs
[
"ub_split_rs"
]
=
False
extra_kwargs
[
"ub_atomic_gemm_rs"
]
=
False
if
is_te_min_version
(
"1.0.0"
,
check_equality
=
False
):
assert
(
tp_comm_buffer_name
is
not
None
),
"Buffer name should be set to configure communication overlap settings"
extra_kwargs
[
"ub_name"
]
=
tp_comm_buffer_name
self
.
expert_parallel
=
self
.
config
.
expert_model_parallel_size
>
1
if
is_expert
and
self
.
expert_parallel
:
rng_tracker_name
=
get_expert_parallel_rng_tracker_name
()
else
:
rng_tracker_name
=
None
if
is_te_min_version
(
"1.7.0"
):
extra_kwargs
[
"rng_tracker_name"
]
=
rng_tracker_name
# Disable communications in TE when using SP or EP by making TE agnostic of model parallel.
tp_size
=
self
.
config
.
tensor_model_parallel_size
tp_group
=
get_tensor_model_parallel_group
(
check_initialized
=
False
)
if
is_expert
and
(
self
.
config
.
sequence_parallel
or
self
.
expert_parallel
):
if
self
.
config
.
moe_extended_tp
:
tp_size
=
get_tensor_and_expert_parallel_world_size
()
if
parallel_mode
==
"column"
:
output_size
=
divide
(
output_size
,
tp_size
)
elif
parallel_mode
==
"row"
:
input_size
=
divide
(
input_size
,
tp_size
)
parallel_mode
=
None
tp_size
=
1
tp_group
=
None
super
().
__init__
(
in_features
=
input_size
,
out_features
=
output_size
,
sequence_parallel
=
self
.
config
.
sequence_parallel
,
fuse_wgrad_accumulation
=
self
.
config
.
gradient_accumulation_fusion
,
tp_group
=
tp_group
,
tp_size
=
tp_size
,
get_rng_state_tracker
=
(
get_cuda_rng_tracker
if
get_cuda_rng_tracker
().
is_initialized
()
else
None
),
init_method
=
condition_init_method
(
config
,
init_method
),
bias
=
bias
,
return_bias
=
self
.
te_return_bias
,
parallel_mode
=
parallel_mode
,
**
extra_kwargs
,
)
for
param
in
self
.
parameters
():
setattr
(
param
,
'allreduce'
,
not
(
is_expert
and
self
.
expert_parallel
))
def
forward
(
self
,
x
):
"""Forward."""
_is_first_microbatch
=
(
None
if
self
.
disable_parameter_transpose_cache
else
self
.
is_first_microbatch
)
out
=
super
().
forward
(
x
,
is_first_microbatch
=
_is_first_microbatch
)
self
.
is_first_microbatch
=
False
# TE only returns a tuple when return_bias is True, otherwise
# it returns a single Tensor, we always want to return two
# values regardless of the arguments.
if
self
.
te_return_bias
:
return
out
return
out
,
None
class
TELayerNormColumnParallelLinear
(
te
.
pytorch
.
LayerNormLinear
):
"""
Wrapper for the Transformer-Engine's `LayerNormLinear` layer that combines
layernorm and linear layers
"""
def
__init__
(
self
,
input_size
:
int
,
output_size
:
int
,
*
,
config
:
TransformerConfig
,
init_method
:
Callable
,
gather_output
:
bool
,
bias
:
bool
,
skip_bias_add
:
bool
,
is_expert
:
bool
,
skip_weight_param_allocation
:
bool
=
False
,
tp_comm_buffer_name
:
str
=
None
,
):
self
.
config
=
config
if
gather_output
:
raise
ValueError
(
'Transformer Engine linear layers do not support gather_output = True'
)
if
is_expert
:
raise
ValueError
(
'Transformer Engine linear layers do not yet support MoE'
)
if
skip_weight_param_allocation
:
raise
ValueError
(
'Transformer Engine linear layers do not support skip_weight_param_allocation'
)
# TE returns a zero length Tensor when bias=False and
# return_bias=True, but we prefer None. So in that case we
# tell TE to not return the bias, and return None
# ourselves. This way our forward always returns two values
# and we don't have to deal with the zero length Tensor.
self
.
te_return_bias
=
skip_bias_add
and
bias
self
.
is_first_microbatch
=
True
self
.
disable_parameter_transpose_cache
=
self
.
config
.
disable_parameter_transpose_cache
extra_kwargs
=
_get_extra_te_kwargs
(
config
)
# Only Transformer-Engine version >= 0.11.0 supports `RMSNorm`
if
is_te_min_version
(
"0.11.0"
):
extra_kwargs
[
"normalization"
]
=
self
.
config
.
normalization
elif
self
.
config
.
normalization
!=
"LayerNorm"
:
te_version
=
get_te_version
()
raise
ValueError
(
f
"Transformer Engine v
{
te_version
}
does not support
{
self
.
config
.
normalization
}
."
)
if
is_te_min_version
(
"0.8.0"
):
if
self
.
config
.
tp_comm_overlap
:
extra_kwargs
[
"ub_bulk_wgrad"
]
=
self
.
config
.
tp_comm_bulk_wgrad
extra_kwargs
[
"ub_bulk_dgrad"
]
=
self
.
config
.
tp_comm_bulk_dgrad
if
is_te_min_version
(
"1.5.0"
,
check_equality
=
False
):
# Use old overlap flags if they were supplied instead
extra_kwargs
[
"ub_overlap_ag"
]
=
(
self
.
config
.
tp_comm_overlap_ag
if
hasattr
(
self
.
config
,
"tp_comm_overlap_ag"
)
else
self
.
config
.
tp_comm_split_ag
or
self
.
config
.
tp_comm_atomic_ag
)
if
is_te_min_version
(
"1.6.0.dev0"
,
check_equality
=
False
):
extra_kwargs
[
"ub_overlap_rs_dgrad"
]
=
(
self
.
config
.
tp_comm_overlap_rs_dgrad
if
hasattr
(
self
.
config
,
"tp_comm_overlap_rs_dgrad"
)
else
False
)
if
tp_comm_buffer_name
==
'qkv'
and
self
.
config
.
tp_comm_overlap_disable_qkv
:
extra_kwargs
[
"ub_overlap_ag"
]
=
False
extra_kwargs
[
"ub_overlap_rs_dgrad"
]
=
False
if
tp_comm_buffer_name
==
'fc1'
and
self
.
config
.
tp_comm_overlap_disable_fc1
:
extra_kwargs
[
"ub_overlap_ag"
]
=
False
extra_kwargs
[
"ub_overlap_rs_dgrad"
]
=
False
else
:
extra_kwargs
[
"ub_atomic_gemm_ag"
]
=
self
.
config
.
tp_comm_atomic_ag
extra_kwargs
[
"ub_split_ag"
]
=
self
.
config
.
tp_comm_split_ag
if
is_te_min_version
(
"1.0.0"
,
check_equality
=
False
):
assert
(
tp_comm_buffer_name
is
not
None
),
"Buffer name should be set to configure communication overlap settings"
extra_kwargs
[
"ub_name"
]
=
tp_comm_buffer_name
super
().
__init__
(
in_features
=
input_size
,
out_features
=
output_size
,
eps
=
self
.
config
.
layernorm_epsilon
,
sequence_parallel
=
self
.
config
.
sequence_parallel
,
fuse_wgrad_accumulation
=
self
.
config
.
gradient_accumulation_fusion
,
tp_group
=
get_tensor_model_parallel_group
(
check_initialized
=
False
),
tp_size
=
self
.
config
.
tensor_model_parallel_size
,
get_rng_state_tracker
=
(
get_cuda_rng_tracker
if
get_cuda_rng_tracker
().
is_initialized
()
else
None
),
init_method
=
condition_init_method
(
config
,
init_method
),
bias
=
bias
,
return_bias
=
self
.
te_return_bias
,
parallel_mode
=
"column"
,
return_layernorm_output
=
False
,
zero_centered_gamma
=
self
.
config
.
layernorm_zero_centered_gamma
,
**
extra_kwargs
,
)
def
forward
(
self
,
x
):
"""Forward."""
_is_first_microbatch
=
(
None
if
self
.
disable_parameter_transpose_cache
else
self
.
is_first_microbatch
)
out
=
super
().
forward
(
x
,
is_first_microbatch
=
_is_first_microbatch
)
self
.
is_first_microbatch
=
False
# TE only returns a tuple when return_bias is True, otherwise
# it returns a single Tensor, we always want to return two
# values regardless of the arguments.
if
self
.
te_return_bias
:
return
out
return
out
,
None
def
sharded_state_dict
(
self
,
prefix
=
''
,
sharded_offsets
=
(),
metadata
=
None
):
"""Sharding along axis 0, bias sharded"""
state_dict
=
self
.
state_dict
(
prefix
=
''
,
keep_vars
=
True
)
return
make_sharded_tensors_for_checkpoint
(
state_dict
,
prefix
,
{
'weight'
:
0
,
'bias'
:
0
},
sharded_offsets
)
class
TEColumnParallelLinear
(
TELinear
):
"""
Wrapper for the Transformer-Engine's `Linear` layer but specialized similar
to megatron's `ColumnParallelLinear` layer.
"""
def
__init__
(
self
,
input_size
:
int
,
output_size
:
int
,
*
,
config
:
ModelParallelConfig
,
init_method
:
Callable
,
gather_output
:
bool
,
bias
:
bool
,
skip_bias_add
:
bool
,
is_expert
:
bool
,
skip_weight_param_allocation
:
bool
=
False
,
tp_comm_buffer_name
:
str
=
None
,
):
if
gather_output
:
raise
ValueError
(
'Transformer Engine linear layers do not support gather_output = True'
)
super
().
__init__
(
input_size
=
input_size
,
output_size
=
output_size
,
parallel_mode
=
"column"
,
config
=
config
,
init_method
=
condition_init_method
(
config
,
init_method
),
bias
=
bias
,
skip_bias_add
=
skip_bias_add
,
is_expert
=
is_expert
,
skip_weight_param_allocation
=
skip_weight_param_allocation
,
tp_comm_buffer_name
=
tp_comm_buffer_name
,
)
def
sharded_state_dict
(
self
,
prefix
=
''
,
sharded_offsets
=
(),
metadata
=
None
):
"""Sharding along axis 0, bias sharded"""
state_dict
=
self
.
state_dict
(
prefix
=
''
,
keep_vars
=
True
)
return
make_sharded_tensors_for_checkpoint
(
state_dict
,
prefix
,
{
'weight'
:
0
,
'bias'
:
0
},
sharded_offsets
)
class
TERowParallelLinear
(
TELinear
):
"""
Wrapper for the Transformer-Engine's `Linear` layer but specialized similar
to megatron's `RowParallelLinear` layer.
"""
def
__init__
(
self
,
input_size
:
int
,
output_size
:
int
,
*
,
config
:
ModelParallelConfig
,
init_method
:
Callable
,
bias
:
bool
,
input_is_parallel
:
bool
,
skip_bias_add
:
bool
,
is_expert
:
bool
,
tp_comm_buffer_name
:
str
=
None
,
):
if
not
input_is_parallel
:
raise
ValueError
(
"Transformer Engine linear layers do not support input_is_parallel = False"
)
super
().
__init__
(
input_size
=
input_size
,
output_size
=
output_size
,
parallel_mode
=
"row"
,
config
=
config
,
init_method
=
condition_init_method
(
config
,
init_method
),
bias
=
bias
,
skip_bias_add
=
skip_bias_add
,
skip_weight_param_allocation
=
False
,
# We don't currently use this for row parallel layers # pylint: disable=line-too-long
is_expert
=
is_expert
,
tp_comm_buffer_name
=
tp_comm_buffer_name
,
)
def
sharded_state_dict
(
self
,
prefix
=
''
,
sharded_offsets
=
(),
metadata
=
None
):
"""Sharding along axis 1, bias not sharded"""
state_dict
=
self
.
state_dict
(
prefix
=
''
,
keep_vars
=
True
)
return
make_sharded_tensors_for_checkpoint
(
state_dict
,
prefix
,
{
'weight'
:
1
},
sharded_offsets
)
class
TEDotProductAttention
(
te
.
pytorch
.
DotProductAttention
):
"""
Wrapper for the Transformer-Engine's `DotProductAttention` layer that also
has "flash attention" enabled.
Note that if Megatron's parallel_state has not been initialized yet, the
tp_group and cp_group passed to TE will be None and must be set later
via set_tensor_parallel_group() and set_context_parallel_group().
"""
cp_stream
:
torch
.
cuda
.
Stream
=
None
def
__init__
(
self
,
config
:
TransformerConfig
,
layer_number
:
int
,
attn_mask_type
:
AttnMaskType
,
attention_type
:
str
,
attention_dropout
:
float
=
None
,
):
self
.
config
=
config
self
.
te_forward_mask_type
=
False
self
.
qkv_format
:
str
=
'sbhd'
if
self
.
config
.
apply_query_key_layer_scaling
!=
bool
(
int
(
os
.
getenv
(
'NVTE_APPLY_QK_LAYER_SCALING'
,
'0'
))
):
raise
ValueError
(
f
"apply_query_key_layer_scaling is
{
self
.
config
.
apply_query_key_layer_scaling
}
"
f
"but environment variable NVTE_APPLY_QK_LAYER_SCALING is "
f
"
{
os
.
getenv
(
'NVTE_APPLY_QK_LAYER_SCALING'
)
}
. Transformer Engine does not support "
f
"setting query key layer scaling via argument, so these two must match."
)
extra_kwargs
=
{}
if
is_te_min_version
(
"0.11.0"
):
extra_kwargs
[
"num_gqa_groups"
]
=
self
.
config
.
num_query_groups
elif
self
.
config
.
num_query_groups
!=
self
.
config
.
num_attention_heads
:
raise
ValueError
(
f
"Transformer Engine v
{
get_te_version
()
}
does not support Grouped Query Attention, "
f
"use a newer version of Transformer Engine. "
f
"(num_query_groups (
{
self
.
config
.
num_query_groups
}
) != "
f
"num_attention_heads (
{
self
.
config
.
num_attention_heads
}
))"
)
if
is_te_min_version
(
"0.10.0"
):
extra_kwargs
[
"attention_type"
]
=
attention_type
# older version don't need attention_type
if
is_te_min_version
(
"0.12.0"
,
check_equality
=
False
):
self
.
te_forward_mask_type
=
True
# Only Transformer-Engine version >= 1.0.0 supports context parallelism
if
is_te_min_version
(
"1.0.0"
):
if
getattr
(
TEDotProductAttention
,
"cp_stream"
)
is
None
:
TEDotProductAttention
.
cp_stream
=
torch
.
cuda
.
Stream
()
extra_kwargs
[
"cp_group"
]
=
get_context_parallel_group
(
check_initialized
=
False
)
extra_kwargs
[
"cp_global_ranks"
]
=
get_context_parallel_global_ranks
(
check_initialized
=
False
)
extra_kwargs
[
"cp_stream"
]
=
TEDotProductAttention
.
cp_stream
else
:
assert
(
self
.
config
.
context_parallel_size
==
1
),
"Only Transformer-Engine version >= 1.0.0 supports context parallelism!"
if
self
.
config
.
deterministic_mode
:
if
int
(
os
.
getenv
(
"NVTE_ALLOW_NONDETERMINISTIC_ALGO"
,
"1"
))
!=
0
:
raise
RuntimeError
(
"deterministic_mode is on and we are using DotProductAttention from "
"Transformer Engine, but NVTE_ALLOW_NONDETERMINISTIC_ALGO is not 0. "
f
"Currently set to:
{
os
.
getenv
(
'NVTE_ALLOW_NONDETERMINISTIC_ALGO'
,
'not set'
)
}
."
)
if
config
.
window_size
is
not
None
:
# Check version
assert
is_te_min_version
(
"1.2.0"
),
(
f
"Transformer-Engine v
{
get_te_version
()
}
must be >= 1.2.0 to support"
"sliding window attention."
)
extra_kwargs
[
'window_size'
]
=
config
.
window_size
super
().
__init__
(
num_attention_heads
=
self
.
config
.
num_attention_heads
,
kv_channels
=
self
.
config
.
kv_channels
,
attention_dropout
=
(
self
.
config
.
attention_dropout
if
attention_dropout
is
None
else
attention_dropout
),
attn_mask_type
=
attn_mask_type
.
name
,
sequence_parallel
=
self
.
config
.
sequence_parallel
,
tp_size
=
self
.
config
.
tensor_model_parallel_size
,
get_rng_state_tracker
=
(
get_cuda_rng_tracker
if
get_cuda_rng_tracker
().
is_initialized
()
else
None
),
tp_group
=
get_tensor_model_parallel_group
(
check_initialized
=
False
),
layer_number
=
layer_number
,
**
extra_kwargs
,
)
def
forward
(
self
,
query
:
Tensor
,
key
:
Tensor
,
value
:
Tensor
,
attention_mask
:
Tensor
,
attn_mask_type
:
AttnMaskType
,
packed_seq_params
:
PackedSeqParams
=
None
,
):
"""Forward."""
packed_seq_kwargs
=
(
dataclasses
.
asdict
(
packed_seq_params
)
if
packed_seq_params
is
not
None
else
{}
)
# overwrite self.qkv_format depending on self.config.apply_rope_fusion, which can be set
# after init
if
self
.
config
.
apply_rope_fusion
and
is_te_min_version
(
"0.13.0"
,
check_equality
=
False
):
self
.
qkv_format
=
'bshd'
qkv_format
=
packed_seq_kwargs
.
get
(
'qkv_format'
,
self
.
qkv_format
)
if
get_te_version
()
<
PkgVersion
(
"1.3.0"
):
# TE 1.3.0 introduces precomputing max_seqlen to remove unnecessary kernels and D2H
# copies (#555)
# These two arguments did not exist prior to 1.3.0
packed_seq_kwargs
.
pop
(
"max_seqlen_q"
,
None
)
packed_seq_kwargs
.
pop
(
"max_seqlen_kv"
,
None
)
if
self
.
config
.
apply_rope_fusion
and
qkv_format
==
'bshd'
:
query
,
key
,
value
=
[
x
.
transpose
(
0
,
1
).
contiguous
()
for
x
in
(
query
,
key
,
value
)]
# In PyTorch, the following two tensors are in fact the same:
# Tensor with shape (1, S, H, D) and stride (S*H*D, H*D, D, 1)
# Tensor with shape (1, S, H, D) and stride (H*D, H*D, D, 1)
# Stride for a dimension that is 1 has no meaning, so tensors created two different ways
# can have same shape but different strides.
# We unify them to the first one to pass the stride check in TE
if
value
.
shape
==
key
.
shape
and
value
.
shape
[
0
]
==
1
and
value
.
stride
()
!=
key
.
stride
():
value
=
value
.
as_strided
(
value
.
shape
,
key
.
stride
())
if
self
.
te_forward_mask_type
:
if
qkv_format
==
'thd'
and
is_te_min_version
(
"1.7.0"
):
# thd format uses flash attention with cuDNN kernel which requires is_padding=True,
# so the only acceptable mask types are `padding_causal` and `padding`. These do not
# necessarily indicate there are padded tokens in the sequence.
if
attn_mask_type
==
AttnMaskType
.
causal
:
attn_mask_type
=
AttnMaskType
.
padding_causal
elif
attn_mask_type
==
AttnMaskType
.
no_mask
:
attn_mask_type
=
AttnMaskType
.
padding
core_attn_out
=
super
().
forward
(
query
,
key
,
value
,
attention_mask
,
attn_mask_type
=
attn_mask_type
.
name
,
**
packed_seq_kwargs
,
)
else
:
core_attn_out
=
super
().
forward
(
query
,
key
,
value
,
attention_mask
,
**
packed_seq_kwargs
)
if
self
.
config
.
apply_rope_fusion
and
qkv_format
==
'bshd'
:
return
core_attn_out
.
transpose
(
0
,
1
)
else
:
return
core_attn_out
if
is_te_min_version
(
"1.9.0.dev0"
):
class
TEGroupedLinear
(
te
.
pytorch
.
GroupedLinear
):
"""
Wrapper for the Transformer-Engine's `GroupedLinear` layer.
Note that if Megatron's parallel_state has not been initialized
yet, the tp_group passed to TE will be None and must be set later
via set_tensor_parallel_group().
"""
def
__init__
(
self
,
num_gemms
:
int
,
input_size
:
int
,
output_size
:
int
,
*
,
parallel_mode
:
str
,
config
:
ModelParallelConfig
,
init_method
:
Callable
,
bias
:
bool
,
skip_bias_add
:
bool
,
is_expert
:
bool
=
False
,
tp_comm_buffer_name
:
str
=
None
,
):
self
.
config
=
config
# TE returns a zero length Tensor when bias=False and
# return_bias=True, but we prefer None. So in that case we
# tell TE to not return the bias, and return None
# ourselves. This way our forward always returns two values
# and we don't have to deal with the zero length Tensor.
self
.
te_return_bias
=
skip_bias_add
and
bias
self
.
is_first_microbatch
=
True
self
.
disable_parameter_transpose_cache
=
self
.
config
.
disable_parameter_transpose_cache
extra_kwargs
=
_get_extra_te_kwargs
(
config
)
extra_kwargs
[
"ub_name"
]
=
tp_comm_buffer_name
self
.
expert_parallel
=
self
.
config
.
expert_model_parallel_size
>
1
if
self
.
expert_parallel
:
extra_kwargs
[
"rng_tracker_name"
]
=
get_expert_parallel_rng_tracker_name
()
# For MoE models, the comms between TP and EP group is explicitly handled by
# MoE token dispatcher. So we disable comms by making TE agnostic of model parallel.
self
.
explicit_expert_comm
=
is_expert
and
(
config
.
tensor_model_parallel_size
>
1
or
self
.
expert_parallel
)
tp_group
=
get_tensor_model_parallel_group
(
check_initialized
=
False
)
if
self
.
explicit_expert_comm
and
config
.
moe_extended_tp
:
tp_size
=
parallel_state
.
get_tensor_and_expert_parallel_world_size
()
else
:
tp_size
=
parallel_state
.
get_tensor_model_parallel_world_size
()
if
self
.
explicit_expert_comm
:
if
parallel_mode
==
"column"
:
output_size
=
divide
(
output_size
,
tp_size
)
elif
parallel_mode
==
"row"
:
input_size
=
divide
(
input_size
,
tp_size
)
parallel_mode
=
None
tp_size
=
1
tp_group
=
None
super
().
__init__
(
num_gemms
=
num_gemms
,
in_features
=
input_size
,
out_features
=
output_size
,
sequence_parallel
=
self
.
config
.
sequence_parallel
,
fuse_wgrad_accumulation
=
self
.
config
.
gradient_accumulation_fusion
,
tp_group
=
tp_group
,
tp_size
=
tp_size
,
get_rng_state_tracker
=
(
get_cuda_rng_tracker
if
get_cuda_rng_tracker
().
is_initialized
()
else
None
),
init_method
=
condition_init_method
(
config
,
init_method
),
bias
=
bias
,
return_bias
=
self
.
te_return_bias
,
parallel_mode
=
parallel_mode
,
**
extra_kwargs
,
)
for
param
in
self
.
parameters
():
setattr
(
param
,
'allreduce'
,
not
(
is_expert
and
self
.
expert_parallel
))
def
forward
(
self
,
x
,
m_splits
):
"""Forward."""
_is_first_microbatch
=
(
None
if
self
.
disable_parameter_transpose_cache
else
self
.
is_first_microbatch
)
out
=
super
().
forward
(
x
,
m_splits
,
is_first_microbatch
=
_is_first_microbatch
)
self
.
is_first_microbatch
=
False
# TE only returns a tuple when return_bias is True, otherwise
# it returns a single Tensor, we always want to return two
# values regardless of the arguments.
if
self
.
te_return_bias
:
return
out
return
out
,
None
def
_sharded_state_dict_grouped
(
self
,
tp_axis_map
,
prefix
=
''
,
sharded_offsets
=
(),
metadata
=
None
):
"""
prefix should be module_name to make keys identical to sequetial ones.
"""
sharded_state_dict
=
{}
full_state_dict
=
self
.
state_dict
(
prefix
=
''
,
keep_vars
=
True
)
num_global_experts
=
(
parallel_state
.
get_expert_model_parallel_world_size
()
*
self
.
num_gemms
)
local_expert_indices_offset
=
(
parallel_state
.
get_expert_model_parallel_rank
()
*
self
.
num_gemms
)
ep_axis
=
len
(
sharded_offsets
)
for
gemm_idx
in
range
(
self
.
num_gemms
):
state_dict
=
{
f
'
{
gemm_idx
}
.weight'
:
full_state_dict
[
f
'weight
{
gemm_idx
}
'
],
f
'
{
gemm_idx
}
._extra_state'
:
full_state_dict
[
'_extra_state'
],
}
if
self
.
use_bias
:
state_dict
[
f
'
{
gemm_idx
}
.bias'
]
=
full_state_dict
[
f
'bias
{
gemm_idx
}
'
]
sub_sd
=
make_sharded_tensors_for_checkpoint
(
state_dict
,
''
,
tp_axis_map
,
(
*
sharded_offsets
,
(
ep_axis
,
local_expert_indices_offset
+
gemm_idx
,
num_global_experts
),
),
)
# Remove expert layers indexing from sharded keys
replace_prefix_for_sharding
(
sub_sd
,
f
'
{
gemm_idx
}
.'
,
prefix
)
sharded_state_dict
.
update
(
{
f
'
{
prefix
}
weight
{
gemm_idx
}
'
:
sub_sd
[
f
'
{
gemm_idx
}
.weight'
],
# TODO: TE's GroupedLinear only has one _extra_state for all experts.
# We need sharding or build/merge fn to handle _extra_state correctly.
f
'
{
prefix
}
_extra_state
{
""
if
gemm_idx
==
0
else
gemm_idx
}
'
:
sub_sd
[
f
'
{
gemm_idx
}
._extra_state'
],
}
)
if
self
.
use_bias
:
sharded_state_dict
[
f
'
{
prefix
}
bias
{
gemm_idx
}
'
]
=
sub_sd
[
f
'
{
gemm_idx
}
.bias'
]
# Adjust replica ids - replication along DP modulo EP
for
k
,
sh_ten
in
sharded_state_dict
.
items
():
replica_id
=
sh_ten
.
replica_id
assert
(
len
(
replica_id
)
==
3
),
f
'Expected replica_id for
{
k
}
to be in (PP, TP, DP) format, got:
{
replica_id
}
'
sh_ten
.
replica_id
=
(
*
replica_id
[:
2
],
parallel_state
.
get_data_modulo_expert_parallel_rank
(),
)
return
sharded_state_dict
class
TEColumnParallelGroupedLinear
(
TEGroupedLinear
):
"""
Wrapper for the Transformer-Engine's `GroupedLinear` layer but specialized
to column-parallel style.
"""
def
__init__
(
self
,
num_gemms
:
int
,
input_size
:
int
,
output_size
:
int
,
*
,
config
:
ModelParallelConfig
,
init_method
:
Callable
,
bias
:
bool
,
skip_bias_add
:
bool
,
is_expert
:
bool
,
tp_comm_buffer_name
:
str
=
None
,
):
super
().
__init__
(
num_gemms
=
num_gemms
,
input_size
=
input_size
,
output_size
=
output_size
,
parallel_mode
=
"column"
,
config
=
config
,
init_method
=
condition_init_method
(
config
,
init_method
),
bias
=
bias
,
skip_bias_add
=
skip_bias_add
,
is_expert
=
is_expert
,
tp_comm_buffer_name
=
tp_comm_buffer_name
,
)
def
sharded_state_dict
(
self
,
prefix
=
''
,
sharded_offsets
=
(),
metadata
=
None
):
"""
For each gemm, sharding along axis 0, bias sharded.
Assume sharded_offsets[-1] is the expert parallel offset.
"""
tp_axis_map
=
{}
for
gemm_idx
in
range
(
self
.
num_gemms
):
tp_axis_map
.
update
({
f
'
{
gemm_idx
}
.weight'
:
0
,
f
'
{
gemm_idx
}
.bias'
:
0
})
return
super
().
_sharded_state_dict_grouped
(
tp_axis_map
,
prefix
,
sharded_offsets
,
metadata
)
class
TERowParallelGroupedLinear
(
TEGroupedLinear
):
"""
Wrapper for the Transformer-Engine's `GroupedLinear` layer but specialized
to row-parallel style.
"""
def
__init__
(
self
,
num_gemms
:
int
,
input_size
:
int
,
output_size
:
int
,
*
,
config
:
ModelParallelConfig
,
init_method
:
Callable
,
bias
:
bool
,
skip_bias_add
:
bool
,
is_expert
:
bool
,
tp_comm_buffer_name
:
str
=
None
,
):
super
().
__init__
(
num_gemms
=
num_gemms
,
input_size
=
input_size
,
output_size
=
output_size
,
parallel_mode
=
"row"
,
config
=
config
,
init_method
=
condition_init_method
(
config
,
init_method
),
bias
=
bias
,
skip_bias_add
=
skip_bias_add
,
is_expert
=
is_expert
,
tp_comm_buffer_name
=
tp_comm_buffer_name
,
)
def
sharded_state_dict
(
self
,
prefix
=
''
,
sharded_offsets
=
(),
metadata
=
None
):
"""
For each gemm, sharding along axis 1, bias not sharded.
Assume sharded_offsets[-1] is the expert parallel offset.
"""
tp_axis_map
=
{
f
'
{
gemm_idx
}
.weight'
:
1
for
gemm_idx
in
range
(
self
.
num_gemms
)}
return
super
().
_sharded_state_dict_grouped
(
tp_axis_map
,
prefix
,
sharded_offsets
,
metadata
)
else
:
TEGroupedLinear
=
None
TEColumnParallelGroupedLinear
=
None
TERowParallelGroupedLinear
=
None
class
TEDelayedScaling
(
te
.
common
.
recipe
.
DelayedScaling
):
"""
Wrapper for the Transformer-Engine's `DelayedScaling` layer.
"""
def
__init__
(
self
,
config
:
ModelParallelConfig
,
fp8_format
:
int
,
override_linear_precision
:
tuple
=
(
False
,
False
,
False
),
):
extra_kwargs
=
_get_extra_te_kwargs
(
config
)
if
is_te_min_version
(
"1.6.0.dev0"
):
extra_kwargs
[
"fp8_dpa"
]
=
config
.
fp8_dot_product_attention
extra_kwargs
[
"fp8_mha"
]
=
config
.
fp8_multi_head_attention
if
get_te_version
()
<
PkgVersion
(
"1.8.0"
):
extra_kwargs
[
"interval"
]
=
config
.
fp8_interval
elif
config
.
fp8_interval
!=
1
:
warnings
.
warn
(
"fp8_interval is deprecated and ignored from Transformer-Engine v1.8.0."
)
super
().
__init__
(
margin
=
config
.
fp8_margin
,
fp8_format
=
fp8_format
,
amax_compute_algo
=
config
.
fp8_amax_compute_algo
,
amax_history_len
=
config
.
fp8_amax_history_len
,
override_linear_precision
=
override_linear_precision
,
**
extra_kwargs
,
)
class
TECudaRNGStatesTracker
(
te
.
pytorch
.
distributed
.
CudaRNGStatesTracker
):
"""Wraps TransformerEngine's CudaRNGStatesTracker so that it is
interchangeable with Megatron's RNG tracker"""
def
is_initialized
(
self
):
"""Checks if the internal RNG state has been set wirth set_states()."""
return
self
.
_is_initialized
def
reset
(
self
):
"""Reset the internal RNG state."""
super
().
reset
()
self
.
_is_initialized
=
False
def
set_states
(
self
,
states
):
"""Set the internal RNG state."""
super
().
set_states
(
states
)
self
.
_is_initialized
=
True
def
add
(
self
,
name
,
seed
):
"""Track the rng state."""
super
().
add
(
name
,
seed
)
self
.
_is_initialized
=
True
def
te_checkpoint
(
forward_func
,
distribute_saved_activations
,
get_rng_state_tracker
,
tp_group
,
hidden_states
,
attention_mask
,
context
,
context_mask
,
rotary_pos_emb
,
):
"""Checkpointing with Transformer-Engine."""
from
transformer_engine.pytorch.distributed
import
checkpoint
if
is_te_min_version
(
"1.5.0"
):
return
checkpoint
(
forward_func
,
hidden_states
,
attention_mask
,
context
,
context_mask
,
rotary_pos_emb
,
distribute_saved_activations
=
distribute_saved_activations
,
get_rng_state_tracker
=
get_rng_state_tracker
,
tp_group
=
tp_group
,
)
else
:
return
checkpoint
(
forward_func
,
distribute_saved_activations
,
get_rng_state_tracker
,
tp_group
,
hidden_states
,
attention_mask
,
context
,
context_mask
,
rotary_pos_emb
,
)
try
:
from
transformer_engine.pytorch.attention
import
_SplitAlongDim
SplitAlongDim
=
_SplitAlongDim
.
apply
except
ImportError
:
SplitAlongDim
=
None
try
:
from
transformer_engine.pytorch.cpu_offload
import
(
get_cpu_offload_context
as
_get_cpu_offload_context
,
)
def
get_cpu_offload_context
(
enabled
,
num_layers
,
model_layers
,
activation_offloading
,
weight_offloading
):
"""Get CPU offload context and sync function."""
if
is_te_min_version
(
"1.10.0.dev0"
):
context
,
sync_func
=
_get_cpu_offload_context
(
enabled
,
num_layers
,
model_layers
,
activation_offloading
,
weight_offloading
)
else
:
context
,
sync_func
=
_get_cpu_offload_context
(
enabled
,
num_layers
,
activation_offloading
,
weight_offloading
)
return
context
,
sync_func
except
ImportError
:
get_cpu_offload_context
=
None
megatron/core/fusions/fused_bias_dropout.py
View file @
4b097dee
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from
typing
import
Optional
,
Tuple
import
torch
from
megatron.core.jit
import
jit_fuser
def
_bias_dropout_add_func
(
x
,
bias
,
residual
,
prob
,
training
):
# type: (Tensor, Optional[Tensor], Tensor, float, bool) -> Tensor
def
_bias_dropout_add_func
(
x
_with_
bias
,
residual
,
prob
,
training
):
# type: (
Tuple[
Tensor, Optional[Tensor]
]
, Tensor, float, bool) -> Tensor
# NOTE: Previously, the argument `bias` used to be passed as
# `bias.expand_as(residual)` when the `bias_dropout_func` is called from the
# transformer layer but broadcasting should automatically take care of that.
# Also, looking at broadcasting semantics, `expand_as` and broadcasting
# seem to be identical performance-wise (both just change the view).
x
,
bias
=
x_with_bias
# unpack
# If we want to train mixed precision, then the output of this function
# should be half precision. However, in AMP O1, the input (residual) is
# in fp32, and it will up-cast the result to fp32, causing pipeline parallel
# GPU communication to hang. Therefore, we need to cast residual to the same
# dtype as x.
residual
=
residual
if
residual
.
dtype
==
x
.
dtype
else
residual
.
to
(
x
.
dtype
)
# The Dropout operation, Residual Addition and the tensor returning can be
# done generically outside the if statement, but that stops fusing of Bias
# Addition-Dropout-Residual Addition operation. So doing it together inside
# the conditional branch to improve performance
if
bias
is
not
None
:
x
=
x
+
bias
out
=
torch
.
nn
.
functional
.
dropout
(
x
,
p
=
prob
,
training
=
training
)
out
=
residual
+
out
return
out
else
:
out
=
torch
.
nn
.
functional
.
dropout
(
x
,
p
=
prob
,
training
=
training
)
out
=
residual
+
out
return
out
def
bias_dropout_add_unfused
(
training
):
def
_bias_dropout_add
(
x_with_bias
,
residual
,
prob
):
return
_bias_dropout_add_func
(
x_with_bias
,
residual
,
prob
,
training
)
return
_bias_dropout_add
@
torch
.
jit
.
script
@
jit_fuser
def
bias_dropout_add_fused_train
(
x_with_bias
:
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]],
residual
:
torch
.
Tensor
,
prob
:
float
,
x_with_bias
:
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]],
residual
:
torch
.
Tensor
,
prob
:
float
)
->
torch
.
Tensor
:
x
,
bias
=
x_with_bias
# unpack
return
_bias_dropout_add_func
(
x
,
bias
,
residual
,
prob
,
True
)
return
_bias_dropout_add_func
(
x_with_bias
,
residual
,
prob
,
True
)
@
torch
.
jit
.
script
@
jit_fuser
def
bias_dropout_add_fused_inference
(
x_with_bias
:
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]],
residual
:
torch
.
Tensor
,
prob
:
float
,
x_with_bias
:
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]],
residual
:
torch
.
Tensor
,
prob
:
float
)
->
torch
.
Tensor
:
x
,
bias
=
x_with_bias
# unpack
return
_bias_dropout_add_func
(
x
,
bias
,
residual
,
prob
,
False
)
return
_bias_dropout_add_func
(
x_with_bias
,
residual
,
prob
,
False
)
def
get_bias_dropout_add
(
training
,
fused
):
def
unfused_bias_dropout_add
(
x_with_bias
,
residual
,
prob
):
x
,
bias
=
x_with_bias
# unpack
return
_bias_dropout_add_func
(
x
,
bias
,
residual
,
prob
,
training
)
if
fused
:
# jit scripting for a nn.module (with dropout) is not
# triggering the fusion kernel. For now, we use two
...
...
@@ -57,4 +70,4 @@ def get_bias_dropout_add(training, fused):
else
:
return
bias_dropout_add_fused_inference
else
:
return
unfused_
bias_dropout_add
return
bias_dropout_add
_unfused
(
training
)
megatron/core/fusions/fused_bias_geglu.py
0 → 100644
View file @
4b097dee
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import
torch
from
megatron.core.jit
import
jit_fuser
###### BIAS GELU FUSION/ NO AUTOGRAD ################
# 1/sqrt(2*pi)-> 0.3989423
# 1/sqrt(2) -> 0.70710678
# sqrt(2/pi) -> 0.79788456
# this function is tanh approximation of gelu
# actual gelu is:
# x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
@
jit_fuser
def
geglu
(
y
):
y_1
,
y_2
=
torch
.
chunk
(
y
,
2
,
-
1
)
return
(
y_1
*
0.5
*
(
1.0
+
torch
.
tanh
(
0.79788456
*
y_1
*
(
1
+
0.044715
*
y_1
*
y_1
))))
*
y_2
@
jit_fuser
def
bias_geglu
(
bias
,
y
):
y
=
y
+
bias
return
geglu
(
y
)
# gradient of tanh approximation of gelu
# gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
@
jit_fuser
def
geglu_back
(
g
,
y
):
y_1
,
y_2
=
torch
.
chunk
(
y
,
2
,
-
1
)
tanh_out
=
torch
.
tanh
(
0.79788456
*
y_1
*
(
1
+
0.044715
*
y_1
*
y_1
))
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
ff
=
0.5
*
y_1
*
((
1
-
tanh_out
*
tanh_out
)
*
(
0.79788456
+
0.1070322243
*
y_1
*
y_1
))
+
0.5
*
(
1
+
tanh_out
)
return
torch
.
cat
(((
g
*
y_2
)
*
ff
,
g
*
(
y_1
*
0.5
*
(
1.0
+
tanh_out
))),
-
1
)
@
jit_fuser
def
bias_geglu_back
(
g
,
y
,
bias
):
y
=
y
+
bias
return
geglu_back
(
g
,
y
)
class
BiasGeGLUFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
# bias is an optional argument
def
forward
(
ctx
,
input
,
bias
):
ctx
.
save_for_backward
(
input
,
bias
)
return
bias_geglu
(
input
,
bias
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
input
,
bias
=
ctx
.
saved_tensors
tmp
=
bias_geglu_back
(
grad_output
,
input
,
bias
)
return
tmp
,
tmp
class
GeGLUFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
# bias is an optional argument
def
forward
(
ctx
,
input
):
ctx
.
save_for_backward
(
input
)
return
geglu
(
input
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
input
=
ctx
.
saved_tensors
tmp
=
geglu_back
(
grad_output
,
input
[
0
])
return
tmp
def
bias_geglu_impl
(
input
,
bias
):
ori_shape
=
input
.
shape
assert
len
(
ori_shape
)
in
[
2
,
3
]
input
=
input
.
view
(
-
1
,
ori_shape
[
-
1
])
if
bias
is
not
None
:
output
=
BiasGeGLUFunction
.
apply
(
input
,
bias
)
else
:
output
=
GeGLUFunction
.
apply
(
input
)
return
output
if
len
(
ori_shape
)
==
2
else
output
.
view
(
ori_shape
[
0
],
ori_shape
[
1
],
-
1
)
megatron/core/fusions/fused_bias_gelu.py
View file @
4b097dee
...
...
@@ -2,7 +2,9 @@
import
torch
###### BIAS GELU FUSION/ NO AUTOGRAD ################
from
megatron.core.jit
import
jit_fuser
# BIAS GELU FUSION/ NO AUTOGRAD ################
# 1/sqrt(2*pi)-> 0.3989423
# 1/sqrt(2) -> 0.70710678
# sqrt(2/pi) -> 0.79788456
...
...
@@ -11,7 +13,7 @@ import torch
# x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
@
torch
.
jit
.
script
@
jit_fuser
def
bias_gelu
(
bias
,
y
):
x
=
bias
+
y
return
x
*
0.5
*
(
1.0
+
torch
.
tanh
(
0.79788456
*
x
*
(
1
+
0.044715
*
x
*
x
)))
...
...
@@ -20,7 +22,7 @@ def bias_gelu(bias, y):
# gradient of tanh approximation of gelu
# gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
@
torch
.
jit
.
script
@
jit_fuser
def
bias_gelu_back
(
g
,
bias
,
y
):
x
=
bias
+
y
tanh_out
=
torch
.
tanh
(
0.79788456
*
x
*
(
1
+
0.044715
*
x
*
x
))
...
...
@@ -44,5 +46,10 @@ class GeLUFunction(torch.autograd.Function):
tmp
=
bias_gelu_back
(
grad_output
,
bias
,
input
)
return
tmp
,
tmp
# This is required to make Sphinx happy :-(
@
classmethod
def
apply
(
cls
,
*
args
,
**
kwargs
):
return
super
().
apply
(
*
args
,
**
kwargs
)
bias_gelu_impl
=
GeLUFunction
.
apply
megatron/core/fusions/fused_bias_swiglu.py
0 → 100644
View file @
4b097dee
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import
torch
import
torch.nn.functional
as
F
from
megatron.core.jit
import
jit_fuser
###### BIAS SWIGLU FUSION/ NO AUTOGRAD ################
@
jit_fuser
def
swiglu
(
y
):
y_1
,
y_2
=
torch
.
chunk
(
y
,
2
,
-
1
)
return
F
.
silu
(
y_1
)
*
y_2
@
jit_fuser
def
bias_swiglu
(
y
,
bias
):
y
=
y
+
bias
return
swiglu
(
y
)
# gradient of tanh approximation of gelu
# gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
@
jit_fuser
def
swiglu_back
(
g
,
y
):
y_1
,
y_2
=
torch
.
chunk
(
y
,
2
,
-
1
)
return
torch
.
cat
(
(
g
*
torch
.
sigmoid
(
y_1
)
*
(
1
+
y_1
*
(
1
-
torch
.
sigmoid
(
y_1
)))
*
y_2
,
g
*
F
.
silu
(
y_1
)),
-
1
)
@
jit_fuser
def
bias_swiglu_back
(
g
,
y
,
bias
):
y
=
y
+
bias
return
swiglu_back
(
g
,
y
)
class
BiasSwiGLUFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
# bias is an optional argument
def
forward
(
ctx
,
input
,
bias
,
fp8_input_store
):
input_for_backward
=
input
.
to
(
torch
.
float8_e4m3fn
)
if
fp8_input_store
else
input
ctx
.
save_for_backward
(
input_for_backward
,
bias
)
ctx
.
ori_input_dtype
=
input
.
dtype
ctx
.
fp8_input_store
=
fp8_input_store
return
bias_swiglu
(
input
,
bias
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
input
,
bias
=
ctx
.
saved_tensors
input
=
input
.
to
(
ctx
.
ori_input_dtype
)
if
ctx
.
fp8_input_store
else
input
tmp
=
bias_swiglu_back
(
grad_output
,
input
,
bias
)
return
tmp
,
tmp
,
None
class
SwiGLUFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
# bias is an optional argument
def
forward
(
ctx
,
input
,
fp8_input_store
):
input_for_backward
=
input
.
to
(
torch
.
float8_e4m3fn
)
if
fp8_input_store
else
input
ctx
.
save_for_backward
(
input_for_backward
)
ctx
.
ori_input_dtype
=
input
.
dtype
ctx
.
fp8_input_store
=
fp8_input_store
return
swiglu
(
input
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
input
=
ctx
.
saved_tensors
[
0
]
input
=
input
.
to
(
ctx
.
ori_input_dtype
)
if
ctx
.
fp8_input_store
else
input
tmp
=
swiglu_back
(
grad_output
,
input
)
return
tmp
,
None
def
bias_swiglu_impl
(
input
,
bias
,
fp8_input_store
=
False
):
ori_shape
=
input
.
shape
assert
len
(
ori_shape
)
in
[
2
,
3
]
input
=
input
.
view
(
-
1
,
ori_shape
[
-
1
])
if
bias
is
not
None
:
output
=
BiasSwiGLUFunction
.
apply
(
input
,
bias
,
fp8_input_store
)
else
:
output
=
SwiGLUFunction
.
apply
(
input
,
fp8_input_store
)
return
output
if
len
(
ori_shape
)
==
2
else
output
.
view
(
ori_shape
[
0
],
ori_shape
[
1
],
-
1
)
# bias_swiglu_impl = BiasSwiGLUFunction.apply
# swiglu_impl = SwiGLUFunction.apply
megatron/core/fusions/fused_cross_entropy.py
0 → 100644
View file @
4b097dee
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from
typing
import
Tuple
import
torch
from
megatron.core.jit
import
jit_fuser
from
megatron.core.parallel_state
import
(
get_tensor_model_parallel_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
)
from
megatron.core.tensor_parallel.cross_entropy
import
VocabParallelCrossEntropy
from
megatron.core.tensor_parallel.utils
import
VocabUtility
@
jit_fuser
def
calculate_logits_max
(
vocab_parallel_logits
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
vocab_parallel_logits
,
logits_max
=
VocabParallelCrossEntropy
.
calculate_logits_max
(
vocab_parallel_logits
)
return
vocab_parallel_logits
,
logits_max
@
jit_fuser
def
calculate_predicted_logits
(
vocab_parallel_logits
:
torch
.
Tensor
,
target
:
torch
.
Tensor
,
logits_max
:
torch
.
Tensor
,
vocab_start_index
:
int
,
vocab_end_index
:
int
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
(
target_mask
,
masked_target_1d
,
predicted_logits
,
sum_exp_logits
,
exp_logits
)
=
(
VocabParallelCrossEntropy
.
calculate_predicted_logits
(
vocab_parallel_logits
,
target
,
logits_max
,
vocab_start_index
,
vocab_end_index
)
)
predicted_logits_sum_exp_logits
=
torch
.
cat
((
predicted_logits
,
sum_exp_logits
))
return
target_mask
,
masked_target_1d
,
predicted_logits_sum_exp_logits
,
exp_logits
@
jit_fuser
def
calculate_cross_entropy_loss
(
exp_logits
:
torch
.
Tensor
,
predicted_logits_sum_exp_logits
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
split_val
=
predicted_logits_sum_exp_logits
.
size
()[
0
]
//
2
predicted_logits
,
sum_exp_logits
=
torch
.
split
(
predicted_logits_sum_exp_logits
,
split_val
)
exp_logits
,
loss
=
VocabParallelCrossEntropy
.
calculate_cross_entropy_loss
(
exp_logits
,
predicted_logits
,
sum_exp_logits
)
return
exp_logits
,
loss
@
jit_fuser
def
calculate_gradients
(
softmax
:
torch
.
Tensor
,
grad_output
:
torch
.
Tensor
,
target_mask
:
torch
.
Tensor
,
masked_target_1d
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
(
grad_2d
,
arange_1d
,
softmax_update
,
grad_input
)
=
(
VocabParallelCrossEntropy
.
prepare_gradient_calculation_operands
(
softmax
,
target_mask
)
)
grad_input
=
VocabParallelCrossEntropy
.
calculate_gradients
(
grad_2d
,
arange_1d
,
masked_target_1d
,
softmax_update
,
grad_input
,
grad_output
)
grad_input
=
grad_input
.
to
(
torch
.
bfloat16
)
return
grad_input
class
_VocabParallelCrossEntropy
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
vocab_parallel_logits
,
target
):
vocab_parallel_logits
,
logits_max
=
calculate_logits_max
(
vocab_parallel_logits
)
torch
.
distributed
.
all_reduce
(
logits_max
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
group
=
get_tensor_model_parallel_group
()
)
# Get the partition's vocab indices
get_vocab_range
=
VocabUtility
.
vocab_range_from_per_partition_vocab_size
partition_vocab_size
=
vocab_parallel_logits
.
size
()[
-
1
]
rank
=
get_tensor_model_parallel_rank
()
world_size
=
get_tensor_model_parallel_world_size
()
vocab_start_index
,
vocab_end_index
=
get_vocab_range
(
partition_vocab_size
,
rank
,
world_size
)
(
target_mask
,
masked_target_1d
,
predicted_logits_sum_exp_logits
,
exp_logits
)
=
(
calculate_predicted_logits
(
vocab_parallel_logits
,
target
,
logits_max
,
vocab_start_index
,
vocab_end_index
)
)
# All reduce is needed to get the chunks from other GPUs.
# In the fused case, tensors are batches to invoke a single
# AllReduce call
torch
.
distributed
.
all_reduce
(
predicted_logits_sum_exp_logits
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
get_tensor_model_parallel_group
(),
)
exp_logits
,
loss
=
calculate_cross_entropy_loss
(
exp_logits
,
predicted_logits_sum_exp_logits
)
# Store softmax, target-mask and masked-target for backward pass.
ctx
.
save_for_backward
(
exp_logits
,
target_mask
,
masked_target_1d
)
return
loss
@
staticmethod
def
backward
(
ctx
,
grad_output
):
# Retreive tensors from the forward path.
softmax
,
target_mask
,
masked_target_1d
=
ctx
.
saved_tensors
grad_input
=
calculate_gradients
(
softmax
,
grad_output
,
target_mask
,
masked_target_1d
)
return
grad_input
,
None
def
fused_vocab_parallel_cross_entropy
(
vocab_parallel_logits
,
target
):
"""
Performs cross entropy loss when logits are split across tensor parallel ranks
Args:
vocab_parallel_logits: logits split across tensor parallel ranks
dimension is [sequence_length, batch_size, hidden_size]
target: correct vocab ids of dimseion [sequence_length, micro_batch_size]
"""
return
_VocabParallelCrossEntropy
.
apply
(
vocab_parallel_logits
,
target
)
megatron/core/fusions/fused_layer_norm.py
View file @
4b097dee
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import
importlib
import
inspect
import
numbers
import
torch
from
torch
import
Tensor
from
torch.nn
import
init
from
torch.nn.parameter
import
Parameter
from
megatron.core.transformer
import
TransformerConfig
from
megatron.core.utils
import
make_viewless_tensor
try
:
from
apex.contrib.layer_norm.layer_norm
import
FastLayerNormFN
HAVE_PERSIST_LAYER_NORM
=
True
except
:
except
ImportError
:
HAVE_PERSIST_LAYER_NORM
=
False
try
:
from
apex.normalization.fused_layer_norm
import
FusedLayerNormAffineFunction
HAVE_FUSED_LAYER_NORM
=
True
except
:
except
ImportError
:
HAVE_FUSED_LAYER_NORM
=
False
class
FusedLayerNorm
(
torch
.
nn
.
Module
):
"""Layer Norm, fused into a single CUDA kernel.
Args:
hidden_size (int): Transformer hidden dimension.
eps (float): Epsilon added to denominator, for numerical stability.
persist_layer_norm (bool): Use persistent fused layer norm kernel.
This kernel supports only a set of hidden sizes. Please
check persist_ln_hidden_sizes if your hidden size is supported.
zero_centered_gamma (bool): Adjust LayerNorm weights such that they are
centered around zero. This improves numerical stability.
config (TransformerConfig): Transformer config. Include to match custom
layer norm interfaces.
normalization (str): Normalization type, used for Transformer Engine.
Must equal 'LayerNorm' here.
"""
def
__init__
(
self
,
hidden_size
,
eps
=
1e-5
,
persist_layer_norm
=
True
,
sequence_parallel
=
False
,
zero_centered_gamma
=
False
,
config
:
TransformerConfig
,
hidden_size
:
int
,
eps
:
float
=
1e-5
,
persist_layer_norm
:
bool
=
True
,
zero_centered_gamma
:
bool
=
False
,
normalization
:
str
=
"LayerNorm"
,
# included to match TE interface
):
super
().
__init__
()
self
.
zero_centered_gamma
=
zero_centered_gamma
self
.
config
=
config
self
.
zero_centered_gamma
=
self
.
config
.
layernorm_zero_centered_gamma
assert
(
self
.
config
.
normalization
==
"LayerNorm"
),
f
'(
{
self
.
config
.
normalization
}
) is not supported in FusedLayerNorm'
# List of hiddens sizes supported in the persistent layer norm kernel
# If the hidden size is not supported, fall back to the non-persistent
...
...
@@ -66,22 +96,24 @@ class FusedLayerNorm(torch.nn.Module):
49152
,
65536
,
]
persist_layer_norm
=
self
.
config
.
persist_layer_norm
if
hidden_size
not
in
persist_ln_hidden_sizes
or
not
HAVE_PERSIST_LAYER_NORM
:
persist_layer_norm
=
False
if
not
persist_layer_norm
and
not
HAVE_FUSED_LAYER_NORM
:
# TODO: Add pytorch only layer norm
raise
ValueError
(
f
'Apex must
currently
be installed to use
megatron c
or
e
.'
)
raise
ValueError
(
f
'Apex must be installed to use
FusedLayerN
or
m
.'
)
if
isinstance
(
hidden_size
,
numbers
.
Integral
):
hidden_size
=
(
hidden_size
,)
self
.
hidden_size
=
torch
.
Size
(
hidden_size
)
self
.
eps
=
eps
self
.
weight
=
Parameter
(
torch
.
Tensor
(
*
hidden_size
))
self
.
bias
=
Parameter
(
torch
.
Tensor
(
*
hidden_size
))
# Parameters need to be initialized with torch.empty rather than torch.Tensor for correct device placement with nemo2.
self
.
weight
=
Parameter
(
torch
.
empty
(
*
hidden_size
))
self
.
bias
=
Parameter
(
torch
.
empty
(
*
hidden_size
))
self
.
reset_parameters
()
self
.
persist_layer_norm
=
persist_layer_norm
self
.
sequence_parallel
=
sequence_parallel
self
.
sequence_parallel
=
self
.
config
.
sequence_parallel
# set sequence parallelism flag on weight and bias parameters
setattr
(
self
.
weight
,
'sequence_parallel'
,
self
.
sequence_parallel
)
...
...
@@ -96,11 +128,16 @@ class FusedLayerNorm(torch.nn.Module):
init
.
ones_
(
self
.
weight
)
init
.
zeros_
(
self
.
bias
)
def
forward
(
self
,
input
)
:
def
forward
(
self
,
input
:
Tensor
)
->
Tensor
:
weight
=
self
.
weight
+
1
if
self
.
zero_centered_gamma
else
self
.
weight
if
self
.
persist_layer_norm
:
if
'memory_efficient'
in
inspect
.
getfullargspec
(
FastLayerNormFN
.
forward
).
args
:
output
=
FastLayerNormFN
.
apply
(
input
,
weight
,
self
.
bias
,
self
.
eps
,
self
.
config
.
memory_efficient_layer_norm
)
else
:
output
=
FastLayerNormFN
.
apply
(
input
,
weight
,
self
.
bias
,
self
.
eps
)
# Apex's fast layer norm function outputs a 'view' tensor (i.e., has
...
...
@@ -112,7 +149,20 @@ class FusedLayerNorm(torch.nn.Module):
)
else
:
output
=
FusedLayerNormAffineFunction
.
apply
(
if
(
'memory_efficient'
in
inspect
.
getfullargspec
(
FusedLayerNormAffineFunction
.
forward
).
args
):
return
FusedLayerNormAffineFunction
.
apply
(
input
,
weight
,
self
.
bias
,
self
.
hidden_size
,
self
.
eps
,
self
.
config
.
memory_efficient_layer_norm
,
)
else
:
return
FusedLayerNormAffineFunction
.
apply
(
input
,
weight
,
self
.
bias
,
self
.
hidden_size
,
self
.
eps
)
...
...
megatron/core/fusions/fused_softmax.py
View file @
4b097dee
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
from
typing
import
Optional
import
torch
import
torch.nn
as
nn
from
megatron.core.transformer.enums
import
AttnMaskType
from
megatron.core.transformer.utils
import
get_default_causal_mask
class
ScaledUpperTriangMaskedSoftmax
(
torch
.
autograd
.
Function
):
...
...
@@ -96,7 +98,7 @@ class FusedScaleMaskSoftmax(nn.Module):
"""
fused operation: scaling + mask + softmax
Arg
ument
s:
Args:
input_in_fp16: flag to indicate if input in fp16 data format.
input_in_bf16: flag to indicate if input in bf16 data format.
attn_mask_type: attention mask type (pad or causal)
...
...
@@ -131,7 +133,12 @@ class FusedScaleMaskSoftmax(nn.Module):
assert
self
.
scale
is
None
or
softmax_in_fp32
,
"softmax should be in fp32 when scaled"
def
forward
(
self
,
input
,
mask
):
def
forward
(
self
,
input
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]):
"""Forward pass of softmax with masked input.
In case attn_mask_type is causal the mask is generated and None can be passed.
A user-defined mask is only needed when attn_mask_type is not causal.
"""
# [b, np, sq, sk]
assert
input
.
dim
()
==
4
...
...
@@ -186,6 +193,15 @@ class FusedScaleMaskSoftmax(nn.Module):
if
self
.
scale
is
not
None
:
input
=
input
*
self
.
scale
# Generate causal mask if not given
sq
,
sk
=
input
.
size
(
2
),
input
.
size
(
3
)
if
self
.
attn_mask_type
==
AttnMaskType
.
causal
and
mask
is
None
and
sq
>
1
:
# If sq == 1 then either KV cache is used or one-element context is passed
# so keeping mask=None in this case; subsequent code should handle it
assert
sq
==
sk
,
"causal mask is only for self attention"
mask
=
get_default_causal_mask
(
sq
)
mask_output
=
self
.
mask_func
(
input
,
mask
)
if
mask
is
not
None
else
input
probs
=
torch
.
nn
.
Softmax
(
dim
=-
1
)(
mask_output
)
...
...
megatron/core/inference/__init__.py
0 → 100644
View file @
4b097dee
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
megatron/core/inference/ammo_support/__init__.py
0 → 100644
View file @
4b097dee
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import
warnings
warnings
.
warn
(
"The 'megatron.core.inference.ammo_support' module is deprecated and will be removed in a future release. "
"Please use megatron.core.inference.modelopt_support instead"
,
DeprecationWarning
,
)
megatron/core/inference/ammo_support/gpt/model_specs.py
0 → 100644
View file @
4b097dee
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from
megatron.core.inference.modelopt_support.gpt.model_specs
import
get_gpt_layer_modelopt_spec
Prev
1
…
8
9
10
11
12
13
14
15
16
…
18
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment