Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
dcu_megatron
Commits
4e867b3c
Commit
4e867b3c
authored
Aug 06, 2025
by
jerrrrry
Browse files
Initial commit
parents
Changes
327
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
6740 additions
and
0 deletions
+6740
-0
Megatron-LM/megatron/core/dist_checkpointing/strategies/two_stage.py
.../megatron/core/dist_checkpointing/strategies/two_stage.py
+268
-0
Megatron-LM/megatron/core/dist_checkpointing/strategies/zarr.py
...on-LM/megatron/core/dist_checkpointing/strategies/zarr.py
+321
-0
Megatron-LM/megatron/core/dist_checkpointing/tensor_aware_state_dict.py
...gatron/core/dist_checkpointing/tensor_aware_state_dict.py
+347
-0
Megatron-LM/megatron/core/dist_checkpointing/utils.py
Megatron-LM/megatron/core/dist_checkpointing/utils.py
+319
-0
Megatron-LM/megatron/core/dist_checkpointing/validation.py
Megatron-LM/megatron/core/dist_checkpointing/validation.py
+560
-0
Megatron-LM/megatron/core/distributed/README.md
Megatron-LM/megatron/core/distributed/README.md
+11
-0
Megatron-LM/megatron/core/distributed/__init__.py
Megatron-LM/megatron/core/distributed/__init__.py
+8
-0
Megatron-LM/megatron/core/distributed/custom_fsdp/__init__.py
...tron-LM/megatron/core/distributed/custom_fsdp/__init__.py
+3
-0
Megatron-LM/megatron/core/distributed/custom_fsdp/fully_sharded_data_parallel.py
...re/distributed/custom_fsdp/fully_sharded_data_parallel.py
+749
-0
Megatron-LM/megatron/core/distributed/custom_fsdp/param_and_grad_buffer.py
...ron/core/distributed/custom_fsdp/param_and_grad_buffer.py
+2055
-0
Megatron-LM/megatron/core/distributed/data_parallel_base.py
Megatron-LM/megatron/core/distributed/data_parallel_base.py
+96
-0
Megatron-LM/megatron/core/distributed/distributed_data_parallel.py
...LM/megatron/core/distributed/distributed_data_parallel.py
+520
-0
Megatron-LM/megatron/core/distributed/distributed_data_parallel_config.py
...tron/core/distributed/distributed_data_parallel_config.py
+82
-0
Megatron-LM/megatron/core/distributed/finalize_model_grads.py
...tron-LM/megatron/core/distributed/finalize_model_grads.py
+331
-0
Megatron-LM/megatron/core/distributed/param_and_grad_buffer.py
...ron-LM/megatron/core/distributed/param_and_grad_buffer.py
+884
-0
Megatron-LM/megatron/core/distributed/torch_fully_sharded_data_parallel.py
...ron/core/distributed/torch_fully_sharded_data_parallel.py
+141
-0
Megatron-LM/megatron/core/enums.py
Megatron-LM/megatron/core/enums.py
+20
-0
Megatron-LM/megatron/core/export/__init__.py
Megatron-LM/megatron/core/export/__init__.py
+1
-0
Megatron-LM/megatron/core/export/data_type.py
Megatron-LM/megatron/core/export/data_type.py
+5
-0
Megatron-LM/megatron/core/export/export_config.py
Megatron-LM/megatron/core/export/export_config.py
+19
-0
No files found.
Too many changes to show.
To preserve performance only
327 of 327+
files are displayed.
Plain diff
Email patch
Megatron-LM/megatron/core/dist_checkpointing/strategies/two_stage.py
0 → 100644
View file @
4e867b3c
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" 2-stage checkpoint loading. """
import
time
from
collections
import
defaultdict
from
dataclasses
import
dataclass
from
functools
import
partial
,
wraps
from
itertools
import
chain
from
logging
import
getLogger
from
operator
import
attrgetter
,
itemgetter
from
pathlib
import
Path
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
torch
from
..dict_utils
import
dict_list_map_inplace
,
map_reduce
,
nested_values
from
..mapping
import
ShardedStateDict
,
ShardedTensor
from
.base
import
LoadShardedStrategy
from
.tensorstore
import
_load_from_array
,
open_ts_array
from
.zarr
import
flatten_range
,
load_zarr_based_sharded_metadata
_import_trigger
=
None
timers
=
defaultdict
(
list
)
logger
=
getLogger
(
__name__
)
logger
.
warning
(
'megatron.core.dist_checkpointing.two_stage module is deprecated'
' and will be removed in Megatron-Core v0.12. Please use'
' FullyParallelLoadStrategyWrapper to accomplish a parallelized checkpoint load.'
)
def
timed
(
verbose
=
True
):
"""Timing decorator."""
def
timed_dec
(
fn
):
name
=
fn
.
__name__
@
wraps
(
fn
)
def
wrapped
(
*
args
,
**
kwargs
):
if
verbose
:
logger
.
debug
(
f
'
{
name
}
init'
)
start
=
time
.
time
()
ret
=
fn
(
*
args
,
**
kwargs
)
took
=
time
.
time
()
-
start
if
verbose
:
logger
.
debug
(
f
'
{
name
}
took
{
took
}
s'
)
timers
[
name
].
append
(
took
)
return
ret
return
wrapped
return
timed_dec
@
dataclass
class
_ShardedTensorMetadata
:
global_rank
:
int
sharded_tensor_no_data
:
ShardedTensor
dist_group_rank
:
Tuple
[
int
]
# id of distributed group
dist_group_ranks
:
Tuple
[
int
]
# id of distributed group
data_size
:
Optional
[
int
]
=
None
# bytes
def
sharded_tensor_chunk_id
(
sharded_tensor
:
ShardedTensor
):
"""Id of a sharded tensor."""
return
(
sharded_tensor
.
key
,
sharded_tensor
.
global_offset
)
class
TwoStageDataParallelLoadShardedStrategy
(
LoadShardedStrategy
):
"""Loads one checkpoint replica from storage and broadcasts to other nodes.
This strategy loads checkpoint from storage on minimal set of nodes
and distributes the checkpoint to other nodes with torch.distributed.
Loading is performed with tensorstore.
Steps:
0. (optional) create Gloo distributed groups
1. Exchange ShardedTensors metadata between all nodes
2. Align needed tensors within DP groups
3. For each globally unique tensor:
3.a) on one of the ranks load it from storage to CPU and move to CUDA
3.b) allocate CUDA tensor on other ranks
3.c) broadcast within DP group
3.d) copy tensor content to the model param location
3.e) free tensor buffers from a) and b)
Notes:
1. Loading and broadcasting is done sequentially to avoid both host and device OOMs
2. There is a lot of overlap potential between all three steps done for each tensor:
2.a) loading from storage to numpy
2.b) moving CPU tensors to CUDA
2.c) broadcast
"""
def
__init__
(
self
,
data_parallel_group
,
cpu_transfer
=
True
):
super
().
__init__
()
self
.
cpu_transfer
=
cpu_transfer
self
.
data_parallel_group_orig
=
data_parallel_group
self
.
data_parallel_group
=
None
if
cpu_transfer
else
data_parallel_group
self
.
dp_group_ranks
=
tuple
(
sorted
(
torch
.
distributed
.
get_process_group_ranks
(
data_parallel_group
))
)
self
.
dp_group_rank
=
torch
.
distributed
.
get_rank
(
self
.
data_parallel_group_orig
)
self
.
global_rank
=
torch
.
distributed
.
get_rank
()
def
load
(
self
,
sharded_state_dict
:
ShardedStateDict
,
checkpoint_dir
:
Path
):
"""Main load method."""
self
.
maybe_init_gloo_group
()
all_tensors_sorted
=
self
.
_build_load_plan
(
sharded_state_dict
)
self
.
_exchange_loaded_tensors
(
all_tensors_sorted
,
sharded_state_dict
,
checkpoint_dir
)
# TODO: fix hang in summarize_load_times
# self.summarize_load_times()
return
sharded_state_dict
def
summarize_load_times
(
self
):
"""Summarize load times."""
torch
.
distributed
.
barrier
()
logger
.
info
(
'Checkpoint loading finished. Summary:'
)
# TODO: `timers` keys are not guaranteed to be the same across ranks which causes hangs
for
key
,
times
in
sorted
(
timers
.
items
()):
times_sum
=
sum
(
times
)
max_times
=
torch
.
tensor
([
times_sum
],
device
=
'cuda'
)
avg_times
=
torch
.
tensor
([
times_sum
],
device
=
'cuda'
)
torch
.
distributed
.
all_reduce
(
max_times
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
)
torch
.
distributed
.
all_reduce
(
avg_times
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
)
avg_times
/=
torch
.
distributed
.
get_world_size
()
if
torch
.
distributed
.
get_rank
()
==
0
:
logger
.
info
(
f
'
{
key
}
: max
{
max_times
[
0
]
}
, avg
{
avg_times
[
0
]
}
'
)
@
timed
(
verbose
=
False
)
def
load_tensor_from_storage
(
self
,
checkpoint_dir
,
ten_meta
:
_ShardedTensorMetadata
):
"""Load tensor from storage."""
logger
.
debug
(
f
'_load_from_array(
{
ten_meta
.
sharded_tensor_no_data
.
key
}
) init'
)
ret
=
_load_from_array
(
ten_meta
.
sharded_tensor_no_data
,
checkpoint_dir
,
load_directly_on_device
=
False
,
apply_flattened_range
=
False
,
)
logger
.
debug
(
f
'_load_from_array(
{
ten_meta
.
sharded_tensor_no_data
.
key
}
) DONE'
)
return
ret
@
timed
()
def
maybe_init_gloo_group
(
self
):
"""Create Gloo groups."""
if
not
self
.
cpu_transfer
:
return
all_groups
=
[
None
]
*
torch
.
distributed
.
get_world_size
()
torch
.
distributed
.
all_gather_object
(
all_groups
,
self
.
dp_group_ranks
)
all_groups
=
set
(
tuple
(
sorted
(
gr
))
for
gr
in
all_groups
)
for
group_ranks
in
sorted
(
all_groups
):
# "two_stage" module will be deprecated, so not replace new_group()
# with ...parallel_state.create_group() func setting group_desc here.
gloo_pg
=
torch
.
distributed
.
new_group
(
ranks
=
group_ranks
,
backend
=
'gloo'
)
if
self
.
global_rank
in
group_ranks
:
self
.
data_parallel_group
=
gloo_pg
assert
self
.
dp_group_rank
==
torch
.
distributed
.
get_rank
(
self
.
data_parallel_group
)
def
check_backend_compatibility
(
self
,
loaded_version
):
pass
# TODO
def
check_version_compatibility
(
self
,
loaded_version
):
pass
# TODO
@
timed
()
def
_build_load_plan
(
self
,
sharded_state_dict
:
ShardedStateDict
)
->
List
[
_ShardedTensorMetadata
]:
local_meta
=
[
_ShardedTensorMetadata
(
self
.
global_rank
,
sharded_ten
.
without_data
(),
self
.
dp_group_rank
,
self
.
dp_group_ranks
,
)
for
sharded_ten
in
nested_values
(
sharded_state_dict
)
]
all_meta
=
[
None
]
*
torch
.
distributed
.
get_world_size
(
group
=
self
.
data_parallel_group
)
torch
.
distributed
.
all_gather_object
(
all_meta
,
local_meta
,
group
=
self
.
data_parallel_group
)
all_meta
=
list
(
chain
.
from_iterable
(
all_meta
))
all_tensors_sorted
=
self
.
deduplicate_chunks
(
all_meta
)
return
all_tensors_sorted
@
timed
()
def
deduplicate_chunks
(
self
,
ten_metas
:
List
[
_ShardedTensorMetadata
]):
"""Group tensors by chunk and then pick the tensor with the lowest rank.
NOTE: with proper loading overlap, loading from randomized ranks
(instead of the smallest one) could be beneficial here.
"""
ten_metas
=
map_reduce
(
ten_metas
,
key_fn
=
lambda
meta
:
sharded_tensor_chunk_id
(
meta
.
sharded_tensor_no_data
),
reduce_fn
=
partial
(
min
,
key
=
attrgetter
(
'dist_group_rank'
)),
)
all_metas_sorted
=
list
(
map
(
itemgetter
(
1
),
sorted
(
ten_metas
.
items
())))
return
all_metas_sorted
@
timed
()
def
_exchange_loaded_tensors
(
self
,
ten_metas
:
List
[
_ShardedTensorMetadata
],
sharded_state_dict
,
checkpoint_dir
):
logger
.
debug
(
f
'_exchange_loaded_tensors, num ten_metas:
{
len
(
ten_metas
)
}
'
)
for
ten_meta
in
ten_metas
:
src_rank
=
torch
.
distributed
.
get_global_rank
(
self
.
data_parallel_group
,
ten_meta
.
dist_group_rank
)
if
self
.
dp_group_rank
==
ten_meta
.
dist_group_rank
:
exchange_tensor
=
self
.
load_tensor_from_storage
(
checkpoint_dir
,
ten_meta
)
if
not
self
.
cpu_transfer
:
exchange_tensor
=
exchange_tensor
.
cuda
()
else
:
# TODO: for non-flattened ranges we could reuse the buffer from the start here
exchange_tensor
=
torch
.
empty
(
ten_meta
.
sharded_tensor_no_data
.
local_shape
,
device
=
'cpu'
if
self
.
cpu_transfer
else
'cuda'
,
dtype
=
ten_meta
.
sharded_tensor_no_data
.
dtype
,
)
logger
.
debug
(
f
'exchange
{
ten_meta
.
sharded_tensor_no_data
.
key
}
,
{
exchange_tensor
.
shape
}
\
(
{
exchange_tensor
.
numel
()
}
), broadcast(
{
src_rank
}
->
{
self
.
dp_group_ranks
}
)'
)
torch
.
distributed
.
broadcast
(
exchange_tensor
,
group
=
self
.
data_parallel_group
,
src
=
src_rank
)
self
.
_distribute_data_to_state_dict
(
ten_meta
,
exchange_tensor
,
sharded_state_dict
)
logger
.
debug
(
f
'exchange
{
ten_meta
.
sharded_tensor_no_data
.
key
}
done'
)
# free buffer memory
exchange_tensor
=
None
@
timed
(
verbose
=
False
)
def
_distribute_data_to_state_dict
(
self
,
ten_meta
:
_ShardedTensorMetadata
,
loaded_ten
:
torch
.
Tensor
,
sharded_state_dict
:
ShardedStateDict
,
):
tensor_key
=
sharded_tensor_chunk_id
(
ten_meta
.
sharded_tensor_no_data
)
def
_fill_in_data
(
t
:
Union
[
ShardedTensor
,
torch
.
Tensor
]):
if
not
isinstance
(
t
,
ShardedTensor
)
or
sharded_tensor_chunk_id
(
t
)
!=
tensor_key
:
# already filled-in or key not matching
return
t
sharded_tensor
:
ShardedTensor
=
t
x
=
loaded_ten
if
sharded_tensor
.
flattened_range
is
not
None
:
x
=
flatten_range
(
sharded_tensor
,
x
)
# Reuse existing buffer
sharded_tensor
.
data
.
data
.
copy_
(
x
)
return
sharded_tensor
.
data
dict_list_map_inplace
(
_fill_in_data
,
sharded_state_dict
)
def
load_tensors_metadata
(
self
,
checkpoint_dir
:
Path
):
def
get_ts_shape_dtype
(
path
):
arr
=
open_ts_array
(
path
)
return
arr
.
shape
,
arr
.
dtype
.
numpy_dtype
return
load_zarr_based_sharded_metadata
(
checkpoint_dir
,
get_ts_shape_dtype
)
Megatron-LM/megatron/core/dist_checkpointing/strategies/zarr.py
0 → 100644
View file @
4e867b3c
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Strategies using Zarr as an underlying format. """
import
logging
import
os
from
functools
import
partial
from
logging
import
getLogger
from
pathlib
import
Path
from
typing
import
Callable
,
List
,
Optional
,
Tuple
import
numpy
as
np
import
torch
import
zarr
from
..core
import
CheckpointingException
from
..dict_utils
import
dict_list_map_inplace
,
nested_values
from
..mapping
import
ShardedStateDict
,
ShardedTensor
,
is_main_replica
from
.base
import
(
LoadShardedStrategy
,
SaveShardedStrategy
,
StrategyAction
,
register_default_strategy
,
)
logger
=
logging
.
getLogger
(
__name__
)
numpy_to_torch_dtype_dict
=
{
np
.
dtype
(
'bool'
):
torch
.
bool
,
np
.
dtype
(
'uint8'
):
torch
.
uint8
,
np
.
dtype
(
'int8'
):
torch
.
int8
,
np
.
dtype
(
'int16'
):
torch
.
int16
,
np
.
dtype
(
'int32'
):
torch
.
int32
,
np
.
dtype
(
'int64'
):
torch
.
int64
,
np
.
dtype
(
'float16'
):
torch
.
float16
,
np
.
dtype
(
'float32'
):
torch
.
float32
,
np
.
dtype
(
'float64'
):
torch
.
float64
,
np
.
dtype
(
'complex64'
):
torch
.
complex64
,
np
.
dtype
(
'complex128'
):
torch
.
complex128
,
}
torch_to_numpy_dtype_dict
=
{
v
:
k
for
k
,
v
in
numpy_to_torch_dtype_dict
.
items
()}
try
:
# Register a bfloat16 type with this import
import
tensorstore
# pylint: disable=unused-import
HAS_BFLOAT16
=
True
numpy_to_torch_dtype_dict
[
np
.
dtype
(
'bfloat16'
)]
=
torch
.
bfloat16
torch_to_numpy_dtype_dict
[
torch
.
bfloat16
]
=
np
.
dtype
(
'bfloat16'
)
except
ImportError
:
HAS_BFLOAT16
=
False
logger
=
getLogger
(
__name__
)
def
register_default_zarr_strategies
():
"""Register default strategies related to Zarr backend."""
register_default_strategy
(
StrategyAction
.
SAVE_SHARDED
,
'zarr'
,
1
,
ZarrSaveShardedStrategy
(
'zarr'
,
1
)
)
class
ZarrSaveShardedStrategy
(
SaveShardedStrategy
):
"""Save strategy for Zarr backend."""
def
__init__
(
self
,
backend
:
str
,
version
:
int
):
super
().
__init__
(
backend
,
version
)
logger
.
warning
(
f
'`zarr` distributed checkpoint backend is deprecated.'
' Please switch to PyTorch Distributed format (`torch_dist`).'
)
def
save
(
self
,
sharded_state_dict
:
ShardedStateDict
,
checkpoint_dir
:
Path
):
sharded_tensors
=
list
(
nested_values
(
sharded_state_dict
))
arrays
=
_create_or_open_zarr_arrays
(
sharded_tensors
,
checkpoint_dir
)
for
ten
,
arr
in
zip
(
sharded_tensors
,
arrays
):
_save_to_existing_array
(
ten
,
arr
)
torch
.
distributed
.
barrier
()
def
_create_or_open_zarr_arrays
(
sharded_tensors
:
List
[
ShardedTensor
],
checkpoint_dir
:
Path
)
->
List
[
Optional
[
zarr
.
Array
]]:
"""Returns list of zarr arrays corresponding to given tensors.
For a sharded tensors that:
a) is main replica and represents the first chunk (all offsets 0), creates the Zarr array
b) is main replica but not the first chunk,
opens the arrays created in (a) (possibly by other process)
c) otherwise, sets the corresponding array to None since it won't be used
Args:
sharded_tensors (List[ShardedTensor]): sharded tensors from a given rank
that will be saved to checkpoint
checkpoint_dir (Path): checkpoint in which the arrays will be created
"""
arrays
=
[]
for
ten
in
sharded_tensors
:
arr
=
_create_zarr_array
(
ten
,
checkpoint_dir
)
if
_should_create_array
(
ten
)
else
None
arrays
.
append
(
arr
)
torch
.
distributed
.
barrier
()
# Open arrays created above by other processes
for
arr_idx
,
ten
in
enumerate
(
sharded_tensors
):
if
arrays
[
arr_idx
]
is
not
None
:
# array created by this process
assert
_should_create_array
(
ten
),
ten
continue
if
not
is_main_replica
(
ten
.
replica_id
):
# this array won't be needed for saving and can stay None
continue
open_kwargs
=
{}
if
ten
.
flattened_range
is
not
None
:
open_kwargs
[
'synchronizer'
]
=
zarr
.
ProcessSynchronizer
(
str
(
checkpoint_dir
/
f
'
{
ten
.
key
}
.sync'
)
)
arrays
[
arr_idx
]
=
_open_zarr_array_verbose
(
checkpoint_dir
/
ten
.
key
,
'r+'
,
**
open_kwargs
)
return
arrays
def
_should_create_array
(
ten
:
ShardedTensor
):
return
(
is_main_replica
(
ten
.
replica_id
)
and
set
(
ten
.
global_offset
)
==
{
0
}
and
(
ten
.
flattened_range
is
None
or
ten
.
flattened_range
.
start
==
0
)
)
def
_save_to_existing_array
(
sharded_tensor
:
ShardedTensor
,
arr
:
Optional
[
zarr
.
Array
]):
if
not
is_main_replica
(
sharded_tensor
.
replica_id
):
return
assert
arr
is
not
None
x
=
sharded_tensor
.
data
x
=
x
.
detach
().
cpu
()
torch
.
cuda
.
synchronize
()
if
x
.
dtype
==
torch
.
bfloat16
:
x
=
x
.
float
()
x
=
x
.
numpy
()
x
=
x
.
astype
(
'bfloat16'
)
else
:
x
=
x
.
numpy
()
if
sharded_tensor
.
flattened_range
is
None
:
arr
[
sharded_tensor
.
global_slice
()]
=
x
else
:
arr
.
set_coordinate_selection
(
sharded_tensor
.
global_coordinates
(),
x
)
def
_create_zarr_array
(
sharded_tensor
:
ShardedTensor
,
checkpoint_dir
:
Path
):
np_dtype
=
torch_to_numpy_dtype_dict
[
sharded_tensor
.
dtype
]
try
:
arr
=
zarr
.
create
(
sharded_tensor
.
global_shape
,
dtype
=
np_dtype
,
store
=
checkpoint_dir
/
sharded_tensor
.
key
,
chunks
=
sharded_tensor
.
max_allowed_chunks
(),
compressor
=
None
,
fill_value
=
None
,
write_empty_chunks
=
True
,
)
logger
.
debug
(
f
'Created a new Zarr array at
{
checkpoint_dir
/
sharded_tensor
.
key
}
'
)
except
zarr
.
errors
.
ContainsArrayError
as
e
:
raise
CheckpointingException
(
f
'Array
{
checkpoint_dir
/
sharded_tensor
.
key
}
already exists'
)
from
e
if
HAS_BFLOAT16
and
np_dtype
==
np
.
dtype
(
'bfloat16'
):
arr
.
_dtype
=
np_dtype
zarray
=
arr
.
store
[
'.zarray'
]
arr
.
store
[
'.zarray'
]
=
zarray
.
replace
(
b
'<V2'
,
b
'bfloat16'
)
return
arr
class
ZarrLoadShardedStrategy
(
LoadShardedStrategy
):
"""Load strategy for the Zarr backend."""
def
load
(
self
,
sharded_state_dict
:
ShardedStateDict
,
checkpoint_dir
:
Path
):
dict_list_map_inplace
(
partial
(
_load_from_array
,
checkpoint_dir
=
checkpoint_dir
),
sharded_state_dict
)
return
sharded_state_dict
def
load_tensors_metadata
(
self
,
checkpoint_dir
:
Path
):
def
get_zarr_shape_dtype
(
path
):
arr
=
zarr
.
open
(
path
,
'r'
)
return
arr
.
shape
,
arr
.
dtype
return
load_zarr_based_sharded_metadata
(
checkpoint_dir
,
get_zarr_shape_dtype
)
def
check_backend_compatibility
(
self
,
loaded_version
):
pass
# TODO
def
check_version_compatibility
(
self
,
loaded_version
):
pass
# TODO
def
_load_from_array
(
sharded_tensor
:
ShardedTensor
,
checkpoint_dir
:
Path
):
assert
isinstance
(
sharded_tensor
,
ShardedTensor
),
type
(
sharded_tensor
)
arr
=
_open_zarr_array_verbose
(
checkpoint_dir
/
sharded_tensor
.
key
,
'r'
)
if
not
sharded_tensor
.
allow_shape_mismatch
and
sharded_tensor
.
global_shape
!=
arr
.
shape
:
_msg
=
(
f
'Global shape mismatch for loaded (
{
arr
.
shape
}
)'
f
' and expected (
{
sharded_tensor
.
global_shape
}
) tensor'
f
' for key
{
sharded_tensor
.
key
}
'
)
raise
CheckpointingException
(
_msg
)
x
=
arr
[
sharded_tensor
.
global_slice
()]
# flattened tensors loading is delayed
return
postprocess_numpy_array
(
x
,
sharded_tensor
)
def
_open_zarr_array_verbose
(
path
:
Path
,
mode
:
str
,
**
open_kwargs
):
try
:
return
zarr
.
open
(
str
(
path
),
mode
,
**
open_kwargs
)
except
zarr
.
errors
.
PathNotFoundError
as
e
:
ckpt_dir
=
path
.
parent
err_msg
=
f
'Array
{
path
}
not found'
if
ckpt_dir
.
exists
():
ckpt_files
=
[
f
.
name
for
f
in
ckpt_dir
.
iterdir
()]
logger
.
debug
(
f
'
{
err_msg
}
. Checkpoint directory
{
ckpt_dir
}
content:
{
ckpt_files
}
'
)
else
:
err_msg
+=
f
'. Checkpoint directory
{
ckpt_dir
}
does not exist.'
raise
CheckpointingException
(
err_msg
)
from
e
def
postprocess_numpy_array
(
loaded_array
,
sharded_tensor
,
apply_flattened_range
=
True
):
"""Turn numpy array to torch tensor."""
x
=
loaded_array
if
HAS_BFLOAT16
and
x
.
dtype
==
np
.
dtype
(
'bfloat16'
):
x
=
x
.
astype
(
np
.
dtype
(
'float32'
))
x
=
torch
.
from_numpy
(
x
)
x
=
x
.
bfloat16
()
else
:
x
=
torch
.
from_numpy
(
x
)
# TODO: consider some other consistency checks
if
x
.
shape
!=
sharded_tensor
.
local_shape
:
if
sharded_tensor
.
allow_shape_mismatch
:
x
=
pad_to_expected_shape
(
x
,
sharded_tensor
)
else
:
_msg
=
(
f
'Local shape mismatch for loaded (
{
x
.
shape
}
)'
f
' and expected (
{
sharded_tensor
.
local_shape
}
) tensor'
f
' for key
{
sharded_tensor
.
key
}
'
)
raise
CheckpointingException
(
_msg
)
if
apply_flattened_range
and
sharded_tensor
.
flattened_range
is
not
None
:
x
=
flatten_range
(
sharded_tensor
,
x
)
# TODO: consider cuda() tensors support
return
x
def
flatten_range
(
sharded_tensor
,
x
):
"""Apply flattened range to a tensor."""
return
x
.
flatten
()[
sharded_tensor
.
flattened_range
]
def
pad_to_expected_shape
(
x
:
torch
.
Tensor
,
expected_sharded_ten
:
ShardedTensor
):
"""Pad tensor to the expected shape."""
pad_args
=
[]
assert
len
(
x
.
shape
)
==
len
(
expected_sharded_ten
.
local_shape
)
# Reversed iteration order because F.pad expects so
for
x_sh
,
exp_sh
,
axis_fragm
in
reversed
(
list
(
zip
(
x
.
shape
,
expected_sharded_ten
.
local_shape
,
expected_sharded_ten
.
axis_fragmentations
)
)
):
if
x_sh
==
exp_sh
:
pad_args
.
extend
((
0
,
0
))
elif
x_sh
>
exp_sh
:
assert
False
,
(
f
'Expected shape (
{
exp_sh
}
) smaller than actual (
{
x_sh
}
)'
f
' for
{
repr
(
expected_sharded_ten
)
}
'
)
else
:
pad_args
.
extend
((
0
,
exp_sh
-
x_sh
))
# TODO: behavior control with envvar is for testing purposes only, remove it
if
not
int
(
os
.
environ
.
get
(
'DIST_CKPT_PAD_REPLICATE'
,
0
)):
return
torch
.
nn
.
functional
.
pad
(
x
,
pad_args
)
# unsqueeze and squeeze to get shapes supported by cudnn
print
(
f
'Replicating last row for
{
expected_sharded_ten
.
key
}
'
)
if
x
.
dtype
==
torch
.
bfloat16
:
return
(
torch
.
nn
.
functional
.
pad
(
x
.
float
().
unsqueeze
(
0
),
pad_args
,
mode
=
'replicate'
)
.
squeeze
(
0
)
.
bfloat16
()
)
return
torch
.
nn
.
functional
.
pad
(
x
.
unsqueeze
(
0
),
pad_args
,
mode
=
'replicate'
).
squeeze
(
0
)
def
load_zarr_based_sharded_metadata
(
checkpoint_dir
:
Path
,
get_shape_dtype_fn
:
Callable
[[
str
],
Tuple
[
Tuple
[
int
],
np
.
dtype
]]
)
->
ShardedStateDict
:
"""Load metadata of Zarr arrays.
Args:
checkpoint_dir (str): checkpoint root directory
get_shape_dtype_fn (str -> ((int, ...), np.dtype)): a function returning
an array shape and dtype for a given Zarr array path
"""
sharded_state_dict
=
{}
for
subdir
in
checkpoint_dir
.
iterdir
():
if
not
subdir
.
is_dir
()
or
not
(
subdir
/
'.zarray'
).
exists
():
continue
key
=
subdir
.
name
arr_shape
,
arr_dtype
=
get_shape_dtype_fn
(
str
(
subdir
))
sharded_state_dict
[
key
]
=
ShardedTensor
(
key
,
None
,
numpy_to_torch_dtype_dict
[
arr_dtype
],
arr_shape
,
arr_shape
,
tuple
(
0
for
_
in
arr_shape
),
tuple
(
1
for
_
in
arr_shape
),
)
return
sharded_state_dict
Megatron-LM/megatron/core/dist_checkpointing/tensor_aware_state_dict.py
0 → 100644
View file @
4e867b3c
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Utilities for transforming state_dict, including a tensor-aware implementation."""
import
logging
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
Iterable
,
Iterator
,
List
,
Optional
,
Tuple
import
torch
from
nvidia_resiliency_ext.checkpointing.local.base_state_dict
import
TensorAwareStateDict
from
.dict_utils
import
dict_list_map_inplace
,
dict_list_map_outplace
,
merge
,
nested_values
from
.exchange_utils
import
(
ShardDistribution
,
determine_main_replica_uniform_distribution
,
exchange_by_distribution
,
)
from
.mapping
import
ShardedObject
,
ShardedStateDict
,
ShardedTensor
,
StateDict
,
apply_factory_merges
from
.state_dict_utils
import
load_preprocess
,
save_preprocess
from
.utils
import
(
_sharded_object_id
,
_sharded_tensor_shard_id
,
debug_time
,
extract_sharded_base
,
zip_strict
,
)
from
.validation
import
determine_global_metadata
,
validate_sharding_integrity
logger
=
logging
.
getLogger
(
__name__
)
@
dataclass
class
MCoreTensorAwareStateDict
(
TensorAwareStateDict
):
"""
MCore-specific class defining the interface between the MCore state dict and checkpoint manager.
This class distinguishes between raw objects, the common state dict, and sharded state dicts
(tensor parts). It also handles optional metadata needed for fully parallel save/load.
"""
common
:
StateDict
sharded_state_dict
:
ShardedStateDict
_is_hollow
:
bool
=
False
@
staticmethod
def
_validate_params
(
algo
):
if
algo
!=
'atomic'
and
algo
!=
'fully_parallel'
:
raise
NotImplementedError
(
'Only "atomic" and "fully_parallel" sharding algorithms are supported.'
)
@
staticmethod
def
_get_distribution
(
fully_parallel
,
sharded_part
,
parallelization_group
,
cached_distribution
=
None
):
if
fully_parallel
:
if
cached_distribution
is
None
:
distribution
=
determine_main_replica_uniform_distribution
(
sharded_part
,
parallelization_group
,
True
)
logger
.
debug
(
f
'MCore_TASD._get_distribution calculated distribution'
)
else
:
distribution
=
cached_distribution
logger
.
debug
(
f
'MCore_TASD._get_distribution used cache'
)
else
:
distribution
=
(
None
,
None
,
None
,
None
)
logger
.
debug
(
f
'MCore_TASD._get_distribution returned empty distribution'
)
return
distribution
@
staticmethod
def
_remove_redundant_data
(
fully_parallel
,
sharded_part
,
shard_to_saving_rank
,
parallelization_group
):
if
fully_parallel
:
for
sh_base
in
nested_values
(
sharded_part
):
# TODO remove redundant objects as well
if
isinstance
(
sh_base
,
ShardedTensor
):
shard_id
=
_sharded_tensor_shard_id
(
sh_base
)
if
shard_to_saving_rank
[
shard_id
]
!=
torch
.
distributed
.
get_rank
(
group
=
parallelization_group
):
sh_base
.
data
=
None
@
classmethod
@
debug_time
(
"from_state_dict"
,
logger
)
def
from_state_dict
(
cls
,
sharded_state_dict
:
ShardedStateDict
,
algo
:
str
=
'fully_parallel'
,
parallelization_group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
cached_metadata
:
ShardDistribution
=
None
,
)
->
Tuple
[
TensorAwareStateDict
,
ShardDistribution
]:
"""
Constructs a TensorAwareStateDict from a sharded state dictionary.
This method preprocesses the input `sharded_state_dict`, validates parameters,
and extracts the necessary data to create an instance of `MCoreTensorAwareStateDict`.
Args:
sharded_state_dict: The input sharded state dictionary to be converted.
algo (str, optional): Initialization algorithm. Defaults to 'fully_parallel'.
- 'fully_parallel' enables fully parallel initialization.
parallelization_group (Optional): A distributed process group for parallelization.
cached_metadata (Optional): Precomputed metadata from previous saves.
- Reuses data that doesn't need recalculation, optimizing the creation process.
Returns:
TensorAwareStateDict: An instance initialized with the provided sharded state dictionary
and optional cached metadata.
- The metadata is stored in memory to speed up future saves.
"""
with
debug_time
(
"_get_distribution"
,
logger
):
cls
.
_validate_params
(
algo
)
fully_parallel
=
algo
==
'fully_parallel'
sharded_part
,
common_state_dict
=
save_preprocess
(
sharded_state_dict
,
cached_metadata
is
None
)
cacheable_distribution
=
cls
.
_get_distribution
(
fully_parallel
,
sharded_part
,
parallelization_group
,
cached_metadata
)
if
cacheable_distribution
is
not
None
:
shard_to_saving_rank
,
_
,
_
,
_
=
cacheable_distribution
cls
.
_remove_redundant_data
(
fully_parallel
,
sharded_part
,
shard_to_saving_rank
,
parallelization_group
)
return
(
MCoreTensorAwareStateDict
(
common
=
common_state_dict
,
sharded_state_dict
=
sharded_part
),
cacheable_distribution
,
)
@
property
def
is_hollow
(
self
):
"""
True iff tensors had been extracted and have not been inserted back yet.
"""
return
self
.
_is_hollow
@
property
def
_sharded_tensors
(
self
):
# Three possible states for sharded_tensor:
# 1. sharded_tensor with data (.data = tensor)
# 2. sharded_tensor hollow (.data = None, .orig_device = orig_device)
# 3. removed sharded_tensor (.data = None, no device information)
# TODO: Consider simplifying by removing the entire sharded_tensor instead of just the data
if
self
.
is_hollow
:
for
sh_base
in
nested_values
(
self
.
sharded_state_dict
):
# FIXME: Hacky way to store the original device of the popped tensor
if
isinstance
(
sh_base
,
ShardedTensor
)
and
hasattr
(
sh_base
,
'orig_device'
):
yield
sh_base
else
:
for
sh_base
in
nested_values
(
self
.
sharded_state_dict
):
if
isinstance
(
sh_base
,
ShardedTensor
)
and
sh_base
.
data
is
not
None
:
yield
sh_base
@
property
def
tensors
(
self
)
->
Iterator
[
torch
.
Tensor
]:
"""
Get the tensor data from the state dict.
"""
assert
not
self
.
is_hollow
# TODO raise exception
return
map
(
lambda
sh_ten
:
sh_ten
.
data
,
self
.
_sharded_tensors
)
@
property
def
common_state_dict
(
self
)
->
Dict
:
"""
Get the common state dict from the state dict.
"""
return
self
.
common
def
pop_tensors
(
self
)
->
List
[
torch
.
Tensor
]:
"""
Extracts the tensor data from the wrapped state dict, preserving metadata.
Replaces the tensor data in sharded_tensors with device type of extracted tensors.
After this operation, the state dictionary is "hollow", containing no tensor data.
Further calls to `pop_tensor` will raise an error.
@return List of extracted tensors
"""
assert
not
self
.
is_hollow
# TODO raise exception
result
=
[]
for
sh_ten
in
self
.
_sharded_tensors
:
result
.
append
(
sh_ten
.
data
)
# FIXME: Hacky way to store the original device, which is not included in the metadata
setattr
(
sh_ten
,
'orig_device'
,
sh_ten
.
data
.
device
.
type
)
sh_ten
.
data
=
None
self
.
_is_hollow
=
True
return
result
def
insert_tensors
(
self
,
tensor_data
:
Iterable
[
torch
.
Tensor
]):
"""
Reverse of `pop_tensors`. Replaces device type in sharded_tensors with actual values
Value of `self` is considered to be the same after:
```
self.insert_tensors(self.pop_tensors())
```
"""
assert
self
.
is_hollow
# TODO raise exception
for
sh_ten
,
ten
in
zip_strict
(
self
.
_sharded_tensors
,
tensor_data
):
# FIXME: Hacky way to store the original device
if
sh_ten
.
orig_device
==
ten
.
device
.
type
:
delattr
(
sh_ten
,
'orig_device'
)
# Tensor might be on non-original device
sh_ten
.
data
=
ten
self
.
_is_hollow
=
False
def
init_tensors
(
self
):
"""
Initializes empty tensors with the same properties as the original tensors.
This function should only be called after the original tensors have been popped.
It ensures that the newly created empty tensors match the shape,
dtype, and device of the originals, but contain no data.
"""
assert
self
.
is_hollow
# TODO raise exception
for
sh_ten
in
self
.
_sharded_tensors
:
# Hacky way to retrieve the original device
sh_ten
.
init_data
(
sh_ten
.
orig_device
)
delattr
(
sh_ten
,
'orig_device'
)
self
.
_is_hollow
=
False
def
copy_tensors_to_cpu
(
self
,
non_blocking
=
False
):
"""
Stores CPU copies of tensors in the state_dict, replacing the originals,
but without destroying them.
The original devices are remembered for restoration with restore_tensor_device().
Using non_blocking=True allows for asynchronous copying.
"""
assert
not
self
.
is_hollow
# TODO raise exception
for
sh_ten
in
self
.
_sharded_tensors
:
if
sh_ten
.
data
.
device
.
type
==
'cpu'
:
# Skip cloning if it's already confirmed to be a copy
if
not
hasattr
(
sh_ten
,
'orig_device'
):
sh_ten
.
data
=
sh_ten
.
data
.
clone
()
else
:
# FIXME: Hacky way to store the original device
if
not
hasattr
(
sh_ten
,
'orig_device'
):
setattr
(
sh_ten
,
'orig_device'
,
sh_ten
.
data
.
device
.
type
)
sh_ten
.
data
=
sh_ten
.
data
.
detach
().
to
(
"cpu"
,
non_blocking
=
non_blocking
)
def
restore_tensor_device
(
self
,
non_blocking
=
True
):
"""
Restores all tensors to their original devices, if a move is required.
Using non_blocking=True allows for asynchronous copying.
"""
assert
not
self
.
is_hollow
# TODO raise exception
for
sh_ten
in
self
.
_sharded_tensors
:
# FIXME: Hacky way to store the original device
if
hasattr
(
sh_ten
,
'orig_device'
):
sh_ten
.
data
=
sh_ten
.
data
.
to
(
sh_ten
.
orig_device
,
non_blocking
=
non_blocking
)
delattr
(
sh_ten
,
'orig_device'
)
def
_insert_sharded_data
(
self
,
fully_parallel
,
sharded_part
,
parallelization_group
,
exchange_algo
):
loaded_tensors
=
{}
for
sh_ten
in
self
.
_sharded_tensors
:
loaded_tensors
[
_sharded_tensor_shard_id
(
sh_ten
)]
=
sh_ten
.
data
if
fully_parallel
:
with
debug_time
(
"_get_distribution"
,
logger
):
distribution
=
self
.
_get_distribution
(
fully_parallel
,
sharded_part
,
parallelization_group
)
if
distribution
is
not
None
:
unloaded_shards
=
{}
for
sh_base
in
nested_values
(
sharded_part
):
# TODO retrieve redundant ShardedObjects once removed in _remove_redundant_data
if
isinstance
(
sh_base
,
ShardedTensor
):
shard_id
=
_sharded_tensor_shard_id
(
sh_base
)
if
shard_id
not
in
loaded_tensors
:
unloaded_shards
[
shard_id
]
=
sh_base
with
debug_time
(
"exchange_by_distribution"
,
logger
):
loaded_tensors
=
exchange_by_distribution
(
loaded_tensors
,
unloaded_shards
,
distribution
,
parallelization_group
,
exchange_algo
,
)
torch
.
cuda
.
synchronize
()
loaded_objects
=
{}
for
sh_base
in
nested_values
(
self
.
sharded_state_dict
):
if
not
isinstance
(
sh_base
,
ShardedTensor
):
assert
isinstance
(
sh_base
,
ShardedObject
)
loaded_objects
[
_sharded_object_id
(
sh_base
)]
=
sh_base
.
data
def
load_sharded_base
(
x
:
Any
):
if
isinstance
(
x
,
ShardedTensor
):
shard_id
=
_sharded_tensor_shard_id
(
x
)
assert
shard_id
in
loaded_tensors
,
(
x
,
shard_id
,
loaded_tensors
.
keys
())
x
=
loaded_tensors
[
shard_id
]
if
isinstance
(
x
,
ShardedObject
):
object_id
=
_sharded_object_id
(
x
)
assert
object_id
in
loaded_objects
,
(
x
,
object_id
,
loaded_objects
.
keys
())
x
=
loaded_objects
[
object_id
]
return
x
dict_list_map_inplace
(
load_sharded_base
,
sharded_part
)
@
debug_time
(
"to_state_dict"
,
logger
)
def
to_state_dict
(
self
,
sharded_state_dict
:
ShardedStateDict
,
algo
:
str
=
'atomic'
,
exchange_algo
:
str
=
'broadcast'
,
validate_access_integrity
:
bool
=
True
,
parallelization_group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
):
"""
Convert tensor-aware dict back to the original state_dict
"""
with
debug_time
(
"load_preprocess_and_state_dict_manipulations"
,
logger
):
assert
not
self
.
is_hollow
# TODO raise exception
self
.
_validate_params
(
algo
)
fully_parallel
=
algo
==
'fully_parallel'
# __adding__ common part
recreated_state_dict
=
dict_list_map_outplace
(
lambda
x
:
x
,
self
.
common
)
if
not
sharded_state_dict
:
return
recreated_state_dict
# TODO validate self.sharded_state_dict"] and sharded_state_dict are compatible
sharded_state_dict
,
nonpersistent_state_dict
,
sh_ten_factories
=
load_preprocess
(
sharded_state_dict
)
# __adding__ nonpersistent part
merge
(
recreated_state_dict
,
nonpersistent_state_dict
)
sharded_part
,
_
=
extract_sharded_base
(
sharded_state_dict
)
if
validate_access_integrity
:
with
debug_time
(
"validate_sharding_integrity"
,
logger
):
validate_sharding_integrity
(
determine_global_metadata
(
sharded_part
)[
1
])
# load sharded tensors and sharded objects to sharded_part
with
debug_time
(
"_insert_sharded_data"
,
logger
):
self
.
_insert_sharded_data
(
fully_parallel
,
sharded_part
,
parallelization_group
,
exchange_algo
)
with
debug_time
(
"apply_factory_merges"
,
logger
):
sharded_part
=
apply_factory_merges
(
sharded_part
,
sh_ten_factories
)
# __adding__ sharded_part
merge
(
recreated_state_dict
,
sharded_part
)
return
recreated_state_dict
Megatron-LM/megatron/core/dist_checkpointing/utils.py
0 → 100644
View file @
4e867b3c
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Helpers for manipulating sharded tensors and sharded state dicts. """
import
logging
from
contextlib
import
contextmanager
from
time
import
time
from
typing
import
Dict
,
Optional
,
Tuple
from
.dict_utils
import
dict_list_map_inplace
,
extract_matching_values
from
.mapping
import
(
LocalNonpersistentObject
,
ShardedBase
,
ShardedObject
,
ShardedStateDict
,
ShardedTensor
,
ShardedTensorFactory
,
StateDict
,
)
# _ShardId uniquely identifies a ShardedTensor. This is a subset of ShardedTensor
# attributes: key (str), global_offset (tuple) and flattened_range (optional tuple)
_ShardId
=
Tuple
[
str
,
tuple
,
Optional
[
tuple
]]
def
zip_strict
(
*
args
):
"""
Alternative to Python's builtin zip(..., strict=True) (available in 3.10+).
Apart from providing functionality in earlier versions of Python is also more verbose.
(Python's zip does not print lengths, only which iterable has finished earlier)
"""
args
=
[
list
(
a
)
for
a
in
args
]
lens
=
[
len
(
a
)
for
a
in
args
]
assert
len
(
set
(
lens
))
<=
1
,
f
"Tried to zip iterables of unequal lengths:
{
lens
}
!"
return
zip
(
*
args
)
def
_sharded_tensor_shard_id
(
sharded_tensor
:
ShardedTensor
)
->
_ShardId
:
"""Unique id of the sharded tensor data.
Should yield the same value for same data replicated on different ranks.
Args:
sharded_tensor (ShardedTensor): sharded tensor representing the data shard
Returns (tuple): unique id of a data shard
"""
f_range
=
sharded_tensor
.
flattened_range
return
(
sharded_tensor
.
key
,
sharded_tensor
.
global_offset
,
None
if
f_range
is
None
else
(
f_range
.
start
,
f_range
.
stop
),
)
def
_sharded_object_id
(
sharded_object
:
ShardedObject
)
->
_ShardId
:
"""Unique id of the sharded object data.
Should yield the same value for same data replicated on different ranks.
Args:
sharded_object (ShardedObject): sharded object representing the data shard
Returns (tuple): unique id of a data shard
"""
return
(
sharded_object
.
key
,
sharded_object
.
global_offset
,
sharded_object
.
global_shape
)
def
extract_sharded_tensors
(
sharded_state_dict
:
ShardedStateDict
,
)
->
Tuple
[
ShardedStateDict
,
StateDict
]:
"""Extract a dict consisting of only ShardedTensor objects
from a given state dict with any objects.
Args:
sharded_state_dict: state dict possibly containing ShardedTensor objects
Returns:
Tuple[ShardedStateDict, StateDict]: tuple of:
- state dict with all ShardedTensor (keeping the original state dict structure)
- state dict with all objects other than ShardedTensor
(keeping the original state dict structure)
"""
return
extract_matching_values
(
sharded_state_dict
,
lambda
v
:
isinstance
(
v
,
ShardedTensor
))
def
extract_sharded_tensors_and_factories
(
sharded_state_dict
:
ShardedStateDict
,
)
->
Tuple
[
ShardedStateDict
,
StateDict
]:
"""Extract a dict consisting of only ShardedTensor and ShardedTensorFactory objects
from a given state dict with any objects.
Args:
sharded_state_dict:
state dict possibly containing ShardedTensor and ShardedTensorFactory objects
Returns:
Tuple[ShardedStateDict, StateDict]: tuple of:
- state dict with all ShardedTensor and ShardedTensorFactory
(keeping the original state dict structure)
- state dict with all other objects (keeping the original state dict structure)
"""
return
extract_matching_values
(
sharded_state_dict
,
lambda
v
:
isinstance
(
v
,
(
ShardedTensor
,
ShardedTensorFactory
))
)
def
extract_sharded_tensors_or_nonpersistent
(
sharded_state_dict
:
ShardedStateDict
,
)
->
Tuple
[
ShardedStateDict
,
StateDict
]:
"""Extract a dict consisting of only ShardedTensor, ShardedTensorFactory
and LocalNonpersistentObject objects from a given state dict with any objects.
Args:
sharded_state_dict: state dict possibly containing ShardedTensor, ShardedTensorFactory
and LocalNonpersistentObject objects
Returns:
Tuple[ShardedStateDict, StateDict]: tuple of:
- state dict with all ShardedTensor, ShardedTensorFactory and LocalNonpersistentObject
(keeping the original state dict structure)
- state dict with all other objects (keeping the original state dict structure)
"""
return
extract_matching_values
(
sharded_state_dict
,
lambda
v
:
isinstance
(
v
,
(
ShardedTensor
,
LocalNonpersistentObject
,
ShardedTensorFactory
)),
)
def
extract_sharded_base
(
sharded_state_dict
:
ShardedStateDict
,
)
->
Tuple
[
ShardedStateDict
,
StateDict
]:
"""Extract a dict consisting of only ShardedBase from a given state dict with any objects.
Args:
sharded_state_dict: state dict possibly containing ShardedBase objects
Returns:
Tuple[ShardedStateDict, StateDict]: tuple of:
- state dict with all ShardedBase objects (keeping the original state dict structure)
- state dict with all other objects (keeping the original state dict structure)
"""
return
extract_matching_values
(
sharded_state_dict
,
lambda
v
:
isinstance
(
v
,
ShardedBase
))
def
extract_nonpersistent
(
sharded_state_dict
:
ShardedStateDict
,
)
->
Tuple
[
ShardedStateDict
,
StateDict
]:
"""Extract a dict consisting of only LocalNonpersistentObjects from a given state dict.
Args:
sharded_state_dict: state dict possibly containing LocalNonpersistentObjects
Returns:
Tuple[ShardedStateDict, StateDict]: tuple of:
- state dict with all LocalNonpersistentObjects
(keeping the original state dict structure)
- state dict with all other objects (keeping the original state dict structure)
"""
return
extract_matching_values
(
sharded_state_dict
,
lambda
v
:
isinstance
(
v
,
LocalNonpersistentObject
)
)
def
add_prefix_for_sharding
(
sharded_state_dict
:
ShardedStateDict
,
prefix
:
str
):
"""Prepend a given prefix to all ShardedBase objects in a given state dict *in-place*.
Args:
sharded_state_dict (ShardedStateDict): sharded state dict
prefix (str): prefix to be prepended
Returns:
None: state dict is modified in-place
"""
def
add_prefix
(
t
):
if
isinstance
(
t
,
ShardedBase
):
t
.
key
=
f
'
{
prefix
}{
t
.
key
}
'
return
t
dict_list_map_inplace
(
add_prefix
,
sharded_state_dict
)
def
replace_prefix_for_sharding
(
sharded_state_dict
:
ShardedStateDict
,
old_prefix
:
str
,
new_prefix
:
str
):
"""Replaces the given prefix in *all* sharded keys in a given state dict.
Errors out if some key does not begin with a given prefix.
Args:
sharded_state_dict (ShardedStateDict): sharded state dict to replace keys in
old_prefix (str): prefix to be replaced in each key
new_prefix (str): new prefix
Returns:
None: state dict is modified in place
"""
def
_replace_prefix
(
x
):
if
isinstance
(
x
,
(
ShardedTensor
,
ShardedTensorFactory
,
ShardedObject
)):
if
not
x
.
key
.
startswith
(
old_prefix
):
raise
ValueError
(
f
'Expected
{
x
.
key
}
to begin with prefix
{
old_prefix
}
'
)
x
.
key
=
f
'
{
new_prefix
}{
x
.
key
[
len
(
old_prefix
):]
}
'
# str.removeprefix in Python >= 3.9
return
x
dict_list_map_inplace
(
_replace_prefix
,
sharded_state_dict
)
def
apply_prefix_mapping
(
sharded_state_dict
:
ShardedStateDict
,
prefix_map
:
Dict
[
str
,
str
]):
"""Replaces prefixes *only in keys matching* with one of prefixes in the map.
Args:
sharded_state_dict (ShardedStateDict): sharded state dict to replace keys in
prefix_map (Dict[str, str]):
map of old->new prefixes. The first matching prefix for each key is used
Returns:
None: state dict is modified in place
"""
def
_replace_prefixes
(
x
):
if
not
isinstance
(
x
,
(
ShardedTensor
,
ShardedTensorFactory
,
ShardedObject
)):
return
x
for
old_prefix
,
new_prefix
in
prefix_map
.
items
():
if
x
.
key
.
startswith
(
old_prefix
):
x
.
key
=
(
f
'
{
new_prefix
}{
x
.
key
[
len
(
old_prefix
):]
}
'
# str.removeprefix in Python >= 3.9
)
break
return
x
dict_list_map_inplace
(
_replace_prefixes
,
sharded_state_dict
)
fallback_logger
=
logging
.
getLogger
(
__name__
)
__LOGGER_NAME_STACK
=
[]
__LOGGER_STACK
=
[]
@
contextmanager
def
logger_stack
(
name
:
Optional
[
str
]
=
None
,
current_logger
:
Optional
[
logging
.
Logger
]
=
None
):
"""Context manager for managing logger and name stack.
Temporarily pushes a logger and/or name onto their respective stacks, allowing hierarchical
logging and contextual logger usage. Ensures the logger stack is restored afterward.
Args:
name (str, optional): Name to add to the logger stack. Defaults to None.
current_logger (logging.Logger, optional): Logger to use. Defaults to the last logger in
the stack or a fallback if none exist.
Yields:
Tuple[str, logging.Logger]: A tuple with the concatenated logger name stack and
the current logger for the block.
Example:
with logger_stack("scope", logger):
logger.info("Log within 'scope'")
"""
if
name
:
__LOGGER_NAME_STACK
.
append
(
name
)
if
current_logger
:
__LOGGER_STACK
.
append
(
current_logger
)
last_logger
=
current_logger
elif
__LOGGER_STACK
:
last_logger
=
__LOGGER_STACK
[
-
1
]
else
:
last_logger
=
fallback_logger
try
:
yield
"."
.
join
(
__LOGGER_NAME_STACK
),
last_logger
finally
:
if
name
and
__LOGGER_NAME_STACK
:
__LOGGER_NAME_STACK
.
pop
(
-
1
)
if
current_logger
and
__LOGGER_STACK
:
__LOGGER_STACK
.
pop
(
-
1
)
@
contextmanager
def
debug_time
(
name
:
str
,
logger
:
Optional
[
logging
.
Logger
]
=
None
,
threshold
:
float
=
float
(
"-inf"
),
level
=
None
):
"""Simple context manager for timing functions/code blocks.
Args:
name (str): Label describing the code being measured.
logger (logging.Logger, optional): Logger for output. Defaults to the lowest logger.
threshold (float, optional): Minimum time (seconds) to log. Skips logging if faster.
level (int, optional): Logging level. Defaults to DEBUG if `threshold` is unset;
WARNING otherwise.
"""
with
logger_stack
(
name
,
logger
)
as
(
stacked_name
,
last_logger
):
start
=
time
()
try
:
yield
finally
:
result
=
time
()
-
start
if
result
<
threshold
:
return
if
level
is
None
:
level
=
logging
.
DEBUG
if
threshold
==
float
(
"-inf"
)
else
logging
.
WARNING
last_logger
.
log
(
level
,
f
"
{
stacked_name
}
took
{
result
:.
4
f
}
s"
)
def
debug_msg
(
msg
:
str
):
"""Logs a debug message using the current logger stack.
This function formats and logs a debug message with the current logger
and name stack, preserving context from the logger_stack context manager.
Args:
msg (str): The message to be logged at the debug level.
Example:
debug_msg("Checkpoint initialized")
# Logs: "scope_name Checkpoint initialized" if called within logger_stack("scope_name")
"""
with
logger_stack
(
None
,
None
)
as
(
stacked_name
,
last_logger
):
last_logger
.
debug
(
f
"
{
stacked_name
}
{
msg
}
"
)
Megatron-LM/megatron/core/dist_checkpointing/validation.py
0 → 100644
View file @
4e867b3c
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import
logging
from
collections
import
Counter
,
defaultdict
from
enum
import
Enum
from
pathlib
import
Path
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Set
,
Tuple
,
Union
import
numpy
as
np
import
torch
from
megatron.core.dist_checkpointing
import
ShardedTensor
from
megatron.core.dist_checkpointing.core
import
CheckpointingException
,
maybe_load_config
from
megatron.core.dist_checkpointing.dict_utils
import
(
diff
,
extract_matching_values
,
map_reduce
,
nested_values
,
)
from
megatron.core.dist_checkpointing.mapping
import
(
CommonStateDict
,
ShardedBase
,
ShardedObject
,
ShardedStateDict
,
is_main_replica
,
)
from
megatron.core.dist_checkpointing.strategies.base
import
(
LoadCommonStrategy
,
LoadShardedStrategy
,
SaveCommonStrategy
,
SaveShardedStrategy
,
StrategyAction
,
get_default_strategy
,
)
if
TYPE_CHECKING
:
from
megatron.core.dist_checkpointing.serialization
import
CkptShardedMetadata
logger
=
logging
.
getLogger
(
__name__
)
# pylint: disable=line-too-long
# list of local saved/loaded ShardedBase objects
_LocalMetadata
=
List
[
Union
[
ShardedTensor
,
ShardedObject
]]
# list of lists of global saved/loaded ShardedBase objects (each element corresponds to global rank)
_GlobalMetadata
=
List
[
_LocalMetadata
]
class
StrictHandling
(
Enum
):
"""Determines handling of load mismatch (non-empty "unexpected" or "missing" keys).
Different flags carry different implications on performance and behaviour and
are divided into two groups:
- *_UNEXPECTED
- *_ALL
The first group ignores missing keys (present in the checkpoint but missing
in the sharded state dict) which is created in order to avoid inter-rank
metadata exchange. Note that the metadata exchange will happen anyway
with `load(..., validate_access_integrity=True)` flag in which case using the
`*_ALL` option is recommended as it provides a more thorough check with no
performance penalty wrt. `*_UNEXPECTED` group.
All options except for the first one (`ASSUME_OK_UNEXPECTED`) require
extra disk access before the load in order to remove unexpected keys
from the sharded state dict requested to load.
"""
# Relies on the underlying strategy to raise error on unexpected keys
ASSUME_OK_UNEXPECTED
=
'assume_ok_unexpected'
# Logs (with WARNING level) "unexpected" keys. Missing keys are ignored.
# This is treated as a reasonable default for a "non-strict" load
LOG_UNEXPECTED
=
'log_unexpected'
# Logs (with WARNING level) all mismatched keys.
LOG_ALL
=
'log_all'
# Raise error on unexpected keys before load attempt.
# Gives cleaner error message than `ASSUME_OK_UNEXPECTED` but requires
# extra disk access.
RAISE_UNEXPECTED
=
'raise_unexpected'
# Raise error on any mismatch. Similar to `RAISE_UNEXPECTED` but requires
# metadata exchange.
RAISE_ALL
=
'raise_all'
# "Unexpected" mismatches are not reported, but returned by the `load`
# function along with the loaded state dict. Missing keys are ignored.
RETURN_UNEXPECTED
=
'return_unexpected'
# All mismatches are returned along with the loaded state dict.
RETURN_ALL
=
'return_all'
# Simply ignores mismatches (not recommended)
IGNORE_ALL
=
'ignore_all'
@
staticmethod
def
requires_explicit_ckpt_mismatch_check
(
val
:
'StrictHandling'
)
->
bool
:
"""Whether a given strict flag involves mismatch check against the checkpoint."""
return
val
!=
StrictHandling
.
ASSUME_OK_UNEXPECTED
@
staticmethod
def
requires_global_app_metadata
(
val
:
'StrictHandling'
)
->
bool
:
"""Whether a given strict option requires global metadata for validation."""
return
val
in
(
StrictHandling
.
IGNORE_ALL
,
StrictHandling
.
RAISE_ALL
,
StrictHandling
.
RETURN_ALL
,
StrictHandling
.
LOG_ALL
,
)
@
staticmethod
def
requires_returning_mismatch_keys
(
val
:
'StrictHandling'
)
->
bool
:
"""Whether a given strict option results in extra return value from the `load` function."""
return
val
in
(
StrictHandling
.
RETURN_UNEXPECTED
,
StrictHandling
.
RETURN_ALL
)
def
parse_strict_flag
(
strict
:
Union
[
str
,
StrictHandling
])
->
StrictHandling
:
"""Parse user passed strict flag from a string to StrictHandling instance.
Args:
strict (str, StrictHandling): strict flag to parse. If already an instance
of StrictHandling, this function is a noop.
Returns:
StrictHandling: enum instance
"""
if
isinstance
(
strict
,
StrictHandling
):
return
strict
try
:
return
StrictHandling
(
strict
)
except
(
ValueError
,
TypeError
)
as
e
:
raise
ValueError
(
f
'Invalid strict flag:
{
e
}
'
)
from
e
def
validate_integrity_and_strict_load
(
sharded_state_dict
:
ShardedStateDict
,
strict
:
StrictHandling
,
validate_access_integrity
:
bool
,
local_metadata
:
Optional
[
_LocalMetadata
]
=
None
,
global_metadata
:
Optional
[
_GlobalMetadata
]
=
None
,
ckpt_sharded_metadata
:
Optional
[
'CkptShardedMetadata'
]
=
None
,
)
->
Tuple
[
ShardedStateDict
,
Set
[
str
],
Set
[
str
]]:
"""Validates sharding integrity and potential mismatches with the checkpoint.
`validate_access_integrity` controls sharding integrity check (orthogonal
to strictness checking) which verifies `sharded_state_dict` runtime completeness
(in isolation from the actual checkpoint).
`strict` flag controls handling of mismatches between the requested
sharded state dict to load and the actual checkpoint. See `StrictHandling`
docs for details regarding flag behavior and performance implications
(disk interactions or inter-rank communication).
Args:
sharded_state_dict (ShardedStateDict): sharded state dict to verify.
strict (StrictHandling): flag determining how to handle sharded keys mismatch.
validate_access_integrity (bool): whether to perform sharding validation.
local_metadata (_LocalMetadata, optional): local sharded state dict metadata.
Defaults to None, in which case it's determined based on `sharded_state_dict`.
global_metadata (_GlobalMetadata, optional): global sharded state dict metadata
(exchanged between ranks). Defaults to None, in which case "missing"
keys are not determined.
ckpt_sharded_metadata (CkptShardedMetadata, optional): sharded metadata
from the checkpoint. Defaults to None, which only makes sense
for the `StrictHandling.ASSUME_OK_UNEXPECTED` strict value.
Returns:
Tuple[ShardedStateDict, Set[str], Set[str]]: tuple of: sharded state dict
without unexpected keys, missing and unexpected keys. Missing keys are equal
on all ranks, unexpected keys might differ across ranks. Additionally,
missing keys might be erroneously empty (depending on `strict` value).
"""
missing_keys
,
unexpected_keys
=
[],
[]
if
StrictHandling
.
requires_explicit_ckpt_mismatch_check
(
strict
):
if
ckpt_sharded_metadata
is
None
:
raise
CheckpointingException
(
'Cannot verify checkpoint mismatch with ckpt_sharded_metadata=None.'
)
if
local_metadata
is
None
:
local_metadata
=
[
sh_base
.
without_data
()
for
sh_base
in
nested_values
(
sharded_state_dict
)
]
# We don't want to check for missing keys even if we could
_skip_missing_keys
=
strict
in
(
StrictHandling
.
ASSUME_OK_UNEXPECTED
,
StrictHandling
.
LOG_UNEXPECTED
,
StrictHandling
.
RAISE_UNEXPECTED
,
StrictHandling
.
RETURN_UNEXPECTED
,
)
missing_keys
,
unexpected_keys
=
_determine_missing_and_unexpected_keys
(
ckpt_sharded_metadata
,
local_metadata
,
None
if
_skip_missing_keys
else
global_metadata
)
sharded_state_dict
=
adjust_non_strict_load
(
sharded_state_dict
,
unexpected_keys
)
if
strict
==
StrictHandling
.
IGNORE_ALL
:
missing_keys
,
unexpected_keys
=
[],
[]
elif
strict
in
(
StrictHandling
.
RAISE_UNEXPECTED
,
StrictHandling
.
RAISE_ALL
):
maybe_report_missing_and_unexpected_keys
(
missing_keys
,
unexpected_keys
,
True
)
elif
strict
in
(
StrictHandling
.
LOG_UNEXPECTED
,
StrictHandling
.
LOG_ALL
):
maybe_report_missing_and_unexpected_keys
(
missing_keys
,
unexpected_keys
,
False
)
if
validate_access_integrity
:
if
global_metadata
is
None
:
raise
CheckpointingException
(
'Cannot check sharding intergrity without global_metadata (None).'
)
validate_sharding_integrity
(
global_metadata
)
return
sharded_state_dict
,
missing_keys
,
unexpected_keys
def
verify_checkpoint_and_load_strategy
(
checkpoint_dir
:
str
,
sharded_strategy
:
Union
[
LoadShardedStrategy
,
Tuple
[
str
,
int
],
None
]
=
None
,
common_strategy
:
Union
[
LoadCommonStrategy
,
Tuple
[
str
,
int
],
None
]
=
None
,
)
->
Tuple
[
LoadShardedStrategy
,
LoadCommonStrategy
]:
"""Verifies if checkpoint metadata exists and matches given strategies.
If no strategies are passed, they are determined based on the checkpoint metadata.
Args:
checkpoint_dir (str): checkpoint directory
sharded_strategy (LoadShardedStrategy, Tuple[str, int], optional): sharded load strategy to be verified
if compatible with the checkpoint content. If None, the default sharded load strategy
for the checkpoint backend will be returned.
common_strategy (LoadCommonStrategy, Tuple[str, int], optional): common load strategy to be verified
if compatible with the checkpoint content. If None, the default common load strategy
for the checkpoint backend will be returned.
"""
if
not
Path
(
checkpoint_dir
).
exists
():
raise
CheckpointingException
(
f
'Checkpoint directory
{
checkpoint_dir
}
does not exist'
)
saved_config
=
maybe_load_config
(
checkpoint_dir
)
if
saved_config
is
None
:
raise
CheckpointingException
(
f
'
{
checkpoint_dir
}
is not a distributed checkpoint'
)
if
sharded_strategy
is
None
:
sharded_strategy
=
get_default_strategy
(
StrategyAction
.
LOAD_SHARDED
,
saved_config
.
sharded_backend
,
saved_config
.
sharded_backend_version
,
)
elif
isinstance
(
sharded_strategy
,
tuple
):
sharded_strategy
=
get_default_strategy
(
StrategyAction
.
LOAD_SHARDED
,
*
sharded_strategy
)
if
common_strategy
is
None
:
common_strategy
=
get_default_strategy
(
StrategyAction
.
LOAD_COMMON
,
saved_config
.
common_backend
,
saved_config
.
common_backend_version
,
)
elif
isinstance
(
common_strategy
,
tuple
):
sharded_strategy
=
get_default_strategy
(
StrategyAction
.
LOAD_COMMON
,
*
common_strategy
)
sharded_strategy
.
check_backend_compatibility
(
saved_config
.
sharded_backend
)
sharded_strategy
.
check_version_compatibility
(
saved_config
.
sharded_backend_version
)
common_strategy
.
check_backend_compatibility
(
saved_config
.
common_backend
)
common_strategy
.
check_version_compatibility
(
saved_config
.
common_backend_version
)
return
sharded_strategy
,
common_strategy
def
adjust_non_strict_load
(
sharded_state_dict
:
ShardedStateDict
,
sharded_keys_to_remove
:
Set
[
str
]
)
->
ShardedStateDict
:
"""Adjusts sharded state dict removing keys not existing in the checkpoint.
Args:
sharded_state_dict (ShardedStateDict): sharded state dict to modify
sharded_keys_to_remove (Set[str]): keys to remove from the state dict
Returns:
ShardedStateDict: state dict without ShardedBase objects with specified keys
"""
def
is_unexpected_key
(
x
:
ShardedBase
):
assert
isinstance
(
x
,
ShardedBase
),
f
'Unexpected type
{
type
(
x
)
}
'
return
x
.
key
in
sharded_keys_to_remove
_
,
sharded_state_dict
=
extract_matching_values
(
sharded_state_dict
,
is_unexpected_key
)
return
sharded_state_dict
def
_determine_missing_and_unexpected_keys
(
ckpt_sharded_metadata
:
'CkptShardedMetadata'
,
local_metadata
:
_LocalMetadata
,
global_metadata
:
Optional
[
_GlobalMetadata
]
=
None
,
)
->
Tuple
[
Set
[
str
],
Set
[
str
]]:
"""Determines load mismatches based on metadata.
There is an asymmetry between "unexpected" and "missing" keys.
Unexpected keys can be determined based only on local metadata.
Missing keys must be based on global metadata, since other ranks might access
different keys than the current rank.
In consequence, the return value of this function is different on each rank:
"missing_keys" are equal, but "unexpected_keys" might differ across ranks.
Args:
ckpt_sharded_metadata (CkptShardedMetadata): sharded state dict (without data)
constructed based on the checkpoint content
local_metadata (_LocalMetadata): list of local ShardedBase objects
requested to be loaded by this rank
global_metadata (_GlobalMetadata, optional): list of global ShardedBase objects
requested to be loaded by all ranks. Defaults to None, in which case
returned "missing" keys are empty.
Returns:
Tuple[Set[str], Set[str]]: missing and unexpected keys. Missing keys are equal
on all ranks, unexpected keys might differ across ranks. If passed
`global_metadata` is empty, returned missing keys are empty as well.
"""
local_accessed_keys
=
set
(
sh_base
.
key
for
sh_base
in
local_metadata
)
ckpt_keys
=
set
(
sh_base
.
key
for
sh_base
in
ckpt_sharded_metadata
.
values
())
unexpected_keys
=
local_accessed_keys
-
ckpt_keys
if
global_metadata
is
not
None
:
global_accessed_keys
=
set
(
sh_base
.
key
for
rank_metadata
in
global_metadata
for
sh_base
in
rank_metadata
)
missing_keys
=
ckpt_keys
-
global_accessed_keys
else
:
missing_keys
=
set
()
if
missing_keys
:
logger
.
debug
(
f
'Dist ckpt load missing keys:
{
missing_keys
}
'
)
if
unexpected_keys
:
logger
.
debug
(
f
'Dist ckpt load unexpected keys:
{
unexpected_keys
}
'
)
return
missing_keys
,
unexpected_keys
def
maybe_report_missing_and_unexpected_keys
(
missing_keys
:
Set
[
str
],
unexpected_keys
:
Set
[
str
],
raise_error
:
bool
=
True
)
->
None
:
"""Raises or logs an error in case missing or unexpected keys are non-empty.
Args:
missing_keys (Set[str]): missing keys in the state dict
unexpected_keys (Set[str]): unexpected keys in the state dict
raise_error: If True, raises error on mismatch. Otherwise, logs mismatch
with WARNING level.
Returns:
None
Raises:
CheckpointingException: if `raise_error` is True and at least one of
`missing_keys` or `unexpected_keys` are non-empty.
"""
if
not
missing_keys
and
not
unexpected_keys
:
return
missing_title_msg
=
(
f
'Some keys found in the checkpoint are missing in the provided sharded state dict. '
)
missing_body_msg
=
f
'Missing keys (for all ranks):
{
missing_keys
}
. '
unexpected_title_msg
=
f
'Unexpected keys (not found in the checkpoint) encountered in the provided sharded state dict. '
unexpected_body_msg
=
f
'Unexpected keys (for this rank):
{
unexpected_keys
}
. '
error_msg
=
''
if
missing_keys
:
error_msg
+=
missing_title_msg
if
unexpected_keys
:
error_msg
+=
unexpected_title_msg
error_msg
+=
'
\n
'
if
missing_keys
:
error_msg
+=
missing_body_msg
if
unexpected_keys
:
error_msg
+=
unexpected_body_msg
if
raise_error
:
raise
CheckpointingException
(
error_msg
)
else
:
logger
.
warning
(
error_msg
)
def
_validate_common_state_dict
(
common_state_dict
:
CommonStateDict
)
->
None
:
"""Validate consistancy across ranks for the common state dict
We save the common state dict only on rank 0. We validate to make sure that the common dict is consistant across ranks before saving.
Args:
common_state_dict: The common state dict present in all ransk
"""
# Gather the common state dict across ranks onto rank 0 for comparison
rank
=
torch
.
distributed
.
get_rank
()
other_rank_state_dicts
=
[
None
]
*
torch
.
distributed
.
get_world_size
()
if
rank
==
0
else
None
torch
.
distributed
.
gather_object
(
common_state_dict
,
other_rank_state_dicts
)
common_state_dict_diff
=
{}
if
rank
==
0
:
main_rank_state_dict
=
common_state_dict
for
rank
,
rank_state_dict
in
enumerate
(
other_rank_state_dicts
[
1
:],
1
):
only_left
,
only_right
,
mismatch
=
diff
(
main_rank_state_dict
,
rank_state_dict
)
if
only_left
or
only_right
or
mismatch
:
common_state_dict_diff
[
rank
]
=
(
only_left
,
only_right
,
mismatch
)
if
len
(
common_state_dict_diff
)
!=
0
:
logger
.
warning
(
f
'There is difference in the common state dict in different ranks. The differences are
{
common_state_dict_diff
}
'
)
def
validate_sharding_integrity
(
global_metadata
:
_GlobalMetadata
,
common_state_dict
:
CommonStateDict
=
None
)
->
None
:
"""Validate if the ShardedTensors and ShardedObjects from multiple processes define correct sharding.
Local ShardedTensors and ShardedObject metadata is exchanged with `torch.distributed.all_gather_object`
and then process with global rank 0 checks if main replicas of the shards:
- cover the whole global tensors
- don't overlap
Args:
global_metadata (_GlobalMetadata): ShardedTensor and ShardedObject objects from all ranks.
common_state_dict (CommonStateDict): The common state dict stored by rank 0
Returns:
None
Raises:
CheckpointingException for invalid access pattern
"""
if
common_state_dict
is
not
None
:
_validate_common_state_dict
(
common_state_dict
)
if
torch
.
distributed
.
get_rank
()
!=
0
:
return
key_shardings
=
defaultdict
(
list
)
for
rank
,
rank_shardings
in
enumerate
(
global_metadata
):
for
sharding
in
rank_shardings
:
key_shardings
[
sharding
.
key
].
append
((
rank
,
sharding
))
for
key
,
shardings
in
key_shardings
.
items
():
if
isinstance
(
shardings
[
0
][
1
],
ShardedObject
):
_validate_objects_for_key
(
shardings
)
else
:
_validate_sharding_for_key
(
shardings
)
def
_validate_sharding_for_key
(
rank_sharding
:
List
[
Tuple
[
int
,
ShardedTensor
]]):
some_rank_shard
=
rank_sharding
[
0
][
1
]
global_shape
=
some_rank_shard
.
global_shape
local_shape
=
some_rank_shard
.
local_shape
dtype
=
some_rank_shard
.
dtype
has_flattened_range
=
some_rank_shard
.
flattened_range
is
not
None
for
rank
,
sharding
in
rank_sharding
:
assert
sharding
.
dtype
==
dtype
,
(
sharding
.
dtype
,
dtype
,
some_rank_shard
)
assert
sharding
.
global_shape
==
global_shape
,
(
sharding
.
global_shape
,
global_shape
,
some_rank_shard
,
)
assert
sharding
.
local_shape
==
local_shape
,
(
sharding
.
local_shape
,
local_shape
,
some_rank_shard
,
)
assert
(
sharding
.
flattened_range
is
not
None
)
==
has_flattened_range
,
(
(
sharding
.
flattened_range
is
not
None
),
has_flattened_range
,
some_rank_shard
,
)
shard_access_cnt
=
_compute_shards_access
(
rank_sharding
)
if
has_flattened_range
:
map_reduce
(
rank_sharding
,
lambda
x
:
x
[
1
].
global_offset
,
lambda
x
:
x
[
1
],
_validate_sharding_for_key_flattened
,
)
# For each shard with at least 1 flattened tensor in it, the above
# `_validate_sharding_for_key_flattened` ensure a correct consistent pattern
# The only thing that can go wrong at this point is that some shard don't have
# *any* representatives which will be checked later by comparing `shard_access_cnt == 1`
shard_access_cnt
=
torch
.
minimum
(
shard_access_cnt
,
torch
.
tensor
([
1
]))
if
not
torch
.
all
(
shard_access_cnt
==
1
):
raise
CheckpointingException
(
f
'Invalid access pattern for
{
rank_sharding
[
0
][
1
]
}
:
{
shard_access_cnt
}
'
)
def
_compute_shards_access
(
rank_sharding
):
shard_access_cnt
=
torch
.
zeros
(
rank_sharding
[
0
][
1
].
axis_fragmentations
,
dtype
=
torch
.
int
,
device
=
'cpu'
)
for
rank
,
sharding
in
rank_sharding
:
if
is_main_replica
(
sharding
.
replica_id
):
shard_access_cnt
[
sharding
.
local_chunk_offset_in_global
()]
+=
1
return
shard_access_cnt
def
_validate_sharding_for_key_flattened
(
tensors_by_shard
):
all_slices
=
[]
local_shape
=
tensors_by_shard
[
0
].
local_shape
for
sharding
in
tensors_by_shard
:
assert
sharding
.
local_shape
==
local_shape
sharding
:
ShardedTensor
if
not
is_main_replica
(
sharding
.
replica_id
):
continue
all_slices
.
append
((
sharding
.
flattened_range
.
start
,
sharding
.
flattened_range
.
stop
))
starts
,
stops
=
map
(
np
.
asarray
,
zip
(
*
sorted
(
all_slices
)))
expected_size
=
np
.
product
(
local_shape
)
if
starts
[
0
]
!=
0
or
stops
[
-
1
]
!=
expected_size
or
not
np
.
all
(
starts
[
1
:]
==
stops
[:
-
1
]):
raise
CheckpointingException
(
f
'Flattened ranges dont cover the whole shard
{
tensors_by_shard
[
0
]
}
of size
{
expected_size
}
. Ranges:
{
(
starts
,
stops
)
}
'
)
def
_validate_objects_for_key
(
sharded_objects
:
List
[
ShardedObject
]):
"""Ensure uniqueness of saved objects."""
unique_keys
=
[
sh_obj
.
unique_key
for
_
,
sh_obj
in
sharded_objects
if
is_main_replica
(
sh_obj
.
replica_id
)
]
if
len
(
unique_keys
)
!=
len
(
set
(
unique_keys
)):
duplicates
=
{
k
:
cnt
for
k
,
cnt
in
Counter
(
unique_keys
).
items
()
if
cnt
>
1
}
logger
.
error
(
f
'Duplicate ShardedObject keys and counts:
{
duplicates
}
'
)
raise
CheckpointingException
(
f
'Duplicate ShardedObject keys:
{
list
(
duplicates
.
keys
())
}
'
)
expected_shard_num
=
np
.
prod
(
sharded_objects
[
0
][
1
].
global_shape
)
if
len
(
unique_keys
)
!=
expected_shard_num
:
err_msg
=
f
'Invalid access pattern:
{
expected_shard_num
-
len
(
unique_keys
)
}
ShardedObject are missing.'
logger
.
error
(
f
'
{
err_msg
}
Existing shards:
{
unique_keys
}
'
)
raise
CheckpointingException
(
err_msg
)
def
determine_global_metadata
(
sharded_state_dict
:
ShardedStateDict
,
)
->
Tuple
[
_LocalMetadata
,
_GlobalMetadata
]:
"""Exchanges local metadata with `all_gather_object` to determine global metadata.
Args:
sharded_state_dict (ShardedStateDict): local sharded state dict
Returns:
Tuple[_LocalMetadata, _GlobalMetadata]: local and global ShardedBase objects with stripped data
"""
local_metadata
=
[
ten
.
without_data
()
for
ten
in
nested_values
(
sharded_state_dict
)]
global_metadata
=
[
None
]
*
torch
.
distributed
.
get_world_size
()
torch
.
distributed
.
all_gather_object
(
global_metadata
,
local_metadata
)
return
local_metadata
,
global_metadata
def
validate_sharded_objects_handling
(
sharded_strategy
:
Union
[
SaveShardedStrategy
,
LoadShardedStrategy
],
common_strategy
:
Union
[
SaveCommonStrategy
,
LoadCommonStrategy
],
)
->
None
:
"""Checks if either of the passed strategies can handle sharded objects.
Args:
sharded_strategy (Union[SaveShardedStrategy, LoadShardedStrategy]): sharded strategy used for saving/loading
common_strategy (Union[SaveCommonStrategy, LoadCommonStrategy]): common strategy used for saving/loading
Returns:
None
Raises:
CheckpointingException: if both strategies can't handle ShardedObjects
"""
if
(
not
sharded_strategy
.
can_handle_sharded_objects
and
not
common_strategy
.
can_handle_sharded_objects
):
raise
CheckpointingException
(
f
'Either sharded strategy or common strategy must implement ShardedObjects handling.'
f
' Both
{
sharded_strategy
}
and
{
common_strategy
}
specify can_handle_sharded_objects=False'
)
Megatron-LM/megatron/core/distributed/README.md
0 → 100644
View file @
4e867b3c
## How to use pytorch FSDP2?
Add these flag to enable Torch FSDP2.
```
--use-torch-fsdp2
--no-gradient-accumulation-fusion
--ckpt-format torch_dist
```
It is worth noting that CUDA_MAX_CONNECTIONS=1 should not be enabled to ensure that the communication of FSDP and the computation on the primary stream can be fully parallelized.
Megatron-LM/megatron/core/distributed/__init__.py
0 → 100644
View file @
4e867b3c
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from
packaging.version
import
Version
from
.distributed_data_parallel
import
DistributedDataParallel
from
.distributed_data_parallel_config
import
DistributedDataParallelConfig
from
.finalize_model_grads
import
finalize_model_grads
from
.torch_fully_sharded_data_parallel
import
TorchFullyShardedDataParallel
Megatron-LM/megatron/core/distributed/custom_fsdp/__init__.py
0 → 100644
View file @
4e867b3c
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from
.fully_sharded_data_parallel
import
FullyShardedDataParallel
Megatron-LM/megatron/core/distributed/custom_fsdp/fully_sharded_data_parallel.py
0 → 100644
View file @
4e867b3c
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import
functools
import
logging
from
contextlib
import
contextmanager
from
enum
import
Enum
,
auto
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
import
torch
import
torch.nn
as
nn
from
torch.utils._pytree
import
tree_flatten
,
tree_unflatten
from
megatron.core
import
parallel_state
from
megatron.core.config_logger
import
has_config_logger_enabled
,
log_config_to_disk
from
megatron.core.distributed.custom_fsdp.param_and_grad_buffer
import
(
AllGatherPipeline
,
BucketingPolicy
,
GradReducePipeline
,
ParamAndGradBuffer
,
PrefetchOrder
,
)
from
megatron.core.distributed.data_parallel_base
import
_BaseDataParallel
from
megatron.core.distributed.distributed_data_parallel_config
import
DistributedDataParallelConfig
from
megatron.core.fp8_utils
import
is_float8tensor
from
megatron.core.transformer.transformer_config
import
TransformerConfig
from
megatron.core.transformer.transformer_layer
import
TransformerLayer
from
megatron.core.utils
import
is_submodule
,
log_single_rank
logger
=
logging
.
getLogger
(
__name__
)
class
TrainingState
(
Enum
):
"""States of a FSDP parameter group, which are coupled with
the sharding activity of parameters and gradients during training."""
# From pre-forward before post-forward, where parameters should be unsharded
FORWARD
=
auto
()
# Prior to backward computation, where parameters should be unsharded
PRE_BACKWARD
=
auto
()
# After backward computation, where gradients should be re-sharded
POST_BACKWARD
=
auto
()
# Before and after module forward computaton or before pre-backward and
# after post-backward states, where no un/sharding activity happens
IDLE
=
auto
()
class
FullyShardedDataParallel
(
_BaseDataParallel
):
"""Fully Sharded Data Parallel training for MCore models.
A distributed training wrapper that shards model parameters, gradients and optimizer
states across data parallel workers. Integrates seamlessly with MCore's tensor
and expert parallelism features.
We supports following modes:
- no_shard: Traditional data parallel training without parameter sharding.
- optim: Shards optimizer states, this is conceptually close to "ZeRO-1", and
main weights for mixed precision training, meanwhile the following `optim_grads`
and `optim_grads_params` will also sharding main weights
during mixed-precision training, omitted without detailed notation.
- optim_grads: Shards gradients and optimizer states, this is conceptually close to "ZeRO-2".
- optim_grads_params: Shards parameters, gradients and optimizer states, this
is conceptually close to "ZeRO-3".
Key Features:
- Compatible with MCore's tensor, context and expert parallelism
- Automatic mixed precision training (BF16/FP8)
- Gradient accumulation and bucketing
- Optimized activation recompute with shard-aware communication: When recomputing
a whole Transformer layer, gather parameters once for both the recomputation
and backward computation
- Compatible with MCore's distributed checkpointing
Args:
config: Transformer config object.
ddp_config: FullyShardedDataParallel config object.
module: Underlying model.
fsdp_unit_modules: List of modules that should be treated as FSDP Unit,
i.e., the minimum releasable model unit. If not provided, defaults to
[TransformerLayer, LanguageModelEmbedding] for GPT-like models. In
addition to this, it affects the granularity of the communication
parameter grouping and triggers aggregate collective communication
in fp8 mixed precision training.
disable_bucketing: If true, force assign all parameters to a single bucket. If false,
use standard bucketing policy: assign parameters to smaller buckets and all-reduce
per bucket.
Examples:
>>> model = GPTModel(config)
>>> model = FullyShardedDataParallel(
... config,
... model,
... ddp_config,
... fsdp_unit_modules = [TransformerLayer, LanguageModelEmbedding],
... )
"""
# TODO: add hybrid FSDP (shard model states in a partial DP domain)
def
__init__
(
self
,
config
:
TransformerConfig
,
ddp_config
:
DistributedDataParallelConfig
,
module
:
torch
.
nn
.
Module
,
fsdp_unit_modules
:
Optional
[
List
[
torch
.
nn
.
Module
]]
=
None
,
disable_bucketing
:
bool
=
False
,
device
:
Optional
[
torch
.
device
]
=
None
,
):
super
().
__init__
(
config
=
config
,
module
=
module
)
if
has_config_logger_enabled
(
config
):
log_config_to_disk
(
config
,
locals
(),
prefix
=
type
(
self
).
__name__
)
self
.
module
=
module
self
.
ddp_config
=
ddp_config
log_single_rank
(
logger
,
logging
.
INFO
,
f
'Setting up DistributedDataParallel with config
{
self
.
ddp_config
}
'
,
)
self
.
bucket_size
=
self
.
ddp_config
.
bucket_size
if
disable_bucketing
:
self
.
bucket_size
=
None
self
.
device
=
device
if
device
else
torch
.
cuda
.
current_device
()
self
.
param_to_bucket_group
=
{}
if
fsdp_unit_modules
is
not
None
:
self
.
fsdp_unit_modules
=
fsdp_unit_modules
else
:
if
self
.
ddp_config
.
data_parallel_sharding_strategy
==
"optim_grads_params"
:
self
.
fsdp_unit_modules
=
[
TransformerLayer
]
else
:
self
.
fsdp_unit_modules
=
[]
self
.
main_weights
=
True
self
.
data_parallel_group
=
parallel_state
.
get_data_parallel_group
(
with_context_parallel
=
True
)
self
.
expert_data_parallel_group
=
parallel_state
.
get_expert_data_parallel_group
()
# Determine if we should delay the gradient reduction.
self
.
is_delay_grad_reduce
=
self
.
ddp_config
.
data_parallel_sharding_strategy
in
[
"no_shard"
,
"optim"
,
]
if
self
.
ddp_config
.
data_parallel_sharding_strategy
==
"optim_grads_params"
:
assert
self
.
ddp_config
.
overlap_param_gather
if
not
self
.
is_delay_grad_reduce
:
assert
self
.
ddp_config
.
overlap_grad_reduce
self
.
_init_fsdp_param_and_grad_buffer
()
self
.
_register_fsdp_hooks
(
self
.
module
)
# Delete references to weight_tensor if they exist since we don't want two parameter copies
# if we re-mapped parameters (which happens when we use the distributed optimizer).
# This is a temporary workaround around a TE bug that is fixed with
# https://github.com/NVIDIA/TransformerEngine/pull/719.
@
torch
.
no_grad
()
def
unmap_weight_tensor
(
m
):
if
hasattr
(
m
,
'weight_tensor'
):
m
.
weight_tensor
=
None
self
.
module
.
apply
(
unmap_weight_tensor
)
def
_init_fsdp_param_and_grad_buffer
(
self
):
if
self
.
config
.
calculate_per_token_loss
:
# We don't need to scale the gradients in this case.
gradient_scaling_factor
=
None
expert_gradient_scaling_factor
=
None
else
:
if
self
.
ddp_config
.
average_in_collective
:
# FIXME(@jianbinc): Will fix this issue based on Parallel Folding's EDP patch MR.
raise
Exception
(
"Not supported"
)
else
:
data_parallel_world_size
=
parallel_state
.
get_data_parallel_world_size
(
with_context_parallel
=
True
)
gradient_scaling_factor
=
1.0
/
data_parallel_world_size
expert_gradient_scaling_factor
=
1.0
/
data_parallel_world_size
# Initialize the param and grad buffer.
self
.
data_parallel_sharding_strategy
=
self
.
ddp_config
.
data_parallel_sharding_strategy
self
.
param_to_name
=
{
p
:
name
for
name
,
p
in
self
.
module
.
named_parameters
()}
self
.
param_and_grad_buffer
=
ParamAndGradBuffer
(
self
.
ddp_config
,
self
.
module
,
bucketing_policy
=
BucketingPolicy
(
suggested_bucket_size
=
self
.
bucket_size
,
fsdp_unit_modules
=
self
.
fsdp_unit_modules
,
data_parallel_sharding_strategy
=
self
.
data_parallel_sharding_strategy
,
),
data_parallel_group
=
self
.
data_parallel_group
,
expert_data_parallel_group
=
self
.
expert_data_parallel_group
,
preserve_fp32_weights
=
self
.
ddp_config
.
preserve_fp32_weights
,
grad_reduce_in_fp32
=
self
.
ddp_config
.
grad_reduce_in_fp32
,
gradient_scaling_factor
=
gradient_scaling_factor
,
expert_gradient_scaling_factor
=
expert_gradient_scaling_factor
,
device
=
self
.
device
,
reset_parameters_for_meta_device_init_module
=
self
.
config
.
init_model_with_meta_device
,
)
self
.
param_and_grad_buffer
self
.
side_stream_for_buffer_copy_and_grad_accum
=
torch
.
cuda
.
Stream
()
# Initialize the reduce-scatter pipeline.
self
.
grad_reduce_pipeline
=
GradReducePipeline
(
self
.
param_and_grad_buffer
,
cuda_stream
=
self
.
side_stream_for_buffer_copy_and_grad_accum
)
# Initialize the all-gather pipeline.
self
.
all_gather_pipeline
=
AllGatherPipeline
(
self
.
param_and_grad_buffer
)
suggested_communication_unit_size
=
self
.
ddp_config
.
suggested_communication_unit_size
if
suggested_communication_unit_size
is
None
:
if
self
.
data_parallel_sharding_strategy
==
"optim_grads_params"
:
total_param_elements
=
0
total_fsdp_module
=
0
for
module
in
self
.
module
.
modules
():
if
isinstance
(
module
,
tuple
(
self
.
fsdp_unit_modules
)):
total_fsdp_module
+=
1
total_param_elements
+=
sum
(
p
.
numel
()
for
p
in
module
.
parameters
())
# The suggested size is twice the number of elements in the FSDP modules.
# This ensures we process the current FSDP module and attempt to prefetch
# the next FSDP module, making the flow of communication better.
suggested_communication_unit_size
=
total_param_elements
//
total_fsdp_module
*
2
elif
self
.
bucket_size
is
not
None
:
suggested_communication_unit_size
=
self
.
bucket_size
*
2
self
.
suggested_RS_queue_capacity
=
suggested_communication_unit_size
self
.
suggested_AG_prefetch_size
=
suggested_communication_unit_size
def
_register_fsdp_hooks
(
self
,
root_module
):
"""Register necessary hooks for Fully Sharded Data Parallel (FSDP) execution on the model.
This function sets up various hooks required for FSDP operations, including parameter
resharding/unsharding and gradient handling. The registered hooks are:
- Pre-forward hook: Unshards parameters before forward pass
- Post-forward hook: Reshards parameters after forward pass
- Pre-backward hook: Unshards parameters before backward pass
- Post-backward hook: Reshards parameters and reduces gradients after backward pass
Args:
root_module: The PyTorch module to register FSDP hooks on
Note:
These hooks are essential for FSDP's memory efficiency as they manage:
1. Dynamic parameter sharding/unsharding to reduce memory footprint
2. Proper gradient synchronization across distributed processes
3. Gradient accumulation for large batch training
Returns:
None
"""
# Initialize module training state.
for
m
in
root_module
.
modules
():
setattr
(
m
,
"_training_state"
,
TrainingState
.
IDLE
)
self
.
forward_pre_hooks
=
{}
self
.
forward_hooks
=
{}
self
.
backward_pre_hooks
=
{}
"""
An FSDP unit is a module designed to manage the lifecycle of model parameters
in Fully Sharded Data Parallel (FSDP) training. It ensures that parameters
are only used within the module and are released immediately after
the forward and backward computations are completed.
This approach is crucial for efficient memory management, as releasing
parameters too early can lead to issues if other computations depend on them.
`optim` and `optim_grads` do not require FSDP units because they do not
shard model parameters.
"""
fsdp_unit_modules
=
self
.
fsdp_unit_modules
def
release_module_parameters
(
module
,
*
unused
):
for
param
in
module
.
parameters
():
bucket_id
=
self
.
param_and_grad_buffer
.
param_to_param_group
[
param
]
self
.
all_gather_pipeline
.
release_bucket
(
bucket_id
)
if
not
self
.
ddp_config
.
keep_fp8_transpose_cache_when_using_custom_fsdp
:
release_params_fp8_transpose_cache
(
module
.
parameters
())
def
release_params_fp8_transpose_cache
(
params
):
for
param
in
params
:
if
is_float8tensor
(
param
):
param
.
_transpose_invalid
=
True
param
.
_transpose
=
None
def
all_gather_module_parameters
(
module
,
*
unused
,
prefetch
=
True
,
prefetch_order
=
PrefetchOrder
.
FORWARD_PASS_ORDER
,
wait_bucket_ready
=
True
,
):
ag_pipeline
=
self
.
all_gather_pipeline
ag_pipeline
.
all_gather_params
(
params
=
list
(
module
.
parameters
()),
prefetch
=
prefetch
,
prefetch_order
=
prefetch_order
,
suggested_AG_prefetch_size
=
self
.
suggested_AG_prefetch_size
,
)
if
wait_bucket_ready
:
for
param
in
module
.
parameters
():
bucket_id
=
self
.
param_and_grad_buffer
.
param_to_param_group
[
param
]
ag_pipeline
.
wait_bucket_ready
(
bucket_id
)
def
_grad_acc
(
param
):
"""
Accumulate the gradient in the main_grad buffer.
"""
group_id
=
self
.
param_and_grad_buffer
.
param_to_param_group
[
param
]
group
=
self
.
param_and_grad_buffer
.
parameter_groups
[
group_id
]
if
not
group
.
requires_grad
:
return
overwrite_main_grad
=
self
.
ddp_config
.
data_parallel_sharding_strategy
in
[
"optim_grads"
,
"optim_grads_params"
,
]
if
overwrite_main_grad
:
if
not
param
.
grad_added_to_main_grad
:
if
param
.
grad
is
not
None
:
param
.
main_grad
.
copy_
(
param
.
grad
)
del
param
.
grad
else
:
param
.
main_grad
.
zero_
()
else
:
if
not
param
.
grad_added_to_main_grad
:
if
param
.
grad
is
not
None
:
param
.
main_grad
.
add_
(
param
.
grad
)
del
param
.
grad
# Reset the grad accumulate flag.
param
.
grad_added_to_main_grad
=
False
self
.
_params_require_handle_grad
=
set
()
def
_post_backward
(
module
,
*
unused
):
if
isinstance
(
module
,
tuple
(
fsdp_unit_modules
)):
if
self
.
ddp_config
.
data_parallel_sharding_strategy
==
"optim_grads_params"
:
release_module_parameters
(
module
)
module
.
_training_state
=
TrainingState
.
IDLE
param_list
=
list
(
module
.
parameters
())
else
:
param_list
=
list
(
module
.
parameters
(
recurse
=
False
))
for
param
in
param_list
:
_grad_acc
(
param
)
self
.
_params_require_handle_grad
.
discard
(
param
)
grad_reduce_every_bprop
=
self
.
ddp_config
.
data_parallel_sharding_strategy
in
[
"optim_grads"
,
"optim_grads_params"
,
]
if
grad_reduce_every_bprop
or
self
.
is_last_microbatch
:
self
.
grad_reduce_pipeline
.
reduce_gradients
(
param_list
,
suggested_queue_capacity
=
self
.
suggested_RS_queue_capacity
)
def
_pre_forward_param_unshard
(
module
:
nn
.
Module
,
args
:
Tuple
[
Any
,
...],
kwargs
:
Dict
[
str
,
Any
]
):
# Unshard the parameters before the forward pass.
input_training_state
=
module
.
_training_state
fsdp_forward_prefetch
=
True
if
input_training_state
==
TrainingState
.
PRE_BACKWARD
:
# In activation recomputation case, we need to cancel forward prefetch.
fsdp_forward_prefetch
=
False
else
:
module
.
_training_state
=
TrainingState
.
FORWARD
if
isinstance
(
module
,
tuple
(
fsdp_unit_modules
)):
param_list
=
list
(
module
.
parameters
())
self
.
all_gather_pipeline
.
all_gather_params
(
params
=
param_list
,
prefetch
=
fsdp_forward_prefetch
,
suggested_AG_prefetch_size
=
self
.
suggested_AG_prefetch_size
,
)
for
param
in
param_list
:
bucket_id
=
self
.
param_and_grad_buffer
.
param_to_param_group
[
param
]
self
.
all_gather_pipeline
.
wait_bucket_ready
(
bucket_id
)
else
:
# All-gather the parameters in every forward pass for FSDP.
param_list
=
list
(
module
.
parameters
(
recurse
=
False
))
self
.
all_gather_pipeline
.
all_gather_params
(
params
=
param_list
,
prefetch
=
fsdp_forward_prefetch
,
suggested_AG_prefetch_size
=
self
.
suggested_AG_prefetch_size
,
)
for
param
in
param_list
:
bucket_id
=
self
.
param_and_grad_buffer
.
param_to_param_group
[
param
]
self
.
all_gather_pipeline
.
wait_bucket_ready
(
bucket_id
)
return
args
,
kwargs
def
_register_post_backward_hook
(
post_backward_hook
:
callable
,
module
:
nn
.
Module
,
args
:
Tuple
[
Any
,
...],
kwargs
:
Dict
[
str
,
Any
],
):
# Register the backward function to reduce gradients after the backward pass.
# And for optim_grads_params, we need to release the parameters after the backward pass.
if
not
torch
.
is_grad_enabled
():
return
args
,
kwargs
args_list
,
args_spec
=
tree_flatten
(
args
)
kwargs_list
,
kwargs_spec
=
tree_flatten
(
kwargs
)
args_kwargs_list
=
list
(
args_list
)
+
list
(
kwargs_list
)
inp_tensor_indices
:
List
[
int
]
=
[]
inp_tensors
:
List
[
torch
.
Tensor
]
=
[]
for
i
,
obj
in
enumerate
(
args_kwargs_list
):
if
torch
.
is_tensor
(
obj
)
and
obj
.
requires_grad
:
inp_tensor_indices
.
append
(
i
)
inp_tensors
.
append
(
obj
)
if
len
(
inp_tensors
)
==
0
:
return
args
,
kwargs
inp_tensors
=
RegisterFSDPBackwardFunction
.
apply
(
functools
.
partial
(
post_backward_hook
,
module
),
*
inp_tensors
)
for
inp_tensor_idx
,
inp_tensor
in
zip
(
inp_tensor_indices
,
inp_tensors
):
args_kwargs_list
[
inp_tensor_idx
]
=
inp_tensor
args_list
=
args_kwargs_list
[:
len
(
args_list
)]
kwargs_list
=
args_kwargs_list
[
len
(
args_list
)
:]
args
=
tree_unflatten
(
args_list
,
args_spec
)
kwargs
=
tree_unflatten
(
kwargs_list
,
kwargs_spec
)
return
args
,
kwargs
fsdp_modules
=
[]
for
name
,
module
in
root_module
.
named_modules
():
if
any
(
is_submodule
(
module
,
fsdp_module
)
for
fsdp_module
in
fsdp_modules
):
continue
if
isinstance
(
module
,
tuple
(
fsdp_unit_modules
)):
fsdp_modules
.
append
(
module
)
self
.
forward_pre_hooks
[
f
'module
{
name
}
parameter unshard'
]
=
(
module
.
register_forward_pre_hook
(
_pre_forward_param_unshard
,
prepend
=
True
,
with_kwargs
=
True
)
)
self
.
forward_pre_hooks
[
f
"module
{
name
}
register post-backward hook"
]
=
(
module
.
register_forward_pre_hook
(
functools
.
partial
(
_register_post_backward_hook
,
_post_backward
),
with_kwargs
=
True
,
)
)
def
_root_post_backward
(
*
unused
):
# Make sure all the gradients are handled.
for
param
in
self
.
_params_require_handle_grad
:
_grad_acc
(
param
)
# Reduce the remain gradients.
grad_reduce_every_bprop
=
self
.
ddp_config
.
data_parallel_sharding_strategy
in
[
"optim_grads"
,
"optim_grads_params"
,
]
if
grad_reduce_every_bprop
or
self
.
is_last_microbatch
:
self
.
grad_reduce_pipeline
.
reduce_gradients
(
list
(
self
.
_params_require_handle_grad
),
suggested_queue_capacity
=
self
.
suggested_RS_queue_capacity
,
)
self
.
grad_reduce_pipeline
.
reset
()
# Reset root_pre_backward_hook_issued flag.
self
.
_root_pre_backward_hook_issued
=
False
def
_pre_backward
(
module
:
nn
.
Module
,
*
unused
):
module
.
_training_state
=
TrainingState
.
PRE_BACKWARD
if
isinstance
(
module
,
tuple
(
fsdp_unit_modules
)):
all_gather_module_parameters
(
module
,
prefetch_order
=
PrefetchOrder
.
BACKWARD_PASS_ORDER
)
self
.
_root_pre_backward_hook_issued
=
False
def
_root_pre_backward
(
module
:
nn
.
Module
,
*
unused
):
"""Marks the module's training state as 'pre_backward' before the
backprop, this function is registered on the root module.
This marking enables us to determine whether forward pass needs to
perform reshard/unshard operations in activation recomputation
scenarios.
"""
if
self
.
_root_pre_backward_hook_issued
:
return
self
.
_root_pre_backward_hook_issued
=
True
if
self
.
ddp_config
.
data_parallel_sharding_strategy
==
"optim_grads_params"
:
for
module
in
root_module
.
modules
():
if
isinstance
(
module
,
tuple
(
fsdp_unit_modules
)):
module
.
_training_state
=
TrainingState
.
PRE_BACKWARD
for
param
in
module
.
parameters
():
bucket_id
=
self
.
param_and_grad_buffer
.
param_to_param_group
[
param
]
self
.
all_gather_pipeline
.
wait_bucket_ready
(
bucket_id
,
empty_ok
=
True
)
self
.
all_gather_pipeline
.
release_bucket
(
bucket_id
)
self
.
_params_require_handle_grad
=
set
()
for
param_group
in
self
.
param_and_grad_buffer
.
parameter_groups
:
if
not
param_group
.
requires_grad
:
continue
self
.
_params_require_handle_grad
|=
set
(
param_group
.
params
)
for
param
in
param_group
.
params
:
param
.
grad_added_to_main_grad
=
False
torch
.
autograd
.
Variable
.
_execution_engine
.
queue_callback
(
_root_post_backward
)
def
_post_forward
(
module
:
nn
.
Module
,
input
:
Any
,
output
:
Any
):
# When composing with module-hook-based activation checkpointing, the
# post-backward hook is responsible for the reshard
if
module
.
_training_state
==
TrainingState
.
PRE_BACKWARD
:
return
output
release_module_parameters
(
module
)
module
.
_training_state
=
TrainingState
.
IDLE
return
output
def
_release_module_fp8_transpose_cache
(
module
:
nn
.
Module
,
*
unused
):
release_params_fp8_transpose_cache
(
module
.
parameters
(
recurse
=
False
))
if
len
(
fsdp_unit_modules
)
!=
0
:
fsdp_modules
=
[]
for
name
,
module
in
root_module
.
named_modules
():
if
any
(
is_submodule
(
module
,
fsdp_module
)
for
fsdp_module
in
fsdp_modules
):
continue
if
isinstance
(
module
,
tuple
(
fsdp_unit_modules
)):
fsdp_modules
.
append
(
module
)
self
.
forward_hooks
[
f
"release module
{
name
}
parameters"
]
=
(
module
.
register_forward_hook
(
_post_forward
,
prepend
=
False
)
)
self
.
backward_pre_hooks
[
f
"all-gather module
{
name
}
parameters"
]
=
(
module
.
register_full_backward_pre_hook
(
_pre_backward
)
)
elif
not
self
.
ddp_config
.
keep_fp8_transpose_cache_when_using_custom_fsdp
:
self
.
forward_hooks
[
f
"remove module
{
name
}
fp8 transpose cache"
]
=
(
module
.
register_forward_hook
(
_release_module_fp8_transpose_cache
,
prepend
=
False
)
)
# Registering all models with all parameters is to handle some special cases
# where the forward function of root_module is not called, but the forward
# functions of these equivalent modules are called instead.
for
name
,
module
in
root_module
.
named_modules
():
if
len
(
list
(
module
.
parameters
()))
!=
len
(
list
(
root_module
.
parameters
())):
continue
self
.
backward_pre_hooks
[
f
"
{
name
}
_root_pre_backward"
]
=
(
module
.
register_full_backward_pre_hook
(
_root_pre_backward
)
)
self
.
_root_pre_backward_hook_handle
=
root_module
.
register_full_backward_pre_hook
(
_root_pre_backward
)
@
contextmanager
def
no_sync
(
self
):
"""
Context manager that turns off gradient synchronization.
For grads shard mode there will actually always be gradient sync happening.
"""
# FIXME: Better handling of grads shard mode and no_sync in the training loop so that
# the code doesn't bog down developers.
self
.
is_last_microbatch
=
False
try
:
yield
finally
:
self
.
is_last_microbatch
=
True
def
start_param_sync
(
self
,
*
unused
,
force_sync
:
bool
=
False
,
force_dispatch
:
bool
=
False
):
"""
Initiates param sync (all-gather) communication operations for all model parameters.
By default, when overlap_param_gather is set to True, dispatches asynchronous communication
calls; when overlap_param_gather is set to False, calls synchronous communication
ops. Can override this default behavior using flags below.
Args:
force_sync (bool, optional): force synchronous collective regardless of
other settings.
force_dispatch (bool, optional): force dispatch regardless of other settings.
"""
if
not
force_sync
and
self
.
ddp_config
.
overlap_param_gather
:
# All-gather the first bucket before the forward pass.
first_param
=
list
(
self
.
module
.
parameters
())[
0
]
self
.
all_gather_pipeline
.
all_gather_params
(
params
=
[
first_param
],
prefetch
=
False
)
else
:
self
.
all_gather_pipeline
.
reset
()
for
bucket_id
in
range
(
self
.
all_gather_pipeline
.
num_buckets
):
self
.
all_gather_pipeline
.
all_gather_bucket_and_set_items
(
bucket_id
=
bucket_id
,
async_op
=
True
)
group
=
self
.
param_and_grad_buffer
.
parameter_groups
[
bucket_id
]
if
group
.
model_weight_buffer
is
None
:
continue
if
group
.
model_weight_buffer
.
is_data_distributed
:
# If model weight is sharded, we wait for the all-gather to complete and
# then release the bucket immediately to save memory usage.
self
.
all_gather_pipeline
.
wait_bucket_ready
(
bucket_id
)
for
bucket_id
in
range
(
self
.
all_gather_pipeline
.
num_buckets
):
self
.
all_gather_pipeline
.
wait_bucket_ready
(
bucket_id
)
def
start_grad_sync
(
self
,
*
unused
):
"""
Initiates grad sync (all-reduce or reduce-scatter) communication operations
for all model gradients.
When overlap_grad_reduce is set to True, dispatches asynchronous communication
calls. When overlap_grad_reduce is set to False, calls synchronous
communication ops.
"""
if
not
self
.
ddp_config
.
overlap_grad_reduce
:
if
self
.
data_parallel_sharding_strategy
==
"no_shard"
:
self
.
param_and_grad_buffer
.
all_reduce_gradients
(
async_op
=
self
.
ddp_config
.
overlap_grad_reduce
)
else
:
self
.
param_and_grad_buffer
.
reduce_scatter_gradients
()
def
finish_grad_sync
(
self
):
"""
Finishes grad sync (all-reduce or reduce-scatter) communication operations
for all model gradients.
When overlap_grad_reduce is set to True, waits for asynchronous communication
calls to complete. When overlap_grad_reduce is set to False, calls synchronous
communication ops.
"""
if
self
.
ddp_config
.
overlap_grad_reduce
:
self
.
grad_reduce_pipeline
.
wait_for_previous_grad_reduce
(
0
)
self
.
grad_reduce_pipeline
.
reset
()
else
:
self
.
start_grad_sync
()
self
.
param_and_grad_buffer
.
update_main_grads
()
if
self
.
ddp_config
.
overlap_param_gather
:
self
.
all_gather_pipeline
.
reset
()
def
optimizer_named_parameters
(
self
)
->
List
[
Tuple
[
str
,
torch
.
Tensor
]]:
"""
Returns a list of tuples containing the main weights and their corresponding names
for mixed-precision training, to be used by the optimizer for updates.
Returns:
List[Tuple[str, torch.Tensor]]: A list of tuples, where each tuple
contains a main weight tensor and its corresponding name.
"""
return
self
.
param_and_grad_buffer
.
optimizer_named_parameters
def
scale_gradients
(
self
,
scaling_factor
:
float
):
"""Scale all gradients inside the buffers by `scaling_factor`."""
self
.
param_and_grad_buffer
.
scale_gradients
(
scaling_factor
)
def
zero_grad_buffer
(
self
):
"""
Zeros out all grad buffers. Needs to be called at the beginning of each
training iteration.
"""
for
param
in
self
.
module
.
parameters
():
if
param
.
requires_grad
:
param
.
grad_added_to_main_grad
=
False
self
.
param_and_grad_buffer
.
zero_grad
()
def
broadcast_params
(
self
):
"""
Syncs parameters across all DP ranks.
"""
for
param
in
self
.
module
.
parameters
():
is_expert_parallel
=
not
getattr
(
param
,
'allreduce'
,
True
)
if
is_expert_parallel
:
data_parallel_group
=
parallel_state
.
get_data_modulo_expert_parallel_group
(
with_context_parallel
=
True
)
else
:
data_parallel_group
=
parallel_state
.
get_data_parallel_group
(
with_context_parallel
=
True
)
torch
.
distributed
.
broadcast
(
param
.
data
,
src
=
torch
.
distributed
.
get_global_rank
(
data_parallel_group
,
0
),
group
=
data_parallel_group
,
)
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
"""
Copies parameters and buffers from state_dict into the wrapped module and its
descendants. If strict is True, then the keys of state_dict must exactly match
the keys returned by this module’s state_dict() function.
"""
if
self
.
ddp_config
.
data_parallel_sharding_strategy
==
"optim_grads_params"
:
# make a copy of the state_dict to avoid modifying the input state_dict
state_dict
=
state_dict
.
copy
()
state_dict_extra_states
=
{}
for
key
in
list
(
state_dict
.
keys
()):
if
key
.
endswith
(
"_extra_state"
):
state_dict_extra_states
[
key
]
=
state_dict
[
key
]
del
state_dict
[
key
]
self
.
module
.
load_state_dict
(
state_dict_extra_states
,
strict
=
False
)
prefix
=
"module."
buffer
=
self
.
param_and_grad_buffer
for
param_groups
in
buffer
.
parameter_groups
:
wbuf
=
param_groups
.
model_weight_buffer
for
model_param
in
wbuf
.
params
:
if
is_float8tensor
(
model_param
):
fp8_meta
=
model_param
.
_fp8_meta
[
'scaling_fwd'
]
fp8_meta_index
=
model_param
.
_fp8_meta_index
model_param
.
_scale_inv
.
copy_
(
fp8_meta
.
scale_inv
[
fp8_meta_index
])
param_name
=
f
"
{
buffer
.
param_to_name
[
model_param
]
}
"
[
len
(
prefix
)
:]
if
param_name
in
state_dict
:
if
wbuf
and
wbuf
.
is_data_distributed
:
model_param
.
fully_shard_param_local_shard
.
data
.
copy_
(
state_dict
[
param_name
]
)
else
:
model_param
.
data
.
copy_
(
state_dict
[
param_name
])
del
state_dict
[
param_name
]
self
.
module
.
load_state_dict
(
state_dict
,
strict
=
False
)
return
self
.
module
.
load_state_dict
(
state_dict
,
strict
=
strict
)
class
RegisterFSDPBackwardFunction
(
torch
.
autograd
.
Function
):
"""
Register a backward function that will be called after the backward pass
of the model. This function is used to release the parameters after the
backward pass.
"""
@
staticmethod
def
forward
(
ctx
,
post_backward
,
*
inputs
:
torch
.
Tensor
):
"""
Forward pass of the RegisterFSDPBackwardFunction function.
"""
ctx
.
post_backward
=
post_backward
return
inputs
@
staticmethod
def
backward
(
ctx
,
*
grads
:
torch
.
Tensor
):
"""
Backward pass of the RegisterFSDPBackwardFunction function.
"""
ctx
.
post_backward
()
return
(
None
,)
+
grads
Megatron-LM/megatron/core/distributed/custom_fsdp/param_and_grad_buffer.py
0 → 100644
View file @
4e867b3c
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import
dataclasses
import
gc
import
inspect
import
logging
import
math
import
traceback
import
warnings
from
collections
import
namedtuple
from
contextlib
import
ExitStack
from
enum
import
Enum
from
typing
import
Any
,
List
,
Optional
,
Tuple
import
torch
from
torch.distributed
import
_coalescing_manager
from
megatron.core
import
parallel_state
from
megatron.core.distributed.distributed_data_parallel_config
import
DistributedDataParallelConfig
from
megatron.core.fp8_utils
import
is_float8tensor
,
modify_underlying_storage
,
quantize_param_shard
from
megatron.core.tensor_parallel
import
get_cuda_rng_tracker
from
megatron.core.utils
import
is_submodule
,
is_te_min_version
,
log_on_each_pipeline_stage
try
:
from
transformer_engine.pytorch
import
fp8_model_init
except
:
pass
try
:
from
transformer_engine.pytorch.module.base
import
TransformerEngineBaseModule
except
:
pass
logger
=
logging
.
getLogger
(
__name__
)
def
_p_assert
(
cond
:
Any
,
s
:
str
,
raise_assertion_error
:
bool
=
True
)
->
None
:
"""Alternate to ``assert`` when in the backward context to print the error
message ``s`` since otherwise, it is swallowed.
"""
if
not
cond
:
print
(
s
)
traceback
.
print_stack
()
if
raise_assertion_error
:
raise
AssertionError
(
s
)
def
_alloc_storage
(
tensor
:
torch
.
Tensor
,
size
:
torch
.
Size
)
->
None
:
"""
Allocate storage for ``tensor`` with the given size.
Returns:
bool: ``True`` if this method allocated storage and ``False`` if the
storage was already allocated.
"""
with
torch
.
no_grad
():
if
not
torch
.
distributed
.
_functional_collectives
.
is_torchdynamo_compiling
():
already_allocated
=
tensor
.
_typed_storage
().
_size
()
==
size
.
numel
()
if
not
already_allocated
:
tensor_storage_size
=
tensor
.
_typed_storage
().
_size
()
_p_assert
(
tensor_storage_size
==
0
,
"Tensor storage should have been resized to be 0 but got PLACEHOLDEr"
,
)
tensor
.
_typed_storage
().
_resize_
(
size
.
numel
())
def
_free_storage
(
tensor
:
torch
.
Tensor
):
"""
Frees the underlying storage of ``tensor``.
Returns:
bool: ``True`` if the method freed the storage and ``False`` if the
storage was already freed.
"""
with
torch
.
no_grad
():
if
not
torch
.
distributed
.
_functional_collectives
.
is_torchdynamo_compiling
():
already_freed
=
tensor
.
_typed_storage
().
_size
()
==
0
if
not
already_freed
:
_p_assert
(
tensor
.
storage_offset
()
==
0
,
"Freeing a tensor's storage is unsafe when it is not the sole occupant
\n
"
f
"storage offset:
{
tensor
.
storage_offset
()
}
\n
"
f
"storage size:
{
tensor
.
_typed_storage
().
_size
()
}
\n
"
f
"tensor shape:
{
tensor
.
shape
}
"
,
)
tensor
.
_typed_storage
().
_resize_
(
0
)
TensorItemIndex
=
namedtuple
(
'TensorItemIndex'
,
[
'global_data_index'
,
'size'
,
'item_id'
,
'bucket_id'
,
'shape'
]
)
BucketIndex
=
namedtuple
(
'BucketIndex'
,
[
'bucket_id'
,
'global_data_index'
,
'size'
,
'items'
])
ShardBucketIndex
=
namedtuple
(
'ShardBucketIndex'
,
[
'bucket_id'
,
'global_data_index'
,
'local_data_index'
,
'bucket_data_index'
,
'size'
],
)
@
dataclasses
.
dataclass
class
BucketingPolicy
:
"""
A policy for bucketing in Fully Sharded Data Parallel (FSDP) training.
Attributes:
suggested_bucket_size (int): The suggested size of each bucket in num of elements.
fsdp_unit_modules (list): A list of module classes that are treated as a
single unit for FSDP bucketing.
data_parallel_sharding_strategy (str): The strategy used for sharding
data parallel modules.
Note:
This policy is used to configure the bucketing behavior in FSDP training.
"""
suggested_bucket_size
:
Optional
[
int
]
=
40_000_000
fsdp_unit_modules
:
List
[
torch
.
nn
.
Module
]
=
dataclasses
.
field
(
default_factory
=
list
)
data_parallel_sharding_strategy
:
str
=
'no_shard'
def
_pad
(
number_to_be_padded
:
int
,
divisor
:
int
)
->
int
:
return
int
(
math
.
ceil
(
number_to_be_padded
/
divisor
)
*
divisor
)
def
build_data_parallel_buffer_index
(
elements
:
List
[
torch
.
Size
],
data_parallel_rank
:
int
,
data_parallel_world_size
:
int
,
is_data_distributed
:
bool
,
ddp_config
:
DistributedDataParallelConfig
,
bucket_id
:
int
=
0
,
)
->
Tuple
[
int
,
List
[
tuple
],
List
[
tuple
],
List
[
tuple
]]:
"""
Assuming that all input tensor elements are consecutively compose a global
buffer, give the index range of every tensor, every bucket and every in
bucket local buffer.
Args:
elements (List[torch.Size]): List of input tensor.
data_parallel_rank (int): Rank of the current process in the data parallel group.
data_parallel_world_size (int): World size of the data parallel group.
bucket_id (int, optional): The id of the bucket. Defaults to 0.
Returns:
Tuple[int, List[tuple], List[tuple], List[tuple]]: The index range of every tensor,
every bucket and every in bucket local buffer.
"""
def
_pad_if_needed
(
data_index
:
int
)
->
int
:
"""
Pads data indices if using distributed optimizer (to ensure uniform sharding).
"""
if
ddp_config
.
data_parallel_sharding_strategy
!=
'no_shard'
:
# Workaround for TE bug causing cuBLAS to pick an incompatible algorithm.
# This also helps cuBLAS pick more efficient algorithms for GEMMs.
# We now ensure that all buckets start at a memory address that is 256-byte
# aligned (128 values since params and grads use >= 16-bit precision).
return
_pad
(
data_index
,
math
.
lcm
(
data_parallel_world_size
,
128
))
return
data_index
def
add_item
(
item_id
,
item
,
bucket
,
item_index_map
,
bucket_id
):
bucket
.
append
(
item
)
bucket_size
=
sum
([
it
.
numel
()
for
it
in
bucket
])
item_index_map
.
append
(
TensorItemIndex
(
data_index
+
bucket_size
-
item
.
numel
(),
item
.
numel
(),
item_id
=
item_id
,
bucket_id
=
bucket_id
,
shape
=
item
,
)
)
item_index_map
=
[]
bucket
=
[]
data_index
=
0
for
item_id
,
item
in
enumerate
(
elements
):
add_item
(
item_id
,
item
,
bucket
,
item_index_map
,
bucket_id
)
bucket_size
=
sum
([
it
.
numel
()
for
it
in
bucket
])
bucket_size
=
_pad_if_needed
(
bucket_size
)
bucket_index
=
BucketIndex
(
bucket_id
,
data_index
,
bucket_size
,
items
=
list
(
filter
(
lambda
x
:
x
.
bucket_id
==
bucket_id
,
item_index_map
)),
)
shard_size
=
bucket_index
.
size
//
data_parallel_world_size
bucket_data_index
=
shard_size
*
data_parallel_rank
global_data_index
=
bucket_index
.
global_data_index
+
bucket_data_index
if
is_data_distributed
:
shard_bucket_index
=
ShardBucketIndex
(
bucket_id
,
global_data_index
,
0
,
bucket_data_index
,
shard_size
)
else
:
shard_bucket_index
=
ShardBucketIndex
(
bucket_id
,
global_data_index
,
global_data_index
,
bucket_data_index
,
shard_size
)
return
item_index_map
,
bucket_index
,
shard_bucket_index
@
dataclasses
.
dataclass
class
Bucket
:
"""
A container for holding data in Fully Sharded Data Parallel (FSDP) training.
Attributes:
data (torch.Tensor): A tensor containing the data elements
grouped together in a bucket.
data_operation_event (Optional[torch.cuda.Event]): An optional CUDA event
used to synchronize data operations.
status (Any): An optional status object used to track the state of the bucket.
Note:
Buckets are used to optimize communication in FSDP training by
grouping small tensors together.
"""
data
:
torch
.
Tensor
data_operation_event
:
Optional
[
torch
.
cuda
.
Event
]
=
None
status
:
Any
=
None
class
TemporaryBucketAllocator
:
"""
A utility class for managing temporary buckets (buffers) used in FSDP
operations like parameters unshard and gradients reduction.
This allocator handles the dynamic allocation and deallocation of temporary memory buffers
needed during FSDP (Fully Sharded Data Parallel) operations, particularly for parameters
unshard and gradients reduction. It helps optimize memory usage by allowing temporary
buckets to be released when no longer needed.
Key Features:
- Dynamic allocation of temporary buckets for FSDP operations
- Memory-efficient management of temporary buffers
- Support for both parameters unshard and gradients reduction operations
- Automatic cleanup of unused buckets to save memory
Usage:
```python
# Create an allocator instance
allocator = TemporaryBucketAllocator(name="gpt_parameters")
# Allocate a temporary bucket
temp_bucket = allocator.allocate(size=1024, dtype=torch.float32)
# Use the temporary bucket for FSDP operations
# ... perform all-gather or reduce-scatter ...
# Free the bucket when done
allocator.free(temp_bucket)
```
Note:
It's important to release temporary buckets after use to prevent memory leaks
and optimize memory usage during training.
"""
def
__init__
(
self
):
self
.
buckets
=
{}
def
allocate
(
self
,
bucket_id
:
int
,
size
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
)
->
Bucket
:
"""
allocate a temporary bucket.
"""
if
bucket_id
not
in
self
.
buckets
:
self
.
buckets
[
bucket_id
]
=
Bucket
(
data
=
torch
.
empty
(
size
,
dtype
=
dtype
,
device
=
device
))
return
self
.
buckets
[
bucket_id
]
def
free
(
self
,
bucket_id
:
int
):
"""
free a temporary bucket.
"""
if
bucket_id
in
self
.
buckets
:
_free_storage
(
self
.
buckets
[
bucket_id
].
data
)
del
self
.
buckets
[
bucket_id
]
class
StorageResizeBasedBucketAllocator
(
TemporaryBucketAllocator
):
"""
A specialized temporary bucket allocator that resizes the storage of temporary buckets
based on the required size.
"""
def
__init__
(
self
):
self
.
buckets
=
{}
# {bucket_id: Bucket}
def
allocate
(
self
,
bucket_id
:
int
,
size
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
)
->
Bucket
:
"""
allocate a temporary bucket.
"""
if
bucket_id
not
in
self
.
buckets
:
self
.
buckets
[
bucket_id
]
=
Bucket
(
data
=
torch
.
empty
(
size
,
dtype
=
dtype
,
device
=
device
))
bucket
=
self
.
buckets
[
bucket_id
]
_alloc_storage
(
bucket
.
data
,
torch
.
Size
([
size
]))
return
bucket
def
free
(
self
,
bucket_id
:
int
):
"""
free a temporary bucket.
"""
if
bucket_id
in
self
.
buckets
:
_free_storage
(
self
.
buckets
[
bucket_id
].
data
)
class
RotaryBucketAllocator
(
TemporaryBucketAllocator
):
"""A specialized temporary bucket allocator that implements a circular buffer recycling strategy
to minimize memory fragmentation in FSDP operations.
RotaryBucketAllocator extends TemporaryBucketAllocator by maintaining a limited pool of
pre-allocated buffers that are reused in a circular manner. This approach helps prevent
memory fragmentation that typically occurs with frequent allocation and deallocation of
temporary buffers during FSDP operations.
Key Features:
- Circular buffer recycling strategy for memory efficiency
- Reduced memory fragmentation compared to dynamic allocation
- Pre-allocated buffer pool for faster access
- Automatic buffer reuse without explicit deallocation
Usage:
```python
# Create a rotary allocator
allocator = RotaryBucketAllocator(name="gpt_parameters")
# Get a temporary buffer from the pool
temp_bucket = allocator.allocate(size=1024, dtype=torch.float32)
# Use the temporary bucket for FSDP operations
# ... perform all-gather or reduce-scatter ...
# Free the bucket when done, make it in idle buffer pool
allocator.free(temp_bucket)
```
"""
def
__init__
(
self
,
name
:
str
):
self
.
name
=
name
self
.
num_global_buffer
=
0
self
.
idle_buffer
=
[]
# [buffer_id]
self
.
using_buffer
=
{}
# {bucket_id: buffer_id}
self
.
buckets
=
{}
def
allocate
(
self
,
bucket_id
:
int
,
size
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
)
->
Bucket
:
"""
allocate a temporary bucket.
"""
def
_get_global_buffer
(
buffer_id
:
int
):
return
parallel_state
.
get_global_memory_buffer
().
get_tensor
(
[
size
],
dtype
=
dtype
,
name
=
self
.
_get_gbuf_name
(
buffer_id
)
)
if
bucket_id
in
self
.
using_buffer
:
buffer_id
=
self
.
using_buffer
[
bucket_id
]
return
Bucket
(
data
=
_get_global_buffer
(
buffer_id
))
if
len
(
self
.
idle_buffer
)
==
0
:
# allocate new buffer
buffer_id
=
self
.
num_global_buffer
self
.
num_global_buffer
+=
1
self
.
idle_buffer
.
append
(
buffer_id
)
buffer_id
=
self
.
idle_buffer
.
pop
(
0
)
self
.
using_buffer
[
bucket_id
]
=
buffer_id
return
Bucket
(
data
=
_get_global_buffer
(
buffer_id
))
def
_get_gbuf_name
(
self
,
buffer_id
:
int
):
return
f
"
{
self
.
name
}
_
{
buffer_id
}
"
def
free
(
self
,
bucket_id
:
int
):
"""
free a temporary bucket.
"""
if
bucket_id
in
self
.
using_buffer
:
buffer_id
=
self
.
using_buffer
.
pop
(
bucket_id
)
self
.
idle_buffer
.
append
(
buffer_id
)
class
DataParallelBuffer
:
"""
A class that manages the data parallel buffer for Fully Sharded Data Parallel (FSDP) training.
"""
def
__init__
(
self
,
ddp_config
:
DistributedDataParallelConfig
,
params
:
List
[
torch
.
nn
.
Parameter
],
is_data_distributed
:
bool
,
bucket_id
:
int
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
device
:
Optional
[
torch
.
device
]
=
None
,
data_parallel_group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
temporary_bucket_allocator
:
Optional
[
TemporaryBucketAllocator
]
=
None
,
init_meta_only
:
bool
=
False
,
is_dtype_float8
:
bool
=
False
,
gradient_scaling_factor
:
Optional
[
float
]
=
None
,
)
->
None
:
self
.
ddp_config
=
ddp_config
self
.
params
=
params
_param_dtype
=
{
p
.
dtype
for
p
in
self
.
params
}
assert
len
(
_param_dtype
)
==
1
,
f
'params have different dtypes:
{
_param_dtype
}
'
self
.
is_data_distributed
=
is_data_distributed
self
.
bucket_id
=
bucket_id
self
.
dtype
=
dtype
if
dtype
else
next
(
iter
(
_param_dtype
))
self
.
device
=
device
self
.
data_parallel_group
=
data_parallel_group
self
.
dp_rank
=
torch
.
distributed
.
get_rank
(
group
=
self
.
data_parallel_group
)
self
.
dp_world_size
=
torch
.
distributed
.
get_world_size
(
group
=
self
.
data_parallel_group
)
self
.
temporary_bucket_allocator
=
(
temporary_bucket_allocator
if
temporary_bucket_allocator
else
TemporaryBucketAllocator
()
)
self
.
is_dtype_float8
=
is_dtype_float8
self
.
gradient_scaling_factor
=
gradient_scaling_factor
(
self
.
item_index_map
,
self
.
bucket_index
,
self
.
shard_bucket_index
)
=
(
build_data_parallel_buffer_index
(
[
p
.
shape
for
p
in
self
.
params
],
self
.
dp_rank
,
self
.
dp_world_size
,
is_data_distributed
,
ddp_config
,
bucket_id
=
bucket_id
,
)
)
self
.
data_size
=
(
self
.
bucket_index
.
size
if
not
is_data_distributed
else
self
.
shard_bucket_index
.
size
)
if
init_meta_only
:
self
.
data
=
None
else
:
self
.
data
=
torch
.
empty
(
self
.
data_size
,
dtype
=
self
.
dtype
,
device
=
device
)
self
.
param_idx
=
{
p
:
i
for
i
,
p
in
enumerate
(
self
.
params
)}
self
.
placeholder_bucket
=
None
self
.
placeholder_items
=
{}
def
fetch_bucket
(
self
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
and_allocate_params_data
:
bool
=
False
)
->
Bucket
:
"""
Fetch a communication buffer for data-parallel operations.
The size of the bucket is defined by the `DataParallelBuffer` instance.
If `and_allocate_params_data` is True, this method resets the parameter
data stored in the `DataParallelBuffer` instance.
Args:
dtype (Optional[torch.dtype], optional): The data type of the tensor
to fetch a buffer for. Defaults to None.
and_allocate_params_data (bool, optional): Whether to allocate and
reset parameter data. Defaults to False.
Returns:
Bucket: The communication buffer for the specified data type.
"""
if
dtype
is
None
:
dtype
=
self
.
dtype
bucket_index
=
self
.
bucket_index
if
not
self
.
is_data_distributed
and
dtype
==
self
.
dtype
:
bucket
=
Bucket
(
data
=
self
.
data
[
bucket_index
.
global_data_index
:
bucket_index
.
global_data_index
+
bucket_index
.
size
]
)
else
:
bucket
=
self
.
temporary_bucket_allocator
.
allocate
(
bucket_id
=
bucket_index
.
bucket_id
,
size
=
bucket_index
.
size
,
dtype
=
dtype
,
device
=
self
.
device
,
)
if
and_allocate_params_data
:
for
p
in
self
.
params
:
item_id
=
self
.
param_idx
[
p
]
if
is_float8tensor
(
p
):
p
.
_data
=
self
.
get_item_from_bucket
(
bucket
,
item_id
).
view
(
p
.
shape
)
else
:
p
.
data
=
self
.
get_item_from_bucket
(
bucket
,
item_id
).
view
(
p
.
shape
)
return
bucket
def
free_bucket_storage
(
self
,
and_free_params_data
:
bool
=
False
):
"""
Release the storage of a temporary communication bucket.
If the bucket is temporary, this method frees its storage.
If `and_free_params_data` is True, this method also releases the storage
of the parameter data stored in the `DataParallelBuffer` instance.
Args:
and_free_params_data (bool, optional): Whether to also release the
storage of the parameter data. Defaults to False.
Returns:
None
"""
if
not
self
.
is_data_distributed
:
return
self
.
temporary_bucket_allocator
.
free
(
self
.
bucket_index
.
bucket_id
)
if
and_free_params_data
:
if
self
.
placeholder_bucket
is
None
:
self
.
placeholder_bucket
=
Bucket
(
data
=
torch
.
empty
(
self
.
bucket_index
.
size
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
)
for
p
in
self
.
params
:
item_id
=
self
.
param_idx
[
p
]
self
.
placeholder_items
[
item_id
]
=
self
.
get_item_from_bucket
(
self
.
placeholder_bucket
,
item_id
).
view
(
p
.
shape
)
_free_storage
(
self
.
placeholder_bucket
.
data
)
for
p
in
self
.
params
:
item_id
=
self
.
param_idx
[
p
]
if
is_float8tensor
(
p
):
p
.
_data
=
self
.
placeholder_items
[
item_id
]
else
:
p
.
data
=
self
.
placeholder_items
[
item_id
]
def
_get_item_slice_in_shard
(
self
,
item_id
:
int
)
->
Tuple
[
int
,
int
]:
item_index
=
self
.
item_index_map
[
item_id
]
shard_bucket_index
=
self
.
shard_bucket_index
item_global_start
=
item_index
.
global_data_index
item_global_end
=
item_index
.
global_data_index
+
item_index
.
size
shard_bucket_start
=
shard_bucket_index
.
global_data_index
shard_bucket_end
=
shard_bucket_index
.
global_data_index
+
shard_bucket_index
.
size
if
item_global_start
>
shard_bucket_end
or
item_global_end
<
shard_bucket_start
:
return
(
0
,
0
)
start
=
max
(
item_global_start
,
shard_bucket_start
)
-
item_global_start
end
=
min
(
item_global_end
,
shard_bucket_end
)
-
item_global_start
return
(
start
,
end
)
# pylint: disable=missing-function-docstring
def
locate_item_in_global_item
(
self
,
item_id
:
int
)
->
Tuple
[
int
,
int
]:
item_index
=
self
.
item_index_map
[
item_id
]
if
not
self
.
is_data_distributed
:
return
(
0
,
item_index
.
size
)
slice_start
,
slice_end
=
self
.
_get_item_local_shard_index
(
item_id
)
if
slice_start
==
slice_end
:
return
(
0
,
0
)
local_shard_index_to_global_index_offset
=
(
self
.
shard_bucket_index
.
global_data_index
-
self
.
shard_bucket_index
.
local_data_index
)
slice_start
+=
local_shard_index_to_global_index_offset
slice_end
+=
local_shard_index_to_global_index_offset
return
(
slice_start
-
item_index
.
global_data_index
,
slice_end
-
item_index
.
global_data_index
,
)
def
_get_item_local_shard_index
(
self
,
item_id
:
int
)
->
Tuple
[
int
,
int
]:
slice_start
,
slice_end
=
self
.
_get_item_slice_in_shard
(
item_id
)
if
slice_start
==
slice_end
:
return
(
0
,
0
)
item_index
=
self
.
item_index_map
[
item_id
]
shard_bucket_index
=
self
.
shard_bucket_index
offset
=
(
item_index
.
global_data_index
-
shard_bucket_index
.
global_data_index
+
shard_bucket_index
.
local_data_index
)
return
(
offset
+
slice_start
,
offset
+
slice_end
)
def
_get_item_local_index
(
self
,
item_id
:
int
)
->
Tuple
[
int
,
int
]:
if
not
self
.
is_data_distributed
:
item_index
=
self
.
item_index_map
[
item_id
]
return
(
item_index
.
global_data_index
,
item_index
.
global_data_index
+
item_index
.
size
)
return
self
.
_get_item_local_shard_index
(
item_id
)
def
set_item
(
self
,
item_id
:
int
,
item_data
:
torch
.
Tensor
)
->
None
:
"""
Update a tensor item managed by the `DataParallelBuffer` instance.
The storage of the item is mapped to the communication bucket.
This method updates the item data and ensures consistency with the bucket.
Args:
item_id (int): The ID of the tensor item to update.
item_data (torch.Tensor): The new data for the tensor item.
Returns:
None
"""
if
self
.
is_data_distributed
:
slice_start
,
slice_end
=
self
.
_get_item_slice_in_shard
(
item_id
)
item_data
=
item_data
.
flatten
()[
slice_start
:
slice_end
]
local_index_start
,
local_index_end
=
self
.
_get_item_local_index
(
item_id
)
shard
=
self
.
data
[
local_index_start
:
local_index_end
]
if
shard
.
numel
()
>
0
:
shard
.
data
.
copy_
(
item_data
.
flatten
())
def
get_item
(
self
,
item_id
:
int
,
only_shard
:
bool
=
False
)
->
torch
.
Tensor
:
"""
Retrieve a tensor item managed by the `DataParallelBuffer` instance.
The storage of the item is mapped to the communication bucket.
If `only_shard` is True, returns only the shard of the item corresponding
to the current process.
Otherwise, returns the entire item.
Args:
item_id (int): The ID of the tensor item to retrieve.
only_shard (bool, optional): Whether to return only the shard of the
item. Defaults to False.
Returns:
torch.Tensor: The retrieved tensor item.
"""
if
only_shard
:
start
,
end
=
self
.
_get_item_local_shard_index
(
item_id
)
else
:
start
,
end
=
self
.
_get_item_local_index
(
item_id
)
return
self
.
data
[
start
:
end
]
def
get_item_from_bucket
(
self
,
bucket
:
Bucket
,
item_id
:
int
):
"""get item from bucket."""
item_index
=
self
.
item_index_map
[
item_id
]
bucket_index
=
self
.
bucket_index
start_index
=
item_index
.
global_data_index
-
bucket_index
.
global_data_index
end_index
=
start_index
+
item_index
.
size
item
=
bucket
.
data
[
start_index
:
end_index
]
return
item
def
get_shard_from_bucket
(
self
,
bucket
:
Bucket
):
"""Get the local sharding of the bucket."""
shard_bucket_index
=
self
.
shard_bucket_index
offset
=
shard_bucket_index
.
bucket_data_index
shard_size
=
shard_bucket_index
.
size
shard
=
bucket
.
data
[
offset
:
offset
+
shard_size
]
return
shard
def
get_shard_from_local_buffer
(
self
)
->
torch
.
Tensor
:
"""Get the local sharding of the bucket."""
index
=
self
.
shard_bucket_index
return
self
.
data
[
index
.
local_data_index
:
index
.
local_data_index
+
index
.
size
]
@
dataclasses
.
dataclass
class
ParameterGroup
:
"""
A group of model parameters with associated metadata for data-parallel training.
This dataclass encapsulates a list of PyTorch parameters and additional information
necessary for managing data-parallel operations, such as data type, gradient requirements,
and buffer assignments.
"""
params
:
List
[
torch
.
nn
.
Parameter
]
dtype
:
Optional
[
torch
.
dtype
]
=
None
is_expert_param
:
bool
=
False
requires_grad
:
Optional
[
bool
]
=
None
fsdp_unit_id
:
Optional
[
int
]
=
None
data_parallel_world_size
:
Optional
[
int
]
=
None
model_weight_buffer
:
Optional
[
DataParallelBuffer
]
=
None
main_weight_buffer
:
Optional
[
DataParallelBuffer
]
=
None
main_grad_buffer
:
Optional
[
DataParallelBuffer
]
=
None
def
_get_parameter_groups
(
module
:
torch
.
nn
.
Module
,
policy
:
BucketingPolicy
,
meta_device_init_fp8_params
:
dict
,
bucket_group_by_fsdp_unit
:
bool
=
True
,
):
"""
Get the parameter group for the given module and parameters.
"""
param_to_name
=
{
p
:
name
for
name
,
p
in
module
.
named_parameters
()}
fsdp_units
=
[]
if
policy
.
fsdp_unit_modules
:
param_to_id
=
{}
for
i
,
p
in
enumerate
(
module
.
parameters
()):
param_to_id
[
p
]
=
i
fsdp_modules
=
[]
for
m
in
module
.
modules
():
# Skip nested FSDP module.
if
any
(
is_submodule
(
module
,
fsdp_module
)
for
fsdp_module
in
fsdp_modules
):
continue
if
isinstance
(
m
,
tuple
(
policy
.
fsdp_unit_modules
)):
fsdp_units
.
append
([
param_to_name
[
p
]
for
p
in
m
.
parameters
()])
fsdp_modules
.
append
(
m
)
def
_does_param_require_new_bucket
(
param
):
"""
Split shared embedding parameters into separate bucket if using distributed
optimizer that makes use of reduce-scatters instead of all-reduces.
This ensures that the first and last pipeline stage partition optimizer state
for the shared embedding parameters the same way across DP replicas, allowing
the DP reduce-scatter to be before the embedding all-reduce.
"""
return
(
getattr
(
param
,
"shared_embedding"
,
False
)
and
policy
.
data_parallel_sharding_strategy
!=
"no_shard"
)
is_expert_parameter
=
lambda
p
:
not
getattr
(
p
,
'allreduce'
,
True
)
# Step 1: Group the parameters according to their execution order and attributes.
parameter_groups
=
[]
for
name
,
param
in
module
.
named_parameters
():
param_attrs
=
dict
(
dtype
=
(
"float8"
if
is_float8tensor
(
param
)
or
meta_device_init_fp8_params
.
get
(
name
,
False
)
else
param
.
dtype
),
is_expert_param
=
is_expert_parameter
(
param
),
requires_grad
=
param
.
requires_grad
,
fsdp_unit_id
=
None
,
)
for
fsdp_unit_id
,
fsdp_unit
in
enumerate
(
fsdp_units
):
if
name
in
fsdp_unit
:
param_attrs
[
"fsdp_unit_id"
]
=
fsdp_unit_id
break
found_group
=
False
for
param_group
in
parameter_groups
:
group_attrs
=
{
key
:
value
for
key
,
value
in
param_group
.
__dict__
.
items
()
if
key
in
param_attrs
}
if
group_attrs
==
param_attrs
:
param_group
.
params
.
append
(
param
)
found_group
=
True
break
if
not
found_group
:
parameter_groups
.
append
(
ParameterGroup
([
param
],
**
param_attrs
))
# Step 2: Bucket the parameters based on the guide bucket size.
suggested_bucket_size
=
policy
.
suggested_bucket_size
bucket_groups
=
[]
for
group
in
parameter_groups
:
bucket
=
[]
basic_attrs
=
{
key
:
value
for
key
,
value
in
group
.
__dict__
.
items
()
if
key
in
[
'dtype'
,
'is_expert_param'
,
'requires_grad'
,
'fsdp_unit_id'
]
}
for
param
in
group
.
params
:
if
_does_param_require_new_bucket
(
param
):
if
len
(
bucket
)
>
0
:
bucket_groups
.
append
(
ParameterGroup
(
bucket
,
**
basic_attrs
))
bucket_groups
.
append
(
ParameterGroup
([
param
],
**
basic_attrs
))
bucket
=
[]
continue
bucket
.
append
(
param
)
if
(
group
.
fsdp_unit_id
is
None
and
suggested_bucket_size
and
sum
([
p
.
numel
()
for
p
in
bucket
])
>=
suggested_bucket_size
):
bucket_groups
.
append
(
ParameterGroup
(
bucket
,
**
basic_attrs
))
bucket
=
[]
continue
if
bucket
:
bucket_groups
.
append
(
ParameterGroup
(
bucket
,
**
basic_attrs
))
param_to_param_group
=
{}
for
group_id
,
group
in
enumerate
(
bucket_groups
):
for
param
in
group
.
params
:
param_to_param_group
[
param
]
=
group_id
# Generate the groups of collective buckets, where each group aggregates
# the collectives per FSDP unit. This improves performance by reducing
# the number of collective calls and increasing per-collective efficiency.
#
# Set default aggregate buckets of bucket.
bucket_group_of_bucket
=
{}
for
bucket_id
in
range
(
len
(
bucket_groups
)):
bucket_group_of_bucket
[
bucket_id
]
=
[
bucket_id
]
# Set aggregate buckets by FSDP units.
if
bucket_group_by_fsdp_unit
:
bucket_group_map
=
{}
for
bucket_id
,
param_group
in
enumerate
(
bucket_groups
):
if
param_group
.
fsdp_unit_id
is
None
:
continue
id
=
(
param_group
.
fsdp_unit_id
,
param_group
.
is_expert_param
)
if
id
not
in
bucket_group_map
:
bucket_group_map
[
id
]
=
[]
bucket_group_map
[
id
].
append
(
bucket_id
)
for
bucket_group
in
bucket_group_map
.
values
():
for
bucket_id
in
bucket_group
:
bucket_group_of_bucket
[
bucket_id
]
=
bucket_group
return
(
bucket_groups
,
param_to_param_group
,
bucket_group_of_bucket
)
class
ParamAndGradBuffer
:
"""A class that manages parameter grouping, buffer allocation, and
communication operations for data-parallel distributed training.
This class provides functionality to:
1. Group parameters based on their data types and communication group sizes
2. Create contiguous buffers for model weights, gradients, and high-precision
main weights
3. Handle parameter unsharding, gradient reduction, and weight
synchronization operations
Key Features:
- Efficient parameter grouping based on data types and communication patterns
- Memory-efficient contiguous buffer allocation
- Support for mixed-precision training with main weights
- Distributed operations including parameters all-gather and gradients
reduce-scatter/all-reduce
- Synchronized weight updates between model and main weights
Note:
This class is designed for distributed training scenarios where efficient
parameter management and communication are crucial for performance.
Args:
ddp_config (DistributedDataParallelConfig): The distributed data parallel
configuration.
module (torch.nn.Module): The module whose parameters are to be grouped
and flatten.
bucketing_policy (BucketingPolicy): The bucketing policy.
data_parallel_group (torch.distributed.ProcessGroup): The data parallel group.
expert_data_parallel_group (Optional[torch.distributed.ProcessGroup]):
The expert data parallel group.
preserve_fp32_weights (bool): Whether to preserve FP32 weights.
grad_reduce_in_fp32 (bool): Whether to reduce gradients in FP32.
gradient_scaling_factor (Optional[float]): The gradient scaling factor.
expert_gradient_scaling_factor (Optional[float]): The expert gradient
scaling factor.
device (torch.device): The parameter and gradient buffer device.
only_create_grad_buffer_and_main_weight_buffer_for_param_requires_grad (bool):
Whether to only create the gradient buffer and main weight buffer
for parameters that require gradients. Default is True.
"""
def
__init__
(
self
,
ddp_config
:
DistributedDataParallelConfig
,
module
:
torch
.
nn
.
Module
,
bucketing_policy
:
BucketingPolicy
,
data_parallel_group
:
torch
.
distributed
.
ProcessGroup
,
expert_data_parallel_group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
preserve_fp32_weights
:
bool
=
True
,
grad_reduce_in_fp32
:
bool
=
True
,
gradient_scaling_factor
:
Optional
[
float
]
=
None
,
expert_gradient_scaling_factor
:
Optional
[
float
]
=
None
,
device
:
torch
.
device
=
torch
.
device
(
'cuda'
),
only_create_grad_buffer_and_main_weight_buffer_for_param_requires_grad
:
bool
=
True
,
reset_parameters_for_meta_device_init_module
:
bool
=
False
,
):
self
.
ddp_config
=
ddp_config
self
.
module
=
module
self
.
bucketing_policy
=
bucketing_policy
self
.
param_to_name
=
{
p
:
name
for
name
,
p
in
self
.
module
.
named_parameters
()}
self
.
preserve_fp32_weights
=
preserve_fp32_weights
self
.
grad_reduce_in_fp32
=
grad_reduce_in_fp32
self
.
data_parallel_group
=
data_parallel_group
self
.
expert_data_parallel_group
=
expert_data_parallel_group
self
.
params
=
list
(
module
.
parameters
())
self
.
gradient_scaling_factor
=
gradient_scaling_factor
self
.
expert_gradient_scaling_factor
=
expert_gradient_scaling_factor
self
.
device
=
device
self
.
only_create_grad_buffer_and_main_weight_buffer_for_param_requires_grad
=
(
only_create_grad_buffer_and_main_weight_buffer_for_param_requires_grad
)
self
.
reset_parameters_for_meta_device_init_module
=
(
reset_parameters_for_meta_device_init_module
)
# Mark fp8 param.
meta_device_init_fp8_params
=
{}
if
reset_parameters_for_meta_device_init_module
:
for
m
in
module
.
modules
():
if
not
isinstance
(
m
,
TransformerEngineBaseModule
):
continue
for
name
,
param
in
m
.
named_parameters
(
recurse
=
False
):
# The fp8 param initialized from the meta device may NOT be
# an fp8 tensor, according to the internal logic of the TE
# to determine whether this parameter is fp8 or not.
fp8_meta_index
=
m
.
param_init_meta
[
name
].
fp8_meta_index
if
m
.
primary_weights_in_fp8
and
fp8_meta_index
is
not
None
:
meta_device_init_fp8_params
[
self
.
param_to_name
[
param
]]
=
True
# Get the parameter groups.
(
self
.
parameter_groups
,
self
.
param_to_param_group
,
self
.
bucket_group_of_bucket
)
=
(
_get_parameter_groups
(
module
,
bucketing_policy
,
meta_device_init_fp8_params
)
)
self
.
_init_each_parameter_group_buffers
(
meta_device_init_fp8_params
)
# Initialize the optimizer named parameters.
self
.
optimizer_named_parameters
=
self
.
_init_optimizer_named_parameters
()
self
.
_log_parameter_groups
()
def
_log_parameter_groups
(
self
):
"""
Log the parameter groups for all pipeline stages.
"""
# Log buckets for all PP stages.
if
(
parallel_state
.
get_data_parallel_rank
(
with_context_parallel
=
True
)
==
0
and
parallel_state
.
get_tensor_model_parallel_rank
()
==
0
):
bucket_groups
=
self
.
parameter_groups
param_to_name
=
self
.
param_to_name
log_strs
=
[]
log_strs
.
append
(
f
'Number of parameter groups for FSDP:
{
len
(
bucket_groups
)
}
'
)
for
index
,
group
in
enumerate
(
bucket_groups
):
numel
=
0
for
param
in
group
.
params
:
numel
+=
param
.
numel
()
log_strs
.
append
(
f
"Params for group
{
index
+
1
}
(
{
numel
}
elements, dtype:
{
group
.
dtype
}
, "
f
"fsdp_unit_id:
{
group
.
fsdp_unit_id
}
, "
f
"has_weight_buffer:
{
group
.
model_weight_buffer
is
not
None
}
, "
f
"has_grad_buffer:
{
group
.
main_grad_buffer
is
not
None
}
, "
f
"has_main_weight_buffer:
{
group
.
main_weight_buffer
is
not
None
}
):"
)
for
param
in
group
.
params
:
log_strs
.
append
(
f
'
\t
{
param_to_name
[
param
]
}
'
)
log_on_each_pipeline_stage
(
logger
,
logging
.
INFO
,
'
\n
'
.
join
(
log_strs
))
def
_init_each_parameter_group_buffers
(
self
,
meta_device_init_fp8_params
):
"""
Initialize the buffers for each parameter group.
"""
data_parallel_sharding_strategy
=
self
.
ddp_config
.
data_parallel_sharding_strategy
if
data_parallel_sharding_strategy
==
'no_shard'
:
is_model_weight_buffer_distributed
=
False
is_main_weight_buffer_distributed
=
False
is_grad_buffer_distributed
=
False
elif
data_parallel_sharding_strategy
==
'optim'
:
is_model_weight_buffer_distributed
=
False
is_main_weight_buffer_distributed
=
True
is_grad_buffer_distributed
=
False
elif
data_parallel_sharding_strategy
==
'optim_grads'
:
is_model_weight_buffer_distributed
=
False
is_main_weight_buffer_distributed
=
True
is_grad_buffer_distributed
=
True
elif
data_parallel_sharding_strategy
==
'optim_grads_params'
:
is_model_weight_buffer_distributed
=
True
is_main_weight_buffer_distributed
=
True
is_grad_buffer_distributed
=
True
else
:
raise
ValueError
(
f
'Invalid data_parallel_sharding_strategy:
{
data_parallel_sharding_strategy
}
'
)
self
.
memory_allocator_for_model_weight_buffer
=
StorageResizeBasedBucketAllocator
()
self
.
buffer_all_in_one
=
True
preserve_fp32_weights
=
self
.
preserve_fp32_weights
grad_reduce_in_fp32
=
self
.
grad_reduce_in_fp32
buffer_size
=
{
torch
.
float32
:
0
,
torch
.
float16
:
0
,
torch
.
bfloat16
:
0
,
"float8"
:
0
}
for
group_id
,
group
in
enumerate
(
self
.
parameter_groups
):
dp_group
=
(
self
.
data_parallel_group
if
not
group
.
is_expert_param
else
self
.
expert_data_parallel_group
)
group
.
data_parallel_world_size
=
torch
.
distributed
.
get_world_size
(
group
=
dp_group
)
gradient_scaling_factor
=
(
self
.
gradient_scaling_factor
if
not
group
.
is_expert_param
else
self
.
expert_gradient_scaling_factor
)
one_param
=
group
.
params
[
0
]
is_dtype_float8
=
is_float8tensor
(
one_param
)
or
meta_device_init_fp8_params
.
get
(
self
.
param_to_name
[
one_param
],
False
)
if
is_dtype_float8
:
param_dtype
=
torch
.
uint8
grad_dtype
=
torch
.
bfloat16
else
:
param_dtype
=
group
.
params
[
0
].
dtype
grad_dtype
=
param_dtype
should_create_grad_buffer_or_main_weight_buffer
=
(
not
self
.
only_create_grad_buffer_and_main_weight_buffer_for_param_requires_grad
or
group
.
requires_grad
)
# Initialize the model weight buffer.
if
data_parallel_sharding_strategy
!=
'no_shard'
:
group
.
model_weight_buffer
=
DataParallelBuffer
(
self
.
ddp_config
,
group
.
params
,
is_data_distributed
=
is_model_weight_buffer_distributed
and
group
.
data_parallel_world_size
>
1
,
dtype
=
param_dtype
,
device
=
self
.
device
,
data_parallel_group
=
dp_group
,
init_meta_only
=
True
,
is_dtype_float8
=
is_dtype_float8
,
temporary_bucket_allocator
=
self
.
memory_allocator_for_model_weight_buffer
,
bucket_id
=
group_id
,
)
# Initialize the main weight buffer.
if
should_create_grad_buffer_or_main_weight_buffer
and
preserve_fp32_weights
:
group
.
main_weight_buffer
=
DataParallelBuffer
(
self
.
ddp_config
,
group
.
params
,
is_data_distributed
=
is_main_weight_buffer_distributed
and
group
.
data_parallel_world_size
>
1
,
dtype
=
torch
.
float32
,
device
=
self
.
device
,
data_parallel_group
=
dp_group
,
init_meta_only
=
True
,
bucket_id
=
group_id
,
)
# Initialize the main grad buffer.
if
should_create_grad_buffer_or_main_weight_buffer
:
group
.
main_grad_buffer
=
DataParallelBuffer
(
self
.
ddp_config
,
group
.
params
,
is_data_distributed
=
is_grad_buffer_distributed
and
group
.
data_parallel_world_size
>
1
,
dtype
=
torch
.
float32
if
grad_reduce_in_fp32
else
grad_dtype
,
device
=
self
.
device
,
data_parallel_group
=
dp_group
,
init_meta_only
=
True
,
is_dtype_float8
=
not
grad_reduce_in_fp32
and
grad_dtype
is
torch
.
uint8
,
gradient_scaling_factor
=
gradient_scaling_factor
,
bucket_id
=
group_id
,
)
if
grad_reduce_in_fp32
:
buffer_size
[
torch
.
float32
]
+=
group
.
main_grad_buffer
.
data_size
elif
group
.
main_grad_buffer
.
is_dtype_float8
:
buffer_size
[
"float8"
]
+=
group
.
main_grad_buffer
.
data_size
else
:
buffer_size
[
group
.
main_grad_buffer
.
dtype
]
+=
group
.
main_grad_buffer
.
data_size
reset_context_args
=
{
"init_param_with_fp8"
:
self
.
ddp_config
.
fp8_param_gather
}
module_reset_flag
=
{}
if
self
.
reset_parameters_for_meta_device_init_module
:
self
.
param_to_direct_module
=
{}
for
name
,
m
in
self
.
module
.
named_modules
():
for
p
in
m
.
parameters
(
recurse
=
False
):
self
.
param_to_direct_module
[
p
]
=
(
name
,
m
)
meta_params_numel
=
0
cuda_params_numel
=
0
cpu_params_numel
=
0
for
group
in
self
.
parameter_groups
:
for
p
in
group
.
params
:
if
p
.
is_meta
:
meta_params_numel
+=
p
.
numel
()
elif
p
.
device
.
type
==
'cuda'
:
cuda_params_numel
+=
p
.
numel
()
else
:
cpu_params_numel
+=
p
.
numel
()
log_str
=
(
f
"Meta params numel:
{
meta_params_numel
/
1_000_000
:.
2
f
}
M, "
f
"CUDA params numel:
{
cuda_params_numel
/
1_000_000
:.
2
f
}
M, "
f
"CPU params numel:
{
cpu_params_numel
/
1_000_000
:.
2
f
}
M"
)
log_on_each_pipeline_stage
(
logger
,
logging
.
INFO
,
log_str
)
# Initialize the model weight buffer data of each parameter group.
for
group
in
self
.
parameter_groups
:
wbuf
=
group
.
model_weight_buffer
if
wbuf
:
wbuf
.
data
=
torch
.
empty
(
wbuf
.
data_size
,
dtype
=
wbuf
.
dtype
,
device
=
self
.
device
)
bucket
=
wbuf
.
fetch_bucket
()
mbuf
=
group
.
main_weight_buffer
if
mbuf
:
mbuf
.
data
=
torch
.
empty
(
mbuf
.
data_size
,
dtype
=
mbuf
.
dtype
,
device
=
self
.
device
)
for
item_id
,
p
in
enumerate
(
group
.
params
):
if
wbuf
:
if
self
.
reset_parameters_for_meta_device_init_module
and
p
.
is_meta
:
m_name
,
m
=
self
.
param_to_direct_module
[
p
]
if
not
module_reset_flag
.
get
(
m_name
,
False
)
and
hasattr
(
m
,
"reset_parameters"
):
old_params
=
list
(
m
.
parameters
(
recurse
=
False
))
# If the GPU memory over threshold, empty cache to leave
# some memory for initialization of the model on the
# CUDA device.
if
check_gpu_memory
(
threshold
=
0.5
):
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
m
.
to_empty
(
device
=
self
.
device
,
recurse
=
False
)
if
is_te_min_version
(
"0.9.0"
)
and
not
isinstance
(
m
,
TransformerEngineBaseModule
):
reset_context_args
[
"with_cuda_rng_tracker"
]
=
True
with
ResetParametersContext
(
**
reset_context_args
):
m
.
reset_parameters
()
module_reset_flag
[
m_name
]
=
True
new_params
=
list
(
m
.
parameters
(
recurse
=
False
))
self
.
_reset_parameters
(
old_params
,
new_params
)
p
=
group
.
params
[
item_id
]
# After resetting parameters, delete fp8 transpose cache
# if we do not need keep cache.
if
not
self
.
ddp_config
.
keep_fp8_transpose_cache_when_using_custom_fsdp
:
for
_param
in
m
.
parameters
(
recurse
=
False
):
if
is_float8tensor
(
_param
):
_param
.
_transpose_invalid
=
True
_param
.
_transpose
=
None
assert
not
p
.
is_meta
,
(
self
.
param_to_name
[
p
],
module_reset_flag
)
wbuf
.
set_item
(
item_id
,
p
.
data
)
# reset the parameter data to the buffer
new_param_data
=
wbuf
.
get_item_from_bucket
(
bucket
,
item_id
).
view
(
p
.
shape
)
if
is_float8tensor
(
p
):
modify_underlying_storage
(
p
,
new_param_data
)
else
:
old_param_data
=
p
.
data
p
.
data
=
new_param_data
assert
old_param_data
.
_base
is
None
p
.
data
.
detach
().
copy_
(
old_param_data
)
del
old_param_data
if
mbuf
:
if
hasattr
(
p
,
'get_high_precision_init_val'
):
mbuf
.
set_item
(
item_id
,
p
.
get_high_precision_init_val
())
p
.
clear_high_precision_init_val
()
else
:
mbuf
.
set_item
(
item_id
,
p
)
if
wbuf
and
wbuf
.
is_data_distributed
:
"""
When MCore Custom FSDP `optim_grads_params` is enabled,
it is necessary to save the tensor local shard. This local shard is
accessible through the `fully_shard_param_local_shard`
attribute of the tensor.
This attribute contains the local shard of the fully
sharded parameter, which is essential for correctly
saving and loading the model state when using
`optim_grads_params` with FSDP.
Example:
>>> # Assuming `tensor` is a fully sharded parameter
>>> local_shard = tensor.fully_shard_param_local_shard
>>> # Save the local shard as needed
"""
local_shard
=
wbuf
.
get_item
(
item_id
,
only_shard
=
True
)
local_shard
.
fsdp_shard_orig_param
=
p
p
.
fully_shard_param_local_shard
=
local_shard
p
.
fully_shard_param_local_index
=
wbuf
.
locate_item_in_global_item
(
item_id
)
def
disable_shard_param_to_function
(
*
unused
):
"""Prevents users from accessing the 'to' operation
on parameters after sharding.
This restriction helps maintain data integrity and
proper sharding behavior by disabling direct 'to'
device/dtype operations on sharded parameters.
"""
raise
RuntimeError
(
"Your model is wrapped by MCore Custom FSDP. All "
"parameter dtypes and devices must be set before FSDP "
"wrapping. After FSDP wrapping, parameter storage "
"is sharded and you cannot modify parameter "
"dtypes or devices."
)
setattr
(
p
,
'to'
,
disable_shard_param_to_function
)
def
disable_shard_param_cpu_function
(
*
unused
):
warnings
.
warn
(
"The parameters are sharded by custom fsdp, "
"and no actual cpu operation is performed."
)
return
torch
.
empty
([],
device
=
'cpu'
)
setattr
(
p
,
'cpu'
,
disable_shard_param_cpu_function
)
if
wbuf
and
wbuf
.
is_data_distributed
:
wbuf
.
free_bucket_storage
()
# Allocate the main_weight buffer and main_grad buffer data in one buffer.
if
self
.
buffer_all_in_one
:
self
.
buffer
=
{
torch
.
float32
:
torch
.
empty
(
buffer_size
[
torch
.
float32
],
dtype
=
torch
.
float32
,
device
=
self
.
device
),
torch
.
float16
:
torch
.
empty
(
buffer_size
[
torch
.
float16
],
dtype
=
torch
.
float16
,
device
=
self
.
device
),
torch
.
bfloat16
:
torch
.
empty
(
buffer_size
[
torch
.
bfloat16
],
dtype
=
torch
.
bfloat16
,
device
=
self
.
device
),
"float8"
:
torch
.
empty
(
buffer_size
[
"float8"
],
dtype
=
torch
.
uint8
,
device
=
self
.
device
),
}
offset
=
{
torch
.
float32
:
0
,
torch
.
float16
:
0
,
torch
.
bfloat16
:
0
,
"float8"
:
0
}
def
_alloc
(
dtype
,
size
):
if
self
.
buffer_all_in_one
:
if
dtype
==
torch
.
uint8
:
dtype
=
"float8"
data
=
self
.
buffer
[
dtype
][
offset
[
dtype
]
:
offset
[
dtype
]
+
size
]
offset
[
dtype
]
+=
size
return
data
return
torch
.
empty
(
size
,
dtype
=
dtype
,
device
=
self
.
device
)
# Initialize the main grad buffer data of each parameter group.
for
group
in
self
.
parameter_groups
:
gbuf
=
group
.
main_grad_buffer
if
not
gbuf
:
continue
gbuf
.
data
=
_alloc
(
gbuf
.
dtype
,
gbuf
.
data_size
)
gbuf
.
data
.
zero_
()
for
item_id
,
p
in
enumerate
(
group
.
params
):
p
.
fsdp_managed_main_grad
=
gbuf
.
get_item
(
item_id
)
p
.
_gbuf
=
gbuf
p
.
_item_id
=
item_id
def
main_grad_getter
(
p
):
# Make sure main_grad memory storage ready.
bucket
=
p
.
_gbuf
.
fetch_bucket
()
gbuf
=
p
.
_gbuf
item_id
=
p
.
_item_id
return
gbuf
.
get_item_from_bucket
(
bucket
,
item_id
).
view
(
p
.
shape
)
setattr
(
p
.
__class__
,
'main_grad'
,
property
(
main_grad_getter
))
if
gbuf
.
is_data_distributed
:
gbuf
.
free_bucket_storage
()
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
def
_reset_parameters
(
self
,
old_params
,
new_params
):
assert
len
(
old_params
)
==
len
(
new_params
)
param_map
=
{}
for
old_param
,
new_param
in
zip
(
old_params
,
new_params
):
param_map
[
old_param
]
=
new_param
self
.
param_to_name
[
new_param
]
=
self
.
param_to_name
[
old_param
]
del
self
.
param_to_name
[
old_param
]
self
.
param_to_param_group
[
new_param
]
=
self
.
param_to_param_group
[
old_param
]
del
self
.
param_to_param_group
[
old_param
]
self
.
param_to_direct_module
[
new_param
]
=
self
.
param_to_direct_module
[
old_param
]
del
self
.
param_to_direct_module
[
old_param
]
for
item_id
,
p
in
enumerate
(
self
.
params
):
if
p
in
param_map
:
new_p
=
param_map
[
p
]
self
.
params
[
item_id
]
=
new_p
for
group
in
self
.
parameter_groups
:
for
item_id
,
p
in
enumerate
(
group
.
params
):
if
p
not
in
param_map
:
continue
new_p
=
param_map
[
p
]
group
.
params
[
item_id
]
=
new_p
for
buf
in
[
group
.
model_weight_buffer
,
group
.
main_weight_buffer
,
group
.
main_grad_buffer
,
]:
if
buf
is
None
:
continue
buf
.
param_idx
[
new_p
]
=
buf
.
param_idx
[
p
]
del
buf
.
param_idx
[
p
]
def
scale_gradients
(
self
,
scaling_factor
:
float
)
->
None
:
"""Scale the gradient data by `scaling_factor`."""
for
group
in
self
.
parameter_groups
:
if
group
.
main_grad_buffer
is
None
:
continue
group
.
main_grad_buffer
.
data
*=
scaling_factor
self
.
update_main_grads
()
def
zero_grad
(
self
):
"""
Zero out the underlying grad_buffer and reset all buckets in preparation
for the next iteration of training.
"""
for
_
,
param
in
self
.
optimizer_named_parameters
:
if
param
.
grad
is
not
None
and
param
.
grad
.
_base
is
None
:
# For tensors that are not referenced, trying to use storage
# resize to make memory free immediately.
_free_storage
(
param
.
grad
)
param
.
grad
=
None
for
group
in
self
.
parameter_groups
:
if
group
.
main_grad_buffer
is
None
:
continue
group
.
main_grad_buffer
.
data
.
zero_
()
def
_init_optimizer_named_parameters
(
self
)
->
List
[
Tuple
[
str
,
torch
.
nn
.
Parameter
]]:
named_parameters
=
[]
for
pg
in
self
.
parameter_groups
:
if
pg
.
main_grad_buffer
is
None
:
continue
optimizer_state_is_shard
=
pg
.
main_grad_buffer
.
is_data_distributed
or
(
pg
.
main_weight_buffer
and
pg
.
main_weight_buffer
.
is_data_distributed
)
for
item_id
,
orig_param
in
enumerate
(
pg
.
params
):
if
pg
.
main_weight_buffer
:
param
=
pg
.
main_weight_buffer
.
get_item
(
item_id
,
only_shard
=
optimizer_state_is_shard
)
elif
pg
.
model_weight_buffer
:
param
=
pg
.
model_weight_buffer
.
get_item
(
item_id
,
only_shard
=
optimizer_state_is_shard
)
else
:
param
=
orig_param
def
set_param_attribute_closure
(
param
,
orig_param
):
def
set_param_attribute
():
for
attr_name
in
[
'requires_grad'
,
'sequence_parallel'
,
'shared'
,
'tensor_model_parallel'
,
'partition_dim'
,
'partition_stride'
,
'is_embedding_or_output_parameter'
,
]:
if
hasattr
(
orig_param
,
attr_name
):
setattr
(
param
,
attr_name
,
getattr
(
orig_param
,
attr_name
))
return
set_param_attribute
setattr
(
param
,
'reset_attribute'
,
set_param_attribute_closure
(
param
,
orig_param
))
setattr
(
param
,
'orig_param'
,
orig_param
)
param
.
reset_attribute
()
named_parameters
.
append
((
self
.
param_to_name
[
orig_param
],
param
))
return
named_parameters
def
update_main_grads
(
self
):
"""Update the main gradients for preparing the optimizer step."""
for
_
,
param
in
self
.
optimizer_named_parameters
:
param
.
reset_attribute
()
orig_param
=
param
.
orig_param
group
=
self
.
parameter_groups
[
self
.
param_to_param_group
[
orig_param
]]
item_id
=
group
.
main_grad_buffer
.
param_idx
[
orig_param
]
optimizer_grad
=
group
.
main_grad_buffer
.
get_item
(
item_id
,
only_shard
=
group
.
main_weight_buffer
.
is_data_distributed
)
setattr
(
param
,
'grad'
,
optimizer_grad
.
to
(
param
.
dtype
)
if
optimizer_grad
.
numel
()
>
0
else
None
,
)
@
property
def
num_buckets
(
self
):
"""Return the number of buckets."""
return
len
(
self
.
parameter_groups
)
@
torch
.
no_grad
()
def
copy_main_weights_to_model_weights
(
self
):
"""Update the model weights from the main weights."""
for
pg
in
self
.
parameter_groups
:
mbuf
=
pg
.
main_weight_buffer
wbuf
=
pg
.
model_weight_buffer
if
mbuf
is
None
:
continue
fp8_params
=
[]
shard_fp32_from_fp8
=
[]
shard_offsets_in_fp8
=
[]
shard_model_params
=
[]
for
param
in
pg
.
params
:
item_id
=
mbuf
.
param_idx
[
param
]
if
wbuf
:
if
wbuf
.
is_data_distributed
or
mbuf
.
is_data_distributed
:
model_param
=
wbuf
.
get_item
(
item_id
,
only_shard
=
True
)
main_weight
=
mbuf
.
get_item
(
item_id
,
only_shard
=
True
)
else
:
model_param
=
wbuf
.
get_item
(
item_id
)
main_weight
=
mbuf
.
get_item
(
item_id
)
else
:
assert
not
mbuf
.
is_data_distributed
model_param
=
param
main_weight
=
pg
.
main_weight_buffer
.
get_item
(
item_id
)
if
is_float8tensor
(
param
):
fp8_params
.
append
(
param
)
if
model_param
.
numel
()
==
0
:
shard_fp32_from_fp8
.
append
(
None
)
shard_offsets_in_fp8
.
append
(
None
)
shard_model_params
.
append
(
None
)
else
:
shard_fp32_from_fp8
.
append
(
main_weight
)
shard_offsets_in_fp8
.
append
(
wbuf
.
locate_item_in_global_item
(
item_id
)[
0
])
shard_model_params
.
append
(
model_param
)
continue
if
model_param
.
numel
()
>
0
:
model_param
.
data
.
copy_
(
main_weight
.
view
(
model_param
.
shape
))
quantize_param_shard
(
fp8_params
,
shard_fp32_from_fp8
,
shard_offsets_in_fp8
,
wbuf
.
data_parallel_group
,
shard_model_params
,
)
@
torch
.
no_grad
()
def
copy_model_weights_to_main_weights
(
self
):
"""Copy the model weights to the main weights."""
for
group
in
self
.
parameter_groups
:
mbuf
=
group
.
main_weight_buffer
if
mbuf
is
None
:
continue
wbuf
=
group
.
model_weight_buffer
if
mbuf
.
is_data_distributed
:
copyin_data
=
wbuf
.
get_shard_from_local_buffer
()
else
:
copyin_data
=
wbuf
.
data
assert
mbuf
.
data
.
numel
()
==
copyin_data
.
numel
(),
(
f
"Master weight buffer size
{
mbuf
.
data
.
numel
()
}
does not match "
f
"model weight buffer size
{
copyin_data
.
numel
()
}
"
)
mbuf
.
data
.
copy_
(
copyin_data
.
data
)
def
all_gather_parameters
(
self
,
async_op
:
bool
=
True
):
"""All gather the parameters.
Args:
async_op (bool, optional): Whether to do the all-reduce
asynchronously. Defaults to False.
"""
assert
all
(
[
not
g
.
model_weight_buffer
.
is_data_distributed
for
g
in
self
.
parameter_groups
]
),
'all_gather_parameters() should only be called when parameters are not sharded.'
all_gather_ops
=
[]
for
g
in
self
.
parameter_groups
:
shard
=
g
.
model_weight_buffer
.
get_shard_from_local_buffer
()
all_gather_handler
=
torch
.
distributed
.
all_gather_into_tensor
(
output_tensor
=
g
.
model_weight_buffer
.
data
,
input_tensor
=
shard
,
group
=
g
.
model_weight_buffer
.
data_parallel_group
,
async_op
=
async_op
,
)
if
async_op
:
all_gather_ops
.
append
(
all_gather_handler
)
for
op
in
all_gather_ops
:
op
.
wait
()
def
reduce_scatter_gradients
(
self
,
async_op
:
bool
=
True
):
"""Reduce scatter the gradients.
Args:
async_op (bool, optional): Whether to do the all-reduce
asynchronously. Defaults to False.
"""
assert
all
(
[
not
g
.
main_grad_buffer
.
is_data_distributed
for
g
in
self
.
parameter_groups
]
),
'reduce_scatter_gradients() should only be called when gradients are not sharded.'
reduce_scatter_ops
=
[]
for
g
in
self
.
parameter_groups
:
gbuf
=
g
.
main_grad_buffer
if
gbuf
is
not
None
:
continue
scaling_factor
=
gbuf
.
gradient_scaling_factor
reduce_op
=
gradient_reduce_preprocessing
(
gbuf
.
data
,
scaling_factor
,
self
.
ddp_config
)
reduce_scatter_handler
=
torch
.
distributed
.
reduce_scatter_tensor
(
output
=
gbuf
.
get_shard_from_local_buffer
(),
input
=
gbuf
.
data
,
op
=
reduce_op
,
group
=
g
.
main_grad_buffer
.
data_parallel_group
,
async_op
=
async_op
,
)
if
async_op
:
reduce_scatter_ops
.
append
(
reduce_scatter_handler
)
for
op
in
reduce_scatter_ops
:
op
.
wait
()
def
all_reduce_gradients
(
self
,
async_op
:
bool
=
False
):
"""All reduce the gradients.
Args:
async_op (bool, optional): Whether to do the all-reduce
asynchronously. Defaults to False.
"""
assert
all
(
[
not
g
.
main_grad_buffer
.
is_data_distributed
for
g
in
self
.
parameter_groups
if
g
.
main_grad_buffer
]
),
'all_reduce_gradients() should only be called when gradients are not sharded.'
all_reduce_ops
=
[]
for
g
in
self
.
parameter_groups
:
gbuf
=
g
.
main_grad_buffer
if
gbuf
is
not
None
:
continue
scaling_factor
=
gbuf
.
gradient_scaling_factor
reduce_op
=
gradient_reduce_preprocessing
(
gbuf
.
data
,
scaling_factor
,
self
.
ddp_config
)
all_reduce_handler
=
torch
.
distributed
.
all_reduce
(
gbuf
.
data
,
op
=
reduce_op
,
group
=
gbuf
.
data_parallel_group
,
async_op
=
async_op
)
if
async_op
:
all_reduce_ops
.
append
(
all_reduce_handler
)
for
op
in
all_reduce_ops
:
op
.
wait
()
class
BucketStatus
(
Enum
):
"""
An enumeration of possible statuses for a data-parallel communication bucket.
Attributes:
EMPTY (int): The bucket is empty and not in use.
COMMUNICATING (int): The bucket is currently being used for communication.
READY_TO_USE (int): The bucket is filled with data and ready for use.
"""
EMPTY
=
1
COMMUNICATING
=
2
READY_TO_USE
=
3
class
GradReducePipeline
:
"""
Pipeline for reducing gradients.
"""
def
__init__
(
self
,
param_and_grad_buffer
:
ParamAndGradBuffer
,
cuda_stream
:
Optional
[
torch
.
cuda
.
Stream
]
=
None
,
check_nans
:
bool
=
False
,
)
->
None
:
self
.
buffer
=
param_and_grad_buffer
self
.
grad_reduce_queue
=
[]
self
.
bucket_status
=
{
i
:
BucketStatus
.
EMPTY
for
i
in
range
(
self
.
buffer
.
num_buckets
)
if
self
.
buffer
.
parameter_groups
[
i
].
main_grad_buffer
}
self
.
bucket_grad_ready_params
=
[
set
()
for
_
in
range
(
self
.
buffer
.
num_buckets
)]
self
.
cuda_stream
=
cuda_stream
self
.
check_nans
=
check_nans
@
property
def
num_buckets
(
self
):
"""Return the number of buckets."""
return
self
.
buffer
.
num_buckets
def
reset
(
self
):
"""Handle the processing tasks and reset the pipeline."""
self
.
wait_for_previous_grad_reduce
(
0
)
for
bucket_id
,
grad_ready_params
in
enumerate
(
self
.
bucket_grad_ready_params
):
param_list
=
self
.
buffer
.
parameter_groups
[
bucket_id
].
params
n_params
=
len
(
param_list
)
param_to_name
=
self
.
buffer
.
param_to_name
assert
len
(
grad_ready_params
)
==
0
,
(
f
"Found
{
len
(
grad_ready_params
)
}
out of
{
n_params
}
parameters that are ready for "
f
"reduce-scatter/all-reduce, but the pipeline is being reset. "
f
"grad_ready_params:
{
[
param_to_name
[
p
]
for
p
in
grad_ready_params
]
}
"
f
"param_list:
{
[
param_to_name
[
p
]
for
p
in
param_list
]
}
"
)
for
bucket_id
,
_
in
self
.
bucket_status
.
items
():
gbuf
=
self
.
buffer
.
parameter_groups
[
bucket_id
].
main_grad_buffer
gbuf
.
free_bucket_storage
()
self
.
bucket_status
[
bucket_id
]
=
BucketStatus
.
EMPTY
def
reduce_gradients
(
self
,
params
:
List
[
torch
.
Tensor
],
suggested_queue_capacity
:
Optional
[
int
]
=
None
):
"""Reduce the gradients for the given parameters.
Args:
params (List[torch.Tensor]): The parameters.
suggested_queue_capacity (int, optional): The suggested queue capacity.
Defaults to None.
"""
for
param
in
params
:
bucket_id
=
self
.
buffer
.
param_to_param_group
[
param
]
param_group
=
self
.
buffer
.
parameter_groups
[
bucket_id
]
if
not
param
.
requires_grad
:
assert
param_group
.
requires_grad
is
False
,
(
f
"Param
{
self
.
buffer
.
param_to_name
[
param
]
}
has requires_grad=False, "
f
"but it is in a parameter group with requires_grad=True."
)
continue
assert
param_group
.
requires_grad
,
(
f
"Param
{
self
.
buffer
.
param_to_name
[
param
]
}
has requires_grad=True, "
f
"but it is in a parameter group with requires_grad=False."
)
# Mark grad as ready for reduce-scatter/all-reduce.
self
.
bucket_grad_ready_params
[
bucket_id
].
add
(
param
)
if
len
(
self
.
bucket_grad_ready_params
[
bucket_id
])
==
len
(
param_group
.
params
):
self
.
wait_for_previous_grad_reduce
(
suggested_queue_capacity
=
suggested_queue_capacity
)
self
.
mark_bucket_ready
(
bucket_id
,
async_rs
=
True
)
def
wait_for_previous_grad_reduce
(
self
,
suggested_queue_size
:
int
=
1
,
suggested_queue_capacity
:
Optional
[
int
]
=
None
):
"""
Wait for the previous reduce-scatter/all-reduce to finish.
Args:
suggested_queue_size (int, optional): The recommended queue size. Defaults to 1.
suggested_queue_capacity (Optional[int], optional): The recommended queue capacity.
Defaults to None.
"""
if
suggested_queue_capacity
is
not
None
:
queue_space
=
sum
(
[
self
.
buffer
.
parameter_groups
[
bucket_id
].
main_grad_buffer
.
bucket_index
.
size
for
_
,
_
,
bucket_id
in
self
.
grad_reduce_queue
]
)
while
queue_space
>
suggested_queue_capacity
:
grad_reduce_event
,
free_up_grad_bucket
,
bucket_id
=
self
.
grad_reduce_queue
.
pop
(
0
)
grad_reduce_event
.
wait
()
free_up_grad_bucket
()
queue_space
-=
self
.
buffer
.
parameter_groups
[
bucket_id
].
main_grad_buffer
.
bucket_index
.
size
else
:
suggested_queue_size
=
max
(
0
,
min
(
suggested_queue_size
,
self
.
buffer
.
num_buckets
-
1
))
while
len
(
self
.
grad_reduce_queue
)
>
suggested_queue_size
:
grad_reduce_event
,
free_up_grad_bucket
,
_
=
self
.
grad_reduce_queue
.
pop
(
0
)
grad_reduce_event
.
wait
()
free_up_grad_bucket
()
def
mark_bucket_ready
(
self
,
bucket_id
:
int
,
async_rs
:
bool
=
False
)
->
bool
:
"""Mark the bucket ready for reduce-scatter/all-reduce, if all bucket in
the bucket group are ready, then do the reduce-scatter/all-reduce.
Args:
bucket_id (int): The bucket to be marked.
async_rs (bool, optional): Whether to do the reduce-scatter/all-reduce
asynchronously. Defaults to False.
Returns:
bool: True if the bucket is go for reduce-scatter/all-reduce.
"""
# Prepare the bucket group for gradient reduce. Note that the
# some bucket parameters do not require grad, so we need to
# remove them from the bucket group.
bucket_group
=
self
.
buffer
.
bucket_group_of_bucket
[
bucket_id
]
bucket_group
=
[
i
for
i
in
bucket_group
if
self
.
buffer
.
parameter_groups
[
i
].
main_grad_buffer
]
# If any bucket in the bucket group is not ready, skip the gradient reduce
# waiting for the bucket group to be all ready before executing.
for
bucket_id
in
bucket_group
:
param_group
=
self
.
buffer
.
parameter_groups
[
bucket_id
]
if
len
(
self
.
bucket_grad_ready_params
[
bucket_id
])
!=
len
(
param_group
.
params
):
return
False
current_stream
=
torch
.
cuda
.
current_stream
()
reduce_scatter_stream
=
(
self
.
cuda_stream
if
self
.
cuda_stream
is
not
None
else
torch
.
cuda
.
current_stream
()
)
reduce_scatter_stream
.
wait_stream
(
current_stream
)
dp_group
=
self
.
buffer
.
parameter_groups
[
bucket_id
].
main_grad_buffer
.
data_parallel_group
with
torch
.
cuda
.
stream
(
reduce_scatter_stream
):
with
_coalescing_manager
(
dp_group
,
async_ops
=
async_rs
)
as
coalescing_event
:
grad_shards
=
{}
for
bucket_id
in
bucket_group
:
gbuf
=
self
.
buffer
.
parameter_groups
[
bucket_id
].
main_grad_buffer
bucket
=
gbuf
.
fetch_bucket
()
scaling_factor
=
gbuf
.
gradient_scaling_factor
reduce_op
=
gradient_reduce_preprocessing
(
gbuf
.
data
,
scaling_factor
,
gbuf
.
ddp_config
)
if
gbuf
.
ddp_config
.
data_parallel_sharding_strategy
==
'no_shard'
:
torch
.
distributed
.
all_reduce
(
bucket
.
data
,
op
=
reduce_op
,
group
=
gbuf
.
data_parallel_group
)
else
:
grad_shard
=
gbuf
.
get_shard_from_bucket
(
bucket
)
# pylint: disable=C0301
# The `grad_shard`` is part of `bucket.data`` and the following
# new empty is important for memory safety, when using
# TORCH_NCCL_AVOID_RECORD_STREAMS=1.
# For reference: https://dev-discuss.pytorch.org/t/fsdp-cudacachingallocator-an-outsider-newb-perspective/1486
grad_shard
=
torch
.
empty_like
(
grad_shard
)
torch
.
distributed
.
reduce_scatter_tensor
(
output
=
grad_shard
,
input
=
bucket
.
data
,
op
=
reduce_op
,
group
=
gbuf
.
data_parallel_group
,
)
grad_shards
[
bucket_id
]
=
grad_shard
self
.
bucket_status
[
bucket_id
]
=
BucketStatus
.
COMMUNICATING
coalescing_event
.
wait
()
for
bucket_id
in
bucket_group
:
# Local gradient accumulate
gbuf
=
self
.
buffer
.
parameter_groups
[
bucket_id
].
main_grad_buffer
if
gbuf
.
ddp_config
.
data_parallel_sharding_strategy
!=
'no_shard'
:
# Gradient accumulate on local buffer
local_buffer
=
gbuf
.
get_shard_from_local_buffer
()
local_buffer
+=
grad_shards
[
bucket_id
]
reduce_scatter_view_out_event
=
reduce_scatter_stream
.
record_event
()
free_up_grad_bucket_func
=
{}
for
bucket_id
in
bucket_group
:
def
get_closure
(
bucket_id
):
def
free_up_grad_bucket
():
self
.
bucket_grad_ready_params
[
bucket_id
]
=
set
()
gbuf
=
self
.
buffer
.
parameter_groups
[
bucket_id
].
main_grad_buffer
if
gbuf
.
is_data_distributed
:
gbuf
.
free_bucket_storage
()
self
.
bucket_status
[
bucket_id
]
=
BucketStatus
.
EMPTY
return
free_up_grad_bucket
free_up_grad_bucket_func
[
bucket_id
]
=
get_closure
(
bucket_id
)
if
async_rs
:
for
bucket_id
,
free_up_grad_bucket
in
free_up_grad_bucket_func
.
items
():
self
.
grad_reduce_queue
.
append
(
(
reduce_scatter_view_out_event
,
free_up_grad_bucket
,
bucket_id
)
)
return
True
reduce_scatter_view_out_event
.
wait
()
for
free_up_grad_bucket
in
free_up_grad_bucket_func
.
values
():
free_up_grad_bucket
()
return
True
class
PrefetchOrder
(
Enum
):
"""
An enumeration of possible prefetch orders for data-parallel operations.
Attributes:
FORWARD_PASS_ORDER (int): Prefetch in the order of forward pass computation.
BACKWARD_PASS_ORDER (int): Prefetch in the order of backward pass computation.
"""
FORWARD_PASS_ORDER
=
0
BACKWARD_PASS_ORDER
=
1
class
AllGatherPipeline
:
"""
Pipeline for all-gathering parameters.
"""
def
__init__
(
self
,
param_and_grad_buffer
:
ParamAndGradBuffer
)
->
None
:
self
.
buffer
=
param_and_grad_buffer
self
.
param_gather_event_map
=
{}
self
.
bucket_status
=
{
i
:
BucketStatus
.
EMPTY
for
i
in
range
(
self
.
buffer
.
num_buckets
)}
self
.
bucket_can_be_released
=
{
i
:
False
for
i
in
range
(
self
.
buffer
.
num_buckets
)}
self
.
bucket_to_bucket_group
=
{}
group_id
=
0
for
bucket_group
in
self
.
buffer
.
bucket_group_of_bucket
.
values
():
new_group
=
False
for
bucket_id
in
bucket_group
:
if
bucket_id
not
in
self
.
bucket_to_bucket_group
:
new_group
=
True
break
if
new_group
:
group_id
+=
1
for
bucket_id
in
bucket_group
:
self
.
bucket_to_bucket_group
[
bucket_id
]
=
group_id
@
property
def
num_buckets
(
self
):
"""Return the number of buckets."""
return
self
.
buffer
.
num_buckets
def
reset
(
self
):
"""Reset the pipeline state."""
if
len
(
self
.
param_gather_event_map
)
>
0
:
warnings
.
warn
(
"There are still pending all-gather tasks, process them. "
f
"Bucket status:
{
self
.
bucket_status
}
."
,
UserWarning
,
)
while
len
(
self
.
param_gather_event_map
)
>
0
:
bucket_id
=
next
(
iter
(
self
.
param_gather_event_map
))
self
.
wait_bucket_ready
(
bucket_id
)
for
bucket_id
in
self
.
bucket_can_be_released
:
self
.
bucket_can_be_released
[
bucket_id
]
=
True
self
.
recycle_unused_buckets
()
assert
all
([
status
is
BucketStatus
.
EMPTY
for
status
in
self
.
bucket_status
.
values
()]),
(
f
"There are still working buckets, it is not safe to reset. "
f
"bucket_status:
{
self
.
bucket_status
}
."
)
assert
all
(
[
not
can_be_released
for
can_be_released
in
self
.
bucket_can_be_released
.
values
()]
),
(
f
"The bucket can be released table is in an abnormal state, not safe to reset. "
f
"bucket_can_be_released:
{
self
.
bucket_can_be_released
}
."
)
def
all_gather_params
(
self
,
params
:
List
[
torch
.
Tensor
],
prefetch
:
bool
=
False
,
prefetch_order
:
PrefetchOrder
=
PrefetchOrder
.
FORWARD_PASS_ORDER
,
suggested_AG_prefetch_size
:
Optional
[
int
]
=
None
,
):
"""All-gather the params. If prefetch is enabled, prefetch next buckets
in the order of `prefetch_order`.
Args:
params (List[torch.Tensor]): The list of params to be all-gathered.
prefetch (bool, optional): Whether to prefetch the next bucket. Defaults to False.
prefetch_order (PrefetchOrder, optional): The order of prefetching.
Defaults to PrefetchOrder.FORWARD_PASS_ORDER.
suggested_AG_prefetch_size (Optional[int], optional):
The suggested prefetch size for all-gathering. Defaults to None.
"""
if
len
(
params
)
==
0
:
return
ag_buckets
=
[
self
.
buffer
.
param_to_param_group
[
item
]
for
item
in
params
]
ag_buckets
=
list
(
sorted
(
set
(
ag_buckets
)))
parameter_groups
=
self
.
buffer
.
parameter_groups
# If prefetch is enabled, we will add prefetch buckets to ag_buckets.
if
prefetch
:
def
next_bucket_id
(
ag_buckets
):
if
prefetch_order
==
PrefetchOrder
.
FORWARD_PASS_ORDER
:
bucket_id
=
ag_buckets
[
0
]
+
1
for
i
in
ag_buckets
[
1
:]:
if
i
!=
bucket_id
:
break
bucket_id
+=
1
else
:
bucket_id
=
ag_buckets
[
-
1
]
-
1
for
i
in
reversed
(
ag_buckets
[:
-
1
]):
if
i
!=
bucket_id
:
break
bucket_id
-=
1
if
bucket_id
<
0
or
bucket_id
>=
self
.
buffer
.
num_buckets
:
return
None
return
bucket_id
if
suggested_AG_prefetch_size
is
not
None
:
bucket_id
=
next_bucket_id
(
ag_buckets
)
while
bucket_id
is
not
None
:
all_gather_size
=
sum
(
[
parameter_groups
[
i
].
model_weight_buffer
.
bucket_index
.
size
for
i
in
ag_buckets
]
)
if
all_gather_size
>=
suggested_AG_prefetch_size
:
break
ag_buckets
.
extend
(
self
.
buffer
.
bucket_group_of_bucket
[
bucket_id
])
ag_buckets
=
list
(
sorted
(
set
(
ag_buckets
)))
bucket_id
=
next_bucket_id
(
ag_buckets
)
else
:
bucket_id
=
next_bucket_id
(
ag_buckets
)
if
bucket_id
is
not
None
:
ag_buckets
.
extend
(
self
.
buffer
.
bucket_group_of_bucket
[
bucket_id
])
ag_buckets
=
list
(
sorted
(
set
(
ag_buckets
)))
ag_buckets
=
[
i
for
i
in
ag_buckets
if
self
.
bucket_status
[
i
]
==
BucketStatus
.
EMPTY
]
if
len
(
ag_buckets
)
==
0
:
return
# Divide buckets into aggregate groups
bucket_group_to_buckets
=
{}
for
bucket_id
in
ag_buckets
:
group_id
=
self
.
bucket_to_bucket_group
[
bucket_id
]
if
group_id
not
in
bucket_group_to_buckets
:
bucket_group_to_buckets
[
group_id
]
=
[]
bucket_group_to_buckets
[
group_id
].
append
(
bucket_id
)
# Coalesce all-gather operations for all buckets in the same data-parallel-group
for
_
,
buckets
in
bucket_group_to_buckets
.
items
():
param_group
=
parameter_groups
[
buckets
[
0
]]
dp_group
=
param_group
.
model_weight_buffer
.
data_parallel_group
with
_coalescing_manager
(
dp_group
,
async_ops
=
True
)
as
coalescing_event
:
for
bucket_id
in
buckets
:
self
.
all_gather_bucket_and_set_items
(
bucket_id
,
async_op
=
True
)
# reset param gather event with coalescing event
for
bucket_id
in
buckets
:
_
,
mark_bucket_ready_to_use
=
self
.
param_gather_event_map
[
bucket_id
]
self
.
param_gather_event_map
[
bucket_id
]
=
(
coalescing_event
,
mark_bucket_ready_to_use
,
)
def
wait_bucket_ready
(
self
,
bucket_id
,
empty_ok
=
False
):
"""Wait for the bucket to be ready."""
if
self
.
bucket_status
[
bucket_id
]
==
BucketStatus
.
READY_TO_USE
:
return
if
self
.
bucket_status
[
bucket_id
]
==
BucketStatus
.
EMPTY
:
if
empty_ok
:
return
raise
ValueError
(
f
"Bucket
{
bucket_id
}
is empty."
)
param_gather_event
,
mark_bucket_ready_to_use
=
self
.
param_gather_event_map
.
pop
(
bucket_id
)
param_gather_event
.
wait
()
mark_bucket_ready_to_use
()
@
torch
.
no_grad
()
def
release_bucket
(
self
,
bucket_id
:
int
):
"""Release the bucket."""
if
self
.
bucket_status
[
bucket_id
]
==
BucketStatus
.
EMPTY
:
return
if
self
.
bucket_status
[
bucket_id
]
==
BucketStatus
.
COMMUNICATING
:
raise
ValueError
(
f
"Bucket
{
bucket_id
}
is communicating."
)
wbuf
=
self
.
buffer
.
parameter_groups
[
bucket_id
].
model_weight_buffer
wbuf
.
free_bucket_storage
()
self
.
bucket_status
[
bucket_id
]
=
BucketStatus
.
EMPTY
def
recycle_unused_buckets
(
self
):
"""Recycle the unused buckets."""
for
bucket_id
,
can_be_released
in
self
.
bucket_can_be_released
.
items
():
if
can_be_released
:
self
.
release_bucket
(
bucket_id
)
self
.
bucket_can_be_released
[
bucket_id
]
=
False
@
torch
.
no_grad
()
def
all_gather_bucket_and_set_items
(
self
,
bucket_id
:
int
,
async_op
:
bool
=
False
)
->
None
:
"""All-gather the bucket and set the items."""
self
.
bucket_can_be_released
[
bucket_id
]
=
False
if
self
.
bucket_status
[
bucket_id
]
!=
BucketStatus
.
EMPTY
:
return
self
.
bucket_status
[
bucket_id
]
=
BucketStatus
.
COMMUNICATING
wbuf
=
self
.
buffer
.
parameter_groups
[
bucket_id
].
model_weight_buffer
# Lazy release the unused buckets.
self
.
recycle_unused_buckets
()
bucket
=
wbuf
.
fetch_bucket
(
and_allocate_params_data
=
True
)
param_gather_event
=
torch
.
distributed
.
all_gather_into_tensor
(
output_tensor
=
bucket
.
data
,
input_tensor
=
wbuf
.
get_shard_from_local_buffer
(),
group
=
wbuf
.
data_parallel_group
,
async_op
=
async_op
,
)
def
get_closure
(
bucket_id
):
@
torch
.
no_grad
()
def
mark_bucket_ready_to_use
():
self
.
bucket_status
[
bucket_id
]
=
BucketStatus
.
READY_TO_USE
return
mark_bucket_ready_to_use
mark_bucket_ready_to_use
=
get_closure
(
bucket_id
)
if
async_op
:
self
.
param_gather_event_map
[
bucket_id
]
=
(
param_gather_event
,
mark_bucket_ready_to_use
)
return
mark_bucket_ready_to_use
()
@
torch
.
no_grad
()
def
gradient_reduce_preprocessing
(
grad_data
,
scaling_factor
,
ddp_config
):
"""
Gradient reduce preprocessing for gradient averaging and gradient scaling.
"""
if
scaling_factor
is
None
:
reduce_op
=
torch
.
distributed
.
ReduceOp
.
SUM
elif
ddp_config
.
average_in_collective
:
reduce_op
=
torch
.
distributed
.
ReduceOp
.
AVG
elif
ddp_config
.
gradient_reduce_div_fusion
and
grad_data
.
dtype
!=
torch
.
bfloat16
:
reduce_op
=
torch
.
distributed
.
_make_nccl_premul_sum
(
scaling_factor
)
else
:
grad_data
.
mul_
(
scaling_factor
)
reduce_op
=
torch
.
distributed
.
ReduceOp
.
SUM
return
reduce_op
def
check_gpu_memory
(
threshold
=
0.9
):
"""
Check if the GPU memory is over the threshold.
Args:
threshold (float, optional): The threshold to check if the GPU memory is over.
Defaults to 0.9.
Returns:
bool: True if the GPU memory is over the threshold.
"""
if
not
torch
.
cuda
.
is_available
():
return
False
device
=
torch
.
cuda
.
current_device
()
allocated
=
torch
.
cuda
.
memory_allocated
(
device
)
reserved
=
torch
.
cuda
.
memory_reserved
(
device
)
total
=
torch
.
cuda
.
get_device_properties
(
device
).
total_memory
allocated_ratio
=
allocated
/
total
reserved_ratio
=
reserved
/
total
near_full
=
allocated_ratio
>=
threshold
or
reserved_ratio
>=
threshold
if
near_full
:
log_on_each_pipeline_stage
(
logger
,
logging
.
INFO
,
f
"GPU Memory: Allocated:
{
allocated_ratio
:.
2
%
}
, Reserved:
{
reserved_ratio
:.
2
%
}
"
,
)
return
near_full
class
ResetParametersContext
:
"""
Context manager for resetting parameters for meta device initialization module.
"""
def
__init__
(
self
,
init_param_with_fp8
=
False
,
with_cuda_rng_tracker
=
False
):
self
.
init_param_with_fp8
=
init_param_with_fp8
self
.
with_cuda_rng_tracker
=
with_cuda_rng_tracker
def
__enter__
(
self
):
self
.
stack
=
ExitStack
()
if
self
.
init_param_with_fp8
:
args
=
{
"enabled"
:
True
}
if
"preserve_high_precision_init_val"
in
inspect
.
signature
(
fp8_model_init
).
parameters
:
args
[
"preserve_high_precision_init_val"
]
=
True
self
.
stack
.
enter_context
(
fp8_model_init
(
**
args
))
if
self
.
with_cuda_rng_tracker
:
self
.
stack
.
enter_context
(
get_cuda_rng_tracker
().
fork
())
return
self
def
__exit__
(
self
,
*
exc_details
):
self
.
stack
.
__exit__
(
*
exc_details
)
Megatron-LM/megatron/core/distributed/data_parallel_base.py
0 → 100644
View file @
4e867b3c
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from
contextlib
import
contextmanager
import
torch
from
..transformer.module
import
MegatronModule
from
..transformer.transformer_config
import
TransformerConfig
class
_BaseDataParallel
(
MegatronModule
):
"""A template class for DistributedDataParallel implementations."""
def
__init__
(
self
,
config
:
TransformerConfig
,
module
:
torch
.
nn
.
Module
):
super
().
__init__
(
config
=
config
)
self
.
module
=
module
def
forward
(
self
,
*
inputs
,
**
kwargs
):
"""
Calls the wrapped module's forward() method.
"""
return
self
.
module
(
*
inputs
,
**
kwargs
)
@
contextmanager
def
no_sync
(
self
):
"""
Context manager that turns off gradient synchronization.
"""
try
:
yield
finally
:
pass
def
start_grad_sync
(
self
,
*
unused
):
"""
Initiates grad sync (all-reduce or reduce-scatter) communication operations
for all model gradients.
When overlap_grad_reduce is set to True, dispatches asynchronous communication
calls. When overlap_grad_reduce is set to False, calls synchronous
communication ops.
"""
pass
def
scale_gradients
(
self
,
scaling_factor
:
float
)
->
None
:
"""Scale all gradients inside the buffers by `scaling_factor`."""
pass
def
finish_grad_sync
(
self
):
"""
Finishes grad sync (all-reduce or reduce-scatter) communication operations
for all model gradients.
When overlap_grad_reduce is set to True, waits for asynchronous communication
calls to complete. When overlap_grad_reduce is set to False, calls synchronous
communication ops.
"""
pass
def
zero_grad_buffer
(
self
):
"""
Zeros out all grad buffers. Needs to be called at the beginning of each
training iteration.
"""
pass
def
broadcast_params
(
self
):
"""
Syncs parameters across all DP ranks.
"""
pass
def
state_dict
(
self
,
prefix
=
''
,
keep_vars
=
False
,
destination
=
None
):
"""
Returns a dictionary containing references to the whole state of the
wrapped module.
Both parameters and persistent buffers (e.g. running averages) are included.
Keys are corresponding parameter and buffer names. Parameters and buffers
set to None are not included.
"""
return
self
.
module
.
state_dict
(
prefix
=
prefix
,
keep_vars
=
keep_vars
,
destination
=
destination
)
def
state_dict_for_save_checkpoint
(
self
,
prefix
=
''
,
keep_vars
=
False
):
"""
Returns wrapped module's state_dict for checkpoint saving.
"""
return
self
.
module
.
state_dict_for_save_checkpoint
(
prefix
=
prefix
,
keep_vars
=
keep_vars
)
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
"""
Copies parameters and buffers from state_dict into the wrapped module and its
descendants. If strict is True, then the keys of state_dict must exactly match
the keys returned by this module’s state_dict() function.
"""
self
.
module
.
load_state_dict
(
state_dict
,
strict
=
strict
)
Megatron-LM/megatron/core/distributed/distributed_data_parallel.py
0 → 100644
View file @
4e867b3c
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import
logging
from
contextlib
import
contextmanager
import
torch
from
..
import
parallel_state
from
..config_logger
import
has_config_logger_enabled
,
log_config_to_disk
from
..fp8_utils
import
is_float8tensor
from
..transformer.cuda_graphs
import
is_graph_capturing
from
..transformer.transformer_config
import
TransformerConfig
from
..utils
import
log_single_rank
from
.data_parallel_base
import
_BaseDataParallel
from
.distributed_data_parallel_config
import
DistributedDataParallelConfig
from
.param_and_grad_buffer
import
_ParamAndGradBuffer
,
partition_buckets
logger
=
logging
.
getLogger
(
__name__
)
class
DistributedDataParallel
(
_BaseDataParallel
):
"""
DDP wrapper which stores grads in contiguous buffers. Also has option of overlapping
communication with backprop computation by breaking up full model's gradients into smaller
buckets and running all-reduce / reduce-scatter on each bucket asynchronously. This class
also provides the option to do the gradient accumulation in a type other than the param type
(e.g., fp32 for a bf16 model).
Args:
config: Transformer config object.
ddp_config: DistributedDataParallel config object.
module: Underlying model.
disable_bucketing: If true, force assign all parameters to a single bucket. If false,
use standard bucketing policy: assign parameters to smaller buckets and all-reduce
per bucket _if_ overlap_grad_reduce is True and pp_rank is 0.
"""
def
__init__
(
self
,
config
:
TransformerConfig
,
ddp_config
:
DistributedDataParallelConfig
,
module
:
torch
.
nn
.
Module
,
disable_bucketing
:
bool
=
False
,
):
super
().
__init__
(
config
=
config
,
module
=
module
)
if
has_config_logger_enabled
(
config
):
log_config_to_disk
(
config
,
locals
(),
prefix
=
type
(
self
).
__name__
)
self
.
module
=
module
# If bucket_size is not provided as an input, use sane default.
# If using very large dp_sizes, make buckets larger to ensure that chunks used in NCCL
# ring-reduce implementations are large enough to remain bandwidth-bound rather than
# latency-bound.
if
ddp_config
.
bucket_size
is
None
:
ddp_config
.
bucket_size
=
max
(
40000000
,
1000000
*
parallel_state
.
get_data_parallel_world_size
()
)
# Set bucket_size to infinity if overlap_grad_reduce is False.
if
not
ddp_config
.
overlap_grad_reduce
:
ddp_config
.
bucket_size
=
None
self
.
ddp_config
=
ddp_config
log_single_rank
(
logger
,
logging
.
INFO
,
f
'Setting up DistributedDataParallel with config
{
self
.
ddp_config
}
'
,
)
# Turn off bucketing if we are on a pipeline stage that is not the first (since
# data-parallel communication on these stages is not on the critical path), or if
# disable_bucketing is True (e.g., we might not want to break up model parameters
# into buckets for model chunks after the first in the interleaved schedule).
self
.
bucket_size
=
self
.
ddp_config
.
bucket_size
if
parallel_state
.
get_pipeline_model_parallel_rank
()
>
0
:
self
.
bucket_size
=
None
if
disable_bucketing
:
self
.
bucket_size
=
None
self
.
param_to_bucket_group
=
{}
# Group parameters by their gradient type.
param_to_name
=
{}
dense_params
=
[]
expert_parallel_params
=
[]
self
.
params_with_grad
=
[]
for
name
,
param
in
self
.
module
.
named_parameters
():
if
not
param
.
requires_grad
:
continue
# Track params with grad to enable direct setting
# of param.grad_added_to_main_grad
self
.
params_with_grad
.
append
(
param
)
param
.
grad_added_to_main_grad
=
False
param_to_name
[
param
]
=
name
if
getattr
(
param
,
'allreduce'
,
True
):
dense_params
.
append
(
param
)
else
:
expert_parallel_params
.
append
(
param
)
def
_allocate_buffers_for_parameters
(
input_params
,
data_parallel_group
,
gradient_scaling_factor
):
param_and_grad_dtype_to_params
=
{}
param_and_grad_dtype_to_offsets
=
{}
param_and_grad_dtype_to_indices
=
{}
# Group parameters by their gradient type.
for
param
in
input_params
:
assert
param
.
requires_grad
param_dtype
=
param
.
dtype
if
is_float8tensor
(
param
):
# Currently TE's Float8Tensor is a wrapper of torch.Tensor. It has a "fake"
# dtype (usually a higher precision dtype such as bfloat16), but its actual
# data is stored in the form of a torch uint8 tensor within the Float8Tensor's
# ".data" attribute. Therefore, when creating the param buffer for fp8 params,
# it is necessary to use torch.uint8, not the "fake" dtype got from
# "param.dtype".
param_dtype
=
torch
.
uint8
grad_dtype
=
torch
.
float
if
self
.
ddp_config
.
grad_reduce_in_fp32
else
param
.
dtype
params
=
param_and_grad_dtype_to_params
.
get
((
param_dtype
,
grad_dtype
),
[])
params
.
append
(
param
)
param_and_grad_dtype_to_params
[(
param_dtype
,
grad_dtype
)]
=
params
# Get the index of each param among the params with same dtype, if a param is fp8,
# use its "fake" high precision dtype to find which params have same dtype with it.
# For example:
# Case 1:
# params = [p1(bf16), p2(bf16), p3(bf16), p4(bf16)]
# param_and_grad_dtype_to_indices = {
# (torch.bfloat16, torch.float32): [0, 1, 2, 3],
# }
# Case 2:
# params = [p1(bf16), p2(fp8), p3(fp8), p4(bf16)]
# param_and_grad_dtype_to_indices = {
# (torch.bfloat16, torch.float32): [0, 3],
# (torch.uint8, torch.float32): [1, 2],
# }
# We need these indices to load a non-native-fp8 checkpoint in native-fp8 mode.
offset
=
param_and_grad_dtype_to_offsets
.
get
((
param
.
dtype
,
grad_dtype
),
0
)
param_and_grad_dtype_to_offsets
[(
param
.
dtype
,
grad_dtype
)]
=
offset
+
1
indices
=
param_and_grad_dtype_to_indices
.
get
((
param_dtype
,
grad_dtype
),
[])
indices
.
append
(
offset
)
param_and_grad_dtype_to_indices
[(
param_dtype
,
grad_dtype
)]
=
indices
if
not
config
.
calculate_per_token_loss
:
target_gradient_scaling_factor
=
1.0
/
parallel_state
.
get_data_parallel_world_size
(
with_context_parallel
=
True
)
if
self
.
ddp_config
.
average_in_collective
:
if
self
.
ddp_config
.
num_distributed_optimizer_instances
==
1
:
# Collective is averaging gradients in collective with data_parallel_group.
assert
(
gradient_scaling_factor
/
torch
.
distributed
.
get_world_size
(
group
=
data_parallel_group
)
==
target_gradient_scaling_factor
)
else
:
# For non-expert parameters, gradient_scaling_factor is 1.
# For expert parameters, gradient_scaling_factor is edp_size/dp_size.
assert
(
gradient_scaling_factor
==
1
)
or
(
gradient_scaling_factor
==
(
parallel_state
.
get_expert_data_parallel_world_size
()
/
parallel_state
.
get_data_parallel_world_size
(
with_context_parallel
=
True
)
)
)
else
:
assert
gradient_scaling_factor
==
target_gradient_scaling_factor
# Allocate the grad buffers and map the grads.
buffers
=
[]
for
(
param_dtype
,
grad_dtype
),
params
in
param_and_grad_dtype_to_params
.
items
():
buffers
.
append
(
_ParamAndGradBuffer
(
self
.
ddp_config
,
param_dtype
,
grad_dtype
,
params
,
data_parallel_group
,
self
.
bucket_size
,
param_to_name
,
gradient_scaling_factor
,
param_and_grad_dtype_to_indices
[(
param_dtype
,
grad_dtype
)],
)
)
# In some scenarios, we want to put buckets from different buffers into a group so that
# their communication can be aggregated. For example, when there are both fp8 buffers
# and bf16 buffers in the model and vpp is enabled, each model chunk will have an fp8
# bucket and a bf16 bucket, which doubles the number of communication kernels, and
# because of the use of CUDA_DEVICE_MAX_CONNECTIONS=1, having multiple back-to-back
# communications will prevent the overlap of the communication kernels with computation
# kernels.
# If bucketing is explicitly disabled, then put all buckets in a buffer into a single
# bucket group.
bucket_groups
=
partition_buckets
(
buffers
,
force_single_bucket_group
=
disable_bucketing
)
if
self
.
ddp_config
.
num_distributed_optimizer_instances
>
1
:
assert
(
parallel_state
.
get_expert_model_parallel_world_size
()
==
1
),
"Partial DistOpt cannot support MoE models with expert parallelism."
assert
(
self
.
ddp_config
.
use_distributed_optimizer
),
'Partial DistOpt cannot be used without DistOpt'
communication_stream
=
torch
.
cuda
.
Stream
(
device
=
torch
.
cuda
.
current_device
())
for
bucket_group
in
bucket_groups
:
bucket_group
.
inter_distributed_optimizer_instance_group
=
(
parallel_state
.
get_inter_partial_data_parallel_group
()
)
bucket_group
.
communication_stream
=
communication_stream
# Set `next_param_gather_bucket_group` for different bucket groups by iterating through
# buckets in reverse order (since all-gathers happen in reverse order of buckets).
if
self
.
ddp_config
.
use_distributed_optimizer
and
self
.
ddp_config
.
overlap_param_gather
:
num_bucket_groups
=
len
(
bucket_groups
)
for
i
in
range
(
1
,
num_bucket_groups
):
bucket_groups
[
num_bucket_groups
-
i
].
next_param_gather_bucket_group
=
(
bucket_groups
[
num_bucket_groups
-
i
-
1
]
)
# Create map from param to bucket group, used in pre_hook.
for
bucket_group
in
bucket_groups
:
for
bucket
in
bucket_group
.
buckets
:
for
param
in
bucket
.
params_list
:
self
.
param_to_bucket_group
[
param
]
=
bucket_group
return
buffers
,
bucket_groups
if
config
.
calculate_per_token_loss
:
assert
(
not
self
.
ddp_config
.
average_in_collective
),
"Cannot average in collective when calculating per-token loss!"
gradient_scaling_factor
=
1.0
expert_gradient_scaling_factor
=
1.0
else
:
# The goal is to scale reduced gradients by 1/dp_size.
# This can be achieved in two ways:
#
# Case 1: average_in_collective=True
# - Non-expert parameters:
# 1. No pre-scaling (gradient_scaling_factor=1.0)
# 2. Do average reduction over dp group (equals to sum then divide by dp_size)
# 3. Final result is scaled by 1/dp_size as desired
#
# - Expert parameters:
# 1. Scale by edp_size/dp_size before reduction
# 2. Do average reduction over edp group (equals to sum then divide by edp_size)
# 3. Resulted scaling: (edp_size/dp_size) * (1/edp_size) = 1/dp_size as desired
# (edp_size = expert data parallel world size)
#
# Case 2: average_in_collective=False
# - Both expert and non-expert parameters:
# 1. Scale gradients by 1/dp_size before reduction
# 2. Do sum reduction across data parallel ranks
# 3. Final result is scaled by 1/dp_size as desired
if
self
.
ddp_config
.
average_in_collective
:
gradient_scaling_factor
=
1.0
expert_gradient_scaling_factor
=
(
parallel_state
.
get_expert_data_parallel_world_size
()
/
parallel_state
.
get_data_parallel_world_size
(
with_context_parallel
=
True
)
)
else
:
data_parallel_world_size
=
parallel_state
.
get_data_parallel_world_size
(
with_context_parallel
=
True
)
gradient_scaling_factor
=
1.0
/
data_parallel_world_size
expert_gradient_scaling_factor
=
1.0
/
data_parallel_world_size
# Allocate the param+grad buffers for dense params' grads.
self
.
buffers
,
self
.
bucket_groups
=
_allocate_buffers_for_parameters
(
dense_params
,
parallel_state
.
get_data_parallel_group
(
with_context_parallel
=
True
,
partial_data_parallel
=
True
),
gradient_scaling_factor
=
gradient_scaling_factor
,
)
# Allocate separate param+grad buffers for expert parallel params' grads.
self
.
expert_parallel_buffers
,
self
.
expert_parallel_bucket_groups
=
(
_allocate_buffers_for_parameters
(
expert_parallel_params
,
parallel_state
.
get_expert_data_parallel_group
(),
gradient_scaling_factor
=
expert_gradient_scaling_factor
,
)
)
# Delete references to weight_tensor if they exist since we don't want two parameter copies
# if we re-mapped parameters (which happens when we use the distributed optimizer).
# This is a temporary workaround around a TE bug that is fixed with
# https://github.com/NVIDIA/TransformerEngine/pull/719.
if
self
.
ddp_config
.
use_distributed_optimizer
:
@
torch
.
no_grad
()
def
unmap_weight_tensor
(
m
):
if
hasattr
(
m
,
'weight_tensor'
):
m
.
weight_tensor
=
None
self
.
module
.
apply
(
unmap_weight_tensor
)
# Register backward hook.
# Accumulation function for the gradients need to be stored so they
# don't go out of scope.
self
.
grad_accs
=
[]
for
param
in
self
.
module
.
parameters
():
if
param
.
requires_grad
:
# Expand so we get access to grad_fn.
param_tmp
=
param
.
expand_as
(
param
)
# Get the gradient accumulator function.
grad_acc
=
param_tmp
.
grad_fn
.
next_functions
[
0
][
0
]
grad_acc
.
register_hook
(
self
.
_make_backward_post_hook
(
param
))
self
.
grad_accs
.
append
(
grad_acc
)
self
.
use_forward_hook
=
(
self
.
ddp_config
.
use_distributed_optimizer
and
self
.
ddp_config
.
overlap_param_gather
)
self
.
remove_forward_pre_hook_handles
=
{}
if
self
.
use_forward_hook
:
self
.
enable_forward_pre_hook
()
self
.
overlap_param_gather_with_optimizer_step
=
False
def
enable_forward_pre_hook
(
self
):
"""
Enable forward pre-hooks needed for param all-gather overlap with forward compute.
"""
assert
self
.
use_forward_hook
assert
len
(
self
.
remove_forward_pre_hook_handles
)
==
0
# Register forward pre-hook for all sub-modules.
for
module
in
self
.
module
.
modules
():
self
.
remove_forward_pre_hook_handles
[
module
]
=
module
.
register_forward_pre_hook
(
self
.
_make_forward_pre_hook
()
)
def
disable_forward_pre_hook
(
self
,
param_sync
:
bool
=
True
):
"""
Disable forward pre-hooks needed for param all-gather overlap with forward compute.
Skip synchronous param all-gather if `param_sync` is False.
"""
assert
self
.
use_forward_hook
# De-register forward pre-hook for all sub-modules.
for
module
in
self
.
module
.
modules
():
assert
self
.
remove_forward_pre_hook_handles
[
module
]
is
not
None
self
.
remove_forward_pre_hook_handles
[
module
].
remove
()
del
self
.
remove_forward_pre_hook_handles
[
module
]
assert
len
(
self
.
remove_forward_pre_hook_handles
)
==
0
# Force synchronize parameters.
if
param_sync
:
self
.
start_param_sync
(
force_sync
=
True
)
def
_make_forward_pre_hook
(
self
):
"""
Create a forward pre-hook to wait on all-gather handles when necessary (i.e.,
when a module uses a parameter in a bucket with a still incomplete all-gather).
"""
def
hook
(
module
,
*
unused
):
assert
(
self
.
use_forward_hook
),
"Should use pre-hook only when overlap_param_gather is True"
if
is_graph_capturing
():
return
# Make sure all parameters in this module have been all-gathered as necessary.
for
param
in
module
.
parameters
(
recurse
=
False
):
# Skip parameters without an associated buffer (such parameters have a
# .requires_grad field equal to False).
if
param
not
in
self
.
param_to_bucket_group
:
continue
assert
param
.
requires_grad
# If aligning param all-gather across pipeline stages, all-gather is dispatched
# by start_param_sync calls in core/pipeline_parallelism/schedules.py.
# If overlapping param all-gather with optimizer step, then all-gather has
# already been dispatched in optimizer step.
skip_next_bucket_dispatch
=
(
self
.
ddp_config
.
align_param_gather
or
self
.
overlap_param_gather_with_optimizer_step
)
self
.
param_to_bucket_group
[
param
].
finish_param_sync
(
skip_next_bucket_dispatch
=
skip_next_bucket_dispatch
)
return
hook
def
_make_backward_post_hook
(
self
,
param
:
torch
.
nn
.
Parameter
):
"""
Creates a backward post-hook to dispatch an all-reduce / reduce-scatter when
ready (i.e., when all grads in a bucket have been computed in all microbatches
in a batch).
"""
def
hook
(
*
unused
):
if
is_graph_capturing
():
return
if
param
in
self
.
param_to_bucket_group
:
assert
param
.
requires_grad
if
self
.
ddp_config
.
overlap_grad_reduce
:
assert
(
param
.
grad
is
not
None
),
'param.grad being None is not safe when overlap_grad_reduce is True'
if
param
.
grad
is
not
None
and
(
not
param
.
grad_added_to_main_grad
or
getattr
(
param
,
'zero_out_wgrad'
,
False
)
):
param
.
main_grad
.
add_
(
param
.
grad
.
data
)
param
.
grad
=
None
if
self
.
ddp_config
.
overlap_grad_reduce
:
self
.
param_to_bucket_group
[
param
].
register_grad_ready
(
param
)
return
hook
@
contextmanager
def
no_sync
(
self
):
"""
Context manager that turns off gradient synchronization.
"""
for
bucket_group
in
self
.
bucket_groups
+
self
.
expert_parallel_bucket_groups
:
bucket_group
.
is_last_microbatch
=
False
try
:
yield
finally
:
for
bucket_group
in
self
.
bucket_groups
+
self
.
expert_parallel_bucket_groups
:
bucket_group
.
is_last_microbatch
=
True
def
start_param_sync
(
self
,
*
unused
,
force_sync
:
bool
=
False
,
force_dispatch
:
bool
=
False
):
"""
Initiates param sync (all-gather) communication operations for all model parameters.
By default, when overlap_param_gather is set to True, dispatches asynchronous communication
calls; when overlap_param_gather is set to False, calls synchronous communication
ops. Can override this default behavior using flags below.
Args:
force_sync (bool, optional): force synchronous collective regardless of
other settings.
force_dispatch (bool, optional): force dispatch regardless of other settings.
"""
if
not
force_sync
:
# If overlapping param AG with optimizer step, AG should not be dispatched again
# in forward_backward_step.
if
self
.
overlap_param_gather_with_optimizer_step
and
not
force_dispatch
:
return
for
bucket_group
in
self
.
bucket_groups
+
self
.
expert_parallel_bucket_groups
:
bucket_group
.
start_param_sync
(
force_sync
=
force_sync
)
def
start_grad_sync
(
self
,
*
unused
):
"""
Initiates grad sync (all-reduce or reduce-scatter) communication operations
for all model gradients.
When overlap_grad_reduce is set to True, dispatches asynchronous communication
calls. When overlap_grad_reduce is set to False, calls synchronous
communication ops.
"""
for
bucket_group
in
self
.
bucket_groups
+
self
.
expert_parallel_bucket_groups
:
bucket_group
.
start_grad_sync
()
def
finish_grad_sync
(
self
):
"""
Finishes grad sync (all-reduce or reduce-scatter) communication operations
for all model gradients.
When overlap_grad_reduce is set to True, waits for asynchronous communication
calls to complete. When overlap_grad_reduce is set to False, calls synchronous
communication ops.
"""
for
bucket_group
in
self
.
bucket_groups
+
self
.
expert_parallel_bucket_groups
:
bucket_group
.
finish_grad_sync
()
def
scale_gradients
(
self
,
scaling_factor
:
float
):
"""Scale all gradients inside the buffers by `scaling_factor`."""
for
buffer
in
self
.
buffers
+
self
.
expert_parallel_buffers
:
buffer
.
scale_gradients
(
scaling_factor
)
def
zero_grad_buffer
(
self
):
"""
Zeros out all grad buffers. Needs to be called at the beginning of each
training iteration.
"""
if
not
getattr
(
self
.
config
,
'external_cuda_graph'
,
False
):
# Don't reset grad_added_to_main_grad when CUDA Graph is used.
# Because in CUDA Graph it no longer has the opportunity to set it back
# to True, and there will be a double-GA.
for
param
in
self
.
params_with_grad
:
param
.
grad_added_to_main_grad
=
False
for
buffer
in
self
.
buffers
+
self
.
expert_parallel_buffers
:
buffer
.
reset
()
for
bucket_group
in
self
.
bucket_groups
+
self
.
expert_parallel_bucket_groups
:
bucket_group
.
reset
()
def
broadcast_params
(
self
):
"""
Syncs parameters across all DP ranks.
"""
for
param
in
self
.
module
.
parameters
():
is_expert_parallel
=
not
getattr
(
param
,
'allreduce'
,
True
)
if
is_expert_parallel
:
data_parallel_group
=
parallel_state
.
get_expert_data_parallel_group
()
else
:
data_parallel_group
=
parallel_state
.
get_data_parallel_group
(
with_context_parallel
=
True
,
partial_data_parallel
=
True
)
torch
.
distributed
.
broadcast
(
param
.
data
,
src
=
torch
.
distributed
.
get_global_rank
(
data_parallel_group
,
0
),
group
=
data_parallel_group
,
)
Megatron-LM/megatron/core/distributed/distributed_data_parallel_config.py
0 → 100644
View file @
4e867b3c
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from
dataclasses
import
dataclass
from
typing
import
Optional
@
dataclass
class
DistributedDataParallelConfig
:
"""Configuration for DistributedDataParallel."""
grad_reduce_in_fp32
:
bool
=
False
"""If true, reduce grads in fp32."""
overlap_grad_reduce
:
bool
=
False
"""If true, overlap grad all-reduce / reduce-scatter with backward compute."""
overlap_param_gather
:
bool
=
False
"""If true, overlap param all-gather with forward compute."""
align_param_gather
:
bool
=
False
"""If true, all PP stages will launch param all-gathers simultaneously. Otherwise, each
PP stage will independently launch as needed.
"""
use_distributed_optimizer
:
bool
=
False
"""If true, issue reduce-scatter collectives to aggregate gradients and clean up
originally allocated model parameters, otherwise issue all-reduce collectives.
"""
num_distributed_optimizer_instances
:
int
=
1
"""Sets the factor by which the DP domain is sharded to have the partial DistOpt
enabled. Defaults to 1, which means DistOpt is across entire DP domain.
"""
check_for_nan_in_grad
:
bool
=
False
"""If true, check for NaNs and Infs in gradients _before_ communication collective."""
check_for_large_grads
:
bool
=
False
"""If true, check for unexpectedly large gradients _before_ communication collective."""
bucket_size
:
Optional
[
int
]
=
None
"""Maximum number of parameters in each bucket. If unspecified, MCore uses a default
value of max(40000000, 1000000 * dp_size) parameters (larger DP sizes need larger
buckets to ensure collectives do not become latency-bound)."""
pad_buckets_for_high_nccl_busbw
:
bool
=
False
"""If true, make sure the bucket size is divisible by a large power of 2 (2^16) to
ensure NCCL collectives have high bus bandwidth at large DP counts, since NCCL
message size (which for ring algorithms is bucket_size / dp_size) apparently needs
to be divisible by a power of 2 for high busbw."""
average_in_collective
:
bool
=
False
"""If true, compute average in collective directly, as opposed to dividing by the
dp_size first and then computing sum in the collective."""
fp8_param_gather
:
bool
=
False
"""If true, keep the compute param in fp8 (do not use any other intermediate dtype) and
perform the param all-gather in fp8."""
use_custom_fsdp
:
bool
=
False
"""If true, use the FSDP code path for DDP."""
data_parallel_sharding_strategy
:
str
=
'no_shard'
"""Sharding strategy for FSDP. Valid values are 'no_shard', 'optim',
'optim_grads', 'optim_grads_params'."""
gradient_reduce_div_fusion
:
bool
=
True
"""If true, perform gradient reduce and division fusion."""
suggested_communication_unit_size
:
int
=
None
"""Specifies the number of elements to communicate at once during
FSDP (Fully Sharded Data Parallel) operations.
This flag also affects FSDP all-gather prefetch behavior. Setting a larger
value increases the communication buffer size, while a smaller value
disables prefetching and may degrade performance. Adjust this value
based on your system's memory and performance requirements."""
preserve_fp32_weights
:
bool
=
True
"""If true, preserve fp32 weights in the custom FSDP ParamAndGradBuffer."""
keep_fp8_transpose_cache_when_using_custom_fsdp
:
bool
=
False
"""If true, keep the fp8 transpose cache when using custom FSDP."""
Megatron-LM/megatron/core/distributed/finalize_model_grads.py
0 → 100644
View file @
4e867b3c
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from
typing
import
List
,
Optional
,
Union
import
torch
from
torch._utils
import
_flatten_dense_tensors
,
_unflatten_dense_tensors
try
:
from
torch.distributed._tensor
import
DTensor
,
distribute_tensor
HAVE_DTENSOR
=
True
except
ImportError
:
HAVE_DTENSOR
=
False
from
..
import
parallel_state
from
..transformer.moe.moe_utils
import
get_updated_expert_bias
from
..transformer.transformer_config
import
TransformerConfig
from
..utils
import
get_attr_wrapped_model
,
get_model_config
def
_get_main_grad_attr
(
param
:
torch
.
nn
.
Parameter
,
use_custom_fsdp
:
bool
=
False
):
if
use_custom_fsdp
:
return
"fsdp_managed_main_grad"
if
hasattr
(
param
,
"main_grad"
):
return
"main_grad"
return
"grad"
def
_unshard_if_dtensor
(
tensor
:
Union
[
torch
.
Tensor
,
"DTensor"
])
->
torch
.
Tensor
:
"""
Unshards the input tensor if it is a DTensor and otherwise returns the
tensor unmodified.
Args:
tensor (Union[torch.Tensor, DTensor]): The tensor to potentially unshard.
Returns:
An unsharded version of the input tensor if it is a DTensor, or the
input tensor unmodified if it is not a DTensor.
"""
if
HAVE_DTENSOR
and
isinstance
(
tensor
,
DTensor
):
unsharded_tensor
=
tensor
.
full_tensor
()
for
k
,
v
in
vars
(
tensor
).
items
():
setattr
(
unsharded_tensor
,
k
,
v
)
return
unsharded_tensor
return
tensor
def
_reshard_if_dtensor
(
tensor_to_shard
:
torch
.
Tensor
,
reference_tensor
:
Union
[
torch
.
Tensor
,
"DTensor"
]
)
->
Union
[
torch
.
Tensor
,
"DTensor"
]:
"""
Reshards the input tensor to match the sharding configuration of the
reference tensor if the reference tensor is a DTensor. Otherwise, returns
the reference tensor unmodified.
Args:
tensor_to_shard (torch.Tensor): The tensor to be potentially sharded.
reference_tensor (Union[torch.Tensor, DTensor]): The reference tensor
for the sharding configuration.
Returns:
Union[torch.Tensor, DTensor]: The sharded tensor matching the reference tensor's
configuration, or the reference tensor itself if it is not a DTensor.
"""
if
HAVE_DTENSOR
and
isinstance
(
reference_tensor
,
DTensor
):
sharded_tensor
=
distribute_tensor
(
tensor_to_shard
,
device_mesh
=
reference_tensor
.
device_mesh
,
placements
=
reference_tensor
.
placements
,
)
for
k
,
v
in
vars
(
reference_tensor
).
items
():
setattr
(
sharded_tensor
,
k
,
v
)
return
sharded_tensor
return
reference_tensor
def
_allreduce_conditional_embedding_grads
(
model
:
List
[
torch
.
nn
.
Module
],
config
:
TransformerConfig
):
"""
All-reduce conditional embedding grads.
Reduce grads across all the pp stages to ensure that parameters of the conditional embedders
(e.g., timestep embedder, FPS embedder, label embedder) stay in sync.
This is for the models with replicated embedders on each PP / VPP rank, like diffusion models.
"""
if
parallel_state
.
get_pipeline_model_parallel_world_size
()
>
1
and
getattr
(
config
,
"has_cond_embedder"
,
False
):
grads_dict
=
{}
for
model_chunk
in
model
:
for
name
,
param
in
get_attr_wrapped_model
(
model_chunk
,
'named_parameters'
)():
if
param
.
requires_grad
and
getattr
(
param
,
'pipeline_parallel'
,
False
):
grad
=
param
.
main_grad
if
name
in
grads_dict
:
# Add all the virtual PP rank's gradients to
# the first local virtual PP rank.
grads_dict
[
name
][
0
].
add_
(
grad
)
# Append to the end for later update after cross-rank reduce.
grads_dict
[
name
].
append
(
grad
)
else
:
grads_dict
[
name
]
=
[
grad
]
if
grads_dict
:
# All-reduce the gradient on the first VPP rank.
grads
=
[
param_grad
[
0
]
for
_
,
param_grad
in
grads_dict
.
items
()]
coalesced
=
_flatten_dense_tensors
(
grads
)
torch
.
distributed
.
all_reduce
(
coalesced
,
group
=
parallel_state
.
get_pipeline_model_parallel_group
()
)
for
buf
,
synced
in
zip
(
grads
,
_unflatten_dense_tensors
(
coalesced
,
grads
)):
buf
.
copy_
(
synced
)
# Update the gradients on other VPP ranks.
for
grads
in
grads_dict
.
values
():
for
grad
in
grads
[
1
:]:
grad
.
copy_
(
grads
[
0
])
def
_allreduce_word_embedding_grads
(
model
:
List
[
torch
.
nn
.
Module
],
config
:
TransformerConfig
):
"""
All-reduce word embedding grads.
Reduce grads across first and last stages to ensure that word_embeddings parameters stay in
sync.
"""
if
(
parallel_state
.
is_rank_in_embedding_group
(
ignore_virtual
=
True
)
and
torch
.
distributed
.
get_world_size
(
parallel_state
.
get_embedding_group
())
>
1
):
if
parallel_state
.
is_pipeline_first_stage
(
ignore_virtual
=
True
):
model_module
=
model
[
0
]
elif
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
model_module
=
model
[
-
1
]
else
:
# We do not support an interleaved schedule for models with encoders yet.
model_module
=
model
[
0
]
ddp_config
=
model_module
.
ddp_config
model_module
=
get_attr_wrapped_model
(
model_module
,
'pre_process'
,
return_model_obj
=
True
)
# If share_embeddings_and_output_weights is True, we need to maintain duplicated
# embedding weights in post processing stage. If use Multi-Token Prediction (MTP),
# we also need to maintain duplicated embedding weights in mtp process stage.
# So we need to allreduce grads of embedding in the embedding group in these cases.
if
model_module
.
share_embeddings_and_output_weights
or
getattr
(
config
,
'mtp_num_layers'
,
0
):
weight
=
model_module
.
shared_embedding_or_output_weight
()
grad_attr
=
_get_main_grad_attr
(
weight
,
ddp_config
.
use_custom_fsdp
)
orig_grad
=
getattr
(
weight
,
grad_attr
)
grad
=
_unshard_if_dtensor
(
orig_grad
)
torch
.
distributed
.
all_reduce
(
grad
,
group
=
parallel_state
.
get_embedding_group
())
setattr
(
weight
,
grad_attr
,
_reshard_if_dtensor
(
grad
,
orig_grad
))
def
_allreduce_position_embedding_grads
(
model
:
List
[
torch
.
nn
.
Module
],
config
:
TransformerConfig
):
"""
All-reduce position_embeddings grad across encoder and decoder stages to ensure that position
embeddings parameters stay in sync.
"""
if
(
parallel_state
.
is_rank_in_position_embedding_group
()
and
torch
.
distributed
.
get_world_size
(
parallel_state
.
get_position_embedding_group
())
>
1
):
if
parallel_state
.
is_pipeline_first_stage
(
ignore_virtual
=
True
):
model_module
=
model
[
0
]
elif
parallel_state
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
model_module
=
model
[
-
1
]
else
:
# We do not support an interleaved schedule for models with encoders yet.
model_module
=
model
[
0
]
ddp_config
=
model_module
.
ddp_config
model_module
=
get_attr_wrapped_model
(
model_module
,
'pre_process'
,
return_model_obj
=
True
)
assert
hasattr
(
model_module
,
'position_embeddings'
)
weight
=
model_module
.
position_embeddings
.
weight
grad_attr
=
_get_main_grad_attr
(
weight
,
ddp_config
.
use_custom_fsdp
)
orig_grad
=
getattr
(
weight
,
grad_attr
)
grad
=
_unshard_if_dtensor
(
orig_grad
)
torch
.
distributed
.
all_reduce
(
grad
,
group
=
parallel_state
.
get_position_embedding_group
())
setattr
(
weight
,
grad_attr
,
_reshard_if_dtensor
(
grad
,
orig_grad
))
def
_allreduce_embedding_grads
(
model
:
List
[
torch
.
nn
.
Module
],
config
:
TransformerConfig
):
"""
All-reduce both word and position embeddings.
"""
_allreduce_word_embedding_grads
(
model
,
config
)
_allreduce_position_embedding_grads
(
model
,
config
)
def
_allreduce_layernorm_grads
(
model
:
List
[
torch
.
nn
.
Module
],
config
:
TransformerConfig
):
"""
All-reduce layernorm grads (for sequence parallelism).
"""
# All-reduce layernorm parameters across model parallel nodes
# when sequence parallelism is used
if
parallel_state
.
get_tensor_model_parallel_world_size
()
>
1
and
(
config
.
sequence_parallel
or
config
.
qk_layernorm
):
params
=
[]
grads
=
[]
for
model_chunk
in
model
:
ddp_config
=
model_chunk
.
ddp_config
for
name
,
param
in
get_attr_wrapped_model
(
model_chunk
,
'named_parameters'
)():
if
param
.
requires_grad
and
(
getattr
(
param
,
'sequence_parallel'
,
False
)
or
'q_layernorm'
in
name
or
'k_layernorm'
in
name
):
params
.
append
(
param
)
grad_attr
=
_get_main_grad_attr
(
param
,
ddp_config
.
use_custom_fsdp
)
grad
=
getattr
(
param
,
grad_attr
)
grad
=
_unshard_if_dtensor
(
grad
)
grads
.
append
(
grad
.
data
)
if
grads
:
coalesced
=
_flatten_dense_tensors
(
grads
)
torch
.
distributed
.
all_reduce
(
coalesced
,
group
=
parallel_state
.
get_tensor_model_parallel_group
()
)
for
param
,
buf
,
synced
in
zip
(
params
,
grads
,
_unflatten_dense_tensors
(
coalesced
,
grads
)
):
buf
.
copy_
(
synced
)
grad_attr
=
_get_main_grad_attr
(
param
,
ddp_config
.
use_custom_fsdp
)
orig_grad
=
getattr
(
param
,
grad_attr
)
setattr
(
param
,
grad_attr
,
_reshard_if_dtensor
(
buf
,
orig_grad
))
def
_update_router_expert_bias
(
model
:
List
[
torch
.
nn
.
Module
],
config
:
TransformerConfig
):
"""
Update the expert bias of the router for a global batch.
This requires all-reduce of local_tokens_per_expert across TPxCPxDP ranks
"""
tokens_per_expert_list
=
[]
expert_bias_list
=
[]
for
model_chunk
in
model
:
for
module
in
get_attr_wrapped_model
(
model_chunk
,
'modules'
)():
if
hasattr
(
module
,
'expert_bias'
):
tokens_per_expert_list
.
append
(
module
.
local_tokens_per_expert
)
expert_bias_list
.
append
(
module
.
expert_bias
)
# For hybrid models with both MoE and Dense layers, this list can be empty.
if
len
(
expert_bias_list
)
==
0
:
return
stacked_tokens_per_expert
=
torch
.
stack
(
tokens_per_expert_list
,
dim
=
0
)
stacked_expert_bias
=
torch
.
stack
(
expert_bias_list
,
dim
=
0
)
stacked_updated_expert_bias
=
get_updated_expert_bias
(
stacked_tokens_per_expert
,
stacked_expert_bias
,
config
.
moe_router_bias_update_rate
)
for
tokens_per_expert
,
expert_bias
,
updated_expert_bias
in
zip
(
tokens_per_expert_list
,
expert_bias_list
,
stacked_updated_expert_bias
):
tokens_per_expert
.
zero_
()
expert_bias
.
copy_
(
updated_expert_bias
)
def
finalize_model_grads
(
model
:
List
[
torch
.
nn
.
Module
],
num_tokens
:
Optional
[
torch
.
Tensor
]
=
None
):
"""
All-reduce all model grads across DP replicas, layernorm grads for sequence parallelism,
embedding grads across first and last pipeline stages (if not tied),
scale gradients by `num_tokens`.
"""
config
=
get_model_config
(
model
[
0
])
# All-reduce / reduce-scatter across DP replicas.
if
config
.
timers
is
not
None
:
config
.
timers
(
'all-grads-sync'
,
log_level
=
1
).
start
(
barrier
=
config
.
barrier_with_L1_time
)
for
model_chunk
in
model
:
model_chunk
.
finish_grad_sync
()
if
config
.
timers
is
not
None
:
config
.
timers
(
'all-grads-sync'
).
stop
()
# All-reduce t_embedder grads (for pp & vpp of DiT).
if
config
.
timers
is
not
None
:
config
.
timers
(
'conditional-embedder-grads-all-reduce'
,
log_level
=
1
).
start
(
barrier
=
config
.
barrier_with_L1_time
)
_allreduce_conditional_embedding_grads
(
model
,
config
)
if
config
.
timers
is
not
None
:
config
.
timers
(
'conditional-embedder-grads-all-reduce'
).
stop
()
# All-reduce layer-norm grads (for sequence parallelism).
if
config
.
timers
is
not
None
:
config
.
timers
(
'layernorm-grads-all-reduce'
,
log_level
=
1
).
start
(
barrier
=
config
.
barrier_with_L1_time
)
_allreduce_layernorm_grads
(
model
,
config
)
if
config
.
timers
is
not
None
:
config
.
timers
(
'layernorm-grads-all-reduce'
).
stop
()
# All-reduce embedding grads (for pipeline parallelism).
if
config
.
timers
is
not
None
:
config
.
timers
(
'embedding-grads-all-reduce'
,
log_level
=
1
).
start
(
barrier
=
config
.
barrier_with_L1_time
)
_allreduce_embedding_grads
(
model
,
config
)
if
config
.
timers
is
not
None
:
config
.
timers
(
'embedding-grads-all-reduce'
).
stop
()
if
config
.
moe_router_enable_expert_bias
:
_update_router_expert_bias
(
model
,
config
)
# normalize gradients for per-token loss normalization.
# if we are using by the number of tokens, then we use that as a divisor. this number
# will be the total number of non-padded tokens in the global batch.
if
num_tokens
is
not
None
:
# the number of tokens is only present on the last stage, so broadcast it
# to the other ranks in the pipeline parallel group.
last_rank
=
parallel_state
.
get_pipeline_model_parallel_last_rank
()
pp_group
=
parallel_state
.
get_pipeline_model_parallel_group
()
if
not
isinstance
(
last_rank
,
list
):
assert
not
isinstance
(
last_rank
,
list
)
last_rank
=
[
last_rank
]
assert
not
isinstance
(
pp_group
,
list
)
pp_group
=
[
pp_group
]
# need to do a broadcast for every pp group, even though num_tokens should be the same.
num_tokens_list
=
[]
for
lr
,
group
in
zip
(
last_rank
,
pp_group
):
torch
.
distributed
.
broadcast
(
num_tokens
,
src
=
lr
,
group
=
group
)
num_tokens_list
.
append
(
torch
.
clone
(
num_tokens
))
assert
all
(
x
.
item
()
==
num_tokens_list
[
0
]
for
x
in
num_tokens_list
)
# all-reduce across DP ranks.
torch
.
distributed
.
all_reduce
(
num_tokens
,
group
=
parallel_state
.
get_data_parallel_group
())
for
model_chunk
in
model
:
if
num_tokens
>
0
:
scaling
=
1.0
/
num_tokens
model_chunk
.
scale_gradients
(
scaling
)
Megatron-LM/megatron/core/distributed/param_and_grad_buffer.py
0 → 100644
View file @
4e867b3c
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import
logging
import
math
import
warnings
from
contextlib
import
nullcontext
from
enum
import
Enum
from
functools
import
partial
from
typing
import
Dict
,
List
,
Optional
import
torch
from
torch.distributed
import
_coalescing_manager
from
megatron.core.rerun_state_machine
import
get_rerun_state_machine
from
..fp8_utils
import
is_float8tensor
,
modify_underlying_storage
from
..utils
import
is_torch_min_version
,
log_on_each_pipeline_stage
from
.distributed_data_parallel_config
import
DistributedDataParallelConfig
logger
=
logging
.
getLogger
(
__name__
)
if
is_torch_min_version
(
"1.13.0"
):
dist_all_gather_func
=
torch
.
distributed
.
all_gather_into_tensor
dist_reduce_scatter_func
=
torch
.
distributed
.
reduce_scatter_tensor
else
:
dist_all_gather_func
=
torch
.
distributed
.
_all_gather_base
dist_reduce_scatter_func
=
torch
.
distributed
.
_reduce_scatter_base
class
BufferType
(
Enum
):
"""
Enumeration for buffer type.
"""
PARAM
=
1
GRAD
=
2
def
shard_buffer
(
buffer
:
torch
.
Tensor
,
data_parallel_world_size
:
int
):
"""
Shard buffer into data_parallel_world_size chunks of equal size.
"""
assert
buffer
.
numel
()
%
data_parallel_world_size
==
0
shard_size
=
buffer
.
numel
()
//
data_parallel_world_size
sharded_buffer
=
[
buffer
[(
r
*
shard_size
)
:
((
r
+
1
)
*
shard_size
)]
for
r
in
range
(
data_parallel_world_size
)
]
return
sharded_buffer
class
_ParamAndGradBucket
:
"""
Bucket to keep track of a subset of the model's parameters and gradients.
Args:
params: List of parameters whose gradients are collated in this bucket.
param_data: View in _ParamAndGradBuffer.param_data that this bucket is responsible for.
grad_data: View in _ParamAndGradBuffer.grad_data that this bucket is responsible for.
offset: Offset of this bucket's view in the larger _ParamAndGradBuffer.
numel_unpadded: Number of unpadded elements in bucket.
gradient_scaling_factor: This factor is utilized to scale gradients prior to their
communication. Its application is twofold: it facilitates the averaging of gradients
and the scaling of gradients in the context of the Mixture of Experts (MoE) model.
bucket_id: Index of bucket in buffer.
"""
def
__init__
(
self
,
params
:
List
[
torch
.
nn
.
Parameter
],
param_data
:
Optional
[
torch
.
Tensor
],
grad_data
:
torch
.
Tensor
,
offset
:
int
,
numel_unpadded
:
int
,
gradient_scaling_factor
:
float
,
bucket_id
:
int
,
):
self
.
params_list
=
params
self
.
params
=
set
(
params
)
# Make sure there are no duplicate params.
assert
len
(
self
.
params_list
)
==
len
(
self
.
params
)
self
.
param_data
=
param_data
self
.
grad_data
=
grad_data
# The distributed optimizer needs to keep track of this bucket's offset
# within the full grad_buffer.
self
.
offset
=
offset
self
.
numel_unpadded
=
numel_unpadded
self
.
gradient_scaling_factor
=
gradient_scaling_factor
self
.
bucket_id
=
bucket_id
class
_ParamAndGradBucketGroup
:
"""
Put multiple buckets into a group so that their communications can be aggregated together.
Provides functionality to register when params in the bucket group have grads ready to be
synced; an asynchronous communication call is automatically launched when _all_ params in
the bucket group have grads ready.
Args:
buckets: A list of buckets.
ddp_config: DistributedDataParallel config object.
collective_group: intra_distributed_optimizer_instance_group if using distributed
optimizer, data_parallel_group if not.
collective_group_size: World size using the intra data-parallel group.
"""
def
__init__
(
self
,
buckets
:
List
[
_ParamAndGradBucket
],
ddp_config
:
DistributedDataParallelConfig
,
collective_group
:
torch
.
distributed
.
ProcessGroup
,
collective_group_size
:
int
,
):
self
.
buckets
=
buckets
self
.
ddp_config
=
ddp_config
if
self
.
ddp_config
.
use_distributed_optimizer
:
self
.
intra_distributed_optimizer_instance_group
=
collective_group
self
.
intra_distributed_optimizer_instance_size
=
collective_group_size
self
.
intra_distributed_optimizer_instance_rank
=
torch
.
distributed
.
get_rank
(
group
=
collective_group
)
else
:
self
.
data_parallel_group
=
collective_group
# State for bookkeeping: params is the set of parameters this bucket group is
# responsible for, params_with_grad is the set of parameters with grads
# available. When overlap_grad_reduce is True, communication (all-reduce
# or reduce-scatter) is issued when params_with_grad equals params.
self
.
param_to_bucket
=
{}
self
.
params
=
set
()
for
bucket
in
self
.
buckets
:
for
param
in
bucket
.
params_list
:
self
.
param_to_bucket
[
param
]
=
bucket
self
.
params
.
add
(
param
)
self
.
next_param_gather_bucket_group
=
None
if
self
.
ddp_config
.
num_distributed_optimizer_instances
>
1
:
self
.
inter_distributed_optimizer_instance_group
=
None
self
.
communication_stream
=
None
self
.
reset
()
self
.
param_gather_handle
=
None
self
.
param_gather_dispatched
=
False
self
.
grad_reduce_handle
=
None
def
reset
(
self
):
"""
Reset metadata in bucket group in preparation for the next iteration of training.
"""
self
.
params_with_grad
=
set
()
self
.
is_last_microbatch
=
True
def
check_grads
(
self
,
check_for_nan_or_inf
,
check_for_large
):
"""
Make sure norm of grads in bucket are not NaN prior to data-parallel
all-reduce / reduce-scatter.
"""
rerun_state_machine
=
get_rerun_state_machine
()
for
i
in
range
(
len
(
self
.
buckets
)):
grad_norm
=
self
.
buckets
[
i
].
grad_data
.
norm
(
p
=
2
)
# check for NaN, Inf and unexpectedly large grads
if
check_for_nan_or_inf
:
rerun_state_machine
.
validate_result
(
result
=
grad_norm
,
rejection_func
=
torch
.
isnan
,
message
=
f
"found NaN in local grad norm for bucket #
{
i
}
"
f
"in backward pass before data-parallel communication collective"
,
tolerance
=
0.001
,
# 0.1% tolerance to account for non-deterministic FA backward
fatal
=
True
,
)
rerun_state_machine
.
validate_result
(
result
=
grad_norm
,
rejection_func
=
torch
.
isinf
,
message
=
f
"found Inf in local grad norm for bucket #
{
i
}
"
f
"in backward pass before data-parallel communication collective"
,
tolerance
=
0.001
,
# 0.1% tolerance to account for non-deterministic FA backward
fatal
=
True
,
)
if
check_for_large
:
rerun_state_machine
.
validate_result
(
result
=
grad_norm
,
rejection_func
=
partial
(
rerun_state_machine
.
is_unexpectedly_large
,
threshold
=
10
,
context
=
"grads"
),
message
=
f
"found unexpected large grads in bucket #
{
i
}
"
f
"in backward pass before data-parallel communication collective"
,
tolerance
=
0.001
,
# 0.1% tolerance to account for non-deterministic FA backward
fatal
=
False
,
)
def
start_param_sync
(
self
,
force_sync
:
bool
=
False
):
"""
Initiates all necessary param all-gathers for this bucket.
When ddp_config.overlap_param_gather is set to True, dispatches an asynchronous
communication call (unless force_sync is True). When ddp_config.overlap_param_gather
is set to False, makes synchronous call.
Args:
force_sync (bool, optional): force synchronous collective regardless of
other settings if true.
"""
assert
self
.
ddp_config
.
use_distributed_optimizer
if
force_sync
:
if
self
.
param_gather_handle
is
not
None
:
self
.
param_gather_handle
.
wait
()
self
.
param_gather_handle
=
None
return
else
:
assert
self
.
param_gather_handle
is
None
async_op
=
self
.
ddp_config
.
overlap_param_gather
and
not
force_sync
# Coalesce communication kernels across buckets in the bucket group.
with
_coalescing_manager
(
self
.
intra_distributed_optimizer_instance_group
,
async_ops
=
async_op
)
as
cm
:
for
bucket
in
self
.
buckets
:
local_data_view
=
shard_buffer
(
bucket
.
param_data
,
self
.
intra_distributed_optimizer_instance_size
)[
self
.
intra_distributed_optimizer_instance_rank
]
dist_all_gather_func
(
bucket
.
param_data
,
local_data_view
,
group
=
self
.
intra_distributed_optimizer_instance_group
,
async_op
=
async_op
,
)
if
async_op
:
self
.
param_gather_handle
=
cm
else
:
# When using `_coalescing_manager`, even if a synchronous op (async_op=False) is used,
# `cm` is not None, which is different from when `_coalescing_manager` is not used in
# which case the torch.distributed._all_gather_base() will return None. In order to
# maintain consistency with prior code, we need to manually set communication handle to
# None.
self
.
param_gather_handle
=
None
self
.
param_gather_dispatched
=
True
def
finish_param_sync
(
self
,
skip_next_bucket_dispatch
:
bool
=
False
):
"""
Finishes param sync communication operation for this bucket. Dispatches
next bucket's param sync if available, unless skip_next_bucket_dispatch
is True.
When ddp_config.overlap_param_gather is set to True, waits for asynchronous
communication call to complete (and dispatches one if one is not already
outstanding). Throws assertion error if ddp_config.overlap_param_gather is set to
False.
Args:
skip_next_bucket_dispatch (bool, optional): if true, dispatch next
bucket's communication if available.
"""
assert
self
.
ddp_config
.
use_distributed_optimizer
assert
self
.
ddp_config
.
overlap_param_gather
# If current bucket's param AG has not been dispatched, dispatch it now (e.g., first
# AG bucket in first model chunk if ddp_config.align_param_gather is False).
if
not
self
.
param_gather_dispatched
:
self
.
start_param_sync
()
if
self
.
param_gather_handle
is
not
None
:
self
.
param_gather_handle
.
wait
()
self
.
param_gather_handle
=
None
# Dispatch next bucket's asynchronous param AG only if it has not been dispatched yet.
if
self
.
next_param_gather_bucket_group
is
not
None
and
not
skip_next_bucket_dispatch
:
if
self
.
next_param_gather_bucket_group
.
param_gather_dispatched
:
warnings
.
warn
(
"The next bucket's parameter all-gather operation has already been "
"dispatched. This may be caused by a mismatch between the order of "
"parameter registration and forward pass execution, which will "
"hurt the communication-computation overlap performance."
)
else
:
self
.
next_param_gather_bucket_group
.
start_param_sync
()
def
start_grad_sync
(
self
):
"""
Initiates grad sync (all-reduce or reduce-scatter) communication operations
for all buckets in the bucket group.
When ddp_config.overlap_grad_reduce is set to True, dispatches an asynchronous
communication call. When ddp_config.overlap_grad_reduce is set to False, makes
synchronous call.
"""
assert
(
self
.
grad_reduce_handle
is
None
),
'Should not have multiple communication calls outstanding at once'
if
self
.
ddp_config
.
check_for_nan_in_grad
or
self
.
ddp_config
.
check_for_large_grads
:
self
.
check_grads
(
check_for_nan_or_inf
=
self
.
ddp_config
.
check_for_nan_in_grad
,
check_for_large
=
self
.
ddp_config
.
check_for_large_grads
,
)
# gradient_scaling_factor already takes into account whether we are computing
# an average or sum in the data-parallel collective.
for
bucket
in
self
.
buckets
:
if
bucket
.
gradient_scaling_factor
!=
1.0
:
bucket
.
grad_data
*=
bucket
.
gradient_scaling_factor
# Decide reduce_op.
reduce_op
=
torch
.
distributed
.
ReduceOp
.
SUM
if
self
.
ddp_config
.
average_in_collective
:
reduce_op
=
torch
.
distributed
.
ReduceOp
.
AVG
# We use the following stream synchronization for the gradient reduction
# within and across DistOpt instances.
# Compute Stream: -------------Gradient compute-------------------
# Comm. Stream: ------(wait for NCCL)-----(wait for NCCL)-------
# NCCL Stream: -------RS------ -------AR------
# Use async communications only when overlap_grad_reduce is True.
async_op
=
(
self
.
ddp_config
.
overlap_grad_reduce
and
self
.
ddp_config
.
num_distributed_optimizer_instances
==
1
)
if
(
self
.
ddp_config
.
num_distributed_optimizer_instances
>
1
and
self
.
ddp_config
.
overlap_grad_reduce
):
# Assign a communication stream if we have multiple DistOpt instances and we
# need to overlap communication.
stream_context
=
torch
.
cuda
.
stream
(
self
.
communication_stream
)
# The RS/AR communication stream needs to wait for the default stream
# to complete its gradient computation before launching the next
# gradient reduction collective.
self
.
communication_stream
.
wait_stream
(
torch
.
cuda
.
default_stream
())
else
:
stream_context
=
nullcontext
()
if
self
.
ddp_config
.
use_distributed_optimizer
:
communication_group
=
self
.
intra_distributed_optimizer_instance_group
else
:
communication_group
=
self
.
data_parallel_group
# Coalesce communication kernels across buckets in the bucket group.
with
stream_context
,
_coalescing_manager
(
communication_group
,
async_ops
=
async_op
)
as
cm
:
for
bucket
in
self
.
buckets
:
if
self
.
ddp_config
.
use_distributed_optimizer
:
local_data_view
=
shard_buffer
(
bucket
.
grad_data
,
self
.
intra_distributed_optimizer_instance_size
)[
self
.
intra_distributed_optimizer_instance_rank
]
dist_reduce_scatter_func
(
local_data_view
,
bucket
.
grad_data
,
op
=
reduce_op
,
group
=
communication_group
,
async_op
=
async_op
,
)
else
:
torch
.
distributed
.
all_reduce
(
bucket
.
grad_data
,
op
=
reduce_op
,
group
=
communication_group
,
async_op
=
async_op
)
# With multiple DistOpt instances, we need to all-reduce across instances.
if
(
self
.
ddp_config
.
use_distributed_optimizer
and
self
.
ddp_config
.
num_distributed_optimizer_instances
>
1
):
assert
self
.
inter_distributed_optimizer_instance_group
is
not
None
# Create a new coalescing manager for the inter-instance all-reduce.
with
stream_context
,
_coalescing_manager
(
self
.
inter_distributed_optimizer_instance_group
,
async_ops
=
async_op
)
as
cm
:
for
bucket
in
self
.
buckets
:
local_data_view
=
shard_buffer
(
bucket
.
grad_data
,
self
.
intra_distributed_optimizer_instance_size
)[
self
.
intra_distributed_optimizer_instance_rank
]
torch
.
distributed
.
all_reduce
(
local_data_view
,
op
=
reduce_op
,
group
=
self
.
inter_distributed_optimizer_instance_group
,
async_op
=
async_op
,
)
if
async_op
:
self
.
grad_reduce_handle
=
cm
else
:
# When using `_coalescing_manager`, even if a synchronous op (async_op=False) is used,
# `cm` is not None, which is different from when `_coalescing_manager` is not used in
# which case the torch.distributed._reduce_scatter_base() will return None. In order to
# maintain consistency with prior code, we need to manually set communication handle to
# None.
self
.
grad_reduce_handle
=
None
def
finish_grad_sync
(
self
):
"""
Finishes grad sync (all-reduce or reduce-scatter) communication operations
for all buckets in the bucket group.
When ddp_config.overlap_grad_reduce is set to True, waits for asynchronous
communication call to complete. When ddp_config.overlap_grad_reduce is set to False,
makes synchronous call.
"""
self
.
param_gather_dispatched
=
False
# If overlap_grad_reduce is False, start (and finish) synchronous communication call here.
if
not
self
.
ddp_config
.
overlap_grad_reduce
:
self
.
start_grad_sync
()
return
# When using multiple DistOpt instances, we don't need to sync here as we launch
# communications on a separate communication stream.
if
self
.
ddp_config
.
num_distributed_optimizer_instances
>
1
:
torch
.
cuda
.
default_stream
().
wait_stream
(
self
.
communication_stream
)
return
assert
self
.
grad_reduce_handle
is
not
None
,
(
f
'Communication call has not been issued for this bucket '
f
'(
{
len
(
self
.
params_with_grad
)
}
/
{
len
(
self
.
params
)
}
params have grad available)'
)
self
.
grad_reduce_handle
.
wait
()
self
.
grad_reduce_handle
=
None
def
register_grad_ready
(
self
,
param
:
torch
.
nn
.
Parameter
):
"""
Registers grads for the passed-in param to be "ready" for grad sync.
When the number of microbatches is greater than 1, we only want to register
grads as ready when processing the last microbatch and ddp_config.overlap_grad_reduce
is True.
"""
assert
(
self
.
ddp_config
.
overlap_grad_reduce
),
'register_grad_ready() should only be called when overlap_grad_reduce is True'
if
self
.
is_last_microbatch
:
assert
param
in
self
.
param_to_bucket
,
'Param is not in the bucket group'
assert
param
not
in
self
.
params_with_grad
,
'Cannot set grad twice'
self
.
params_with_grad
.
add
(
param
)
# If all params in bucket group have grads available, issue communication call.
if
len
(
self
.
params_with_grad
)
==
len
(
self
.
params
):
self
.
start_grad_sync
()
class
_ParamAndGradBuffer
:
"""
Groups parameters and gradients into a contiguous buffer, and then breaks the buffer into
buckets with roughly `bucket_size` parameters each.
Args:
ddp_config: DistributedDataParallel config object.
param_dtype: Type of param tensor.
grad_dtype: Type of grad tensor.
params: List of parameters whose parameters and gradients are collated in the underlying
tensor.
data_parallel_group: Data-parallel process group.
bucket_size: The rough size of each bucket in terms of number of parameters.
param_to_name: Mapping from `torch.nn.Parameter` to name (for logging purposes).
gradient_scaling_factor: This factor is utilized to scale gradients prior to their
communication. Its application is twofold: it facilitates the averaging of gradients
and the scaling of gradients in the context of the Mixture of Experts (MoE) model.
param_indices: The index of each param among the params with same dtype, if a param is fp8,
use its "fake" high precision dtype to determine which params have same dtype with it.
These indices are needed when loading a non-native-fp8 checkpoint in native-fp8 mode.
"""
def
__init__
(
self
,
ddp_config
:
DistributedDataParallelConfig
,
param_dtype
:
torch
.
dtype
,
grad_dtype
:
torch
.
dtype
,
params
:
List
[
torch
.
nn
.
Parameter
],
data_parallel_group
:
torch
.
distributed
.
ProcessGroup
,
bucket_size
:
int
,
param_to_name
:
Dict
[
torch
.
nn
.
Parameter
,
str
],
gradient_scaling_factor
:
float
,
param_indices
:
List
[
int
],
):
self
.
ddp_config
=
ddp_config
self
.
params
=
params
self
.
param_indices
=
param_indices
# Check that params are unique.
unique_params
=
set
()
for
param
in
params
:
assert
param
not
in
unique_params
unique_params
.
add
(
param
)
del
unique_params
# Store attributes that will be needed later.
self
.
param_dtype
=
param_dtype
self
.
grad_dtype
=
grad_dtype
self
.
data_parallel_group
=
data_parallel_group
self
.
data_parallel_world_size
=
torch
.
distributed
.
get_world_size
(
group
=
self
.
data_parallel_group
)
self
.
gradient_scaling_factor
=
gradient_scaling_factor
# Data structures to store underlying buckets and relevant indexing data.
self
.
buckets
=
[]
self
.
param_to_bucket
=
{}
# Param -> bucket mapping.
self
.
param_index_map
=
{}
# Param -> location in buffer mapping (used in dist. optimizer).
def
_pad
(
number_to_be_padded
:
int
,
divisor
:
int
)
->
int
:
return
int
(
math
.
ceil
(
number_to_be_padded
/
divisor
)
*
divisor
)
def
_pad_end_of_bucket_if_needed
(
bucket_end_index
:
int
)
->
int
:
"""
Pads end index of bucket if using distributed optimizer (to ensure uniform sharding).
"""
if
self
.
ddp_config
.
use_distributed_optimizer
:
# Workaround for TE bug causing cuBLAS to pick an incompatible algorithm.
# This also helps cuBLAS pick more efficient algorithms for GEMMs.
# We now ensure that all buckets start at a memory address that is 256-byte
# aligned (128 values since params and grads use >= 16-bit precision).
if
self
.
ddp_config
.
pad_buckets_for_high_nccl_busbw
:
# Make sure the bucket size is divisible by a large power of 2 (2^16) to
# ensure NCCL collectives have high bus bandwidth at large DP counts,
# since NCCL message size (which for ring algorithms is bucket_size /
# dp_size) apparently needs to be divisible by a power of 2 for high busbw.
bucket_size_divisor
=
math
.
lcm
(
self
.
data_parallel_world_size
,
128
,
2
**
16
)
else
:
bucket_size_divisor
=
math
.
lcm
(
self
.
data_parallel_world_size
,
128
)
return
_pad
(
bucket_end_index
,
bucket_size_divisor
)
return
bucket_end_index
def
_pad_start_of_param_if_needed
(
param_start_index
:
int
)
->
int
:
"""
Pads start index of param if using distributed optimizer (to ensure "good" alignment).
"""
if
self
.
ddp_config
.
use_distributed_optimizer
:
# Ensure that params start at 128-byte aligned addresses (64 values
# since params are >= 16-bit precision).
return
_pad
(
param_start_index
,
64
)
return
param_start_index
# First, figure out how many elements should be in the underlying buffer storage.
# Note that if we need to split the buffer into smaller buckets, each of these
# might need to be padded as well (if using the distributed optimizer).
param_start_index
=
0
bucket_start_index
=
param_start_index
bucket_params
=
set
()
self
.
bucket_indices
=
[]
per_bucket_numel_unpadded
=
[]
bucket_id
=
0
def
_update_bucket_metadata
(
param_end_index
:
int
)
->
int
:
"""
Record metadata for the bucket starting at bucket_start_index and ending with the
passed-in param_end_index. Returns the bucket's end_index.
"""
nonlocal
bucket_start_index
,
bucket_params
,
bucket_id
per_bucket_numel_unpadded
.
append
(
param_end_index
-
bucket_start_index
)
bucket_end_index
=
_pad_end_of_bucket_if_needed
(
param_end_index
)
# Record metadata of new bucket.
self
.
bucket_indices
.
append
((
bucket_start_index
,
bucket_end_index
))
bucket_start_index
=
bucket_end_index
# Prepare for next bucket.
bucket_params
=
set
()
bucket_id
+=
1
# Return the potentially padded bucket_end_index.
return
bucket_end_index
def
_does_param_require_new_bucket
(
param
):
"""
Split shared embedding parameters into separate bucket if using distributed
optimizer that makes use of reduce-scatters instead of all-reduces.
This ensures that the first and last pipeline stage partition optimizer state
for the shared embedding parameters the same way across DP replicas, allowing
the DP reduce-scatter to be before the embedding all-reduce.
"""
return
(
getattr
(
param
,
"shared_embedding"
,
False
)
and
self
.
ddp_config
.
use_distributed_optimizer
)
for
param
in
params
[::
-
1
]:
# Iterate through parameters in reverse order to roughly follow backprop order.
this_numel
=
param
.
data
.
nelement
()
param_start_index
=
_pad_start_of_param_if_needed
(
param_start_index
)
# Create bucket with collected parameters if current param needs its own bucket.
if
_does_param_require_new_bucket
(
param
):
# We are creating a bucket for the already accumulated parameters, whose params
# end at the current param_start_index.
if
self
.
ddp_config
.
use_distributed_optimizer
:
# Make sure new bucket is appropriately padded.
if
param_start_index
%
self
.
data_parallel_world_size
!=
0
:
param_start_index
=
_pad_end_of_bucket_if_needed
(
param_start_index
)
if
len
(
bucket_params
)
>
0
:
bucket_end_index
=
_update_bucket_metadata
(
param_start_index
)
param_end_index
=
param_start_index
+
this_numel
self
.
param_index_map
[
param
]
=
(
param_start_index
,
param_end_index
,
bucket_id
)
bucket_params
.
add
(
param
)
# If we have enough elements already or the current param is part of the shared
# embedding layer and needs a separate bucket, form a new bucket.
if
(
bucket_size
is
not
None
and
(
param_end_index
-
bucket_start_index
)
>=
bucket_size
)
or
_does_param_require_new_bucket
(
param
):
bucket_end_index
=
_update_bucket_metadata
(
param_end_index
)
param_start_index
=
bucket_end_index
else
:
param_start_index
=
param_end_index
# Add remaining params to a new bucket.
if
len
(
bucket_params
)
>
0
:
bucket_end_index
=
_update_bucket_metadata
(
param_end_index
)
# Next, create underlying storage for buffer (with numel elements that includes
# padding as necessary).
self
.
numel
=
bucket_end_index
self
.
numel_unpadded
=
sum
(
per_bucket_numel_unpadded
)
assert
self
.
numel_unpadded
<=
self
.
numel
if
self
.
ddp_config
.
use_distributed_optimizer
:
assert
self
.
numel
%
self
.
data_parallel_world_size
==
0
else
:
assert
self
.
numel
==
self
.
numel_unpadded
self
.
param_data
=
None
# Only re-map param tensors if using distributed optimizer.
if
self
.
ddp_config
.
use_distributed_optimizer
:
self
.
param_data
=
torch
.
zeros
(
self
.
numel
,
dtype
=
self
.
param_dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
,
)
self
.
grad_data
=
torch
.
zeros
(
self
.
numel
,
dtype
=
self
.
grad_dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
,
)
# Finally, map param.data and param.main_grad fields to buffers.
bucket_params
=
[]
bucket_start_index
=
0
cur_bucket_id
=
0
for
param
in
params
[::
-
1
]:
param_start_index
,
param_end_index
,
bucket_id
=
self
.
param_index_map
[
param
]
# Assign param.data to appropriate segment of self.param_data.
if
self
.
param_data
is
not
None
:
new_param_data
=
self
.
_get
(
param
.
data
.
shape
,
param_start_index
,
buffer_type
=
BufferType
.
PARAM
)
if
is_float8tensor
(
param
):
modify_underlying_storage
(
param
,
new_param_data
)
else
:
old_param_data
=
param
.
data
param
.
data
=
new_param_data
assert
old_param_data
.
_base
is
None
# Copy tensor values (from initialization or checkpoint).
param
.
data
.
detach
().
copy_
(
old_param_data
)
del
old_param_data
param
.
main_grad
=
self
.
_get
(
param
.
data
.
shape
,
param_start_index
,
buffer_type
=
BufferType
.
GRAD
)
if
bucket_id
!=
cur_bucket_id
:
bucket_end_index
=
_pad_end_of_bucket_if_needed
(
param_start_index
)
self
.
buckets
.
append
(
self
.
_new_bucket
(
bucket_params
=
bucket_params
,
start_index
=
bucket_start_index
,
end_index
=
bucket_end_index
,
numel_unpadded
=
per_bucket_numel_unpadded
[
cur_bucket_id
],
bucket_id
=
cur_bucket_id
,
)
)
bucket_start_index
=
bucket_end_index
bucket_params
=
[]
assert
cur_bucket_id
+
1
==
len
(
self
.
buckets
)
assert
bucket_id
==
cur_bucket_id
+
1
cur_bucket_id
=
bucket_id
bucket_params
.
append
(
param
)
# Add remaining params to a new bucket.
if
len
(
bucket_params
)
>
0
:
bucket_end_index
=
_pad_end_of_bucket_if_needed
(
param_end_index
)
self
.
buckets
.
append
(
self
.
_new_bucket
(
bucket_params
=
bucket_params
,
start_index
=
bucket_start_index
,
end_index
=
bucket_end_index
,
numel_unpadded
=
per_bucket_numel_unpadded
[
cur_bucket_id
],
bucket_id
=
cur_bucket_id
,
)
)
# Log buckets for all PP stages.
log_strs
=
[]
log_strs
.
append
(
f
'Number of buckets for gradient all-reduce / reduce-scatter:
{
len
(
self
.
buckets
)
}
'
)
for
index
,
bucket
in
enumerate
(
self
.
buckets
):
numel
=
0
for
param
in
bucket
.
params
:
numel
+=
param
.
data
.
nelement
()
log_strs
.
append
(
f
"Params for bucket
{
index
+
1
}
(
{
numel
}
elements, "
f
"
{
bucket
.
grad_data
.
nelement
()
}
padded size):"
)
for
param
in
bucket
.
params
:
log_strs
.
append
(
f
'
\t
{
param_to_name
[
param
]
}
'
)
log_on_each_pipeline_stage
(
logger
,
logging
.
INFO
,
'
\n
'
.
join
(
log_strs
))
def
scale_gradients
(
self
,
scaling_factor
:
float
)
->
None
:
"""Scale the gradient data by `scaling_factor`."""
self
.
grad_data
*=
scaling_factor
def
_get
(
self
,
shape
:
torch
.
Size
,
start_index
:
int
,
buffer_type
:
BufferType
)
->
torch
.
Tensor
:
"""
Return a tensor with the input `shape` as a view into the 1-D data starting at
`start_index`.
"""
end_index
=
start_index
+
shape
.
numel
()
assert
end_index
<=
self
.
numel
,
'Requested tensor is out of buffer range'
if
buffer_type
==
BufferType
.
PARAM
:
assert
self
.
param_data
is
not
None
buffer_tensor
=
self
.
param_data
[
start_index
:
end_index
]
elif
buffer_type
==
BufferType
.
GRAD
:
buffer_tensor
=
self
.
grad_data
[
start_index
:
end_index
]
else
:
raise
Exception
(
"Illegal buffer type provided to GradBuffer._get() function"
)
buffer_tensor
=
buffer_tensor
.
view
(
shape
)
return
buffer_tensor
def
_new_bucket
(
self
,
bucket_params
:
List
[
torch
.
nn
.
Parameter
],
start_index
:
int
,
end_index
:
int
,
numel_unpadded
:
int
,
bucket_id
:
int
,
)
->
_ParamAndGradBucket
:
"""
Helper function that creates a new bucket. Also updates param->bucket mapping.
"""
# Assert that indices are correctly padded (if needed), and that bucket
# position is same as originally computed.
if
self
.
ddp_config
.
use_distributed_optimizer
:
assert
start_index
%
self
.
data_parallel_world_size
==
0
assert
end_index
%
self
.
data_parallel_world_size
==
0
assert
(
start_index
,
end_index
)
==
self
.
bucket_indices
[
bucket_id
]
# Get appropriate view into global _ParamAndGradBuffer.
bucketed_param_data
=
None
if
self
.
param_data
is
not
None
:
bucketed_param_data
=
self
.
_get
(
torch
.
Size
([
end_index
-
start_index
]),
start_index
,
buffer_type
=
BufferType
.
PARAM
)
bucketed_grad_data
=
self
.
_get
(
torch
.
Size
([
end_index
-
start_index
]),
start_index
,
buffer_type
=
BufferType
.
GRAD
)
bucket
=
_ParamAndGradBucket
(
params
=
bucket_params
,
param_data
=
bucketed_param_data
,
grad_data
=
bucketed_grad_data
,
offset
=
start_index
,
numel_unpadded
=
numel_unpadded
,
gradient_scaling_factor
=
self
.
gradient_scaling_factor
,
bucket_id
=
bucket_id
,
)
for
bucket_param
in
bucket_params
:
assert
bucket_param
not
in
self
.
param_to_bucket
self
.
param_to_bucket
[
bucket_param
]
=
bucket
return
bucket
def
reset
(
self
):
"""
Zero out the underlying grad_buffer.
"""
self
.
grad_data
.
zero_
()
def
partition_buckets
(
buffers
:
List
[
_ParamAndGradBuffer
],
force_single_bucket_group
:
bool
=
False
)
->
List
[
_ParamAndGradBucketGroup
]:
"""
Automatically regroup the buckets of input buffers and return a list of bucket groups.
In some scenarios, we need to put buckets from different buffers into a group so that their
communication can be aggregated.
For example, when there are both fp8 weights and bf16 biases in the model and virtual
pipeline parallelism is enabled, each model chunk will have an fp8 bucket and a bf16 bucket,
which doubles the number of communication kernels, and because of the use of
CUDA_DEVICE_MAX_CONNECTIONS=1, having multiple back-to-back communications will prevent the
overlap of communication kernels with computation kernels.
The grouping strategy is:
1. If force_single_bucket_group is True, put all buckets across all buffers into a single
bucket group.
2. If force_single_bucket_group is False, when there is no fp8 buffer in the input buffers,
let each bucket group have only one bucket.
3. If force_single_bucket_group is False, when using fp8 params, merge all non-fp8 buckets
into the last fp8 bucket group.
- Since the non-fp8 parameters (typically the biases of various layers) are relatively
small, they are likely to be grouped into a single non-fp8 bucket.
- The fp8 buckets start from the end of the model, i.e., the first bucket corresponds to
the end of the model, while the last bucket corresponds to the beginning.
- If we combine the non-fp8 bucket with the first fp8 bucket, we cannot initiate the
reduce-scatter to synchronize gradients after the backward pass at the end of the model
has completed. This is because we need to wait for the non-fp8 params from the beginning
layers to obtain their gradients.
- Combining the non-fp8 bucket with the last fp8 bucket can help avoid this issue.
Args:
buffers (list): list of input buffers.
single_bucket_group_per_buffer (bool, optional): force group all buckets in each buffer
into a single bucket group.
"""
if
len
(
buffers
)
==
0
:
return
[]
dtype_to_buffer_map
=
{}
for
buffer
in
buffers
:
dtype
=
buffer
.
param_dtype
# Make sure that the param_dtype of any two buffers is different.
assert
dtype
not
in
dtype_to_buffer_map
dtype_to_buffer_map
[
dtype
]
=
buffer
# Case 1: Put all buckets into a single bucket group if force_single_bucket_group is True.
if
force_single_bucket_group
:
buckets
=
[]
ddp_config
=
buffers
[
0
].
ddp_config
data_parallel_group
=
buffers
[
0
].
data_parallel_group
data_parallel_world_size
=
buffers
[
0
].
data_parallel_world_size
for
buffer
in
buffers
:
assert
ddp_config
==
buffer
.
ddp_config
assert
data_parallel_group
==
buffer
.
data_parallel_group
assert
data_parallel_world_size
==
buffer
.
data_parallel_world_size
buckets
.
extend
(
buffer
.
buckets
)
bucket_group
=
_ParamAndGradBucketGroup
(
buckets
,
ddp_config
,
data_parallel_group
,
data_parallel_world_size
)
return
[
bucket_group
]
if
torch
.
uint8
not
in
dtype_to_buffer_map
:
# Case 2: When there is no fp8 buffer in the input buffers, let each bucket group have
# only one bucket.
bucket_groups
=
[]
for
buffer
in
buffers
:
for
bucket
in
buffer
.
buckets
:
bucket_groups
.
append
(
_ParamAndGradBucketGroup
(
[
bucket
],
buffer
.
ddp_config
,
buffer
.
data_parallel_group
,
buffer
.
data_parallel_world_size
,
)
)
return
bucket_groups
else
:
# Case 3: When using fp8 params, merge all non-fp8 buckets into the last fp8 bucket group.
non_fp8_buckets
=
[]
for
buffer
in
buffers
:
if
buffer
.
param_dtype
!=
torch
.
uint8
:
for
bucket
in
buffer
.
buckets
:
non_fp8_buckets
.
append
(
bucket
)
bucket_groups
=
[]
fp8_buffer
=
dtype_to_buffer_map
[
torch
.
uint8
]
for
bucket
in
fp8_buffer
.
buckets
:
if
len
(
bucket_groups
)
==
len
(
fp8_buffer
.
buckets
)
-
1
:
# The last bucket group.
group_buckets
=
[
bucket
]
+
non_fp8_buckets
else
:
# The first N-1 bucket groups.
group_buckets
=
[
bucket
]
bucket_groups
.
append
(
_ParamAndGradBucketGroup
(
group_buckets
,
buffer
.
ddp_config
,
buffer
.
data_parallel_group
,
buffer
.
data_parallel_world_size
,
)
)
return
bucket_groups
Megatron-LM/megatron/core/distributed/torch_fully_sharded_data_parallel.py
0 → 100644
View file @
4e867b3c
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from
typing
import
Optional
,
Set
import
torch
try
:
from
torch.distributed
import
DeviceMesh
from
torch.distributed._composable.fsdp
import
fully_shard
HAVE_FSDP
=
True
except
ImportError
:
HAVE_FSDP
=
False
from
torch.distributed
import
ProcessGroup
from
megatron.core.fp8_utils
import
is_float8tensor
from
..
import
parallel_state
,
tensor_parallel
from
..models.common.embeddings.language_model_embedding
import
LanguageModelEmbedding
from
..models.common.embeddings.rotary_pos_embedding
import
RotaryEmbedding
from
..transformer.transformer_config
import
TransformerConfig
from
..transformer.transformer_layer
import
TransformerLayer
from
.data_parallel_base
import
_BaseDataParallel
from
.distributed_data_parallel_config
import
DistributedDataParallelConfig
class
TorchFullyShardedDataParallel
(
_BaseDataParallel
):
"""
Enables fully sharded data parallelism by wrapping the given model with
the PyTorch FSDP2 API:
https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md
To utilize this class, PyTorch version >= 2.4.0 is required.
Args:
config: Transformer config object.
ddp_config: DistributedDataParallel config object.
module: Underlying model.
sub_modules_to_wrap: Set of sub_modules to shard with FSDP.
Parameters within each sub_module will be all-gathered just-in-time.
The default set includes the following submodules derived from the
GPT model architecture:
TransformerLayer (all Transformer layers)
LanguageModelEmbedding (initial embedding layer)
RotaryEmbedding (initial RoPE layer)
tensor_parallel.ColumnParallelLinear (final output layer)
User can set _fsdp_modules attribute on submodules to set additional
submodules to shard with FSDP.
process_group: Optional ProcessGroup to use for distributed operations.
If None (default), the data parallel process group will be obtained from
parallel_state.get_data_parallel_group(with_context_parallel=True).
"""
def
__init__
(
self
,
config
:
TransformerConfig
,
ddp_config
:
DistributedDataParallelConfig
,
module
:
torch
.
nn
.
Module
,
sub_modules_to_wrap
:
Set
[
torch
.
nn
.
Module
]
=
{
TransformerLayer
,
LanguageModelEmbedding
,
RotaryEmbedding
,
tensor_parallel
.
ColumnParallelLinear
,
},
process_group
:
Optional
[
ProcessGroup
]
=
None
,
):
assert
(
HAVE_FSDP
),
'TorchFullyShardedDataParallel requires PyTorch >= 2.4.0 with FSDP 2 support.'
super
().
__init__
(
config
=
config
,
module
=
module
)
if
process_group
is
None
:
self
.
process_group
=
parallel_state
.
get_data_parallel_group
(
with_context_parallel
=
True
)
else
:
self
.
process_group
=
process_group
self
.
device_mesh
=
DeviceMesh
.
from_group
(
self
.
process_group
,
"cuda"
)
kwargs
=
{
"mesh"
:
self
.
device_mesh
}
def
save_custom_attrs
(
module
):
custom_attrs
=
{}
for
name
,
param
in
module
.
named_parameters
():
attrs
=
vars
(
param
)
if
is_float8tensor
(
param
):
# disable fp8 transpose cache and perform transposing fp8 weights
# at each micro-batch because torch-FSDP doesn't recognize the
# micro-batch id, thus removing unnecessary memory stores
attrs
[
'_fp8_attrs'
][
'transpose_invalid'
]
=
False
del
attrs
[
'_fp8_attrs'
][
'transpose'
]
custom_attrs
[
name
]
=
{
k
:
v
for
k
,
v
in
attrs
.
items
()}
return
custom_attrs
def
restore_custom_attrs
(
module
,
custom_attrs
):
for
name
,
param
in
module
.
named_parameters
():
if
name
in
custom_attrs
:
for
attr_name
,
attr_value
in
custom_attrs
[
name
].
items
():
setattr
(
param
,
attr_name
,
attr_value
)
# Save the custom attributes on Parameters before FSDP overwrites them.
# See https://github.com/pytorch/pytorch/issues/136929.
attrs
=
save_custom_attrs
(
self
.
module
)
sub_modules_to_wrap
=
set
(
sub_modules_to_wrap
)
for
sub_module
in
self
.
module
.
modules
():
fsdp_modules
=
getattr
(
sub_module
,
"_fsdp_modules"
,
[])
for
f
in
fsdp_modules
:
sub_modules_to_wrap
.
add
(
f
)
prev_module
=
None
for
sub_module
in
self
.
module
.
modules
():
# Wrap individual submodules to fetch parameters just-in-time rather than
# conservatively fetching all parameters at the start of each iteration.
# See https://github.com/pytorch/pytorch/issues/114299.
if
any
(
isinstance
(
sub_module
,
sub_module_to_wrap
)
for
sub_module_to_wrap
in
sub_modules_to_wrap
):
fully_shard
(
sub_module
,
**
kwargs
)
# Explicitly set the FSDP backward prefetch schedule to prevent activation
# recomputation from disrupting the automatically generated default schedule.
if
config
.
recompute_granularity
is
not
None
:
sub_module
.
set_modules_to_backward_prefetch
(
[
prev_module
]
if
prev_module
else
[]
)
prev_module
=
sub_module
# Wrap the root module as required by the FSDP API.
# See https://github.com/pytorch/pytorch/issues/114299.
fully_shard
(
self
.
module
,
**
kwargs
)
restore_custom_attrs
(
self
.
module
,
attrs
)
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
"""
No-op because tensors are already loaded in-place by
`_load_base_checkpoint` with FSDP2."""
pass
Megatron-LM/megatron/core/enums.py
0 → 100644
View file @
4e867b3c
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import
enum
class
ModelType
(
enum
.
Enum
):
"""Model type."""
encoder_or_decoder
=
1
encoder_and_decoder
=
2
retro_encoder
=
3
retro_decoder
=
4
class
Fp8Recipe
(
str
,
enum
.
Enum
):
"""FP8 recipe names: delayed, tensorwise, mxfp8."""
delayed
=
"delayed"
tensorwise
=
"tensorwise"
mxfp8
=
"mxfp8"
Megatron-LM/megatron/core/export/__init__.py
0 → 100644
View file @
4e867b3c
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
Megatron-LM/megatron/core/export/data_type.py
0 → 100644
View file @
4e867b3c
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from
enum
import
Enum
DataType
=
Enum
(
'DataType'
,
[
"bfloat16"
,
"float16"
,
"float32"
])
Megatron-LM/megatron/core/export/export_config.py
0 → 100644
View file @
4e867b3c
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from
dataclasses
import
dataclass
@
dataclass
class
ExportConfig
:
"""Base configuration for Megatron Core Export
These parameters control the export setting for trtllm
"""
inference_tp_size
:
int
=
1
inference_pp_size
:
int
=
1
use_parallel_embedding
:
bool
=
False
use_embedding_sharing
:
bool
=
False
Prev
1
…
11
12
13
14
15
16
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment