Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
dcu_megatron
Commits
1106877d
Commit
1106877d
authored
Sep 23, 2025
by
jerrrrry
Browse files
“13.0”
parents
Pipeline
#2934
failed with stages
in 0 seconds
Changes
329
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
7304 additions
and
0 deletions
+7304
-0
Megatron-LM/megatron/core/dist_checkpointing/mapping.py
Megatron-LM/megatron/core/dist_checkpointing/mapping.py
+727
-0
Megatron-LM/megatron/core/dist_checkpointing/optimizer.py
Megatron-LM/megatron/core/dist_checkpointing/optimizer.py
+142
-0
Megatron-LM/megatron/core/dist_checkpointing/serialization.py
...tron-LM/megatron/core/dist_checkpointing/serialization.py
+443
-0
Megatron-LM/megatron/core/dist_checkpointing/state_dict_utils.py
...n-LM/megatron/core/dist_checkpointing/state_dict_utils.py
+112
-0
Megatron-LM/megatron/core/dist_checkpointing/strategies/__init__.py
...M/megatron/core/dist_checkpointing/strategies/__init__.py
+7
-0
Megatron-LM/megatron/core/dist_checkpointing/strategies/async_utils.py
...egatron/core/dist_checkpointing/strategies/async_utils.py
+561
-0
Megatron-LM/megatron/core/dist_checkpointing/strategies/base.py
...on-LM/megatron/core/dist_checkpointing/strategies/base.py
+228
-0
Megatron-LM/megatron/core/dist_checkpointing/strategies/cached_metadata_filesystem_reader.py
...kpointing/strategies/cached_metadata_filesystem_reader.py
+38
-0
Megatron-LM/megatron/core/dist_checkpointing/strategies/common.py
...-LM/megatron/core/dist_checkpointing/strategies/common.py
+193
-0
Megatron-LM/megatron/core/dist_checkpointing/strategies/filesystem_async.py
...on/core/dist_checkpointing/strategies/filesystem_async.py
+637
-0
Megatron-LM/megatron/core/dist_checkpointing/strategies/fully_parallel.py
...tron/core/dist_checkpointing/strategies/fully_parallel.py
+520
-0
Megatron-LM/megatron/core/dist_checkpointing/strategies/resharding.py
...megatron/core/dist_checkpointing/strategies/resharding.py
+320
-0
Megatron-LM/megatron/core/dist_checkpointing/strategies/state_dict_saver.py
...on/core/dist_checkpointing/strategies/state_dict_saver.py
+247
-0
Megatron-LM/megatron/core/dist_checkpointing/strategies/tensorstore.py
...egatron/core/dist_checkpointing/strategies/tensorstore.py
+149
-0
Megatron-LM/megatron/core/dist_checkpointing/strategies/torch.py
...n-LM/megatron/core/dist_checkpointing/strategies/torch.py
+1085
-0
Megatron-LM/megatron/core/dist_checkpointing/strategies/two_stage.py
.../megatron/core/dist_checkpointing/strategies/two_stage.py
+268
-0
Megatron-LM/megatron/core/dist_checkpointing/strategies/zarr.py
...on-LM/megatron/core/dist_checkpointing/strategies/zarr.py
+345
-0
Megatron-LM/megatron/core/dist_checkpointing/tensor_aware_state_dict.py
...gatron/core/dist_checkpointing/tensor_aware_state_dict.py
+394
-0
Megatron-LM/megatron/core/dist_checkpointing/utils.py
Megatron-LM/megatron/core/dist_checkpointing/utils.py
+319
-0
Megatron-LM/megatron/core/dist_checkpointing/validation.py
Megatron-LM/megatron/core/dist_checkpointing/validation.py
+569
-0
No files found.
Too many changes to show.
To preserve performance only
329 of 329+
files are displayed.
Plain diff
Email patch
Megatron-LM/megatron/core/dist_checkpointing/mapping.py
0 → 100644
View file @
1106877d
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
"""Core library classes for representing sharding of tensors and objects.
The main expected usage is wrapping torch.Tensors in state dicts with
ShardedTensor class (mostly with the ShardedTensor.from_rank_offsets classmethod).
"""
import
logging
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
,
field
,
replace
from
itertools
import
chain
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
torch
from
.core
import
CheckpointingException
from
.dict_utils
import
dict_list_map_inplace
logger
=
logging
.
getLogger
(
__name__
)
# These type definitions are just hints to differentiate a plain model state
# dict (StateDict) from a state dict with tensors replaced with ShardedTensors
# (ShardedStateDict).
StateDict
=
Dict
[
str
,
Any
]
CommonStateDict
=
Dict
[
str
,
Any
]
ShardedStateDict
=
Dict
[
str
,
Any
]
ReplicaId
=
Union
[
int
,
Tuple
[
int
,
...]]
class
ShardedBase
(
ABC
):
"""Base class for ShardedTensor and ShardedStateDict."""
key
:
str
data
:
object
replica_id
:
ReplicaId
@
abstractmethod
def
validate_metadata_integrity
(
self
):
"""Codifies the constraints on metadata attributes."""
@
abstractmethod
def
without_data
(
self
)
->
"ShardedBase"
:
"""Returns a new ShardedBase instance with data=None."""
raise
NotImplementedError
@
dataclass
class
ShardedTensor
(
ShardedBase
):
"""Represents a mapping between a local tensor and a global tensor.
Global tensor is assumed to consist of many local tensors distributed
between different processes.
Args:
key: unique identifier of a global tensor
data: local tensor data. Can be None only for consistency validation
dtype: tensor dtype
local_shape: local tensor shape
global_shape: global tensor shape
global_offset: offset of a local tensor in a global tensor,
specified in number of tensor elements
axis_fragmentations: global tensor fragmentation of each axis
replica_id: indicates given local tensor's replication wrt.
local tensors in different processes
prepend_axis_num: number of axes prepended to the local tensor to
reflect global tensor shape. The behavior is similar to
unsqueezing the local tensor.
allow_shape_mismatch: if True, during loading, the global shape of
a stored tensor does not have to match the expected global shape.
Useful for representing tensors with flexible shape,
e.g. padded.
flattened_range: specifies a slice that should be applied to a
flattened tensor with `local_shape` in order to get
the tensor stored as `data`
"""
key
:
str
data
:
Optional
[
torch
.
Tensor
]
=
field
(
repr
=
False
)
dtype
:
torch
.
dtype
local_shape
:
Tuple
[
int
,
...]
global_shape
:
Tuple
[
int
,
...]
global_offset
:
Tuple
[
int
,
...]
axis_fragmentations
:
Optional
[
Tuple
[
int
,
...]]
replica_id
:
ReplicaId
=
0
prepend_axis_num
:
int
=
0
allow_shape_mismatch
:
bool
=
False
flattened_range
:
Optional
[
slice
]
=
None
def
__post_init__
(
self
):
self
.
validate_metadata_integrity
()
def
validate_metadata_integrity
(
self
)
->
None
:
"""Codifies the constraints on metadata attributes.
Meeting those constraints is guaranteed when instantiating a ShardedTensor
class with `from_rank_offsets` or `from_rank_offsets_flat` constructors.
Returns:
None
"""
has_flattened_range
=
self
.
flattened_range
is
not
None
if
self
.
data
is
not
None
:
if
self
.
data
.
dtype
!=
self
.
dtype
:
raise
CheckpointingException
(
f
"Data dtype should match `dtype` attribute for
{
self
}
"
)
if
not
has_flattened_range
and
self
.
data
.
shape
!=
self
.
local_shape
:
raise
CheckpointingException
(
f
"Data shape should match `local_shape` attribute for
{
self
}
"
)
if
has_flattened_range
:
if
self
.
data
.
ndim
!=
1
:
raise
CheckpointingException
(
f
"Data should be 1D for a flattened
{
self
}
"
)
real_data
=
self
.
data
try
:
self
.
data
=
None
self
.
init_data
(
device
=
"meta"
)
if
self
.
data
.
shape
!=
real_data
.
shape
:
raise
CheckpointingException
(
f
"Data shape
{
real_data
.
shape
}
doesnt match"
f
" expected
{
self
.
data
.
shape
}
for
{
self
}
"
)
finally
:
self
.
data
=
real_data
if
len
(
self
.
global_shape
)
!=
len
(
self
.
global_offset
):
raise
CheckpointingException
(
f
"Global offset dimensions should be equal to global shape dimensions for
{
self
}
"
)
if
len
(
self
.
local_shape
)
+
self
.
prepend_axis_num
!=
len
(
self
.
global_shape
):
raise
CheckpointingException
(
f
"Local shape together with `prepend_axis_num` dimensions should be "
f
"equal to global shape dimensions for
{
self
}
"
)
for
off
,
sh
in
zip
(
self
.
global_offset
[
self
.
prepend_axis_num
:],
self
.
local_shape
):
# NOTE: In custom FSDP, we have a case where a new parameter shard is created locally.
# For example, consider parameters [p0, p1, p2] sharded across GPU0 and GPU1.
# GPU0 receives p0 and a portion of p1, while GPU1 receives the
# remaining portion of p1 and p2.
# As a result, there is no parameter shard of p2 on GPU0, and
# the shape of p2 on GPU0 is zero.
if
sh
!=
0
and
off
%
sh
!=
0
:
raise
CheckpointingException
(
f
"Global offset (
{
off
}
) must be divisible by local shape (
{
sh
}
) for
{
self
}
."
)
if
has_flattened_range
and
self
.
flattened_range
.
step
is
not
None
:
raise
CheckpointingException
(
f
"`step` argument in the flattened range of a ShardedTensor is not supported."
)
def
global_slice
(
self
)
->
Tuple
[
Union
[
int
,
slice
],
...]:
"""
Returns a tuple of int and slice objects representing a slice of the
global tensor that this ShardedTensor corresponds to.
"""
assert
len
(
self
.
global_offset
)
==
len
(
self
.
local_shape
)
+
self
.
prepend_axis_num
return
tuple
(
chain
(
(
off
for
off
in
self
.
global_offset
[:
self
.
prepend_axis_num
]),
(
slice
(
off
,
off
+
sh
)
for
off
,
sh
in
zip
(
self
.
global_offset
[
self
.
prepend_axis_num
:],
self
.
local_shape
)
),
)
)
def
global_coordinates
(
self
)
->
Tuple
[
np
.
ndarray
,
...]:
"""
Returns a tuple of np.ndarrays representing the coordinates of the global tensor
that this ShardedTensor corresponds to.
"""
if
self
.
flattened_range
is
None
:
raise
CheckpointingException
(
f
"`global_coordinates` is undefined for"
f
"
{
self
.
__class__
.
__name__
}
without `flattened_range`"
)
local_coords
=
self
.
local_coordinates
()
assert
len
(
local_coords
)
+
self
.
prepend_axis_num
==
len
(
self
.
global_offset
),
(
len
(
local_coords
),
self
,
)
global_coords
=
tuple
(
c
+
off
for
c
,
off
in
zip
((
0
,)
*
self
.
prepend_axis_num
+
local_coords
,
self
.
global_offset
)
)
return
global_coords
def
local_coordinates
(
self
)
->
Tuple
[
np
.
ndarray
,
...]:
"""
Returns a tuple of np.ndarrays representing the coordinates of the local tensor
that this ShardedTensor corresponds to.
"""
if
self
.
flattened_range
is
None
:
raise
CheckpointingException
(
f
"`local_coordinates` is undefined for"
f
"
{
self
.
__class__
.
__name__
}
without `flattened_range`"
)
# TODO: np.unravel_index?
mask
=
np
.
zeros
(
np
.
product
(
self
.
local_shape
),
dtype
=
bool
)
mask
[
self
.
flattened_range
]
=
True
return
np
.
nonzero
(
mask
.
reshape
(
self
.
local_shape
))
def
local_chunk_offset_in_global
(
self
)
->
Tuple
[
int
,
...]:
"""Offset of a local chunk in a global array of chunks.
Returns:
Tuple[int, ...]: the offset of the whole local chunk in a global array of chunks.
"""
assert
len
(
self
.
global_offset
)
==
len
(
self
.
local_shape
)
+
self
.
prepend_axis_num
chunk_offset
=
list
(
self
.
global_offset
[:
self
.
prepend_axis_num
])
for
off
,
sh
in
zip
(
self
.
global_offset
[
self
.
prepend_axis_num
:],
self
.
local_shape
):
assert
off
%
sh
==
0
,
str
(
self
)
chunk_offset
.
append
(
off
//
sh
)
return
tuple
(
chunk_offset
)
def
max_allowed_chunks
(
self
)
->
Tuple
[
int
,
...]:
"""
Returns the maximum allowed chunks for this ShardedTensor.
"""
chunks
=
[]
for
axis_sh
,
axis_fragm
in
zip
(
self
.
global_shape
,
self
.
axis_fragmentations
):
if
not
self
.
allow_shape_mismatch
and
axis_sh
%
axis_fragm
!=
0
:
raise
CheckpointingException
(
f
"Axis shape (
{
axis_sh
}
) not divisible by axis fragmentation (
{
axis_fragm
}
"
)
axis_chunk_size
=
axis_sh
//
axis_fragm
chunks
.
append
(
axis_chunk_size
)
return
tuple
(
chunks
)
def
without_data
(
self
):
return
replace
(
self
,
data
=
None
)
@
classmethod
def
from_rank_offsets
(
cls
,
key
:
str
,
data
:
torch
.
Tensor
,
*
rank_offsets
:
Tuple
[
int
,
int
,
int
],
replica_id
:
ReplicaId
=
0
,
prepend_axis_num
:
int
=
0
,
flattened_range
:
None
=
None
,
**
init_kwargs
,
):
"""Allows to construct the ShardedTensor given offset specified in process ranks.
Args:
key (str): unique key
data (torch.Tensor): local tensor data
rank_offsets (Tuple[int, int, int]): each tuple
(axis, axis_rank_offset, axis_fragm) says that if
global tensor is divided into `axis_fragm` fragment along `axis`
axis, then local tensor data corresponds to the `axis_rank_offset` chunk.
replica_id (ReplicaId): see ShardedTensor
prepend_axis_num (int): see ShardedTensor
flattened_range (None): must be None when using this constructor
init_kwargs: passed to ShardedTensor.__init__
"""
if
flattened_range
is
not
None
:
raise
ValueError
(
"Cannot instantiate a flat ShardedTensor with `from_rank_offsets` method."
" Use `from_rank_offsets_flat` instead"
)
global_offset
=
[
0
]
*
(
data
.
ndim
+
prepend_axis_num
)
global_shape
=
([
1
]
*
prepend_axis_num
)
+
list
(
data
.
shape
)
axis_fragmentations
=
[
1
]
*
(
data
.
ndim
+
prepend_axis_num
)
_seen_axis
=
set
()
for
axis
,
axis_rank_offset
,
axis_fragm
in
rank_offsets
:
if
axis
<
0
or
axis_rank_offset
<
0
or
axis_fragm
<
1
or
axis_rank_offset
>=
axis_fragm
:
raise
CheckpointingException
(
f
"Invalid rank offsets:
{
rank_offsets
}
for key
{
key
}
."
)
_seen_axis
.
add
(
axis
)
local_axis_shape
=
1
if
axis
<
prepend_axis_num
else
data
.
shape
[
axis
-
prepend_axis_num
]
global_shape
[
axis
]
=
axis_fragm
*
local_axis_shape
global_offset
[
axis
]
=
axis_rank_offset
*
local_axis_shape
axis_fragmentations
[
axis
]
=
axis_fragm
return
cls
(
key
,
data
,
data
.
dtype
,
tuple
(
data
.
shape
),
tuple
(
global_shape
),
tuple
(
global_offset
),
tuple
(
axis_fragmentations
),
replica_id
,
prepend_axis_num
,
flattened_range
=
flattened_range
,
**
init_kwargs
,
)
@
classmethod
def
from_rank_offsets_flat
(
cls
,
key
:
str
,
data
:
torch
.
Tensor
,
non_flat_local_shape
:
Tuple
[
int
,
...],
*
args
,
flattened_range
:
Optional
[
slice
]
=
None
,
**
kwargs
,
):
"""Allows to construct a *flattened* ShardedTensor given offset specified in process ranks.
Args:
key (str):
data (torch.Tensor): this should be a flattened data tensor
non_flat_local_shape (Tuple[int, ...]): expected local shape of a non-flat chunk
*args: passed unchanged to the `from_rank_offsets` constructor
flattened_range (slice): see ShardedTensor. Defaults to None, but must be set to
a non-None slice.
**kwargs:
Returns:
ShardedTensor: constructed ShardedTensor instance
"""
if
flattened_range
is
None
:
raise
CheckpointingException
(
"Cannot instantiate a non-flat ShardedTensor with `from_rank_offsets_flat` method."
" Use `from_rank_offsets` instead"
)
if
data
.
ndim
!=
1
:
raise
CheckpointingException
(
f
"Flattened ShardedTensor requires 1D data, got shape:
{
data
.
shape
}
"
)
if
flattened_range
.
stop
-
flattened_range
.
start
!=
data
.
numel
():
raise
CheckpointingException
(
f
"Flattened ShardedTensor data length (
{
data
.
numel
()
}
) must meet the "
f
"slice length:
{
flattened_range
.
stop
-
flattened_range
.
start
}
"
)
non_flat_data_meta
=
torch
.
empty
(
*
non_flat_local_shape
,
dtype
=
data
.
dtype
,
device
=
"meta"
)
sh_ten
=
cls
.
from_rank_offsets
(
key
,
non_flat_data_meta
,
*
args
,
**
kwargs
)
instance
=
replace
(
sh_ten
,
data
=
data
,
flattened_range
=
flattened_range
)
instance
.
validate_metadata_integrity
()
return
instance
def
init_data
(
self
,
device
:
Union
[
str
,
torch
.
device
],
init_fn
=
torch
.
empty
):
"""
Initialize the tensor data of this ShardedTensor.
Only called if `data` attribute is None.
Args:
device (Union[str, torch.device]): device to place the tensor on
init_fn (Callable, optional): function to use to initialize the tensor.
Defaults to `torch.empty`.
"""
if
self
.
data
is
not
None
:
return
self
.
data
=
init_fn
(
self
.
local_shape
,
dtype
=
self
.
dtype
,
device
=
device
)
if
self
.
flattened_range
is
not
None
:
self
.
data
=
self
.
data
.
flatten
()[
self
.
flattened_range
.
start
:
self
.
flattened_range
.
stop
]
def
narrow
(
self
,
dim
:
int
,
start
:
int
,
length
:
int
)
->
List
[
"ShardedTensor"
]:
"""This is an analogue of torch.narrow for ShardedTensors.
Narrowing assumes that we narrow a local tensor on each rank.
This has consequences on local_shape, global_shape, global_offset, etc.
Args:
dim (int): dimension to narrow. Doesn't include prepended axes.
start (int): start element
length (int): length of the slice
Returns:
List[ShardedTensor]: narrowed ShardedTensors. For non-flat tensors,
the list will always have 1 element. For flat ShardedTensors the number of
elements varies depending on `dim` and on overlap, because flat
tensors must be contiguous. In particular the list can be empty.
"""
prepended_dim
=
dim
+
self
.
prepend_axis_num
local_length_along_dim
=
self
.
local_shape
[
dim
]
def
_update_tuple
(
x
,
ind
,
val
):
x
=
list
(
x
)
x
[
ind
]
=
val
return
tuple
(
x
)
def
_safe_div
(
x
,
y
):
assert
x
%
y
==
0
,
(
x
,
y
)
return
x
//
y
# Decrease global shape and global offset by `length / local_length_along_dim`
assert
(
self
.
global_shape
[
prepended_dim
]
%
local_length_along_dim
==
0
),
f
"Only regular grid of local tensors is supported for narrowing, got:
{
self
}
"
assert
(
self
.
global_offset
[
prepended_dim
]
%
local_length_along_dim
==
0
),
f
"Only regular grid of local tensors is supported for narrowing, got:
{
self
}
"
global_shape
=
_update_tuple
(
self
.
global_shape
,
prepended_dim
,
_safe_div
(
self
.
global_shape
[
prepended_dim
]
*
length
,
local_length_along_dim
),
)
global_offset
=
_update_tuple
(
self
.
global_offset
,
prepended_dim
,
_safe_div
(
self
.
global_offset
[
prepended_dim
]
*
length
,
local_length_along_dim
),
)
if
self
.
flattened_range
is
None
:
new_data
=
self
.
data
.
narrow
(
dim
,
start
,
length
)
# always a single result tensor
return
[
replace
(
self
,
data
=
new_data
,
local_shape
=
new_data
.
shape
,
global_shape
=
global_shape
,
global_offset
=
global_offset
,
)
]
else
:
if
dim
!=
0
:
raise
CheckpointingException
(
f
"Narrowing along the first axis is supported for now only, got dim=
{
dim
}
"
)
# If dim=0, we will always get 0 or 1 resulting tensor.
# If dim>1, in general there can be more result tensors (e.g. max 3 for dim=1)
# For on original flat ShardedTensor of local shape [3, 4] and
# flattened_range=slice(5, 10),
# the X signs mark the actual (flat) data in `self.data`
# notice 12 (3*4) total "virtual" elements, out of which 5 is actual data.
# flat original: [.....XXXXX..]
# If we narrow to start=1, length=1 in the original local shape dimensions,
# the overlapping flat slice would be:
# narrow to: [....XXXX....]
# flat overlap: [.....XXX....]
# Now `data` is flattened and sliced, so we must compute local_shape manually
local_shape
=
_update_tuple
(
self
.
local_shape
,
dim
,
length
)
other_dims_volume
=
np
.
prod
(
_update_tuple
(
local_shape
,
dim
,
1
)
)
# 4 in the example above
volume_before_split
=
other_dims_volume
*
start
# 4 in the example above
volume_of_split
=
other_dims_volume
*
length
# 4 in the example above
flat_slice_start_shifted
=
(
self
.
flattened_range
.
start
-
volume_before_split
)
# 5 - 4 = 1 in the example above
flat_slice_stop_shifted
=
(
self
.
flattened_range
.
stop
-
volume_before_split
)
# 10 - 4 = 6 in the example above
# Find an intersection of
# (flat_slice_start_shifted, flat_slice_stop_shifted) vs (0, volume_of_split)
if
flat_slice_stop_shifted
<=
0
or
flat_slice_start_shifted
>=
volume_of_split
:
return
[]
# no intersection
# new_flattened_range = slice(1, 4) in the example above
new_flattened_range
=
slice
(
max
(
flat_slice_start_shifted
,
0
),
min
(
flat_slice_stop_shifted
,
volume_of_split
)
)
# Apply the intersection to the flattened data tensor.
# Compute start and slice appropriate length
intersection_slice_start
=
(
new_flattened_range
.
start
-
flat_slice_start_shifted
)
# 0 in the example above
new_data
=
self
.
data
[
intersection_slice_start
:
intersection_slice_start
+
new_flattened_range
.
stop
-
new_flattened_range
.
start
]
return
[
replace
(
self
,
data
=
new_data
,
local_shape
=
local_shape
,
global_shape
=
global_shape
,
global_offset
=
global_offset
,
flattened_range
=
new_flattened_range
,
)
]
def
is_main_replica
(
replica_id
:
ReplicaId
):
"""Checks if given `replica_id` is considered as main.
"Main" replica is:
- integer 0
- or an iterable with all 0 elements
It is the application responsibility to set correct replicas for sharded tensors.
Args:
replica_id (Union[int, Tuple[int, ...]]): replica id
Returns:
(bool): True for a "main" replica
"""
if
isinstance
(
replica_id
,
int
):
return
replica_id
==
0
return
all
(
r
==
0
for
r
in
replica_id
)
class
LocalNonpersistentObject
:
"""Object that should not be stored in a checkpoint, but restored locally.
Wrapping any object inside the state dict with LocalNonpersistentObject
will result in:
- during saving, this object will *not* be stored in the checkpoint
- during loading, a local version of this object will be placed in a state dict
"""
def
__init__
(
self
,
obj
):
self
.
obj
=
obj
def
unwrap
(
self
):
"""Returns the original object."""
return
self
.
obj
@
dataclass
class
ShardedObject
(
ShardedBase
):
"""Represents a mapping between a local object and a global object.
Global object is assumed to consist of many local objects distributed
between different processes.
NOTE: Contrary to ShardedTensor, it's impossible to change global object
sharding. Conceptually, ShardedObject is a fully-sharded ShardedTensor
with atomic arbitrary typed elements.
Args:
key: unique identifier of a global tensor
data: local object data. Can be None only for consistency validation
global_shape: global object shape
global_offset: offset of a local object in a global object, specified in number of shards
replica_id: indicates local object replication wrt. local objects in different processes
"""
key
:
str
data
:
object
global_shape
:
Tuple
[
int
,
...]
global_offset
:
Tuple
[
int
,
...]
replica_id
:
ReplicaId
=
0
def
__post_init__
(
self
):
self
.
validate_metadata_integrity
()
def
validate_metadata_integrity
(
self
):
if
len
(
self
.
global_shape
)
!=
len
(
self
.
global_offset
):
raise
CheckpointingException
(
f
"Global offset dimensions should be equal to global shape dimensions for
{
self
}
"
)
def
without_data
(
self
):
return
replace
(
self
,
data
=
None
)
@
property
def
unique_key
(
self
):
"""returns a unique key for this object"""
return
(
f
"
{
self
.
key
}
/shard_"
f
"
{
'.'
.
join
(
map
(
str
,
self
.
global_offset
))
}
_"
f
"
{
'.'
.
join
(
map
(
str
,
self
.
global_shape
))
}
"
)
def
__str__
(
self
):
return
f
"
{
self
.
__class__
.
__name__
}
(key='
{
self
.
key
}
')"
@
classmethod
def
empty_from_unique_key
(
cls
,
unique_key
,
replica_id
:
ReplicaId
=
0
)
->
"ShardedObject"
:
"""Instantiates a ShardedObject from a unique key.
Args:
unique_key: a string of the form
<key>/shard_<global_offset>_<global_shape>
replica_id: indicates local object replication wrt.
local objects in different processes
Returns:
a ShardedObject with data=None
"""
key
,
shard_key
=
unique_key
.
split
(
"/"
)
shard_str
,
offset
,
shape
=
shard_key
.
split
(
"_"
)
assert
shard_str
==
"shard"
offset
=
tuple
(
map
(
int
,
offset
.
split
(
"."
)))
shape
=
tuple
(
map
(
int
,
shape
.
split
(
"."
)))
if
len
(
shape
)
+
1
==
len
(
offset
):
# This is a backward-compatible fix. We don't know the last
# element of global shape so set it to -1.
shape
+=
(
-
1
,)
return
cls
(
key
,
None
,
shape
,
offset
,
replica_id
)
FactoryBuildFn
=
Callable
[[
str
,
torch
.
Tensor
,
ReplicaId
,
Optional
[
slice
]],
ShardedStateDict
]
FactoryMergeFn
=
Callable
[[
StateDict
],
torch
.
Tensor
]
@
dataclass
class
ShardedTensorFactory
(
ShardedBase
):
"""Allows to apply transformations to tensors before/after serialization.
The essence of those transformations is that they can be applied to
optimizer states the same way they are applied to the model params.
The ultimate state dict with sharded tensors must depend functionally on
`build_fn` arguments (key, data, replica_id, flattened_range),
which will be provided by the optimizer.
Builder creates a sub-state-dict out of a tensor before saving, and merger
merges the corresponding state dict after loading.
Args:
key (str): unique identifier of the factory
data (torch.Tensor): original model parameter that will be further
transformed by this factory
build_fn (callable): function that transforms the original tensor
to a sharded state dict
merge_fn (callable): function that transforms loaded subtree back
into a single tensor (inverse of `build_fn`)
replica_id (ReplicaId): indicates factory replication wrt.
factories in different processes
flattened_range (slice, optional): indicates additional flattening
applied to the ShardedTensors produced by the factory
"""
key
:
str
data
:
torch
.
Tensor
build_fn
:
FactoryBuildFn
merge_fn
:
FactoryMergeFn
replica_id
:
ReplicaId
=
0
flattened_range
:
Optional
[
slice
]
=
None
def
build
(
self
):
"""Builds a ShardedStateDict from the original tensor"""
return
self
.
build_fn
(
self
.
key
,
self
.
data
,
self
.
replica_id
,
self
.
flattened_range
)
def
validate_metadata_integrity
(
self
):
"""No reasonable checks can be applied"""
pass
def
without_data
(
self
):
return
replace
(
self
,
data
=
None
)
def
apply_factories
(
sharded_state_dict
:
ShardedStateDict
):
"""Turn ShardedTensorFactories into ShardedTensors *in-place*.
Args:
sharded_state_dict (ShardedStateDict): state dict possibly
containing ShardedTensorFactory objects
Returns:
None: state dict is modified in place
"""
def
apply
(
x
):
if
isinstance
(
x
,
ShardedTensorFactory
):
x
=
x
.
build
()
return
x
dict_list_map_inplace
(
apply
,
sharded_state_dict
)
def
apply_factory_merges
(
x1
:
StateDict
,
x2
:
ShardedStateDict
,
key
:
Tuple
[
str
,
...]
=
()
)
->
StateDict
:
"""Apply merges defined by ShardedTensorFactories *in-place*.
Args:
x1 (StateDict): state dict loaded from the checkpoint
x2 (ShardedStateDict): subset of `x1` (in terms of dict keys)
with ShardedTensorFactory
as (possibly nested) values that define how to
merge objects from the `x1` state dict
key (Tuple[str, ...]): current key in a recursive call.
Used only for reporting meaningful errors
Returns:
StateDict: `x1` modified in-place
"""
if
isinstance
(
x2
,
ShardedTensorFactory
):
return
x2
.
merge_fn
(
x1
)
# There rest is almost the same as the `merge` function from `dict_utils`
if
isinstance
(
x1
,
dict
)
and
isinstance
(
x2
,
dict
):
for
k
,
v2
in
x2
.
items
():
if
k
not
in
x1
:
raise
ValueError
(
f
"Different dict keys encountered in `apply_factory_merges` "
f
"(
{
x1
.
keys
()
}
vs
{
x2
.
keys
()
}
)"
)
else
:
x1
[
k
]
=
apply_factory_merges
(
x1
[
k
],
v2
,
key
=
key
+
(
k
,))
elif
isinstance
(
x1
,
list
)
and
isinstance
(
x2
,
list
):
if
len
(
x1
)
!=
len
(
x2
):
err_msg
=
(
f
"Cannot merge two lists with different lengths "
f
"(
{
len
(
x1
)
}
and
{
len
(
x2
)
}
, encountered at key
{
key
}
)"
)
logger
.
error
(
err_msg
+
f
"
\n
x1:
{
x1
}
\n
x2:
{
x2
}
"
)
raise
ValueError
(
err_msg
)
for
i
,
v2
in
enumerate
(
x2
):
x1
[
i
]
=
apply_factory_merges
(
x1
[
i
],
v2
,
key
=
key
+
(
i
,))
elif
isinstance
(
x1
,
list
)
and
isinstance
(
x2
,
dict
):
for
k
,
v2
in
x2
.
items
():
if
not
isinstance
(
k
,
int
):
raise
ValueError
(
f
"Invalid dict key
{
k
}
non-integer type encountered "
f
"in a list-dict merge at level
{
key
}
"
)
if
k
>=
len
(
x1
):
raise
ValueError
(
f
"Dict key
{
k
}
out of bound for list of length"
f
"
{
len
(
x1
)
}
(encountered at level
{
key
}
)"
)
x1
[
k
]
=
apply_factory_merges
(
x1
[
k
],
v2
,
key
=
key
+
(
k
,))
else
:
raise
ValueError
(
f
"Duplicate non-dict and non-list values encountered: `
{
x1
}
` and `
{
x2
}
(at key
{
key
}
)`"
)
return
x1
Megatron-LM/megatron/core/dist_checkpointing/optimizer.py
0 → 100644
View file @
1106877d
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Helpers for defining sharding for optimizer states based on existing sharding
for model parameters.
"""
import
logging
from
copy
import
deepcopy
from
dataclasses
import
replace
from
typing
import
Dict
,
Iterable
,
Tuple
,
Union
logger
=
logging
.
getLogger
(
__name__
)
import
torch
from
megatron.core.utils
import
to_local_if_dtensor
from
.dict_utils
import
nested_values
from
.mapping
import
(
LocalNonpersistentObject
,
ShardedStateDict
,
ShardedTensor
,
ShardedTensorFactory
,
StateDict
,
)
from
.utils
import
extract_sharded_tensors_and_factories
def
get_optim_param_to_id_map
(
optim_params_iter
:
Iterable
[
torch
.
nn
.
Parameter
])
->
Dict
[
int
,
int
]:
"""Generate mapping from optimizer param to optimizer state id."""
param_mappings
=
{}
for
i
,
param
in
enumerate
(
optim_params_iter
):
param
=
to_local_if_dtensor
(
param
)
if
id
(
param
)
not
in
param_mappings
:
param_mappings
[
id
(
param
)]
=
i
return
param_mappings
def
get_param_id_to_sharded_param_map
(
model_sharded_state_dict
:
ShardedStateDict
,
optim_params_iter
:
Iterable
[
torch
.
nn
.
Parameter
]
)
->
Dict
[
int
,
Union
[
ShardedTensor
,
ShardedTensorFactory
]]:
"""Generate mapping from optimizer state ids to model sharded parameters.
Args:
model_sharded_state_dict: sharded state dict with all model sharded tensors
(can have any structure)
optim_params_iter: iterable which iterates over model parameters tracked by the optimizer.
The iteration must be in the same order as in the optimizer parameters.
Returns:
Dict[int, Union[ShardedTensor, ShardedTensorFactory]]: mapping from optimizer state ids
to model sharded parameters.
"""
model_sharded_state_dict
,
_
=
extract_sharded_tensors_and_factories
(
model_sharded_state_dict
)
id_to_sharded_param_map
=
{}
param_to_id_map
=
get_optim_param_to_id_map
(
optim_params_iter
)
# If using PyTorch FSDP2 the values in model_sharded_state_dict would
# have been converted to local tensors during initialization.
# See the make_(tp)_sharded_tensor_for_checkpoint functions.
for
ten
in
nested_values
(
model_sharded_state_dict
):
if
id
(
ten
.
data
)
in
param_to_id_map
:
id_to_sharded_param_map
[
param_to_id_map
[
id
(
ten
.
data
)]]
=
ten
else
:
logger
.
debug
(
f
'
{
ten
}
is not tracked by the optimizer'
)
if
not
id_to_sharded_param_map
:
logger
.
warning
(
"Sharded parameters mapping is empty. It means tensors in model state dict"
" do not correspond to tensors in optimizer parameters map."
" Make sure to call state_dict with `keep_vars=True`."
)
return
id_to_sharded_param_map
def
make_sharded_optimizer_tensor
(
model_param
:
Union
[
ShardedTensor
,
ShardedTensorFactory
],
optim_param
:
torch
.
Tensor
,
prefix
:
str
)
->
Union
[
ShardedTensor
,
ShardedTensorFactory
]:
"""Build a ShardedTensor or ShardedTensorFactory for optimizer param based on model param
Args:
model_param (Union[ShardedTensor, ShardedTensorFactory]): model param
optim_param (torch.Tensor): corresponding optimizer param
prefix (str): optimizer prefix for the ShardedTensor or ShardedTensorFactory
Returns:
Union[ShardedTensor, ShardedTensorFactory]: wrapped optimizer parameter
"""
optim_param
=
to_local_if_dtensor
(
optim_param
)
if
isinstance
(
model_param
,
ShardedTensorFactory
):
return
replace
(
model_param
,
key
=
f
'
{
prefix
}
.
{
model_param
.
key
}
'
,
data
=
optim_param
)
assert
tuple
(
optim_param
.
shape
)
==
model_param
.
local_shape
,
(
f
'Optimizer shape (
{
tuple
(
optim_param
.
shape
)
}
does not match model shape '
f
'(
{
model_param
.
local_shape
}
)'
)
sh_ten
=
replace
(
model_param
,
key
=
f
'
{
prefix
}
.
{
model_param
.
key
}
'
,
data
=
optim_param
,
dtype
=
optim_param
.
dtype
)
sh_ten
.
validate_metadata_integrity
()
return
sh_ten
def
optim_state_to_sharding_state
(
optim_state_dict
:
StateDict
,
id_to_sharded_param_map
:
Dict
[
int
,
ShardedTensor
],
exclude_keys
:
Tuple
[
str
]
=
(),
):
"""Turn optimizer state dict to sharded state dict based on model state dict *in-place*.
Can be used to add sharding information to most common optimizer state dict.
Creates separate ShardedTensors for each key in `optim_state_dict['state']`
(e.g. for torch.optim.Adam there will be separate tensors for `exp_avg` and `exp_avg_sq`)
Args:
optim_state_dict (StateDict): optimizer state dict with
state parameters under `state` key and group hyperparameters under
`param_groups` -> `params` key.
id_to_sharded_param_map (Dict[int, ShardedTensor]): mapping from optimizer param ids
to model sharded tensors. Can be generated with `get_param_id_to_sharded_param_map`
function.
exclude_keys (Tuple[str]): optimizer state keys to exclude from the final state dict.
Returns:
None: state dict is modified in place
"""
sharded_state
=
{}
for
param_id
,
param_state
in
optim_state_dict
[
'state'
].
items
():
sharded_state
[
param_id
]
=
{}
for
state_key
,
param
in
param_state
.
items
():
if
state_key
in
exclude_keys
:
continue
if
param_id
in
id_to_sharded_param_map
:
sharded_state
[
param_id
][
state_key
]
=
make_sharded_optimizer_tensor
(
id_to_sharded_param_map
[
param_id
],
param
,
prefix
=
f
'optimizer.state.
{
state_key
}
'
)
else
:
raise
ValueError
(
f
'Param id
{
param_id
}
does not match any model sharded param'
)
optim_state_dict
[
'param_groups'
]
=
deepcopy
(
optim_state_dict
[
'param_groups'
])
for
group
in
optim_state_dict
[
'param_groups'
]:
group
[
'params'
]
=
LocalNonpersistentObject
(
group
[
'params'
])
optim_state_dict
[
'state'
]
=
sharded_state
Megatron-LM/megatron/core/dist_checkpointing/serialization.py
0 → 100644
View file @
1106877d
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Entrypoints for saving and loading the distributed checkpoints.
Functions `load` and `save` are equivalents of `torch.load` and `torch.save`
but expect torch.Tensors to be wrapped with classes from the `mapping module`.
Additionally, `load` expects the sharded state dict argument as a guidance for
loading the sharded tensors.
"""
import
logging
from
pathlib
import
Path
from
typing
import
Callable
,
Dict
,
Optional
,
Set
,
Tuple
,
Union
import
torch
from
megatron.core.msc_utils
import
MultiStorageClientFeature
from
.
import
ShardedTensor
from
.core
import
CheckpointingConfig
,
save_config
from
.dict_utils
import
extract_matching_values
,
merge
from
.mapping
import
(
CheckpointingException
,
CommonStateDict
,
ShardedObject
,
ShardedStateDict
,
StateDict
,
apply_factory_merges
,
)
from
.state_dict_utils
import
load_preprocess
,
save_preprocess
from
.strategies.async_utils
import
AsyncRequest
from
.strategies.base
import
(
AsyncSaveShardedStrategy
,
LoadCommonStrategy
,
LoadShardedStrategy
,
SaveCommonStrategy
,
SaveShardedStrategy
,
StrategyAction
,
get_default_strategy
,
)
from
.utils
import
extract_sharded_base
from
.validation
import
(
StrictHandling
,
determine_global_metadata
,
parse_strict_flag
,
validate_integrity_and_strict_load
,
validate_sharded_objects_handling
,
verify_checkpoint_and_load_strategy
,
)
logger
=
logging
.
getLogger
(
__name__
)
# flat state dict with sharded objects without any data
CkptShardedMetadata
=
Dict
[
str
,
Union
[
ShardedTensor
,
ShardedObject
]]
_CONTENT_METADATA_KEY
=
'content_metadata'
def
load
(
sharded_state_dict
:
ShardedStateDict
,
checkpoint_dir
:
str
,
sharded_strategy
:
Union
[
LoadShardedStrategy
,
Tuple
[
str
,
int
],
None
]
=
None
,
common_strategy
:
Union
[
LoadCommonStrategy
,
Tuple
[
str
,
int
],
None
]
=
None
,
validate_access_integrity
:
bool
=
True
,
strict
:
Union
[
str
,
StrictHandling
]
=
StrictHandling
.
ASSUME_OK_UNEXPECTED
,
)
->
Union
[
StateDict
,
Tuple
[
StateDict
,
Set
[
str
],
Set
[
str
]]]:
"""Loading entrypoint.
In the steps below, the following verbs refer to corresponding objects:
- load = load from checkpoint
- extract = extract from sharded_state_dict
- add = add to the final state dict
Steps:
1. Load common state dict and form the base of the result state dict
2. Apply factories to sharded_state_dict
3. Extract LocalNonPersistentObject and add
4. (optional) Extract ShardedObjects, load and add
5. Extract ShardedBase, load, apply factory merges and add
Args:
sharded_state_dict (ShardedStateDict): state dict of the existing model
populated with ShardedTensors. Used as a mapping to determine which
parts of global tensors stored in the checkpoint should be loaded.
checkpoint_dir (str): directory with the checkpoint
sharded_strategy (LoadShardedStrategy, Tuple[str, int], optional):
configures loading behavior for sharded tensors
common_strategy (LoadCommonStrategy, Tuple[str, int], optional):
configures loading behavior for common data
validate_access_integrity (bool default = True): checks if each tensor shard is accessed
exactly once (as main replica) by some process
strict (StrictHandling, str, optional): determines the behavior in case of a mismatch
between the requested sharded state dict and the checkpoint. See `StrictHandling` docs
for more details. Some values affect the return value of this function
(missing and unexpected keys are returned).
Defaults to `True` (StrictHandling.ASSUME_OK_UNEXPECTED) which doesn't
incur any performance overhead. Other recommended values
are: `False` (StrictHandling.LOG_UNEXPECTED) which logs only unexpected keys
or `StrictHandling.RETURN_ALL` which returns all mismatch keys.
Returns:
StateDict or Tuple[StateDict, Set[str], Set[str]]: in most cases only
the loaded state dict is returned. If `strict` flag was set to
"""
sharded_strategy
,
common_strategy
=
verify_checkpoint_and_load_strategy
(
checkpoint_dir
,
sharded_strategy
,
common_strategy
)
common_state_dict
=
common_strategy
.
load_common
(
checkpoint_dir
)
sharded_state_dict
,
nonpersistent_state_dict
,
sh_ten_factories
=
load_preprocess
(
sharded_state_dict
)
merge
(
common_state_dict
,
nonpersistent_state_dict
)
# At this point we are only dealing with ShardedBase objects
sharded_state_dict
,
_
=
extract_sharded_base
(
sharded_state_dict
)
# Validation
ckpt_sharded_metadata
=
None
local_metadata
,
global_metadata
=
None
,
None
strict
=
parse_strict_flag
(
strict
)
if
StrictHandling
.
requires_explicit_ckpt_mismatch_check
(
strict
):
ckpt_sharded_metadata
=
load_sharded_metadata
(
checkpoint_dir
,
sharded_strategy
,
common_strategy
# type: ignore[arg-type]
)
if
validate_access_integrity
or
StrictHandling
.
requires_global_app_metadata
(
strict
):
local_metadata
,
global_metadata
=
determine_global_metadata
(
sharded_state_dict
)
sharded_state_dict
,
missing_keys
,
unexpected_keys
=
validate_integrity_and_strict_load
(
sharded_state_dict
,
strict
,
validate_access_integrity
,
local_metadata
,
global_metadata
,
ckpt_sharded_metadata
,
)
# ShardedBase loading
if
not
sharded_strategy
.
can_handle_sharded_objects
:
validate_sharded_objects_handling
(
sharded_strategy
,
common_strategy
)
sharded_objects_state_dict
,
sharded_state_dict
=
extract_matching_values
(
sharded_state_dict
,
lambda
v
:
isinstance
(
v
,
ShardedObject
)
)
sharded_objects
=
common_strategy
.
load_sharded_objects
(
sharded_objects_state_dict
,
checkpoint_dir
)
merge
(
common_state_dict
,
sharded_objects
)
loaded_state_dict
=
sharded_strategy
.
load
(
sharded_state_dict
,
checkpoint_dir
)
merge
(
common_state_dict
,
loaded_state_dict
)
loaded_state_dict
=
apply_factory_merges
(
common_state_dict
,
sh_ten_factories
)
if
StrictHandling
.
requires_returning_mismatch_keys
(
strict
):
return
common_state_dict
,
missing_keys
,
unexpected_keys
else
:
return
common_state_dict
def
load_common_state_dict
(
checkpoint_dir
:
Union
[
str
,
Path
])
->
StateDict
:
"""Load common (non-sharded) objects state dict from the checkpoint.
Args:
checkpoint_dir (str): checkpoint directory
Returns:
StateDict: state dict with non-sharded objects from the checkpoint
"""
if
isinstance
(
checkpoint_dir
,
Path
):
checkpoint_dir
=
str
(
checkpoint_dir
)
logger
.
warning
(
"DEPRECATED: Passing 'checkpoint_dir' as a Path object in load_common_state_dict will "
"no longer be supported in a future release. Please pass it as a string instead."
)
sharded_strategy
,
common_strategy
=
verify_checkpoint_and_load_strategy
(
checkpoint_dir
)
return
common_strategy
.
load_common
(
checkpoint_dir
)
def
load_tensors_metadata
(
checkpoint_dir
:
str
,
sharded_strategy
:
Union
[
LoadShardedStrategy
,
None
]
=
None
)
->
CkptShardedMetadata
:
"""Load tensors metadata from the checkpoint.
Returns a dictionary similar to a sharded state dict, but note that
the dictionary keys are simply ShardedTensor keys (contrary to the
actual sharded state dicts where keys correspond to state dict keys).
Dict values are ShardedTensors without any sharding (so, the only useful
information is tensors global shape and dtype).
Concrete implementation depends on the loading strategy. If no strategy is
given, a default for a given backend is used.
Args:
checkpoint_dir (str): checkpoint directory to load from
sharded_strategy (LoadShardedStrategy, optional): sharded strategy to load metadata.
Defaults to None - in this case a default load strategy for a given checkpoint type
is used.
Returns:
CkptShardedMetadata: flat state dict without data describing ShardedTensors
in the checkpoint
"""
sharded_strategy
,
common_strategy
=
verify_checkpoint_and_load_strategy
(
checkpoint_dir
,
sharded_strategy
)
return
sharded_strategy
.
load_tensors_metadata
(
Path
(
checkpoint_dir
))
def
load_sharded_metadata
(
checkpoint_dir
:
str
,
sharded_strategy
:
Union
[
LoadShardedStrategy
,
None
]
=
None
,
common_strategy
:
Union
[
LoadCommonStrategy
,
None
]
=
None
,
)
->
CkptShardedMetadata
:
"""Load sharded metadata from the checkpoint.
Similar to `load_tensors_metadata`, but includes also ShardedObjects.
Returns a dictionary similar to a sharded state dict, but note that
the dictionary keys are simply ShardedTensor keys (contrary to the
actual sharded state dicts where keys correspond to state dict keys).
Dict values are ShardedTensors without any sharding (so, the only useful
information is tensors global shape and dtype).
Concrete implementation depends on the loading strategy. If no strategy is
given, a default for a given backend is used.
Args:
checkpoint_dir (str): checkpoint directory to load from
sharded_strategy (LoadShardedStrategy, optional): sharded strategy to load metadata.
Defaults to None - in this case a default load strategy for a given checkpoint type
is used.
common_strategy (LoadCommonStrategy, optional): common strategy to load metadata.
Defaults to None - in this case a default load strategy for a given checkpoint type is
used. This strategy won't be used unless `sharded_strategy` can't handle ShardedObjects
Returns:
CkptShardedMetadata: flat state dict without data describing ShardedTensors
and ShardedObjects in the checkpoint
"""
sharded_strategy
,
common_strategy
=
verify_checkpoint_and_load_strategy
(
checkpoint_dir
,
sharded_strategy
,
common_strategy
)
sharded_metadata
=
sharded_strategy
.
load_sharded_metadata
(
checkpoint_dir
)
if
not
sharded_strategy
.
can_handle_sharded_objects
:
validate_sharded_objects_handling
(
sharded_strategy
,
common_strategy
)
common_metadata
=
common_strategy
.
load_sharded_metadata
(
checkpoint_dir
)
sharded_metadata
=
merge
(
sharded_metadata
,
common_metadata
)
return
sharded_metadata
def
load_plain_tensors
(
checkpoint_dir
:
str
)
->
StateDict
:
"""Load checkpoint tensors without any sharding and plain structure.
NOTE: common state dict is NOT included.
Args:
checkpoint_dir (str): checkpoint directory to load the tensors from.
Returns:
StateDict: checkpoint state dict containing only torch.Tensors.
"""
sharded_state_dict
=
load_tensors_metadata
(
checkpoint_dir
)
# Don't validate integrity because shards will be overlapped
# if world_size > 1 (all processes load whole tensors)
return
load
(
sharded_state_dict
,
checkpoint_dir
,
validate_access_integrity
=
False
)
def
load_content_metadata
(
checkpoint_dir
:
Optional
[
str
]
=
None
,
*
,
preloaded_state_dict
:
Optional
[
StateDict
]
=
None
)
->
Optional
[
dict
]:
"""Load content metadata stored in the checkpoint with `save(..., content_metadata=...)`.
Args:
checkpoint_dir (str, optional): checkpoint directory to load the content metadata from.
preloaded_state_dict (StateDict, optional): if the state dict was already loaded,
can be provided to avoid double load from storage
Returns:
dict: checkpoint content metadata
None: in case there is no content metadata in the checkpoint
"""
if
preloaded_state_dict
is
None
:
if
checkpoint_dir
is
None
:
raise
ValueError
(
'Both checkpoint_dir and loaded_state_dict cannot be None'
)
preloaded_state_dict
=
load_common_state_dict
(
checkpoint_dir
)
return
preloaded_state_dict
.
get
(
_CONTENT_METADATA_KEY
)
def
remove_sharded_tensors
(
checkpoint_dir
:
str
,
key_prefix
:
str
):
"""determine the appropriate sharding strategy and delegate removal to the sharded strategy"""
sharded_strategy
,
common_strategy
=
verify_checkpoint_and_load_strategy
(
checkpoint_dir
)
sharded_strategy
.
remove_sharded_tensors
(
checkpoint_dir
,
key_prefix
)
def
save
(
sharded_state_dict
:
ShardedStateDict
,
checkpoint_dir
:
str
,
sharded_strategy
:
Union
[
SaveShardedStrategy
,
Tuple
[
str
,
int
],
None
]
=
None
,
common_strategy
:
Union
[
SaveCommonStrategy
,
Tuple
[
str
,
int
],
None
]
=
None
,
validate_access_integrity
:
bool
=
True
,
async_sharded_save
:
bool
=
False
,
preprocess_common_before_consistancy_check
:
Optional
[
Callable
[[
CommonStateDict
],
StateDict
]
]
=
None
,
content_metadata
:
Optional
[
dict
]
=
None
,
)
->
Optional
[
AsyncRequest
]:
"""Saving entrypoint.
Extracts ShardedTensors from the given state dict. Rank 0 saves the
"regular" part of the checkpoint to common torch file.
The ShardedTensors are saved according to a strategy specified by the
config.
Steps:
1. Apply factories
2. Extract and discard LocalNonPersistentObject
3. Extract all ShardedBase object
4. Save all other objects to common.pt
5. (optional) Extract and save ShardedObjects
6. Save all ShardedBase objects
7. Write metadata.json file with backend and version metadata.
Step (6) can be performed asynchronously (see `async_sharded_save`), in this
case the actual save is embodied in the returned async request and can be
scheduled by the external caller. For async request, step (7) is added as
one of the finalization functions, so that metadata.json is written only
if the checkpoint is complete.
Args:
sharded_state_dict (ShardedStateDict): state dict of the populated with
ShardedTensors. Used as a mapping to determine how local tensors
should be saved as global tensors in the checkpoint.
checkpoint_dir (str): directory to save the checkpoint to
sharded_strategy (SaveShardedStrategy, Tuple[str, int], optional):
configures sharded tensors saving behavior and backend
common_strategy (SaveCommonStrategy, Tuple[str, int], optional):
configures common data saving behavior and backend
validate_access_integrity (bool default = True): checks if each tensor shard is accessed
exactly once (as main replica) by some process.
It also makes sure the common state dict is consistant across all ranks
async_sharded_save (bool, optional): if True, for the sharded state dict part
an async save implementation will be called, with the AsyncRequest
being returned to the caller. Note that it is the caller responsibility to
actually schedule the async save. Defaults to False.
preprocess_common_before_consistancy_check (Callable[[CommonStateDict], StateDict], None):
A callable function that will preprocess the common state dict (i.e can be used to
remove keys that we expect to be different in the state dict). The function must not
modify the original state dict
content_metadata (dict, optional): metadata to identify the checkpoint content.
Useful for framework specific versioning.
Returns:
AsyncRequest (optional): if `async_sharded_save` is True, returns
async request that should be scheduled by the caller of this function.
None otherwise.
"""
if
torch
.
distributed
.
get_rank
()
==
0
:
if
MultiStorageClientFeature
.
is_enabled
():
msc
=
MultiStorageClientFeature
.
import_package
()
checkpoint_dir_path
=
msc
.
Path
(
str
(
checkpoint_dir
))
else
:
checkpoint_dir_path
=
Path
(
checkpoint_dir
)
if
next
(
checkpoint_dir_path
.
iterdir
(),
None
)
is
not
None
:
# Don't throw exception here since this could cause a cascade of failures
# without human intervention in cases where multiple jobs are queued up.
if
torch
.
distributed
.
get_rank
()
==
0
:
logger
.
warning
(
"Overwriting old incomplete / corrupted checkpoint..."
)
if
common_strategy
is
not
None
:
raise
NotImplementedError
(
'The only supported common strategy is torch'
)
if
sharded_strategy
is
None
:
sharded_strategy
=
get_default_save_sharded_strategy
()
if
not
isinstance
(
sharded_strategy
,
SaveShardedStrategy
):
assert
isinstance
(
sharded_strategy
,
tuple
),
type
(
sharded_strategy
)
sharded_strategy
=
get_default_strategy
(
StrategyAction
.
SAVE_SHARDED
,
*
sharded_strategy
)
if
common_strategy
is
None
:
common_strategy
=
get_default_save_common_strategy
()
if
not
isinstance
(
common_strategy
,
SaveCommonStrategy
):
assert
isinstance
(
common_strategy
,
tuple
),
type
(
common_strategy
)
common_strategy
=
get_default_strategy
(
StrategyAction
.
SAVE_COMMON
,
*
common_strategy
)
if
content_metadata
is
not
None
:
sharded_state_dict
[
_CONTENT_METADATA_KEY
]
=
content_metadata
sharded_state_dict
,
state_dict
=
save_preprocess
(
sharded_state_dict
,
validate_access_integrity
,
preprocess_common_before_consistancy_check
)
common_strategy
.
save_common
(
state_dict
,
checkpoint_dir
)
if
not
sharded_strategy
.
can_handle_sharded_objects
:
validate_sharded_objects_handling
(
sharded_strategy
,
common_strategy
)
sharded_objects_state_dict
,
sharded_state_dict
=
extract_matching_values
(
sharded_state_dict
,
lambda
v
:
isinstance
(
v
,
ShardedObject
)
)
common_strategy
.
save_sharded_objects
(
sharded_objects_state_dict
,
checkpoint_dir
)
def
metadata_finalize_fn
():
if
torch
.
distributed
.
get_rank
()
==
0
:
save_config
(
CheckpointingConfig
(
sharded_strategy
.
backend
,
sharded_strategy
.
version
),
checkpoint_dir
,
)
torch
.
distributed
.
barrier
()
if
not
async_sharded_save
:
sharded_strategy
.
save
(
sharded_state_dict
,
checkpoint_dir
)
metadata_finalize_fn
()
return
None
if
not
isinstance
(
sharded_strategy
,
AsyncSaveShardedStrategy
):
raise
CheckpointingException
(
f
'Cannot apply async_save to non-async strategy
{
sharded_strategy
}
'
)
async_request
=
sharded_strategy
.
async_save
(
sharded_state_dict
,
checkpoint_dir
)
async_request
.
finalize_fns
.
append
(
metadata_finalize_fn
)
return
async_request
def
get_default_save_sharded_strategy
(
backend
:
str
=
'torch_dist'
,
version
:
int
=
1
)
->
SaveShardedStrategy
:
"""Get default save sharded strategy."""
return
get_default_strategy
(
StrategyAction
.
SAVE_SHARDED
,
backend
,
version
)
def
get_default_save_common_strategy
(
backend
:
str
=
'torch'
,
version
:
int
=
1
)
->
SaveCommonStrategy
:
"""Get default save common strategy."""
return
get_default_strategy
(
StrategyAction
.
SAVE_COMMON
,
backend
,
version
)
def
get_default_load_sharded_strategy
(
checkpoint_dir
:
str
)
->
LoadShardedStrategy
:
"""Get default load sharded strategy."""
return
verify_checkpoint_and_load_strategy
(
checkpoint_dir
)[
0
]
Megatron-LM/megatron/core/dist_checkpointing/state_dict_utils.py
0 → 100644
View file @
1106877d
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Utilities for transforming state_dict."""
from
typing
import
Callable
,
Union
from
.dict_utils
import
dict_list_map_inplace
,
extract_matching_values
from
.mapping
import
(
CommonStateDict
,
ShardedStateDict
,
ShardedTensor
,
ShardedTensorFactory
,
StateDict
,
apply_factories
,
)
from
.utils
import
extract_nonpersistent
,
extract_sharded_base
from
.validation
import
determine_global_metadata
,
validate_sharding_integrity
def
save_preprocess
(
sharded_state_dict
:
ShardedStateDict
,
validate_access_integrity
:
bool
=
True
,
preprocess_common_before_consistancy_check
:
Callable
[[
CommonStateDict
],
StateDict
]
=
None
,
):
"""Preprocesses the given state dictionary by applying factories,
discarding non-persistent data and extracting the common state dictionary.
Optionally, it can validate sharding integrity.
Args:
sharded_state_dict (ShardedStateDict): The initial state dictionary to be preprocessed.
validate_access_integrity (bool): If True, triggers validation of sharding integrity.
preprocess_common_before_consistancy_check (callable, None): A callable function
that will preprocess the common state dict (i.e can be used to remove keys
that we expect to be different in the state dict)
Returns:
Tuple[ShardedStateDict, dict]:
The preprocessed sharded state dictionary and the common state dictionary.
"""
apply_factories
(
sharded_state_dict
)
_
,
sharded_state_dict
=
extract_nonpersistent
(
sharded_state_dict
)
sharded_part
,
common_state_dict
=
extract_sharded_base
(
sharded_state_dict
)
sharded_part
=
filter_out_empty_flatten_tensor
(
sharded_part
)
if
validate_access_integrity
:
preprocessed_common_state_dict
=
common_state_dict
if
preprocess_common_before_consistancy_check
:
preprocessed_common_state_dict
=
preprocess_common_before_consistancy_check
(
common_state_dict
)
validate_sharding_integrity
(
determine_global_metadata
(
sharded_part
)[
1
],
common_state_dict
=
preprocessed_common_state_dict
,
)
return
sharded_part
,
common_state_dict
def
load_preprocess
(
sharded_state_dict
:
ShardedStateDict
):
"""Preprocesses the given state dictionary by applying factories
and extracting non-persistent data, without modifying the original dictionary.
Args:
sharded_state_dict (ShardedStateDict):
The initial state dictionary to be processed (remains unchanged).
Returns:
Tuple[ShardedStateDict, dict, dict]:
- A preprocessed copy of the sharded state dictionary.
- A dictionary containing non-persistent state data.
- A dictionary of `ShardedTensorFactory` instances.
"""
# Create a copy of sharded_state_dict as the passed in state dict may have
# references that prevent tensors from being deallocated
sharded_state_dict
,
_
=
extract_matching_values
(
sharded_state_dict
,
lambda
x
:
True
)
sharded_state_dict
=
filter_out_empty_flatten_tensor
(
sharded_state_dict
)
sh_ten_factories
,
_
=
extract_matching_values
(
sharded_state_dict
,
lambda
x
:
isinstance
(
x
,
ShardedTensorFactory
),
return_lists_as_dicts
=
True
,
)
apply_factories
(
sharded_state_dict
)
# Data inside sh_ten_factories no longer needed so delete them to reduce memory usage
dict_list_map_inplace
(
ShardedTensorFactory
.
without_data
,
sh_ten_factories
)
# Non-persistent objects
nonpersistent_state_dict
,
sharded_state_dict
=
extract_nonpersistent
(
sharded_state_dict
)
dict_list_map_inplace
(
lambda
o
:
o
.
unwrap
(),
nonpersistent_state_dict
)
return
sharded_state_dict
,
nonpersistent_state_dict
,
sh_ten_factories
def
filter_out_empty_flatten_tensor
(
sharded_state_dict
:
Union
[
dict
,
list
]):
"""
Filter out ShardedTensors with empty flatten_range.
These tensors can cause the PyTorch check in failure.
Args:
sharded_state_dict: state dict possibly containing ShardedTensor objects
"""
# Filter out ShardedTensors with empty flatten_range.
# These tensors can cause the PyTorch check in
# `TorchShardedTensor._init_from_local_shards_and_global_metadata` to fail.
# This situation may occur in custom Fully Sharded Data Parallel (FSDP) cases.
sharded_state_dict
,
_
=
extract_matching_values
(
sharded_state_dict
,
lambda
v
:
not
(
isinstance
(
v
,
ShardedTensor
)
and
v
.
flattened_range
and
v
.
flattened_range
.
start
==
v
.
flattened_range
.
stop
),
)
return
sharded_state_dict
Megatron-LM/megatron/core/dist_checkpointing/strategies/__init__.py
0 → 100644
View file @
1106877d
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Various loading and saving strategies """
from
megatron.core.dist_checkpointing.strategies.common
import
register_default_common_strategies
# We load "common" strategies by default to be always available
register_default_common_strategies
()
Megatron-LM/megatron/core/dist_checkpointing/strategies/async_utils.py
0 → 100644
View file @
1106877d
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""
This module provides an async utilities which allow to start
a checkpoint save process in the background.
"""
import
gc
import
logging
from
abc
import
ABC
,
abstractmethod
from
collections
import
deque
from
contextlib
import
contextmanager
from
queue
import
Empty
from
time
import
sleep
,
time
from
typing
import
Callable
,
Dict
,
List
,
NamedTuple
,
Optional
,
Tuple
import
torch
from
torch
import
multiprocessing
as
mp
from
..utils
import
debug_time
logger
=
logging
.
getLogger
(
__name__
)
@
contextmanager
def
_disable_gc
():
"""Temporarily disables GC."""
gc_enabled
=
gc
.
isenabled
()
try
:
if
gc_enabled
:
gc
.
disable
()
yield
finally
:
if
gc_enabled
:
gc
.
enable
()
class
AsyncRequest
(
NamedTuple
):
"""Represents an async request that needs to be scheduled for execution.
Args:
async_fn (Callable, optional): async function to call. None represents noop.
async_fn_args (Tuple): args to pass to `async_fn`.
finalize_fns (List[Callable]): list of functions to call to finalize the request.
These functions will be called synchronously after `async_fn` is done
*on all ranks*.
async_fn_kwargs (Tuple): kwargs to pass to `async_fn`.
preload_fn (Callable): preload function to stage tensors from GPU to Host.
This should be self-contained with a proper list of arguments with `partial`.
is_frozen (Bool): a flag to indicate this async request can be modified or not.
call_idx (int): index variable used to order async requests for synchronization
in preloading and writing tensors on the async caller
"""
async_fn
:
Optional
[
Callable
]
async_fn_args
:
Tuple
finalize_fns
:
List
[
Callable
]
async_fn_kwargs
:
Dict
=
{}
preload_fn
:
Callable
=
None
is_frozen
:
bool
=
False
call_idx
:
int
=
0
def
add_finalize_fn
(
self
,
fn
:
Callable
)
->
None
:
"""Adds a new finalize function to the request.
Args:
fn (Callable): function to add to the async request. This function
will be called *after* existing finalization functions.
Returns:
None
"""
if
self
.
is_frozen
:
raise
RuntimeError
(
'Cannot add finalization functions to a frozen AsyncRequest'
)
self
.
finalize_fns
.
append
(
fn
)
def
execute_sync
(
self
)
->
None
:
"""Helper to synchronously execute the request.
This logic is equivalent to what should happen in case of the async call.
"""
if
self
.
async_fn
is
not
None
:
self
.
async_fn
(
*
self
.
async_fn_args
)
torch
.
distributed
.
barrier
()
for
finalize_fn
in
self
.
finalize_fns
:
finalize_fn
()
def
freeze
(
self
)
->
'AsyncRequest'
:
"""Freezes the async request, disallowing adding new finalization functions.
Returns:
AsyncRequest: new async request with all same fields except for the
`is_frozen` flag.
"""
return
self
.
_replace
(
is_frozen
=
True
)
class
AsyncCaller
(
ABC
):
"""Wrapper around mp.Process that ensures correct semantic of distributed finalization.
Starts process asynchronously and allows checking if all processes on all ranks are done.
"""
@
abstractmethod
def
schedule_async_call
(
self
,
async_req
:
AsyncRequest
)
->
None
:
"""Schedule `async_req` with some process forking or reusing
persistent worker
This method must be called on all ranks.
Args:
async_req (AsyncRequest): `AsyncRequest` object containing to
start async process
"""
raise
NotImplementedError
(
"This should be implemented"
)
@
abstractmethod
def
is_current_async_call_done
(
self
,
blocking
:
bool
,
no_dist
:
bool
)
->
bool
:
"""Check if async save is finished on all ranks.
For semantic correctness, requires rank synchronization in each check.
This method must be called on all ranks.
Args:
blocking (bool, optional): if True, will wait until the call is done
on all ranks. Otherwise, returns immediately if at least one rank
is still active. Defaults to False.
no_dist (bool, Optional): if True, training ranks simply check its
asynchronous checkpoint writer without synchronization.
Returns:
bool: True if all ranks are done (immediately of after active wait
if `blocking` is True), False if at least one rank is still active.
"""
raise
NotImplementedError
(
"This should be implemented"
)
def
sync_all_async_calls
(
self
,
is_alive
:
int
)
->
bool
:
"""Check if all ranks have completed async checkpoint writing
Args:
is_alive (bool): if True, the current async request is not completed
Returns:
bool: True if all ranks are done, False if at least one rank is still active.
"""
ten
=
torch
.
tensor
([
is_alive
],
dtype
=
torch
.
int
,
device
=
torch
.
cuda
.
current_device
())
torch
.
distributed
.
all_reduce
(
ten
)
return
ten
[
0
]
==
0
@
abstractmethod
def
close
(
self
):
"""Terminate the async caller at exit of an application or some termination conditions"""
logger
.
info
(
f
"AsyncCaller:
{
torch
.
distributed
.
get_rank
()
}
, Destroying Async Caller"
)
def
__del__
(
self
):
raise
NotImplementedError
(
"This should be implemented"
)
class
TemporalAsyncCaller
(
AsyncCaller
):
"""Wrapper around mp.Process that ensures correct semantic of distributed finalization.
Starts process asynchronously and allows checking if all processes on all ranks are done.
"""
def
__init__
(
self
):
self
.
process
:
Optional
[
mp
.
Process
]
=
None
self
.
start_time
:
Optional
[
float
]
=
None
@
_disable_gc
()
def
schedule_async_call
(
self
,
async_req
:
AsyncRequest
)
->
None
:
"""Spawn a process with `async_fn` as the target.
This method must be called on all ranks.
Args:
async_fn (Callable, optional): async function to call. If None,
no process will be started.
async_req (AsyncRequest): `AsyncRequest` object containing to
start async process
"""
if
async_req
.
async_fn
is
None
:
return
# nothing to do
async_fn_args
=
list
(
async_req
.
async_fn_args
)
if
async_req
.
preload_fn
:
# If there's a preload_fn in `async_req`, we call this func
# to do the defined action in `async_req.preload_fn` to
# stage GPU tensors to its defined destination
async_fn_args
[
1
]
=
async_req
.
preload_fn
()
rank
=
torch
.
distributed
.
get_rank
()
start_sync
=
time
()
torch
.
cuda
.
synchronize
()
end_sync
=
time
()
logger
.
debug
(
f
"rank:
{
rank
}
, takes
{
end_sync
-
start_sync
}
to finish D2H "
)
ctx
=
mp
.
get_context
(
'fork'
)
self
.
start_time
=
time
()
self
.
process
=
ctx
.
Process
(
target
=
async_req
.
async_fn
,
args
=
async_fn_args
,
kwargs
=
async_req
.
async_fn_kwargs
)
self
.
process
.
start
()
init_time
=
time
()
logger
.
debug
(
f
"rank:
{
rank
}
, takes
{
init_time
-
self
.
start_time
}
to schedule async ckpt "
)
def
is_current_async_call_done
(
self
,
blocking
:
bool
=
False
,
no_dist
:
bool
=
False
)
->
bool
:
"""Check if async save is finished on all ranks.
For semantic correctness, requires rank synchronization in each check.
This method must be called on all ranks.
Args:
blocking (bool, optional): if True, will wait until the call is done
on all ranks. Otherwise, returns immediately if at least one rank
is still active. Defaults to False.
no_dist (bool, Optional): if True, training ranks simply check its
asynchronous checkpoint writer without synchronization.
Returns:
bool: True if all ranks are done (immediately of after active wait
if `blocking` is True), False if at least one rank is still active.
"""
# The following takes the same overhead
# as torch.distributed.barrier (single integer all-reduce)
is_alive
=
int
(
self
.
process
.
is_alive
())
if
self
.
process
is
not
None
else
0
is_done
=
not
is_alive
if
no_dist
else
self
.
sync_all_async_calls
(
is_alive
)
if
is_done
or
blocking
:
# Process join is called in the following cases
# 1. blocking == True -> regardless of is_done
# 2. blocking == False (non-blocking)
# -> is_done == True: async requests on all ranks are identified to be finished
# `self.close()` makes sure the async callers terminated
self
.
close
()
is_done
=
True
return
is_done
def
close
(
self
):
"""For TemporalAsyncCaller, this method is called explictly in `is_current_async_calls_done`
This method make sure the TemporalAsyncCaller terminated
with all its assigned async request completed
"""
if
self
.
process
:
logger
.
debug
(
f
"rank:
{
torch
.
distributed
.
get_rank
()
}
, joining self.process"
)
self
.
process
.
join
()
self
.
process
=
None
logger
.
debug
(
"TemporalAsyncCaller: Async process join finished "
f
"after
{
time
()
-
self
.
start_time
:.
2
f
}
s from forking"
)
self
.
start_time
=
None
def
__del__
(
self
):
pass
class
PersistentAsyncCaller
(
AsyncCaller
):
"""Wrapper around mp.Process that ensures correct semantic of distributed finalization.
Starts process asynchronously and allows checking if all processes on all ranks are done.
"""
def
__init__
(
self
):
self
.
process
:
mp
.
Process
=
None
self
.
start_time
:
Optional
[
float
]
=
None
ctx
=
mp
.
get_context
(
'spawn'
)
# main queue to deliver `AsyncRequest` from host to the ckpt worker
self
.
queue
:
mp
.
JoinableQueue
=
ctx
.
JoinableQueue
()
# Queue used to synchronize for the completion of preloading tensors to host
# between a trainer and ckpt worker
self
.
preload_q
:
mp
.
JoinableQueue
=
ctx
.
JoinableQueue
()
# Queue used to inform trainer when the saving is completed
self
.
comp_q
:
mp
.
Queue
=
ctx
.
Queue
()
self
.
cur_item
:
int
=
None
self
.
cur_idx
:
int
=
-
1
def
schedule_async_call
(
self
,
async_req
:
AsyncRequest
)
->
None
:
"""Put `AsyncRequest` to the Persistent Async Caller
This method must be called on all ranks.
Args:
async_fn (Callable, optional): async function to call. If None,
no process will be started.
async_req (AsyncRequest): `AsyncRequest` object containing to
schedule a checkpointing request
"""
if
async_req
.
async_fn
is
None
:
return
# nothing to do
start_sync
=
end_sync
=
None
self
.
start_time
=
time
()
if
self
.
process
is
None
:
ctx
=
mp
.
get_context
(
'spawn'
)
logger
.
info
(
f
"PersistentAsyncCaller:
{
torch
.
distributed
.
get_rank
()
}
, Starting Async Caller"
)
self
.
process
:
mp
.
Process
=
ctx
.
Process
(
target
=
PersistentAsyncCaller
.
async_loop
,
args
=
(
torch
.
distributed
.
get_rank
(),
self
.
queue
,
self
.
preload_q
,
self
.
comp_q
,
logger
.
getEffectiveLevel
(),
),
)
self
.
process
.
start
()
logger
.
info
(
f
"PersistentAsyncCaller:
{
torch
.
distributed
.
get_rank
()
}
, Started Async Caller"
)
if
async_req
.
preload_fn
:
self
.
preload_q
.
put
(
async_req
.
call_idx
)
self
.
queue
.
put
(
async_req
)
logger
.
debug
(
f
"rank:
{
torch
.
distributed
.
get_rank
()
}
, put
{
async_req
.
call_idx
}
"
)
if
async_req
.
preload_fn
:
start_sync
=
time
()
# Synchronize for pre-staging tensors
self
.
preload_q
.
join
()
end_sync
=
time
()
logger
.
debug
(
f
"rank:
{
torch
.
distributed
.
get_rank
()
}
, "
f
"takes
{
end_sync
-
start_sync
}
to finish D2H "
)
init_time
=
time
()
logger
.
debug
(
f
"rank:
{
torch
.
distributed
.
get_rank
()
}
, takes
{
init_time
-
self
.
start_time
}
"
"to schedule async ckpt "
)
def
is_current_async_call_done
(
self
,
blocking
:
bool
=
False
,
no_dist
:
bool
=
False
)
->
bool
:
"""Check if async save is finished on all ranks.
For semantic correctness, requires rank synchronization in each check.
This method must be called on all ranks.
Args:
blocking (bool, optional): if True, will wait until the call is done
on all ranks. Otherwise, returns immediately if at least one rank
is still active. Defaults to False.
no_dist (bool, Optional): if True, training ranks simply check its
asynchronous checkpoint writer without synchronization.
Returns:
bool: True if all ranks are done (immediately of after active wait
if `blocking` is True), False if at least one rank is still active.
"""
is_alive
:
bool
=
False
if
self
.
process
:
while
self
.
cur_item
is
None
:
try
:
# Retrieve comp call_idx without waiting
self
.
cur_item
=
self
.
comp_q
.
get_nowait
()
except
Empty
:
# This method is called after any `AsyncRequest` is pushed to the main loop
# So, the background writing is still active
# before the worker put call_idx to `comp_q`
if
not
blocking
:
is_alive
=
True
break
sleep
(
0.1
)
if
self
.
cur_item
is
not
None
:
logger
.
debug
(
f
"rank:
{
torch
.
distributed
.
get_rank
()
}
, item:
{
self
.
cur_item
}
"
f
" is completed,
{
is_alive
}
"
)
is_done
=
not
is_alive
if
no_dist
else
self
.
sync_all_async_calls
(
is_alive
)
# This is set to False when blocking == False so this routine is called again
# to simply call `sync_all_async_calls` to check if other ranks complete the writing
if
is_done
:
# The current request is completed globally. Reset the current item for polling.
logger
.
debug
(
f
"rank:
{
torch
.
distributed
.
get_rank
()
}
, item:
{
self
.
cur_item
}
"
f
" is completed globally,
{
is_done
}
"
)
self
.
cur_item
=
None
return
is_done
def
close
(
self
):
"""Wait on the left async requests and terminate the PersistentAsyncCaller
Signals the PersistentAsyncCaller by sending a 'DONE' message to make it terminated
"""
logger
.
info
(
f
"PersistentAsyncCaller:
{
torch
.
distributed
.
get_rank
()
}
, Destroying Async Caller"
)
if
self
.
process
:
self
.
queue
.
put
(
'DONE'
)
self
.
queue
.
join
()
self
.
process
.
join
()
self
.
process
=
None
def
__del__
(
self
):
self
.
close
()
@
staticmethod
@
_disable_gc
()
def
async_loop
(
rank
:
int
,
queue
:
mp
.
JoinableQueue
,
preload_q
:
mp
.
JoinableQueue
,
comp_q
:
mp
.
Queue
,
log_level
:
int
=
logging
.
INFO
,
):
"""Main function for the persistent checkpoint worker
The persisent worker is created once and terminated at exit or
when application calls `close()` explictily
This routine receives `AsyncRequest` and does `preload_fn` first and
put the integer value in `preload_q` to inform the trainer to proceed.
When the `async_fn` from the request` is completed (background saving is done),
it puts a integer value to `comp_q` to notify the trainer the completion.
Args:
rank (int): the rank of the trainer where the persistent worker is created.
queue (mp.JoinableQueue): the main queue used to receive `AsyncRequest
from the training rank
preload_q (mp.JoinableQueue): a queue to inform trainer that preloading of tensors
from GPU to Host or dedicated location is completed
comp_q (mp.Queue): a queue to inform the training rank the completion of scheduled
async checkpoint request
log_level (int, Optional): an integer to set log-level in this spawned process
to get aligned with the training rank's logging level
"""
logger
=
logging
.
getLogger
(
__name__
)
logger
.
setLevel
(
log_level
)
logger
.
info
(
f
"PersistentAsyncCaller: persistent ckpt worker for
{
rank
}
has started"
)
while
True
:
item
=
queue
.
get
()
if
isinstance
(
item
,
str
)
and
item
==
'DONE'
:
queue
.
task_done
()
break
elif
isinstance
(
item
,
AsyncRequest
):
async_fn_args
=
list
(
item
.
async_fn_args
)
if
item
.
preload_fn
:
call_idx
=
preload_q
.
get
()
# the 2nd arg is state dict
async_fn_args
[
1
]
=
item
.
preload_fn
()
logger
.
debug
(
f
"
{
rank
}
has completed D2H of
{
call_idx
}
"
)
preload_q
.
task_done
()
item
.
async_fn
(
*
async_fn_args
,
**
item
.
async_fn_kwargs
)
logger
.
debug
(
f
"
{
rank
}
has completed saving
{
item
.
call_idx
}
"
)
comp_q
.
put
(
item
.
call_idx
)
queue
.
task_done
()
logger
.
info
(
f
"PersistentAsyncCaller: persistent ckpt worker for
{
rank
}
has terminated"
)
class
_ActiveAsyncRequest
(
NamedTuple
):
"""Helper to represent an active async call.
Args:
idx (int): index of the call (starting from 0)
async_caller (DistributedAsyncCaller): async caller instance that represents
the async process handling the async request
async_request (AsyncRequest): async request that is being called
"""
idx
:
int
async_caller
:
AsyncCaller
async_request
:
AsyncRequest
class
AsyncCallsQueue
:
"""Manages a queue of async calls.
Allows adding a new async call with `schedule_async_request` and finalizing
active calls with `maybe_finalize_async_calls`.
"""
def
__init__
(
self
,
persistent
:
bool
=
False
):
self
.
async_calls
:
deque
[
_ActiveAsyncRequest
]
=
deque
([])
self
.
call_idx
:
int
=
-
1
self
.
persistent
:
bool
=
persistent
self
.
persistent_caller
:
AsyncCaller
=
None
def
_get_async_caller
(
self
):
if
not
self
.
persistent
:
return
TemporalAsyncCaller
()
if
self
.
persistent_caller
is
None
:
self
.
persistent_caller
=
PersistentAsyncCaller
()
return
self
.
persistent_caller
def
schedule_async_request
(
self
,
async_request
:
AsyncRequest
)
->
int
:
"""Start a new async call and add it to a queue of active async calls.
This method must be called on all ranks.
Args:
async_request (AsyncRequest): async request to start.
Returns:
int: index of the async call that was started.
This can help the user keep track of the async calls.
"""
self
.
call_idx
+=
1
async_caller
=
self
.
_get_async_caller
()
# Backward compatibility for local checkpointing built with the old AsyncRequest
if
len
(
async_request
.
_fields
)
!=
len
(
AsyncRequest
.
_fields
):
async_request
=
AsyncRequest
(
**
async_request
.
_asdict
())
async_request
=
async_request
.
freeze
()
async_caller
.
schedule_async_call
(
async_request
.
_replace
(
call_idx
=
self
.
call_idx
,
finalize_fns
=
[])
)
self
.
async_calls
.
append
(
_ActiveAsyncRequest
(
self
.
call_idx
,
async_caller
,
async_request
))
return
self
.
call_idx
def
maybe_finalize_async_calls
(
self
,
blocking
=
False
,
no_dist
=
False
)
->
List
[
int
]:
"""Finalizes all available calls.
This method must be called on all ranks.
Args:
blocking (bool, optional): if True, will wait until all active requests
are done. Otherwise, finalizes only the async request that already
finished. Defaults to False.
Returns:
List[int]: list of indices (as returned by `schedule_async_request`)
of async calls that have been successfully finalized.
"""
call_idx_finalized
=
[]
while
self
.
async_calls
:
next_async_done
=
self
.
async_calls
[
0
].
async_caller
.
is_current_async_call_done
(
blocking
,
no_dist
)
if
not
next_async_done
:
break
with
debug_time
(
"finalize"
,
logger
):
call_idx
,
_
,
async_request
=
self
.
async_calls
.
popleft
()
for
finalize_fn
in
async_request
.
finalize_fns
:
finalize_fn
()
ten
=
torch
.
tensor
([
call_idx
],
dtype
=
torch
.
int
,
device
=
torch
.
cuda
.
current_device
())
torch
.
distributed
.
all_reduce
(
ten
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
)
assert
ten
.
item
()
==
call_idx
,
'Unmatched async calls. '
'That probably means not all ranks are participating in async finalization'
call_idx_finalized
.
append
(
call_idx
)
return
call_idx_finalized
def
get_num_unfinalized_calls
(
self
):
"""Get the number of active async calls."""
return
len
(
self
.
async_calls
)
def
close
(
self
):
"""Finalize all calls upon closing."""
self
.
maybe_finalize_async_calls
(
blocking
=
True
)
if
self
.
persistent
and
self
.
persistent_caller
:
self
.
persistent_caller
.
close
()
Megatron-LM/megatron/core/dist_checkpointing/strategies/base.py
0 → 100644
View file @
1106877d
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Strategies base interfaces. """
from
abc
import
ABC
,
abstractmethod
from
collections
import
defaultdict
from
enum
import
Enum
from
pathlib
import
Path
from
typing
import
Any
,
DefaultDict
,
Union
from
..mapping
import
CheckpointingException
,
ShardedStateDict
,
StateDict
from
.async_utils
import
AsyncCallsQueue
,
AsyncRequest
class
StrategyAction
(
Enum
):
"""Specifies save vs load and sharded vs common action."""
LOAD_COMMON
=
'load_common'
LOAD_SHARDED
=
'load_sharded'
SAVE_COMMON
=
'save_common'
SAVE_SHARDED
=
'save_sharded'
default_strategies
:
DefaultDict
[
str
,
dict
[
tuple
,
Any
]]
=
defaultdict
(
dict
)
async_calls
=
AsyncCallsQueue
()
def
get_default_strategy
(
action
:
StrategyAction
,
backend
:
str
,
version
:
int
):
"""Retrieves a default strategy for a given action, backend and version."""
error_hint
:
str
=
""
try
:
if
backend
==
'zarr'
:
error_hint
=
' Please install `zarr` and `tensorstore!=0.1.46` packages'
from
.tensorstore
import
register_default_tensorstore_strategies
register_default_tensorstore_strategies
()
from
.zarr
import
register_default_zarr_strategies
register_default_zarr_strategies
()
elif
backend
==
'torch_dist'
:
error_hint
=
' Please use PyTorch version >=2.1'
from
.torch
import
register_default_torch_strategies
register_default_torch_strategies
()
except
ImportError
as
e
:
raise
CheckpointingException
(
f
'Cannot import a default strategy for:
{
(
action
.
value
,
backend
,
version
)
}
. '
f
'Error:
{
e
}
. Hint:
{
error_hint
}
'
)
from
e
try
:
return
default_strategies
[
action
.
value
][(
backend
,
version
)]
except
KeyError
as
e
:
raise
CheckpointingException
(
f
'Cannot find a default strategy for:
{
(
action
.
value
,
backend
,
version
)
}
'
)
from
e
def
register_default_strategy
(
action
:
StrategyAction
,
backend
:
str
,
version
:
int
,
strategy
:
Union
[
'SaveStrategyBase'
,
'LoadStrategyBase'
],
):
"""Adds a given strategy to the registry of default strategies.
Args:
action (StrategyAction): specifies save/load and sharded/common
backend (str): backend that the strategy becomes a default for
version (int): version that the strategy becomes a default for
strategy (SaveStrategyBase, LoadStrategyBase): strategy to register
"""
default_strategies
[
action
.
value
][(
backend
,
version
)]
=
strategy
class
LoadStrategyBase
(
ABC
):
"""Base class for a load strategy. Requires implementing checks for compatibility with a
given checkpoint version."""
@
abstractmethod
def
check_backend_compatibility
(
self
,
loaded_backend
):
"""Verifies if this strategy is compatible with `loaded_backend`."""
raise
NotImplementedError
@
abstractmethod
def
check_version_compatibility
(
self
,
loaded_version
):
"""Verifies if this strategy is compatible with `loaded_version`."""
raise
NotImplementedError
@
property
def
can_handle_sharded_objects
(
self
):
"""Returns whether or not this strategy can handle loading ShardedObjects."""
return
False
class
SaveStrategyBase
(
ABC
):
"""Base class for a save strategy. Requires defining a backend type and
version of the saved format."""
def
__init__
(
self
,
backend
:
str
,
version
:
int
):
self
.
backend
=
backend
self
.
version
=
version
@
property
def
can_handle_sharded_objects
(
self
):
"""Returns whether or not this strategy can handle saving ShardedObjects."""
return
False
def
__str__
(
self
):
return
f
'
{
self
.
__class__
.
__name__
}
(
{
self
.
backend
}
,
{
self
.
version
}
)'
class
LoadCommonStrategy
(
LoadStrategyBase
):
"""Load strategy for common (non-sharded) objects"""
@
abstractmethod
def
load_common
(
self
,
checkpoint_dir
:
Union
[
str
,
Path
]):
"""Load common part of the checkpoint."""
raise
NotImplementedError
@
abstractmethod
def
load_sharded_objects
(
self
,
sharded_objects_state_dict
:
ShardedStateDict
,
checkpoint_dir
:
Union
[
str
,
Path
]
):
"""Load sharded objects from the checkpoint."""
raise
NotImplementedError
def
load_sharded_metadata
(
self
,
checkpoint_dir
:
Union
[
str
,
Path
])
->
ShardedStateDict
:
"""Load just the metadata from the checkpoint."""
if
not
self
.
can_handle_sharded_objects
:
return
{}
raise
NotImplementedError
class
LoadShardedStrategy
(
LoadStrategyBase
):
"""Load strategy for sharded tensors"""
@
abstractmethod
def
load
(
self
,
sharded_state_dict
:
ShardedStateDict
,
checkpoint_dir
:
Union
[
str
,
Path
]):
"""Load the sharded part of the checkpoint."""
raise
NotImplementedError
@
abstractmethod
def
load_tensors_metadata
(
self
,
checkpoint_dir
:
Union
[
str
,
Path
]):
"""Load tensors metadata from the checkpoint for ShardedTensors.
Returns a dictionary similar to a sharded state dict, but note that
the dictionary keys are simply ShardedTensor keys (contrary to the
actual sharded state dicts where keys correspond to state dict keys).
Dict values are ShardedTensors without any data and sharding (so, the
only useful information is tensors global shape and dtype).
"""
raise
NotImplementedError
(
f
'Loading only tensors metadata not implemented for
{
self
.
__class__
.
__name__
}
'
)
def
load_sharded_metadata
(
self
,
checkpoint_dir
:
Union
[
str
,
Path
]):
"""Load sharded metadata from the checkpoint for ShardedTensors and ShardedObjects.
Returns a dictionary similar to a sharded state dict, but note that
the dictionary keys are simply sharded keys (contrary to the
actual sharded state dicts where keys correspond to state dict keys).
Dict values are ShardedTensors or ShardedObjects without any data and sharding.
"""
if
not
self
.
can_handle_sharded_objects
:
return
self
.
load_tensors_metadata
(
checkpoint_dir
)
raise
NotImplementedError
(
f
'Loading only sharded metadata not implemented for
{
self
.
__class__
.
__name__
}
'
)
def
remove_sharded_tensors
(
self
,
checkpoint_dir
:
Union
[
str
,
Path
],
key_prefix
:
str
):
"""Remove all tensors whose key starts with key_prefix"""
raise
NotImplementedError
class
SaveCommonStrategy
(
SaveStrategyBase
):
"""Save strategy for common (non-sharded) objects"""
@
abstractmethod
def
save_common
(
self
,
common_state_dict
:
StateDict
,
checkpoint_dir
:
Union
[
str
,
Path
]):
"""Save common part of the state dict."""
raise
NotImplementedError
def
save_sharded_objects
(
self
,
sharded_objects_state_dict
:
ShardedStateDict
,
checkpoint_dir
:
Union
[
str
,
Path
]
):
"""Save sharded objects from the state dict."""
raise
NotImplementedError
class
SaveShardedStrategy
(
SaveStrategyBase
):
"""Save strategy for sharded tensors"""
@
abstractmethod
def
save
(
self
,
sharded_state_dict
:
ShardedStateDict
,
checkpoint_dir
:
Union
[
str
,
Path
]):
"""Save the sharded part of the state dict."""
raise
NotImplementedError
class
AsyncSaveShardedStrategy
(
SaveShardedStrategy
):
"""Save strategy suitable for async save."""
@
abstractmethod
def
async_save
(
self
,
sharded_state_dict
:
ShardedStateDict
,
checkpoint_dir
:
Union
[
str
,
Path
]
)
->
AsyncRequest
:
"""Perform preparation and return an AsyncRequest to the external caller.
Args:
sharded_state_dict (ShardedStateDict): sharded state dict to save
checkpoint_dir (Path): checkpoint target directory
Returns:
AsyncRequest: represents the async save function and finalization function.
It is the caller responsibility to actually schedule the async save.
"""
raise
NotImplementedError
def
save
(
self
,
sharded_state_dict
:
ShardedStateDict
,
checkpoint_dir
:
Union
[
str
,
Path
]):
"""Each async strategy can be trivially used as a sync strategy."""
async_request
=
self
.
async_save
(
sharded_state_dict
,
checkpoint_dir
)
# multiprocessing routines may cause issue when called on parent process
# We keep this verbose call for now
global
async_calls
async_calls
.
schedule_async_request
(
async_request
)
async_calls
.
maybe_finalize_async_calls
(
blocking
=
True
)
Megatron-LM/megatron/core/dist_checkpointing/strategies/cached_metadata_filesystem_reader.py
0 → 100644
View file @
1106877d
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
""" FS Reader with metadata cached support. """
import
os
from
typing
import
Union
from
torch.distributed.checkpoint
import
FileSystemReader
,
Metadata
class
CachedMetadataFileSystemReader
(
FileSystemReader
):
"""
Extends FileSystemReader to cache metadata for improved performance.
Attributes:
_cached_metadata (Metadata or None): Cached metadata from the file system.
"""
def
__init__
(
self
,
path
:
Union
[
str
,
os
.
PathLike
])
->
None
:
"""
Initialize with file system path.
Args:
path (Union[str, os.PathLike]): Path to the checkpoint directory or file.
"""
super
().
__init__
(
path
=
path
)
self
.
_cached_metadata
=
None
def
read_metadata
(
self
)
->
Metadata
:
"""
Read metadata from file system, caching for subsequent calls.
Returns:
Metadata: Checkpoint metadata.
"""
if
self
.
_cached_metadata
is
None
:
self
.
_cached_metadata
=
super
().
read_metadata
()
return
self
.
_cached_metadata
Megatron-LM/megatron/core/dist_checkpointing/strategies/common.py
0 → 100644
View file @
1106877d
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
""" Common strategies. """
import
logging
import
os
from
pathlib
import
Path
from
typing
import
Union
import
torch
from
megatron.core.dist_checkpointing.mapping
import
ShardedStateDict
,
StateDict
from
megatron.core.dist_checkpointing.strategies.base
import
(
SaveCommonStrategy
,
StrategyAction
,
register_default_strategy
,
)
from
megatron.core.msc_utils
import
MultiStorageClientFeature
from
..dict_utils
import
dict_list_map_inplace
,
nested_values
from
..mapping
import
CheckpointingException
,
ShardedObject
,
is_main_replica
from
..strategies.base
import
LoadCommonStrategy
COMMON_STATE_FNAME
=
'common.pt'
logger
=
logging
.
getLogger
(
__name__
)
def
register_default_common_strategies
():
"""Register default common strategies."""
register_default_strategy
(
StrategyAction
.
LOAD_COMMON
,
'torch'
,
1
,
TorchCommonLoadStrategy
())
register_default_strategy
(
StrategyAction
.
SAVE_COMMON
,
'torch'
,
1
,
TorchCommonSaveStrategy
(
'torch'
,
1
)
)
class
TorchCommonSaveStrategy
(
SaveCommonStrategy
):
"""Common save strategy leveraging native torch save/load."""
def
save_common
(
self
,
common_state_dict
:
StateDict
,
checkpoint_dir
:
Union
[
str
,
Path
]):
"""Save common part of the state dict."""
if
torch
.
distributed
.
get_rank
()
==
0
:
path
=
os
.
path
.
join
(
checkpoint_dir
,
COMMON_STATE_FNAME
)
if
MultiStorageClientFeature
.
is_enabled
():
msc
=
MultiStorageClientFeature
.
import_package
()
msc
.
torch
.
save
(
common_state_dict
,
path
)
else
:
torch
.
save
(
common_state_dict
,
path
)
def
save_sharded_objects
(
self
,
sharded_objects_state_dict
:
ShardedStateDict
,
checkpoint_dir
:
Union
[
str
,
Path
]
):
"""Save sharded objects from the state dict."""
for
sh_obj
in
nested_values
(
sharded_objects_state_dict
):
if
is_main_replica
(
sh_obj
.
replica_id
):
save_path
=
os
.
path
.
join
(
checkpoint_dir
,
f
"
{
sh_obj
.
unique_key
}
.pt"
)
parent_dir
=
os
.
path
.
dirname
(
save_path
)
if
MultiStorageClientFeature
.
is_enabled
():
msc
=
MultiStorageClientFeature
.
import_package
()
msc
.
os
.
makedirs
(
parent_dir
,
exist_ok
=
True
)
msc
.
torch
.
save
(
sh_obj
.
data
,
save_path
)
else
:
os
.
makedirs
(
parent_dir
,
exist_ok
=
True
)
torch
.
save
(
sh_obj
.
data
,
save_path
)
def
can_handle_sharded_objects
(
self
):
"""This strategy can handle ShardedObjects."""
return
True
class
TorchCommonLoadStrategy
(
LoadCommonStrategy
):
"""Common load strategy leveraging native torch save/load."""
def
load_common
(
self
,
checkpoint_dir
:
Union
[
str
,
Path
]):
"""Load common (non-sharded) objects state dict from the checkpoint.
Args:
checkpoint_dir (Union[str, Path]): checkpoint directory
Returns:
StateDict: state dict with non-sharded objects from the checkpoint
"""
load_path
=
os
.
path
.
join
(
checkpoint_dir
,
COMMON_STATE_FNAME
)
try
:
if
MultiStorageClientFeature
.
is_enabled
():
msc
=
MultiStorageClientFeature
.
import_package
()
return
msc
.
torch
.
load
(
load_path
,
map_location
=
'cpu'
,
weights_only
=
False
)
else
:
return
torch
.
load
(
load_path
,
map_location
=
'cpu'
,
weights_only
=
False
)
except
FileNotFoundError
as
e
:
err_msg
=
f
'Common file
{
load_path
}
does not exist'
if
MultiStorageClientFeature
.
is_enabled
():
msc
=
MultiStorageClientFeature
.
import_package
()
ckpt_files
=
[
f
.
name
for
f
in
msc
.
Path
(
checkpoint_dir
).
iterdir
()]
else
:
ckpt_files
=
[
f
.
name
for
f
in
checkpoint_dir
.
iterdir
()]
logger
.
debug
(
f
'
{
err_msg
}
. Checkpoint directory content:
{
ckpt_files
}
'
)
raise
CheckpointingException
(
err_msg
)
from
e
def
load_sharded_objects
(
self
,
sharded_objects_state_dict
:
ShardedStateDict
,
checkpoint_dir
:
Union
[
str
,
Path
]
):
"""Replaces all ShardedObject from a given state dict with values loaded from the
checkpoint.
Args:
sharded_objects_state_dict (ShardedStateDict):
sharded state dict defining what objects should be loaded.
checkpoint_dir (Union[str, Path]): checkpoint directory
Returns:
None: sharded state dict is modified in place
"""
def
load_sharded_object
(
sh_obj
:
ShardedObject
):
sh_obj
.
data
=
None
load_path
=
os
.
path
.
join
(
checkpoint_dir
,
f
'
{
sh_obj
.
unique_key
}
.pt'
)
try
:
if
MultiStorageClientFeature
.
is_enabled
():
msc
=
MultiStorageClientFeature
.
import_package
()
loaded_obj
=
msc
.
torch
.
load
(
load_path
,
weights_only
=
False
)
else
:
loaded_obj
=
torch
.
load
(
load_path
,
weights_only
=
False
)
except
FileNotFoundError
as
e
:
# Backward compatible logic: previously the save format was incorrect
base
,
_
=
os
.
path
.
splitext
(
sh_obj
.
unique_key
)
old_load_path
=
os
.
path
.
join
(
checkpoint_dir
,
f
"
{
base
}
.pt"
)
try
:
if
MultiStorageClientFeature
.
is_enabled
():
msc
=
MultiStorageClientFeature
.
import_package
()
loaded_obj
=
msc
.
torch
.
load
(
old_load_path
,
weights_only
=
False
)
else
:
loaded_obj
=
torch
.
load
(
old_load_path
,
weights_only
=
False
)
except
FileNotFoundError
:
err_msg
=
f
'Object shard
{
load_path
}
not found'
obj_subdir
=
os
.
path
.
join
(
checkpoint_dir
,
sh_obj
.
key
)
if
os
.
path
.
exists
(
obj_subdir
):
obj_files
=
os
.
listdir
(
obj_subdir
)
logger
.
debug
(
f
'
{
err_msg
}
. Object
{
sh_obj
.
key
}
directory content:
{
obj_files
}
'
)
else
:
ckpt_files
=
os
.
listdir
(
checkpoint_dir
)
logger
.
debug
(
f
'
{
err_msg
}
. Object
{
sh_obj
.
key
}
directory does not exist. Checkpoint'
f
' directory content:
{
ckpt_files
}
'
)
raise
CheckpointingException
(
err_msg
)
from
e
return
loaded_obj
return
dict_list_map_inplace
(
load_sharded_object
,
sharded_objects_state_dict
)
def
load_sharded_metadata
(
self
,
checkpoint_dir
:
Union
[
str
,
Path
])
->
ShardedStateDict
:
if
MultiStorageClientFeature
.
is_enabled
():
msc
=
MultiStorageClientFeature
.
import_package
()
checkpoint_dir
=
msc
.
Path
(
checkpoint_dir
)
else
:
checkpoint_dir
=
Path
(
checkpoint_dir
)
sharded_metadata
=
{}
for
subdir
in
checkpoint_dir
.
iterdir
():
if
not
subdir
.
is_dir
():
continue
shard_files
=
list
(
subdir
.
glob
(
'shard_*.pt'
))
if
not
shard_files
:
continue
sh_objs
=
[]
for
shard_file
in
shard_files
:
full_key
=
f
'
{
subdir
.
name
}
/
{
shard_file
.
stem
}
'
sh_objs
.
append
(
ShardedObject
.
empty_from_unique_key
(
full_key
))
# This is a backward-compatibility fix, where the last global shape is missing in the
# name
if
sh_objs
[
0
].
global_shape
[
-
1
]
<
0
:
max_last_offset
=
max
(
map
(
lambda
sh_obj
:
sh_obj
.
global_offset
[
-
1
],
sh_objs
))
for
sh_obj
in
sh_objs
:
sh_obj
.
global_shape
=
(
*
sh_obj
.
global_shape
[:
-
1
],
max_last_offset
+
1
)
# Update the sharded state dict
for
sh_obj
in
sh_objs
:
sharded_metadata
[
sh_obj
.
unique_key
]
=
sh_obj
return
sharded_metadata
@
property
def
can_handle_sharded_objects
(
self
):
"""This strategy can handle ShardedObjects."""
return
True
def
check_backend_compatibility
(
self
,
loaded_version
):
pass
def
check_version_compatibility
(
self
,
loaded_version
):
pass
Megatron-LM/megatron/core/dist_checkpointing/strategies/filesystem_async.py
0 → 100644
View file @
1106877d
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
"""Storage writer for PyT Distributed format allowing asynchronous save."""
import
dataclasses
import
inspect
import
logging
import
os
import
pickle
import
queue
from
functools
import
partial
from
heapq
import
heappop
,
heappush
from
itertools
import
chain
from
operator
import
itemgetter
from
pathlib
import
Path
from
time
import
time
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
torch
from
torch
import
multiprocessing
as
mp
from
torch.distributed.checkpoint
import
FileSystemWriter
from
torch.distributed.checkpoint.filesystem
import
DEFAULT_SUFFIX
,
_StoragePrefix
,
_write_item
from
torch.distributed.checkpoint.metadata
import
Metadata
try
:
from
torch.distributed.checkpoint.filesystem
import
_StorageWriterTransforms
except
ImportError
:
_StorageWriterTransforms
=
Any
from
torch.distributed.checkpoint.planner
import
SavePlan
,
SavePlanner
,
WriteItem
,
WriteItemType
from
torch.distributed.checkpoint.storage
import
WriteResult
from
torch.futures
import
Future
from
.async_utils
import
_disable_gc
logger
=
logging
.
getLogger
(
__name__
)
WriteBucket
=
Tuple
[
Path
,
str
,
Tuple
[
list
,
list
]]
# represents writes to a single file
try
:
import
psutil
HAVE_PSUTIL
=
True
except
ImportError
:
HAVE_PSUTIL
=
False
_results_queue
=
None
def
_get_write_results_queue
():
global
_results_queue
if
_results_queue
is
None
:
ctx
=
mp
.
get_context
(
"spawn"
)
_results_queue
=
ctx
.
Manager
().
Queue
()
return
_results_queue
class
FileSystemWriterAsync
(
FileSystemWriter
):
"""
Async-enabled implementation of FileSystemWriter using file I/O.
This class does not spawn the async process itself but relies on an external async mechanism.
**Flow:**
1. Call `write_data`
2. Externally start an async process with `get_save_function_and_args` and its arguments.
3. The async function `writer_proxy_func` calls `write_preloaded_data` across multiple
processes.
4. Once saving is finalized on all ranks, call `super().finish` with the results stored
in `self.writer_result`.
**Note:** Step (3) can also be executed synchronously.
Currently, it is assumed that a separate writer is created for each ckpt save
(intermediate state is stored as writer attributes).
"""
def
__init__
(
self
,
path
:
Union
[
str
,
os
.
PathLike
],
*
args
,
separation_hint
:
Optional
[
str
]
=
None
,
use_msc
:
bool
=
False
,
**
kwargs
,
):
self
.
checkpoint_dir
=
path
self
.
use_msc
=
use_msc
super
().
__init__
(
path
,
*
args
,
**
kwargs
)
if
not
self
.
single_file_per_rank
:
raise
NotImplementedError
(
"single_file_per_rank flag not supported for FileSystemWriterAsync"
)
self
.
can_run_decentralized_global_plan
:
bool
=
True
# Intermediate state between preparation and finalization
self
.
write_buckets
:
Optional
[
List
[
WriteBucket
]]
=
None
self
.
results_queue
:
Optional
[
mp
.
Queue
]
=
None
self
.
separation_hint
=
separation_hint
def
prepare_write_data
(
self
,
plan
:
SavePlan
,
planner
:
SavePlanner
)
->
None
:
"""
First stage of async saving. Copy data to CPU and plan the local saving.
Args:
plan (SavePlan): save plan generated by the PyT Distributed compatible planner
planner (SavePlanner): save planner used to resolve the bytes and tensor data
Returns: None, but stores the save plan in `self.write_buckets`
"""
storage_plan
:
_StoragePrefix
=
plan
.
storage_data
start
=
time
()
logger
.
debug
(
f
"thread_count:
{
self
.
thread_count
}
, time:
{
start
}
"
)
if
self
.
separation_hint
:
assert
(
self
.
thread_count
>
1
),
"thread_count must be at least 2 if separation_hint is provided"
bins
=
self
.
thread_count
//
2
if
self
.
separation_hint
is
not
None
else
self
.
thread_count
item_buckets
=
_split_by_size_and_type
(
bins
,
plan
.
items
)
logger
.
debug
(
f
"bucket_prep, time:
{
time
()
-
start
}
"
)
start
=
time
()
# move tensors from GPU to CPU before starting async writing
# We do D2H synchronously for now
file_count
=
0
def
gen_file
(
prefix
=
""
):
nonlocal
file_count
file_name
=
f
"
{
prefix
}{
storage_plan
.
prefix
}{
file_count
}{
DEFAULT_SUFFIX
}
"
file_count
+=
1
return
file_name
def
_clone_if_needed
(
ten
:
torch
.
Tensor
):
"""Clone if we detect incontiguous storage for CPU tensors
Makes sure we perform a `clone` only if we detect incontiguous storage,
so that we don't blow up host memory unnecessarily.
TODO: For persistent worker, this work should be changed to move the cpu tensor
to shared_memory.
"""
ten
=
ten
.
detach
()
if
ten
.
device
.
type
!=
"cpu"
:
# We do D2H later when the async_request is scheduled for both sync / async
# checkpointing
return
ten
is_view
=
ten
.
untyped_storage
().
size
()
!=
ten
.
numel
()
*
ten
.
itemsize
return
ten
.
clone
()
if
is_view
else
ten
# Prepare bytes / tensor data in each bucket, which will be assigned to each writer process
self
.
write_buckets
=
[]
for
group_name
,
group_buckets
in
_split_by_separation_hint
(
item_buckets
,
self
.
separation_hint
).
items
():
for
bucket
in
group_buckets
:
bytes_data
=
[
(
item
,
planner
.
resolve_data
(
item
))
for
item
in
bucket
if
item
.
type
==
WriteItemType
.
BYTE_IO
]
tensor_data
=
[
(
item
,
_clone_if_needed
(
planner
.
resolve_data
(
item
)))
for
item
in
bucket
if
item
.
type
!=
WriteItemType
.
BYTE_IO
]
if
len
(
bytes_data
)
>
0
or
len
(
tensor_data
)
>
0
:
file_name
=
gen_file
(
prefix
=
group_name
)
self
.
write_buckets
.
append
(
(
# type: ignore[arg-type]
os
.
path
.
join
(
self
.
checkpoint_dir
,
file_name
),
file_name
,
(
bytes_data
,
tensor_data
),
)
)
# Check if there is anything to write on this rank
if
len
(
self
.
write_buckets
)
>
0
:
assert
len
(
self
.
write_buckets
)
<=
self
.
thread_count
,
(
len
(
self
.
write_buckets
),
self
.
thread_count
,
)
self
.
results_queue
=
_get_write_results_queue
()
else
:
self
.
results_queue
=
None
end
=
time
()
logger
.
debug
(
f
"D2H and push, time:
{
end
-
start
}
"
)
def
get_save_function_and_args
(
self
)
->
Tuple
[
Optional
[
Callable
],
Optional
[
Callable
],
List
]:
"""
Get function that saves the data to storage along with its arguments.
Allows the external caller to apply the save function synchronously or asynchronously.
Returns: None (if there is nothing to write on this rank) or a tuple of:
1) the function that saves the data.
2) the function that stages the GPU tensors to a destination for async checkpointing.
This function should be self-contained.
3) arguments to that function in 1).
"""
if
not
self
.
write_buckets
:
return
None
,
None
,
[]
transform_list
=
[
self
.
transforms
]
if
hasattr
(
self
,
"transforms"
)
else
[]
return
(
partial
(
self
.
write_preloaded_data_multiproc
,
transform_list
,
self
.
use_msc
),
partial
(
self
.
preload_tensors
,
self
.
write_buckets
,
True
),
[
torch
.
distributed
.
get_rank
(),
self
.
write_buckets
,
self
.
results_queue
],
)
@
staticmethod
def
preload_tensors
(
write_buckets
:
List
[
WriteBucket
],
non_blocking
=
True
)
->
List
[
WriteBucket
]:
"""
Preloads tensors in `state_dict` to host memory via CPU memory.
Args:
write_buckets (List): List of `WriteBucket` objects that define what to
save in a checkpoint.
non_blocking (bool, optional): knob to enable pinned D2H memcpy. Default is True.
"""
result
=
[]
for
bucket
in
write_buckets
:
file_name
,
storage_key
,
(
bytes_data
,
tensor_data
)
=
bucket
tensor_data
=
[
(
item
,
tensor
.
to
(
"cpu"
,
non_blocking
=
non_blocking
))
for
item
,
tensor
in
tensor_data
]
result
.
append
((
file_name
,
storage_key
,
(
bytes_data
,
tensor_data
)))
if
non_blocking
:
torch
.
cuda
.
synchronize
()
return
result
@
staticmethod
@
_disable_gc
()
def
write_preloaded_data_multiproc
(
transform_list
:
List
[
_StorageWriterTransforms
],
use_msc
:
bool
,
rank
:
int
,
write_buckets
:
List
[
WriteBucket
],
global_results_queue
:
mp
.
Queue
,
)
->
None
:
"""
Performs saving data to storage with multiple processes.
Starts predefined number of processes and uses 2 queues to make sure the results
are complete:
- local_results_queue - to send the actual results
- count_queue - small queue to mark worker as completed
Using just one queue disallowed proper exception handling.
This method is meant to be run in a forked subprocess.
Triggering GC during execution leads to CUDA errors
(cleaning up tensors owned by the parent process).
To prevent this, we disable the GC explicitly for this function with _disable_gc.
Args:
write_buckets (List[WriteBucket]): write plan
global_results_queue (mp.Queue): mp.Queue to collect Dict[List[WriteResults]]
(or an Exception) from parallel write processes to the main training process
Returns: None
"""
logger
=
logging
.
getLogger
(
__name__
)
w_start
=
time
()
write_results_or_exc
:
Union
[
dict
,
Exception
]
=
dict
()
ctx
=
mp
.
get_context
(
"fork"
)
local_results_queue
=
ctx
.
Queue
()
count_queue
=
ctx
.
JoinableQueue
()
p_list
=
[]
for
i
,
write_bucket
in
enumerate
(
write_buckets
):
try
:
count_queue
.
put
(
i
)
kwargs
=
{
"local_proc_idx"
:
i
,
"write_bucket"
:
write_bucket
,
"results_queue"
:
local_results_queue
,
"count_queue"
:
count_queue
,
"use_fsync"
:
True
,
}
if
use_msc
:
import
inspect
# Remove the inspect after the test_async_save.py is fixed.
signature
=
inspect
.
signature
(
FileSystemWriterAsync
.
write_preloaded_data
)
if
len
(
signature
.
parameters
)
>
6
:
kwargs
[
"use_msc"
]
=
use_msc
p_list
.
append
(
ctx
.
Process
(
target
=
partial
(
FileSystemWriterAsync
.
write_preloaded_data
,
transform_list
),
kwargs
=
kwargs
,
)
)
except
Exception
as
e
:
err_msg
=
f
"An error is caught while a proc
{
i
}
is created, error:
{
e
}
"
logger
.
error
(
err_msg
)
write_results_or_exc
=
RuntimeError
(
err_msg
)
if
not
isinstance
(
write_results_or_exc
,
Exception
):
for
p
in
p_list
:
p
.
start
()
logger
.
debug
(
"FileSystemWriterAsync: collecting worker results..."
)
# To make sure all nodes are completed
count_queue
.
join
()
# At this point, all workers completed, so the queue should have exactly
# `len(write_buckets)` items
for
proc_idx
in
range
(
len
(
write_buckets
)):
try
:
local_proc_idx
,
local_results_or_exc
=
local_results_queue
.
get
()
except
queue
.
Empty
:
write_results_or_exc
=
RuntimeError
(
"Unexpected empty `local_results_queue`"
f
" (got only
{
proc_idx
}
/
{
len
(
write_buckets
)
}
items)"
)
break
else
:
if
isinstance
(
local_results_or_exc
,
Exception
):
err_msg
=
(
f
"Local process
{
local_proc_idx
}
encountered"
f
" an error:
{
local_results_or_exc
}
"
)
logger
.
error
(
err_msg
)
write_results_or_exc
=
local_results_or_exc
break
assert
isinstance
(
local_results_or_exc
,
list
),
type
(
local_results_or_exc
)
write_results_or_exc
[
local_proc_idx
]
=
local_results_or_exc
p_list
[
local_proc_idx
].
join
()
logger
.
debug
(
"FileSystemWriterAsync: collected worker results successfully"
)
global_results_queue
.
put
(
write_results_or_exc
)
w_end
=
time
()
logger
.
debug
(
f
"
{
w_end
}
, rank:
{
rank
}
, write(sync,parallel):
{
w_end
-
w_start
}
"
)
@
staticmethod
@
_disable_gc
()
def
write_preloaded_data
(
transform_list
:
List
[
_StorageWriterTransforms
],
local_proc_idx
:
int
,
write_bucket
:
WriteBucket
,
results_queue
:
mp
.
SimpleQueue
,
count_queue
:
mp
.
JoinableQueue
,
use_fsync
:
bool
,
**
kwargs
,
)
->
None
:
"""
Performs actual data saving to storage.
Args:
local_proc_idx (int): index of a local process that performs writing
write_bucket (WriteBucket): data to write to storage
results_queue (mp.Queue): queue to return the write results
to the proxy checkpoint process.
count_queue (mp.JoinableQueue): queue to marks worker task as completed
use_fsync (bool): if True, calls os.fsync at the end of saving
Returns: None, the write result are put into the `queue`
"""
logger
=
logging
.
getLogger
(
__name__
)
logger
.
debug
(
f
"
{
local_proc_idx
}
started"
)
mem_before
=
_process_memory
()
use_msc
=
kwargs
.
get
(
"use_msc"
,
False
)
local_results
=
[]
try
:
file_name
,
storage_key
,
(
bytes_data
,
tensor_data
)
=
write_bucket
extra_kwargs
=
{}
if
"serialization_format"
in
inspect
.
signature
(
_write_item
).
parameters
:
from
torch.distributed.checkpoint.filesystem
import
SerializationFormat
extra_kwargs
[
"serialization_format"
]
=
SerializationFormat
.
TORCH_SAVE
if
use_msc
:
import
multistorageclient
as
msc
open_file
=
msc
.
open
else
:
open_file
=
open
with
open_file
(
file_name
,
"wb"
)
as
stream
:
for
write_item
,
data
in
bytes_data
:
local_results
.
append
(
_write_item
(
*
transform_list
,
stream
,
data
,
write_item
,
storage_key
,
**
extra_kwargs
)
)
for
write_item
,
tensor
in
tensor_data
:
assert
tensor
.
is_cpu
local_results
.
append
(
_write_item
(
*
transform_list
,
stream
,
tensor
,
write_item
,
storage_key
,
**
extra_kwargs
)
)
if
use_fsync
:
if
use_msc
:
stream
.
fsync
()
else
:
os
.
fsync
(
stream
.
fileno
())
local_output
=
(
local_proc_idx
,
local_results
)
except
Exception
as
e
:
logger
.
debug
(
f
"
{
local_proc_idx
}
failed"
)
local_output
=
(
local_proc_idx
,
e
)
# type: ignore[assignment]
results_queue
.
put
(
local_output
)
# Signal this process is done.
count_queue
.
get
()
count_queue
.
task_done
()
mem_after
=
_process_memory
()
logger
.
debug
(
f
"
{
local_proc_idx
}
consumed:
{
mem_after
-
mem_before
}
,"
f
" before:
{
mem_before
}
, after:
{
mem_after
}
"
)
def
write_data
(
self
,
plan
:
SavePlan
,
planner
:
SavePlanner
)
->
Future
[
List
[
WriteResult
]]:
"""Write all items from ``plan``."""
raise
NotImplementedError
(
"write_data not implemented for FileSystemWriterAsync"
)
def
retrieve_write_results
(
self
)
->
List
[
WriteResult
]:
"""
Turn the latest dict including write results from `self.results_queue`
into a single results lists. Includes error check.
Returns (List[WriteResult]): the list of write results
from all local processes performing the save.
"""
assert
self
.
write_buckets
is
not
None
if
self
.
results_queue
is
None
:
write_results_or_exc
=
{}
else
:
try
:
write_results_or_exc
=
self
.
results_queue
.
get_nowait
()
except
queue
.
Empty
:
raise
RuntimeError
(
"results_queue should not be empty"
)
if
isinstance
(
write_results_or_exc
,
Exception
):
raise
RuntimeError
(
f
"Worker failure:
{
write_results_or_exc
}
"
)
from
write_results_or_exc
write_results
:
dict
=
write_results_or_exc
if
len
(
write_results
)
!=
len
(
self
.
write_buckets
):
raise
RuntimeError
(
f
"Incomplete worker results (expected
{
len
(
self
.
write_buckets
)
}
,"
f
" got
{
len
(
write_results
)
}
. This probably indicates a worker failure."
)
return
list
(
chain
.
from_iterable
(
write_results
.
values
()))
def
prepare_decentralized_global_plan
(
self
,
local_plan
:
SavePlan
)
->
SavePlan
:
"""Instead of assigning indices by plan order, uses PyT rank (same outcome).
Args:
local_plan (SavePlan): local plan to turn to a global plan
(without interactions with other ranks)
Returns:
SavePlan - locally transformed plan equivalent to the plan that would be
created by the coordinator
"""
return
dataclasses
.
replace
(
local_plan
,
storage_data
=
_StoragePrefix
(
f
"__
{
torch
.
distributed
.
get_rank
()
}
_"
)
)
def
finish
(
self
,
metadata
:
Metadata
,
results
:
List
[
List
[
WriteResult
]])
->
None
:
"""
Finish the checkpointing process.
Args:
metadata (Metadata): metadata to save
results (List[List[WriteResult]]): results to save
"""
if
self
.
use_msc
:
import
multistorageclient
as
msc
storage_md
=
dict
()
for
wr_list
in
results
:
storage_md
.
update
({
wr
.
index
:
wr
.
storage_data
for
wr
in
wr_list
})
metadata
.
storage_data
=
storage_md
metadata
.
storage_meta
=
self
.
storage_meta
()
path
=
os
.
path
.
join
(
self
.
checkpoint_dir
,
".metadata"
)
with
msc
.
open
(
path
,
"wb"
)
as
metadata_file
:
pickle
.
dump
(
metadata
,
metadata_file
)
else
:
super
().
finish
(
metadata
,
results
)
def
prepare_local_plan
(
self
,
plan
:
SavePlan
)
->
SavePlan
:
"""
Prepare the local plan for the checkpointing process.
"""
if
self
.
use_msc
:
import
multistorageclient
as
msc
msc
.
os
.
makedirs
(
str
(
self
.
checkpoint_dir
),
exist_ok
=
True
)
else
:
super
().
prepare_local_plan
(
plan
)
return
plan
@
property
def
checkpoint_id
(
self
)
->
Union
[
str
,
os
.
PathLike
]:
"""
return the checkpoint_id that will be used to save the checkpoint.
"""
return
str
(
self
.
checkpoint_dir
)
@
classmethod
def
validate_checkpoint_id
(
cls
,
checkpoint_id
:
Union
[
str
,
os
.
PathLike
])
->
bool
:
"""
Validate the checkpoint_id that will be used to save the checkpoint.
This method is available in PyTorch 2.3 and above.
"""
if
checkpoint_id
.
startswith
(
"msc://"
):
return
True
if
hasattr
(
FileSystemWriter
,
"validate_checkpoint_id"
):
return
FileSystemWriter
.
validate_checkpoint_id
(
checkpoint_id
)
return
False
def
_split_by_size_and_type
(
bins
:
int
,
items
:
List
[
WriteItem
])
->
List
[
List
[
WriteItem
]]:
"""
Splits write items according to item size into close to uniform bins.
Same as torch.distributed.checkpoint.filesystem._split_by_size_and_type,
but with a fixed _item_size function.
Args:
bins (int): numbers of bins to split to
items (List[WriteItem]): list of write items
Returns (List[List[WriteItem]]): write items split to bins
"""
if
bins
==
1
:
return
[
items
]
bytes_items
:
List
[
WriteItem
]
=
[]
tensor_items
:
List
[
WriteItem
]
=
[]
for
wi
in
items
:
container
=
bytes_items
if
wi
.
type
==
WriteItemType
.
BYTE_IO
else
tensor_items
container
.
append
(
wi
)
buckets
:
List
[
List
[
WriteItem
]]
=
[[]
for
_
in
range
(
bins
)]
bucket_sizes
=
[
0
for
_
in
range
(
bins
)]
# Assign bytes with a simple round-robin
for
i
,
item
in
enumerate
(
bytes_items
):
buckets
[
i
%
bins
].
append
(
item
)
# Sort tensor items by size in decreasing order once and store the size with item
sized_tensors
=
[(
item
,
_item_size
(
item
))
for
item
in
tensor_items
]
sized_tensors
.
sort
(
key
=
itemgetter
(
1
),
reverse
=
True
)
# Use a min heap for bin assignment
# Store (total_size_of_bin, bin_index) tuples
heap
:
List
[
Tuple
[
int
,
int
]]
=
[(
0
,
i
)
for
i
in
range
(
bins
)]
# Assign tensors using heap
for
item
,
size
in
sized_tensors
:
total_bin_size
,
bin_idx
=
heappop
(
heap
)
buckets
[
bin_idx
].
append
(
item
)
heappush
(
heap
,
(
total_bin_size
+
size
,
bin_idx
))
return
buckets
def
_split_by_separation_hint
(
buckets
:
List
[
List
[
WriteItem
]],
separation_hint
:
Optional
[
str
]
=
None
)
->
Dict
[
str
,
List
[
List
[
WriteItem
]]]:
"""
Splits buckets into those whose keys begin with the separation_hint and those whose keys do not
Args:
buckets (List[List[WriteItem]]): buckets to split
separation_hint (Optional[str]): optional prefix to split on
Returns (Dict[str, List[List[WriteItem]]]): a dictionary
mapping the prefix to the relevant buckets
"""
bins
=
len
(
buckets
)
buckets_with_separation_hint
=
{}
if
separation_hint
is
not
None
:
buckets_default
=
[[]
for
_
in
range
(
bins
)]
buckets_hint
=
[[]
for
_
in
range
(
bins
)]
for
i
in
range
(
bins
):
for
item
in
buckets
[
i
]:
if
item
.
index
.
fqn
.
startswith
(
separation_hint
):
buckets_hint
[
i
].
append
(
item
)
else
:
buckets_default
[
i
].
append
(
item
)
buckets_with_separation_hint
[
""
]
=
buckets_default
buckets_with_separation_hint
[
separation_hint
]
=
buckets_hint
else
:
buckets_with_separation_hint
[
""
]
=
buckets
return
buckets_with_separation_hint
def
_item_size
(
item
:
WriteItem
)
->
int
:
"""
Calculates size (in bytes) of a single write item.
Same as torch.distributed.checkpoint.filesystem._item_size,
but fixes computing chunk size (with item.tensor_data.chunk.sizes)
Args:
item (WriteItem): write item to compute the size of
Returns (int): size of an item in bytes
"""
size
=
1
assert
item
.
tensor_data
is
not
None
# can't use math.prod as PT needs to support older python
for
s
in
item
.
tensor_data
.
chunk
.
sizes
:
size
*=
s
dtype
=
item
.
tensor_data
.
properties
.
dtype
return
size
*
torch
.
_utils
.
_element_size
(
dtype
)
def
_process_memory
()
->
int
:
"""
Get memory used by current process.
Returns (int): memory used by current process
"""
if
not
HAVE_PSUTIL
:
raise
RuntimeError
(
"psutil is not installed, please install it with `pip install psutil`"
)
process
=
psutil
.
Process
(
os
.
getpid
())
mem_info
=
process
.
memory_info
()
return
mem_info
.
rss
Megatron-LM/megatron/core/dist_checkpointing/strategies/fully_parallel.py
0 → 100644
View file @
1106877d
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import
logging
from
pathlib
import
Path
from
time
import
time
from
typing
import
Any
,
Callable
,
Dict
,
Optional
,
Tuple
,
TypeVar
import
torch
import
torch.distributed
as
dist
from
torch.distributed.checkpoint
import
Metadata
from
megatron.core.dist_checkpointing
import
ShardedObject
,
ShardedTensor
from
megatron.core.dist_checkpointing.core
import
CheckpointingException
from
megatron.core.dist_checkpointing.dict_utils
import
(
dict_list_map_inplace
,
extract_matching_values
,
merge
,
nested_values
,
)
from
megatron.core.dist_checkpointing.exchange_utils
import
(
ShardDistribution
,
determine_main_replica_uniform_distribution
,
exchange_by_distribution
,
exchange_loaded_objects_gather_object
,
)
from
megatron.core.dist_checkpointing.mapping
import
ShardedStateDict
,
StateDict
,
is_main_replica
from
megatron.core.dist_checkpointing.strategies.base
import
(
AsyncSaveShardedStrategy
,
LoadShardedStrategy
,
SaveShardedStrategy
,
)
from
megatron.core.dist_checkpointing.utils
import
(
_sharded_object_id
,
_sharded_tensor_shard_id
,
_ShardId
,
debug_time
,
)
from
megatron.core.dist_checkpointing.validation
import
(
determine_global_metadata
,
validate_sharding_integrity
,
)
from
megatron.core.utils
import
get_pg_rank
,
get_pg_size
logger
=
logging
.
getLogger
(
__name__
)
T
=
TypeVar
(
'T'
,
ShardedObject
,
ShardedTensor
)
class
FullyParallelSaveStrategyWrapper
(
AsyncSaveShardedStrategy
):
"""Wraps arbitrary strategy and distributes the save during `save`.
The save distribution happens without any *data* communication.
Only the *metadata* is exchanged and based on data replication on different
ranks, we try to distribute the save as uniformly as possible.
This wrapper assumes, that setting `replica_id` to 0 will make the
underlying strategy do the saving on current rank. All the other `replica_id`s
are set to 1.
Currently, the save distribution is realized with a greedy algorithm
described in `distribute_shards_to_ranks`.
Args:
strategy (SaveShardedStrategy): base strategy to wrap
parallelization_group (ProcessGroup, optional): process group to use for save
distribution. Note that this doesn't have to match exactly the
data distribution, but should cover the replication pattern
to maximize performance. Defaults to the whole world.
do_cache_distribution (bool, optional): whether to cache the save distribution
from previous calls. Should be set to True only if the state dict
structure between the calls is always the same. Defaults to True.
"""
def
__init__
(
self
,
strategy
:
SaveShardedStrategy
,
parallelization_group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
do_cache_distribution
:
bool
=
False
,
):
super
().
__init__
(
strategy
.
backend
,
strategy
.
version
)
self
.
base_strategy
=
strategy
if
parallelization_group
is
None
:
parallelization_group
=
torch
.
distributed
.
group
.
WORLD
self
.
parallelization_group
=
parallelization_group
self
.
do_cache_distribution
=
do_cache_distribution
self
.
cached_distribution
:
Optional
[
ShardDistribution
]
=
None
def
async_save
(
self
,
sharded_state_dict
:
ShardedStateDict
,
checkpoint_dir
:
Path
):
if
not
isinstance
(
self
.
base_strategy
,
AsyncSaveShardedStrategy
):
raise
CheckpointingException
(
f
'Cannot apply async_save to non-async base strategy
{
self
.
base_strategy
}
'
)
self
.
apply_saving_parallelization
(
sharded_state_dict
)
return
self
.
base_strategy
.
async_save
(
sharded_state_dict
,
checkpoint_dir
)
def
save
(
self
,
sharded_state_dict
:
ShardedStateDict
,
checkpoint_dir
:
Path
):
self
.
apply_saving_parallelization
(
sharded_state_dict
)
return
self
.
base_strategy
.
save
(
sharded_state_dict
,
checkpoint_dir
)
def
apply_saving_parallelization
(
self
,
sharded_state_dict
:
ShardedStateDict
)
->
None
:
"""Distributes the save across ranks by exchanging metadata.
Exchanges metadata from the state dict and computes the uniform
(as close as possible) distribution of saves among the ranks.
If `self.do_cache_distribution` is True, caches the distribution between
the calls and subsequent distributions happen without any inter-rank
communication.
Args:
sharded_state_dict (ShardedStateDict): state dict to distribute the saving
Returns: None
"""
start
=
time
()
if
self
.
do_cache_distribution
and
self
.
cached_distribution
is
not
None
:
logger
.
debug
(
f
'Apply *cached* save parallelization'
)
precomputed_distribution
=
self
.
cached_distribution
else
:
logger
.
debug
(
f
'Apply save parallelization'
)
precomputed_distribution
=
determine_main_replica_uniform_distribution
(
sharded_state_dict
,
self
.
parallelization_group
)
distribute_main_replicas_with_precomputed_distribution
(
sharded_state_dict
,
self
.
parallelization_group
,
precomputed_distribution
)
if
self
.
cached_distribution
is
None
:
# First time applying the parallelization
validate_sharding_integrity
(
determine_global_metadata
(
sharded_state_dict
)[
1
])
if
self
.
do_cache_distribution
:
self
.
cached_distribution
=
precomputed_distribution
end
=
time
()
logger
.
debug
(
f
"parallel save sharding, time:
{
end
-
start
}
"
)
@
property
def
can_handle_sharded_objects
(
self
):
return
self
.
base_strategy
.
can_handle_sharded_objects
class
FullyParallelLoadStrategyWrapper
(
LoadShardedStrategy
):
"""Wraps arbitrary load strategy and distributes the load during `load`.
See `load` method docs for details.
Args:
strategy (LoadShardedStrategy): base strategy to wrap
parallelization_group (ProcessGroup, optional): process group to use for load
distribution. Note that this doesn't have to match exactly the
data distribution, but should cover the replication pattern
to maximize performance. Defaults to the whole world.
In most cases, it's recommended to set it to the DP group.
do_cache_distribution (bool, optional): whether to cache the load distribution
from previous calls. Should be set to True only if the state dict
structure between the calls is always the same. Defaults to False,
since the loading in general happens only once during training.
Note that the load distribution *cannot* be reused as a save distribution,
because save/load is not fully symmetrical.
exchange_algo (str): algorithm to use for exchanging the data.
Options:
- broadcast - each rank broadcasts individual tensors to others
- gather_object (default) - ranks all_gather_object the whole loaded state dicts
- gather_rounds (default) - ranks all gather individual tensors in rounds
See method docs for more details.
"""
def
__init__
(
self
,
strategy
:
LoadShardedStrategy
,
parallelization_group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
do_cache_distribution
:
bool
=
False
,
exchange_algo
:
str
=
'broadcast'
,
):
super
().
__init__
()
self
.
base_strategy
=
strategy
if
parallelization_group
is
None
:
parallelization_group
=
(
dist
.
GroupMember
.
WORLD
)
# explicit group needed for torch.distributed.get_global_rank call
self
.
parallelization_group
=
parallelization_group
self
.
do_cache_distribution
=
do_cache_distribution
self
.
exchange_algo
=
exchange_algo
self
.
cached_distribution
:
Optional
[
ShardDistribution
]
=
None
self
.
cached_global_metadata
:
Optional
[
Metadata
]
=
None
@
debug_time
(
"FullyParallelLoadStrategyWrapper.load"
,
logger
)
def
load
(
self
,
sharded_state_dict
:
ShardedStateDict
,
checkpoint_dir
:
Path
)
->
StateDict
:
"""Distributes the load and calls underlying strategy only for parts of the state dict.
Steps:
1. Load metadata is exchanged between the ranks in the parallelization group.
2. Each rank deterministically plans the load for the whole workload
so that the loads are as uniform as possible.
3. Each ranks loads its planned shard of the checkpoint.
4. All ranks exchange the loaded shards.
Internode communication is involved in steps (1) (with metadata)
and (4) (with actual data). Storage interaction is involved in step (3).
Currently, the load distribution (step 2) is realized with a greedy algorithm
described in `distribute_shards_to_ranks` (same as for saving distribution).
Currently, the shards are all gathered between all ranks in the parallelization
group. This might not be optimal (some ranks do not need all tensors),
but it's a reasonable approximation for an optimal exchange in most scenarios.
Args:
sharded_state_dict (ShardedStateDict): sharded state dict to load
checkpoint_dir (Path): checkpoint directory to load from
Returns:
StateDict: loaded state dict. The state dict should be equivalent to
a state dict that would be loaded with the underlying strategy
without this wrapper.
"""
loaded_state_dict
=
{}
if
get_pg_size
(
self
.
parallelization_group
)
<=
1
:
return
self
.
base_strategy
.
load
(
sharded_state_dict
,
checkpoint_dir
)
# Step 1 and 2: exchange load metadata and distribute the load
with
debug_time
(
"self.apply_loading_parallelization"
,
logger
):
precomputed_distribution
:
ShardDistribution
|
None
=
self
.
apply_loading_parallelization
(
sharded_state_dict
)
assert
(
precomputed_distribution
is
not
None
),
'Expecting non-trivial distribution for non-trivial parallelization group'
# Step 3: load part of the checkpoint.
# Load only sharded objects first. ShardedTensors will be loaded separately
# so that we can keep track of sharded tensors loaded by this rank
(
sharded_tensors
,
sharded_state_dict
,
to_load_shards
,
unloaded_shards
)
=
(
self
.
_defer_loading_sharded_tensors
(
sharded_state_dict
)
)
(
sharded_objects
,
sharded_state_dict
,
to_load_objects
,
unloaded_objects
)
=
(
self
.
_defer_loading_sharded_objects
(
sharded_state_dict
)
)
assert
(
len
(
sharded_state_dict
)
==
0
),
"sharded_state_dict is not empty after deferring tensors and objects"
with
debug_time
(
"base_load_ShardedObjects"
,
logger
):
# Load sharded objects first
loaded_objects
=
self
.
base_strategy
.
load
(
to_load_objects
,
checkpoint_dir
)
with
debug_time
(
"base_load_ShardedTensors"
,
logger
):
# Load sharded tensors separately
loaded_tensors
=
self
.
base_strategy
.
load
(
to_load_shards
,
checkpoint_dir
)
with
debug_time
(
"self.exchange_loaded_tensors"
,
logger
):
# Step 4: exchange data between ranks
logger
.
debug
(
f
'Applying parallel load with algo
{
self
.
exchange_algo
}
'
)
all_loaded_tensors
=
exchange_by_distribution
(
loaded_tensors
,
unloaded_shards
,
precomputed_distribution
,
self
.
parallelization_group
,
self
.
exchange_algo
,
)
if
not
set
(
unloaded_shards
.
keys
()).
issubset
(
all_loaded_tensors
.
keys
()):
missing_shards
=
set
(
unloaded_shards
.
keys
())
-
all_loaded_tensors
.
keys
()
raise
CheckpointingException
(
f
'Missing shards after fully parallel loading:
{
missing_shards
}
'
)
with
debug_time
(
"torch.cuda.synchronize"
,
logger
):
torch
.
cuda
.
synchronize
()
all_loaded_objects
=
exchange_loaded_objects_gather_object
(
loaded_objects
)
if
not
set
(
unloaded_objects
.
keys
()).
issubset
(
all_loaded_objects
.
keys
()):
missing_object_shards
=
set
(
unloaded_objects
.
keys
())
-
all_loaded_objects
.
keys
()
raise
CheckpointingException
(
f
'Missing object shards after fully parallel loading:
{
missing_object_shards
}
'
)
torch
.
cuda
.
synchronize
()
self
.
fill_in_deferred_sharded_tensors
(
sharded_tensors
,
all_loaded_tensors
)
self
.
fill_in_deferred_sharded_objects
(
sharded_objects
,
all_loaded_objects
)
merge
(
loaded_state_dict
,
sharded_objects
)
merge
(
loaded_state_dict
,
sharded_tensors
)
if
hasattr
(
self
.
base_strategy
,
"cached_global_metadata"
):
self
.
cached_global_metadata
=
self
.
base_strategy
.
cached_global_metadata
return
loaded_state_dict
@
staticmethod
def
_defer_loading_sharded_objects
(
sharded_state_dict
:
ShardedStateDict
,
)
->
Tuple
[
ShardedStateDict
,
ShardedStateDict
,
Dict
[
_ShardId
,
ShardedObject
],
Dict
[
_ShardId
,
ShardedObject
],
]:
return
_defer_loading_sharded_items
(
sharded_state_dict
,
ShardedObject
,
_sharded_object_id
)
@
staticmethod
def
_defer_loading_sharded_tensors
(
sharded_state_dict
:
ShardedStateDict
,
)
->
Tuple
[
ShardedStateDict
,
ShardedStateDict
,
Dict
[
_ShardId
,
ShardedTensor
],
Dict
[
_ShardId
,
ShardedTensor
],
]:
return
_defer_loading_sharded_items
(
sharded_state_dict
,
ShardedTensor
,
_sharded_tensor_shard_id
)
@
staticmethod
def
fill_in_deferred_sharded_objects
(
sharded_state_dict
:
ShardedStateDict
,
loaded_objects
:
Dict
[
_ShardId
,
Any
]
)
->
None
:
"""Fill in objects not loaded by current rank with objects from `loaded_objects` map.
Args:
sharded_state_dict (ShardedStateDict): sharded state dict to fill in.
ShardedObjects are completely replaced with corresponding objects.
loaded_objects (Dict[_ShardId, Any]): dict allowing to map
ShardedObject from the sharded_state_dict to loaded objects.
Returns:
None
"""
_fill_in_deferred_sharded_items
(
sharded_state_dict
,
loaded_objects
,
ShardedObject
,
_sharded_object_id
)
@
staticmethod
def
fill_in_deferred_sharded_tensors
(
sharded_state_dict
:
ShardedStateDict
,
loaded_tensors
:
Dict
[
_ShardId
,
torch
.
Tensor
]
)
->
None
:
"""Fill in tensors not loaded by current rank with tensors from `loaded_tensors` map.
Args:
sharded_state_dict (ShardedStateDict): sharded state dict to fill in.
ShardedTensors are completely replaced with corresponding torch.Tensors.
loaded_tensors (Dict[_ShardId, torch.Tensor]): dict allowing to map
ShardedTensor from the sharded_state_dict to loaded tensors.
Returns:
None
"""
_fill_in_deferred_sharded_items
(
sharded_state_dict
,
loaded_tensors
,
ShardedTensor
,
_sharded_tensor_shard_id
)
def
apply_loading_parallelization
(
self
,
sharded_state_dict
:
ShardedStateDict
)
->
Optional
[
ShardDistribution
]:
"""Distributes the load across ranks by exchanging metadata.
Exchanges metadata from the state dict and computes the uniform
(as close as possible) distribution of loads among the ranks.
Marks ShardedTensors to be loaded by the current rank with replica_id 0
(and others with non 0 values).
If `self.do_cache_distribution` is True, caches the distribution between
the calls and subsequent distributions happen without any inter-rank
communication.
Args:
sharded_state_dict (ShardedStateDict): state dict to distribute the loading
Returns:
ShardDistribution (optional): the computed loading distribution
"""
if
self
.
do_cache_distribution
and
self
.
cached_distribution
is
not
None
:
logger
.
debug
(
f
'Apply *cached* load parallelization'
)
precomputed_distribution
=
self
.
cached_distribution
else
:
logger
.
debug
(
f
'Apply load parallelization'
)
precomputed_distribution
=
determine_main_replica_uniform_distribution
(
sharded_state_dict
,
self
.
parallelization_group
,
True
)
distribute_main_replicas_with_precomputed_distribution
(
sharded_state_dict
,
self
.
parallelization_group
,
precomputed_distribution
)
if
self
.
do_cache_distribution
:
self
.
cached_distribution
=
precomputed_distribution
return
precomputed_distribution
@
property
def
can_handle_sharded_objects
(
self
):
return
self
.
base_strategy
.
can_handle_sharded_objects
def
load_tensors_metadata
(
self
,
checkpoint_dir
:
Path
):
return
self
.
base_strategy
.
load_tensors_metadata
(
checkpoint_dir
)
def
load_sharded_metadata
(
self
,
checkpoint_dir
:
Path
):
return
self
.
base_strategy
.
load_sharded_metadata
(
checkpoint_dir
)
def
check_backend_compatibility
(
self
,
loaded_version
):
return
self
.
base_strategy
.
check_backend_compatibility
(
loaded_version
)
def
check_version_compatibility
(
self
,
loaded_version
):
return
self
.
base_strategy
.
check_version_compatibility
(
loaded_version
)
def
distribute_main_replicas_with_precomputed_distribution
(
sharded_state_dict
:
ShardedStateDict
,
parallelization_group
:
torch
.
distributed
.
ProcessGroup
,
precomputed_distribution
:
Optional
[
ShardDistribution
],
):
"""Applies the save distribution computed with `determine_main_replica_uniform_distribution`.
Based on rank assignment, sets replica ids of the shards saved by current rank to 0
and all the other replica ids to 1.
Args:
sharded_state_dict (ShardedStateDict): state dict to apply the save distribution to
parallelization_group (ProcessGroup): distribution will be applied within this
process group. Must match with the process group passed to
`determine_main_replica_uniform_distribution`.
precomputed_distribution (ShardDistribution): distribution computed with
`determine_main_replica_uniform_distribution`
Returns: None
Example replica ids of tensors A, B, C before distribution:
rank0: A: (0, 0, 0), B: (0, 0, 0), C: (0, 0, 0)
rank1: A: (0, 0, 1), B: (0, 0, 1), C: (0, 0, 1)
rank2: A: (0, 0, 2), B: (0, 0, 2), C: (0, 0, 2)
Replicas after distribution for the example above:
rank0: A: 0, B: 1, C: 1
rank1: A: 1, B: 0, C: 1
rank2: A: 1, B: 1, C: 0
"""
if
parallelization_group
is
None
:
parallelization_group
=
torch
.
distributed
.
group
.
WORLD
if
get_pg_size
(
group
=
parallelization_group
)
<=
1
:
return
if
precomputed_distribution
is
None
:
raise
ValueError
(
'precomputed_distribution must be not None for non-trivial parallelization group'
)
local_shards
=
list
(
sh_base
for
sh_base
in
nested_values
(
sharded_state_dict
)
if
isinstance
(
sh_base
,
ShardedTensor
)
)
rank_within_dp_group
=
get_pg_rank
(
group
=
parallelization_group
)
for
sh_ten
in
local_shards
:
shard_id
=
_sharded_tensor_shard_id
(
sh_ten
)
if
(
shard_id
in
precomputed_distribution
.
shards_in_this_group
and
rank_within_dp_group
==
precomputed_distribution
.
main_rank_for_shard
[
shard_id
]
):
sh_ten
.
replica_id
=
0
else
:
sh_ten
.
replica_id
=
1
def
_defer_loading_sharded_items
(
sharded_state_dict
:
ShardedStateDict
,
item_type
:
type
,
shard_id_func
:
Callable
[[
T
],
_ShardId
]
)
->
Tuple
[
ShardedStateDict
,
ShardedStateDict
,
Dict
[
_ShardId
,
T
],
Dict
[
_ShardId
,
T
]]:
"""Divides state dict into parts loaded by this vs other ranks.
Args:
sharded_state_dict (ShardedStateDict): state dict with sharded items
that will be divided.
item_type: The type of sharded item (ShardedObject or ShardedTensor)
shard_id_func: Function to get the shard ID for the item type
Returns: a tuple of:
- ShardedStateDict: sub-state dict only with sharded items
- ShardedStateDict: sub-state dict with non-sharded items
- Dict[_ShardId, T]: mapping from shard id to items loaded by *this* rank
- Dict[_ShardId, T]: mapping from shard id to items loaded by *other* ranks
"""
to_load_shards
=
{}
unloaded_shards
=
{}
sharded_items
,
remaining_state_dict
=
extract_matching_values
(
sharded_state_dict
,
lambda
v
:
isinstance
(
v
,
item_type
)
)
def
wrap_non_main_replicas
(
x
:
Any
)
->
Any
:
if
isinstance
(
x
,
item_type
):
shard_id
=
shard_id_func
(
x
)
if
is_main_replica
(
x
.
replica_id
):
to_load_shards
[
shard_id
]
=
x
else
:
unloaded_shards
[
shard_id
]
=
x
return
x
dict_list_map_inplace
(
wrap_non_main_replicas
,
sharded_items
)
return
sharded_items
,
remaining_state_dict
,
to_load_shards
,
unloaded_shards
def
_fill_in_deferred_sharded_items
(
sharded_state_dict
:
ShardedStateDict
,
loaded_items
:
Dict
[
_ShardId
,
Any
],
item_type
:
type
,
shard_id_func
:
Callable
[[
T
],
_ShardId
],
)
->
None
:
"""Helper function to fill in items not loaded by current rank."""
def
fill_in_sharded_item
(
x
:
Any
)
->
Any
:
if
isinstance
(
x
,
item_type
):
try
:
x
=
loaded_items
[
shard_id_func
(
x
)]
except
KeyError
as
e
:
raise
CheckpointingException
(
f
'Missing loaded item shard:
{
shard_id_func
(
x
)
}
'
)
from
e
return
x
dict_list_map_inplace
(
fill_in_sharded_item
,
sharded_state_dict
)
Megatron-LM/megatron/core/dist_checkpointing/strategies/resharding.py
0 → 100644
View file @
1106877d
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Performant resharding of flattened tensors.
Tensors that are first sharded (e.g. across TP) and then flattened cause
very irregular access patterns during loading. The idea for performant save/load
is to store tensors with global shape [X, Y, Z] and local shape [x, y, z]
as tensors with global shape [X // x, Y // y, Z // z, x * y * z] and
local shape [1, 1, 1, x * y * z]. This allows parallel save of tensors along the
last (flattened) dimension. During loading, some additional resharding is needed.
"""
import
logging
import
math
from
dataclasses
import
dataclass
from
itertools
import
product
from
typing
import
Any
,
Dict
,
Tuple
,
Union
import
numpy
as
np
import
torch
from
torch.distributed.checkpoint
import
ChunkStorageMetadata
from
torch.distributed.checkpoint.resharding
import
_shards_get_overlap_region_wrt_saved_tensor
from
megatron.core.dist_checkpointing
import
ShardedTensor
from
megatron.core.dist_checkpointing.core
import
CheckpointingException
from
megatron.core.dist_checkpointing.dict_utils
import
(
dict_list_map_inplace
,
extract_matching_values
,
)
from
megatron.core.dist_checkpointing.mapping
import
(
ShardedStateDict
,
ShardedTensorFactory
,
StateDict
,
apply_factories
,
apply_factory_merges
,
)
logger
=
logging
.
getLogger
(
__name__
)
@
dataclass
class
TensorReformulationMetadata
:
"""Metadata needed to restore the original tensor shape.
Args:
ckpt_orig_global_shape (Tuple[int, ...]): original global shape of the tensor
saved in the checkpoint. This is the global shape of the application,
further reformulated into `ckpt_reform_global_shape` while saving.
ckpt_reform_global_shape (Tuple[int, ...]): reformulated global shape of the tensor
saved in the checkpoint. This is the actual saved shape.
"""
ckpt_orig_global_shape
:
Tuple
[
int
,
...]
ckpt_reform_global_shape
:
Tuple
[
int
,
...]
def
__post_init__
(
self
):
assert
self
.
ckpt_orig_global_shape
def
nd_flattened_tensor_reformulated_global_shape
(
sh_ten
:
ShardedTensor
)
->
Tuple
[
int
,
...]:
"""Reformulated global shape of the flattened N-D ShardedTensor.
N-D tensor global shape [X, Y, Z] and local shape [x, y, z]
is reformulated into global shape [X // x, Y // y, Z // z, x * y * z] and
local shape [1, 1, 1, x * y * z], to allow parallel save of tensors along the
last (flattened) dimension.
Args:
sh_ten (ShardedTensor): flattened N-D ShardedTensor (N > 1)
Returns:
Tuple[int, ...]: reformulated tensor shape
"""
assert
is_nd_flattened_tensor
(
sh_ten
),
sh_ten
return
sh_ten
.
axis_fragmentations
+
(
int
(
np
.
prod
(
sh_ten
.
local_shape
)),)
def
is_nd_flattened_tensor
(
sh_ten
:
Any
)
->
bool
:
"""Checks if ShardedTensor is flattened and more than 1-dimensional
Args:
sh_ten (Any): any object
Returns:
bool: whether the given object is a flattened ShardedTensor and is N-dimensional (N > 1)
"""
return
isinstance
(
sh_ten
,
ShardedTensor
)
and
sh_ten
.
flattened_range
is
not
None
# information needed to restore. With current implementation, this is a nested state dict
# with ShardedTensorFactories which is basically a ShardedStateDict type
ReformulationRestoreMetadata
=
ShardedStateDict
def
apply_nd_flattened_tensors_reformulation
(
sharded_state_dict
:
ShardedStateDict
,
reformulation_metadata
:
Dict
[
str
,
TensorReformulationMetadata
],
)
->
Tuple
[
ShardedStateDict
,
ReformulationRestoreMetadata
]:
"""Applies N-D reformulation to a given sharded state dict.
After applying the method and loading the reformulated state dict,
the `restore_nd_flattened_tensors_formulation` needs to be applied.
Current implementation uses ShardedTensorFactories for convenience of
restoring the original structure, but it's just an implementation detail.
Turns N-D ShardedTensors into factories and immediately applies them,
keeping the data needed to restore the original structure.
Args:
sharded_state_dict (ShardedStateDict): sharded state dict potentially
with tensors to reformulate.
reformulation_metadata (Dict[str, TensorReformulationMetadata]): dict
containing all metadata needed for reformulating tensors in `sharded_state_dict`.
for each N-D flattened tensor `sh_ten` in `sharded_state_dict` there must be an
entry with `sh_ten.key`.
Returns:
tuple:
ShardedStateDict - reformulated sharded state dict
ReformulationRestoreMetadata - data needed to restore the original formulation
with `restore_nd_flattened_tensors_formulation`
"""
def
maybe_reformulate_nd_flattened_tensor
(
sh_ten
:
Any
):
if
not
isinstance
(
sh_ten
,
ShardedTensor
)
or
not
is_nd_flattened_tensor
(
sh_ten
):
return
sh_ten
# N-D flattened ShardedTensor
try
:
sh_ten_reformulation_metadata
=
reformulation_metadata
[
sh_ten
.
key
]
except
KeyError
as
e
:
# Handle legacy checkpointing where 1-D flatten tensor metadata was not saved
if
len
(
sh_ten
.
global_shape
)
==
1
:
return
sh_ten
raise
CheckpointingException
(
f
"Missing reformulation metadata for tensor
{
sh_ten
}
. "
f
"Existing keys:
{
reformulation_metadata
.
keys
()
}
"
)
from
e
ckpt_actual_saved_shape
=
sh_ten_reformulation_metadata
.
ckpt_reform_global_shape
app_actual_load_shape
=
nd_flattened_tensor_reformulated_global_shape
(
sh_ten
)
if
ckpt_actual_saved_shape
==
app_actual_load_shape
:
# Same shape - no need to reshard
return
sh_ten
return
reformulate_single_nd_flattened_tensor
(
sh_ten
,
sh_ten_reformulation_metadata
)
# Turn N-D tensors into factories and immediately apply them
dict_list_map_inplace
(
maybe_reformulate_nd_flattened_tensor
,
sharded_state_dict
)
sh_ten_factories
,
_
=
extract_matching_values
(
sharded_state_dict
,
lambda
x
:
isinstance
(
x
,
ShardedTensorFactory
),
return_lists_as_dicts
=
True
,
)
apply_factories
(
sharded_state_dict
)
# Unlink `data` pointers to free memory
def
unlink_data
(
x
):
x
.
data
=
None
return
x
dict_list_map_inplace
(
unlink_data
,
sh_ten_factories
)
return
sharded_state_dict
,
sh_ten_factories
def
restore_nd_flattened_tensors_formulation
(
state_dict
:
StateDict
,
formulation_restore_metadata
:
ReformulationRestoreMetadata
)
->
StateDict
:
"""Restores the original state dict from a reformulated form.
Inverse of `apply_nd_flattened_tensors_reformulation`.
Args:
state_dict (StateDict): state dict obtained by loading a reformulated
sharded state dict.
formulation_restore_metadata (ReformulationRestoreMetadata): metadata returned by
`apply_nd_flattened_tensors_reformulation` function
Returns:
StateDict: state dict with the original tensors formulation restored
"""
return
apply_factory_merges
(
state_dict
,
formulation_restore_metadata
)
def
reformulate_single_nd_flattened_tensor
(
sh_ten
:
ShardedTensor
,
reformulation_metadata
:
TensorReformulationMetadata
)
->
Union
[
Any
,
ShardedTensorFactory
]:
"""Reformulates shapes of a single N-D flattened ShardedTensor.
We need to define a pair of transformations:
- turn N-D ShardedTensor with original formulation into multiple reformulated ShardedTensors
- merge multiple reformulated loaded torch.Tensors into a single original tensor
Current implementation uses ShardedTensorFactories as a convenient mechanism
for specifying and keeping track of those transformations.
Args:
sh_ten (ShardedTensor): sharded tensor to reformulate.
reformulation_metadata (TensorReformulationMetadata): metadata needed to
perform the reformulation
Returns:
ShardedTensorFactory: factory that keeps information how to reformulate
(build) the ShardedTensor and then restore original formulation (merge)
after loading.
"""
rmd
=
reformulation_metadata
# Data won't be needed - remove unnecessary tensor references
sh_ten
=
sh_ten
.
without_data
()
# Based on reformulation_metadata, determine other tensor shapes and metadata
ckpt_axis_fragmentation
=
rmd
.
ckpt_reform_global_shape
[:
-
1
]
for
sh
,
fragm
in
zip
(
rmd
.
ckpt_orig_global_shape
,
ckpt_axis_fragmentation
):
assert
sh
%
fragm
==
0
,
(
sh_ten
,
rmd
.
ckpt_reform_global_shape
)
ckpt_local_shape_with_prepended_axis
=
tuple
(
sh
//
fragm
for
sh
,
fragm
in
zip
(
rmd
.
ckpt_orig_global_shape
,
ckpt_axis_fragmentation
)
)
assert
(
ckpt_local_shape_with_prepended_axis
[:
sh_ten
.
prepend_axis_num
]
==
(
1
,)
*
sh_ten
.
prepend_axis_num
),
(
ckpt_local_shape_with_prepended_axis
,
sh_ten
)
ckpt_local_shape
=
ckpt_local_shape_with_prepended_axis
[
sh_ten
.
prepend_axis_num
:]
# Iterate over reformulated shapes needed by the application and from checkpoint,
# and generate new ShardedTensors that match the checkpoint sharding.
overlap_dim_offsets
=
[]
assert
len
(
ckpt_axis_fragmentation
)
==
len
(
sh_ten
.
axis_fragmentations
),
(
ckpt_axis_fragmentation
,
sh_ten
,
)
for
dim
,
(
app_chunk_dim_offset
,
ckpt_fragm
,
app_fragm
)
in
enumerate
(
zip
(
sh_ten
.
local_chunk_offset_in_global
(),
ckpt_axis_fragmentation
,
sh_ten
.
axis_fragmentations
,
)
):
# without `int`, it's an exact offset of the app shard expressed in ckpt_local_shape units
first_overlap_dim_offset
=
int
(
ckpt_fragm
/
app_fragm
*
app_chunk_dim_offset
)
# `math.ceil` argument is an exact offset of the app next shard expressed
# in ckpt_local_shape units
next_overlap_dim_offset
=
math
.
ceil
(
ckpt_fragm
/
app_fragm
*
(
app_chunk_dim_offset
+
1
))
overlap_dim_offsets
.
append
(
range
(
first_overlap_dim_offset
,
next_overlap_dim_offset
))
logger
.
debug
(
f
"Generated the following number of overlap shards for each dimension: "
f
"
{
list
(
map
(
len
,
overlap_dim_offsets
))
}
for fragmentation ckpt "
f
"
{
ckpt_axis_fragmentation
}
vs app
{
sh_ten
.
axis_fragmentations
}
"
f
"and chunk offset
{
sh_ten
.
local_chunk_offset_in_global
()
}
"
)
reformulated_sh_tens
=
{}
for
chunk_offset
in
product
(
*
overlap_dim_offsets
):
global_offset
=
tuple
(
chunk_off
*
chunk_shape
for
chunk_off
,
chunk_shape
in
zip
(
chunk_offset
,
ckpt_local_shape_with_prepended_axis
)
)
reformulated_sh_tens
[(
global_offset
,
ckpt_local_shape
)]
=
ShardedTensor
(
sh_ten
.
key
,
None
,
sh_ten
.
dtype
,
ckpt_local_shape
,
rmd
.
ckpt_orig_global_shape
,
global_offset
,
ckpt_axis_fragmentation
,
sh_ten
.
replica_id
,
sh_ten
.
prepend_axis_num
,
sh_ten
.
allow_shape_mismatch
,
flattened_range
=
slice
(
0
,
rmd
.
ckpt_reform_global_shape
[
-
1
]),
# whole ckpt shard
)
# Now, we have to define the transformations from application sharding
# to checkpoint sharding.
@
torch
.
no_grad
()
def
sh_ten_build_fn
(
*
args
,
**
kwargs
):
# Here we simply return the precomputed tensors.
return
reformulated_sh_tens
@
torch
.
no_grad
()
def
sh_ten_merge_fn
(
sub_state_dict
):
# This is the non-flattened local tensor with original formulation
# that we are going to fill with shards loaded from the checkpoint.
app_non_flat_ten
=
torch
.
empty
(
sh_ten
.
local_shape
,
dtype
=
sh_ten
.
dtype
,
device
=
sh_ten
.
data
.
device
if
sh_ten
.
data
is
not
None
else
None
,
)
assert
len
(
sub_state_dict
)
>
0
for
(
ckpt_global_offset
,
ckpt_local_shape
),
ckpt_ten
in
sub_state_dict
.
items
():
# For each ckpt shard, we fill the appropriate application shard part
dest_ten
=
app_non_flat_ten
src_ten
=
ckpt_ten
.
view
(
ckpt_local_shape
)
# We don't need narrowing over `prepend_axis_num` axes so we take
# the [sh_ten.prepend_axis_num:] offsets slice
for
(
dim
,
offset_for_saved_tensor
,
offset_for_current_tensor
,
length
,
)
in
_shards_get_overlap_region_wrt_saved_tensor
(
saved_shard
=
ChunkStorageMetadata
(
ckpt_global_offset
[
sh_ten
.
prepend_axis_num
:],
ckpt_local_shape
),
current_shard
=
ChunkStorageMetadata
(
sh_ten
.
global_offset
[
sh_ten
.
prepend_axis_num
:],
sh_ten
.
local_shape
),
):
src_ten
=
src_ten
.
narrow
(
dim
,
offset_for_saved_tensor
,
length
)
dest_ten
=
dest_ten
.
narrow
(
dim
,
offset_for_current_tensor
,
length
)
dest_ten
.
copy_
(
src_ten
)
return
app_non_flat_ten
.
flatten
()[
sh_ten
.
flattened_range
]
return
ShardedTensorFactory
(
sh_ten
.
key
,
sh_ten
.
data
,
sh_ten_build_fn
,
sh_ten_merge_fn
,
sh_ten
.
replica_id
,
sh_ten
.
flattened_range
,
)
Megatron-LM/megatron/core/dist_checkpointing/strategies/state_dict_saver.py
0 → 100644
View file @
1106877d
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
""" State dict saver for PyT Distributed format allowing asynchronous save. """
from
logging
import
getLogger
from
time
import
time
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch.distributed
as
dist
from
torch.distributed.checkpoint
import
CheckpointException
from
torch.distributed.checkpoint.default_planner
import
DefaultSavePlanner
from
torch.distributed.checkpoint.metadata
import
STATE_DICT_TYPE
,
Metadata
from
torch.distributed.checkpoint.planner
import
SavePlan
,
SavePlanner
from
torch.distributed.checkpoint.utils
import
_DistWrapper
,
_get_failure_dict
if
TYPE_CHECKING
:
from
.filesystem_async
import
FileSystemWriterAsync
from
.torch
import
MCoreSavePlanner
logger
=
getLogger
(
__name__
)
from
dataclasses
import
fields
def
_compare_dataclasses
(
obj1
,
obj2
):
if
type
(
obj1
)
!=
type
(
obj2
):
return
f
"Objects are of different types:
{
type
(
obj1
)
}
and
{
type
(
obj2
)
}
"
differences
=
[]
for
field
in
fields
(
obj1
):
value1
=
getattr
(
obj1
,
field
.
name
)
value2
=
getattr
(
obj2
,
field
.
name
)
if
value1
!=
value2
:
differences
.
append
(
f
"
{
field
.
name
}
:
{
value1
}
!=
{
value2
}
"
)
return
differences
if
differences
else
"All fields are equal"
def
save_state_dict_async_plan
(
state_dict
:
STATE_DICT_TYPE
,
storage_writer
:
'FileSystemWriterAsync'
,
process_group
:
Optional
[
dist
.
ProcessGroup
]
=
None
,
coordinator_rank
:
int
=
0
,
planner
:
Optional
[
Union
[
SavePlanner
,
'MCoreSavePlanner'
]]
=
None
,
cached_ckpt_structure
:
Optional
[
Tuple
[
SavePlan
,
SavePlan
,
bool
]]
=
None
,
loaded_all_plans
:
Optional
[
List
[
SavePlan
]]
=
None
,
)
->
Tuple
[
Tuple
[
'FileSystemWriterAsync'
,
Union
[
Metadata
,
None
],
_DistWrapper
],
SavePlan
,
bool
]:
"""
First stage of saving a state dict to storage.
This is an async adjustment of torch.distributed.checkpoint.state_dict_saver.
In order to support async save, saving should be split into three parts:
1. Planning
2. Actual saving
3. Finalization
Out of these, step (2) *must* happen asynchronously.
The first step is realized with this function.
The planning part consists of several steps, described here:
https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.SavePlanner
Args:
state_dict (STATE_DICT_TYPE): state dict to save
storage_writer (FileSystemWriterAsync): in current version only an instance of
FileSystemWriterAsync
process_group (dist.ProcessGroup, optional): process group used for save planning
coordinator_rank (int, optional): coordinator rank for planning. Defaults to 0.
planner (SavePlanner, optional): save planner for torch.distributed.checkpoint format
cached_ckpt_structure (Tuple[SavePlan, SavePlan, bool], Optional):
Each object of this tuple will be used in the order as following
cached_central_plan (SavePlan): a globally coordinated save plan
cached in the previous iteration
cached_local_plan (SavePlan): a local plan
cached in the previous iteration
validated_cache_reuse (bool): boolean value to tell global_metadata and planning dict
is consistent over iterations
Returns: Tuple of:
- storage writer (the one passed as input)
- metadata from planning (or None if we reuse cached global metadata)
- distributed wrapper used for planning
The return value of this function should be passed as an input to
`save_state_dict_async_finalize` and cached_plan to skip `reduce_scatter` at planning.
"""
cached_central_plan
,
cached_local_plan
,
validated_cache_reuse
=
(
None
,
None
,
False
)
if
cached_ckpt_structure
:
cached_central_plan
,
cached_local_plan
,
validated_cache_reuse
=
cached_ckpt_structure
rank
=
torch
.
distributed
.
get_rank
()
if
torch
.
distributed
.
is_initialized
()
else
0
dist_wrapper
=
_DistWrapper
(
process_group
,
True
,
coordinator_rank
)
if
planner
is
None
:
planner
=
DefaultSavePlanner
()
assert
planner
is
not
None
global_metadata
=
None
logger
.
debug
(
f
"rank:
{
rank
}
, starting state dict save"
)
local_plan
=
cached_local_plan
global_md_verify_reuse
=
False
def
local_step
():
nonlocal
local_plan
assert
planner
is
not
None
# PyTorch 2.4 introduced additional `metadata` argument,
# we have to reference `is_coordinator` args by name
planner
.
set_up_planner
(
state_dict
,
is_coordinator
=
dist_wrapper
.
is_coordinator
)
storage_writer
.
set_up_storage_writer
(
dist_wrapper
.
is_coordinator
)
if
not
validated_cache_reuse
and
local_plan
is
None
:
local_plan
=
planner
.
create_local_plan
()
local_plan
=
storage_writer
.
prepare_local_plan
(
local_plan
)
return
local_plan
def
global_step
(
all_local_plans
):
nonlocal
global_metadata
assert
planner
is
not
None
all_local_plans
,
global_metadata
=
planner
.
create_global_plan
(
all_local_plans
)
all_local_plans
=
storage_writer
.
prepare_global_plan
(
all_local_plans
)
return
all_local_plans
# Execute local and global planning
# Ideally we want to use the cached plan. Otherwise if the planner and storage_writer
# allow it (`can_run_decentralized_global_plan`) we gather the plans to create
# the metadata but prepare the plans independently on each rank.
# In the worst case we have to reduce_scatter all the plans.
start_plan
=
time
()
if
validated_cache_reuse
and
cached_central_plan
:
logger
.
debug
(
f
"rank:
{
rank
}
, Passed cache reusable"
)
local_step
()
central_plan
=
cached_central_plan
elif
getattr
(
planner
,
'can_run_decentralized_global_plan'
,
False
)
and
getattr
(
storage_writer
,
'can_run_decentralized_global_plan'
,
False
):
local_plan
=
local_step
()
global_md_verify_reuse
=
verify_global_md_reuse
(
loaded_all_plans
,
local_plan
,
rank
,
dist_wrapper
)
if
not
loaded_all_plans
or
not
global_md_verify_reuse
:
all_local_plans
=
dist_wrapper
.
gather_object
(
local_plan
)
if
dist_wrapper
.
is_coordinator
:
_
,
global_metadata
=
planner
.
create_global_plan
(
all_local_plans
)
global_metadata
.
all_local_plans
=
all_local_plans
else
:
logger
.
debug
(
f
"rank:
{
rank
}
, Passed cached global metadata"
)
global_metadata
=
None
local_plan
=
planner
.
create_decentralized_global_plan
(
local_plan
)
local_plan
=
storage_writer
.
prepare_decentralized_global_plan
(
local_plan
)
central_plan
=
local_plan
else
:
central_plan
=
dist_wrapper
.
reduce_scatter
(
"plan"
,
local_step
,
global_step
)
central_plan
=
planner
.
finish_plan
(
central_plan
)
end_plan
=
time
()
logger
.
debug
(
f
"rank:
{
rank
}
, plan time:
{
end_plan
-
start_plan
}
"
)
# Prepare async writing of tensors.
# The `storage_writer` will store the information about tensors it needs to save
start
=
time
()
storage_writer
.
prepare_write_data
(
central_plan
,
planner
)
end
=
time
()
logger
.
debug
(
f
"
{
time
()
}
rank:
{
rank
}
, write(async) time:
{
end
-
start
}
"
)
return
(
(
storage_writer
,
global_metadata
,
dist_wrapper
),
central_plan
,
local_plan
,
cached_central_plan
==
central_plan
,
global_md_verify_reuse
,
)
def
verify_global_md_reuse
(
loaded_all_plans
:
List
[
SavePlan
],
local_plan
:
SavePlan
,
rank
:
int
,
dist_wrapper
:
_DistWrapper
)
->
bool
:
"""
Verifies that global metadata reuse is possible by checking the loaded plans from the
checkpoint are consistent, which means we have the same settings when resuming training.
Args:
loaded_all_plans: List[SavePlan], The loaded plans from the checkpoint
(stored in checkpoint metadata).
local_plan: SavePlan, The local save plan.
rank: Current process rank.
dist_wrapper (_DistWrapper): distributed wrapper created during planning
Returns: True iff the global metadata reuse is possible.
"""
logger
.
debug
(
f
"verifying reuse of global metadata"
)
if
not
loaded_all_plans
:
global_md_verify_reuse
=
False
logger
.
debug
(
"loaded global metadata reuse verification: no loaded plans passed"
)
elif
len
(
loaded_all_plans
)
==
dist_wrapper
.
get_world_size
():
local_verify_reuse
=
all
(
getattr
(
local_plan
,
f
.
name
)
==
getattr
(
loaded_all_plans
[
rank
],
f
.
name
)
for
f
in
fields
(
local_plan
)
if
f
.
name
!=
'storage_data'
)
if
not
local_verify_reuse
:
logger
.
debug
(
f
"local_verify_reuse is False: diffs -"
f
"
{
_compare_dataclasses
(
local_plan
,
loaded_all_plans
[
rank
])
}
"
)
all_results
=
torch
.
tensor
([
local_verify_reuse
],
dtype
=
torch
.
int
,
device
=
'cuda'
)
torch
.
distributed
.
all_reduce
(
all_results
,
op
=
torch
.
distributed
.
ReduceOp
.
MIN
)
# Check if all reduced results are True
global_md_verify_reuse
=
all_results
.
item
()
==
1
else
:
global_md_verify_reuse
=
False
return
global_md_verify_reuse
def
save_state_dict_async_finalize
(
storage_writer
:
'FileSystemWriterAsync'
,
global_metadata
:
Metadata
,
dist_wrapper
:
_DistWrapper
)
->
None
:
"""
Finalization of save_state_dict_async_plan.
The input arguments are the same as the save_state_dict_async_plan output,
the `write_results` are retrieved from the storage_writer.
Args:
storage_writer (FileSystemWriterAsync): storage writer used for planning
global_metadata (Metadata): metadata created during planning
dist_wrapper (_DistWrapper): distributed wrapper created during planning
Returns: None
"""
write_results
=
storage_writer
.
retrieve_write_results
()
# Gather the write results that will be saved to the metadata file.
gather_start
=
time
()
all_results
=
dist_wrapper
.
gather_object
(
write_results
)
gather_end
=
time
()
logger
.
debug
(
f
"
{
gather_end
}
,
{
torch
.
distributed
.
get_rank
()
}
, gather:
{
gather_end
-
gather_start
}
"
)
# Store the metadata on coordinator rank
if
dist_wrapper
.
is_coordinator
:
node_failures
=
_get_failure_dict
(
all_results
)
if
len
(
node_failures
)
==
0
:
assert
global_metadata
is
not
None
write_start
=
time
()
storage_writer
.
finish
(
global_metadata
,
all_results
)
write_end
=
time
()
logger
.
debug
(
f
"
{
write_end
}
, metadata_write:
{
write_end
-
write_start
}
"
)
else
:
raise
CheckpointException
(
"write"
,
node_failures
)
Megatron-LM/megatron/core/dist_checkpointing/strategies/tensorstore.py
0 → 100644
View file @
1106877d
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
"""Strategies using TensorStore to load and save Zarr arrays."""
from
functools
import
partial
from
itertools
import
starmap
from
logging
import
getLogger
from
pathlib
import
Path
from
typing
import
Union
import
torch
from
..core
import
CheckpointingException
from
..dict_utils
import
dict_list_map_inplace
from
..mapping
import
ShardedStateDict
,
ShardedTensor
from
.base
import
LoadShardedStrategy
,
StrategyAction
,
register_default_strategy
from
.zarr
import
load_zarr_based_sharded_metadata
,
postprocess_numpy_array
try
:
import
tensorstore
as
ts
HAVE_TENSORSTORE
=
True
except
ImportError
:
from
unittest.mock
import
MagicMock
ts
=
MagicMock
()
HAVE_TENSORSTORE
=
False
logger
=
getLogger
(
__name__
)
def
register_default_tensorstore_strategies
():
"""Register default strategies leveraging tensorstore."""
register_default_strategy
(
StrategyAction
.
LOAD_SHARDED
,
"zarr"
,
1
,
TensorStoreLoadShardedStrategy
()
)
class
TensorStoreLoadShardedStrategy
(
LoadShardedStrategy
):
"""Load strategy for Zarr backend using `tensorstore` for loading."""
def
__init__
(
self
,
load_directly_on_device
:
bool
=
False
):
super
().
__init__
()
self
.
load_directly_on_device
=
load_directly_on_device
def
load
(
self
,
sharded_state_dict
:
ShardedStateDict
,
checkpoint_dir
:
Union
[
str
,
Path
]):
if
isinstance
(
checkpoint_dir
,
str
):
checkpoint_dir
=
Path
(
checkpoint_dir
)
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
f
"Loading distributed checkpoint with
{
self
.
__class__
.
__name__
}
"
)
if
self
.
load_directly_on_device
:
print
(
f
"Loading distributed checkpoint directly on the GPU"
)
load_fn
=
partial
(
_load_from_array
,
checkpoint_dir
=
checkpoint_dir
,
load_directly_on_device
=
self
.
load_directly_on_device
,
)
dict_list_map_inplace
(
load_fn
,
sharded_state_dict
)
return
sharded_state_dict
def
load_tensors_metadata
(
self
,
checkpoint_dir
:
Union
[
str
,
Path
]):
if
isinstance
(
checkpoint_dir
,
str
):
checkpoint_dir
=
Path
(
checkpoint_dir
)
def
get_ts_shape_dtype
(
path
):
arr
=
open_ts_array
(
path
)
return
arr
.
shape
,
arr
.
dtype
.
numpy_dtype
return
load_zarr_based_sharded_metadata
(
checkpoint_dir
,
get_ts_shape_dtype
)
def
check_backend_compatibility
(
self
,
loaded_version
):
pass
# TODO
def
check_version_compatibility
(
self
,
loaded_version
):
pass
# TODO
def
merge_global_slice_with_shape
(
global_slice
,
actual_shape
,
key
):
"""Intersects the global slice with the actual shape (prevent overflow)."""
def
_merge_slice
(
dim_slice
,
dim_size
):
if
isinstance
(
dim_slice
,
slice
):
assert
(
dim_slice
.
start
<
dim_size
),
f
"Got empty slice for ShardedTensor
{
key
}
(
{
dim_slice
}
,
{
dim_size
}
)"
if
dim_slice
.
stop
>
dim_size
:
dim_slice
=
slice
(
dim_slice
.
start
,
dim_size
,
dim_slice
.
step
)
return
dim_slice
assert
len
(
global_slice
)
==
len
(
actual_shape
),
(
global_slice
,
actual_shape
,
key
)
return
tuple
(
starmap
(
_merge_slice
,
zip
(
global_slice
,
actual_shape
)))
def
_load_from_array
(
sharded_tensor
:
ShardedTensor
,
checkpoint_dir
:
Path
,
load_directly_on_device
:
bool
=
False
,
apply_flattened_range
:
bool
=
True
,
):
x
=
_load_regular_chunk
(
sharded_tensor
,
checkpoint_dir
)
ten
=
postprocess_numpy_array
(
x
,
sharded_tensor
,
apply_flattened_range
)
if
load_directly_on_device
:
sharded_tensor
.
data
.
data
.
copy_
(
ten
)
return
sharded_tensor
.
data
else
:
return
ten
def
_load_regular_chunk
(
sharded_tensor
:
ShardedTensor
,
checkpoint_dir
:
Path
):
assert
isinstance
(
sharded_tensor
,
ShardedTensor
),
type
(
sharded_tensor
)
arr
=
open_ts_array
(
checkpoint_dir
/
sharded_tensor
.
key
)
if
sharded_tensor
.
global_shape
==
arr
.
shape
:
x
=
(
arr
[
sharded_tensor
.
global_slice
()].
read
().
result
()
)
# flattened tensors loading is delayed
elif
sharded_tensor
.
allow_shape_mismatch
:
global_slice
=
merge_global_slice_with_shape
(
sharded_tensor
.
global_slice
(),
arr
.
shape
,
sharded_tensor
.
key
)
x
=
arr
[
global_slice
].
read
().
result
()
# flattened tensors loading is delayed
else
:
_msg
=
(
f
"Global shape mismatch for loaded (
{
arr
.
shape
}
)"
f
" and expected (
{
sharded_tensor
.
global_shape
}
) tensor"
f
" for key
{
sharded_tensor
.
key
}
"
)
raise
CheckpointingException
(
_msg
)
return
x
def
open_ts_array
(
arr_path
:
Path
):
"""Opens a Zarr file array with Tensorstore with basic setting.
Args:
arr_path (Path): path to a Zarr (Tensorstore) array
"""
if
not
HAVE_TENSORSTORE
:
raise
RuntimeError
(
"tensorstore is required, please install it with `pip install tensorstore`"
)
spec
=
{
"driver"
:
"zarr"
,
"metadata_key"
:
".zarray"
,
"kvstore"
:
{}}
spec
[
"kvstore"
]
=
{
"driver"
:
"file"
,
"path"
:
str
(
arr_path
)}
try
:
arr
=
ts
.
open
(
ts
.
Spec
(
spec
),
open
=
True
).
result
()
except
Exception
as
e
:
raise
CheckpointingException
(
f
"Array
{
arr_path
}
could not be loaded. Error:
{
e
}
"
)
from
e
return
arr
Megatron-LM/megatron/core/dist_checkpointing/strategies/torch.py
0 → 100644
View file @
1106877d
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Strategies using PyTorch distributed.checkpoint as an underlying format. """
import
io
import
os
import
pickle
import
warnings
from
collections
import
ChainMap
,
defaultdict
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
from
itertools
import
product
from
logging
import
getLogger
from
pathlib
import
Path
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
,
cast
import
torch
from
packaging.version
import
Version
as
PkgVersion
from
torch.distributed
import
checkpoint
from
torch.distributed._shard.metadata
import
ShardMetadata
from
torch.distributed._shard.sharded_tensor
import
Shard
from
torch.distributed._shard.sharded_tensor
import
ShardedTensor
as
TorchShardedTensor
from
torch.distributed._shard.sharded_tensor
import
ShardedTensorMetadata
,
TensorProperties
from
torch.distributed.checkpoint
import
(
BytesStorageMetadata
,
DefaultLoadPlanner
,
DefaultSavePlanner
,
FileSystemReader
,
FileSystemWriter
,
LoadPlan
,
Metadata
,
ReadItem
,
SavePlan
,
TensorStorageMetadata
,
WriteItem
,
)
from
torch.distributed.checkpoint._nested_dict
import
FLATTEN_MAPPING
,
unflatten_state_dict
from
torch.distributed.checkpoint._traverse
import
OBJ_PATH
,
traverse_state_dict
from
torch.distributed.checkpoint.metadata
import
Metadata
from
torch.distributed.checkpoint.planner_helpers
import
_create_write_items
from
...utils
import
get_torch_version
,
is_torch_min_version
from
..core
import
CheckpointingException
from
..dict_utils
import
nested_values
from
..mapping
import
(
ShardedBase
,
ShardedObject
,
ShardedStateDict
,
ShardedTensor
,
StateDict
,
is_main_replica
,
)
from
.async_utils
import
AsyncRequest
from
.base
import
(
AsyncSaveShardedStrategy
,
LoadShardedStrategy
,
StrategyAction
,
register_default_strategy
,
)
from
.cached_metadata_filesystem_reader
import
CachedMetadataFileSystemReader
from
.filesystem_async
import
FileSystemWriterAsync
from
.resharding
import
(
TensorReformulationMetadata
,
apply_nd_flattened_tensors_reformulation
,
is_nd_flattened_tensor
,
nd_flattened_tensor_reformulated_global_shape
,
restore_nd_flattened_tensors_formulation
,
)
from
.state_dict_saver
import
save_state_dict_async_finalize
,
save_state_dict_async_plan
try
:
if
not
torch
.
cuda
.
is_available
():
raise
ImportError
from
transformer_engine.pytorch.float8_tensor
import
Float8Tensor
HAVE_TE
=
True
except
ImportError
:
HAVE_TE
=
False
try
:
from
torch.distributed._tensor
import
DTensor
HAVE_DTENSOR
=
True
except
ImportError
:
HAVE_DTENSOR
=
False
from
megatron.core.msc_utils
import
MultiStorageClientFeature
MSC_PREFIX
=
"msc://"
_metadata_fn
:
str
=
".metadata"
def
register_default_torch_strategies
():
"""Register default strategies related to PyT Distributed backend."""
register_default_strategy
(
StrategyAction
.
LOAD_SHARDED
,
'torch_dist'
,
1
,
TorchDistLoadShardedStrategy
()
)
register_default_strategy
(
StrategyAction
.
SAVE_SHARDED
,
'torch_dist'
,
1
,
TorchDistSaveShardedStrategy
(
'torch_dist'
,
1
)
)
logger
=
getLogger
(
__name__
)
def
flatten_state_dict
(
state_dict
:
ShardedStateDict
,
)
->
Tuple
[
ShardedStateDict
,
Dict
[
str
,
OBJ_PATH
]]:
"""Flattens state dict into a single level dict.
It's a copy of torch.distributed.checkpoint._nested_dict.flatten_state_dict
which also accepts ShardedBase tensors as terminal objects
Args:
state_dict (ShardedStateDict): state dict to be flattened
Returns (tuple): flattened state dict and a mapping allowing to recreate the original one
"""
flattened
=
{}
mappings
=
{}
def
flat_copy
(
path
:
OBJ_PATH
,
value
:
Any
)
->
None
:
new_fqn
=
"."
.
join
(
map
(
str
,
path
))
if
new_fqn
in
flattened
:
raise
ValueError
(
f
"duplicated flatten key
{
new_fqn
}
"
)
flattened
[
new_fqn
]
=
value
mappings
[
new_fqn
]
=
path
traverse_state_dict
(
state_dict
,
flat_copy
,
lambda
x
:
isinstance
(
x
,
(
torch
.
Tensor
,
ShardedBase
)))
return
flattened
,
mappings
def
sharded_tensor_to_torch_sharded_tensor
(
sh_tens
:
List
[
ShardedTensor
],
rank
:
Optional
[
int
]
=
None
,
load_legacy_1d_flatten_tensors
:
bool
=
False
,
)
->
TorchShardedTensor
:
"""Convert MCore ShardedTensor to PyT ShardedTensor. PyT requires information about all chunks.
On high-level, this function follows the logic of
torch.distributed.fsdp._shard_utils._create_chunk_sharded_tensor.
Additionally, it saves `prepend_axis_num` and `has_flattened_range` (specific to MCore)
as attributes for further restoration in `_unwrap_pyt_sharded_tensor`.
NOTE: this function assumes regular (grid) sharding of the MCore ShardedTensor.
The only local irregularities could be introduced with a `flattened_range` attribute.
This function handles 2 different type of ShardedTensors:
1. Non-flat regular ShardedTensors (`not has_flattened_range`)
2. N-D flattened ShardedTensors (`has_flattened_range`)
(1) type are saved according to their original shape.
Type (2) however requires global shape adjustment for efficiency:
we treat [X, Y, Z] global shape tensor with local shape [x, y, z]
as a [X // x, Y // y, Z // z, x * y * z] tensor with last axis
partitioned according to `flattened_range` slices.
This will need special handling while resharding.
Args:
sh_tens (List[ShardedTensor]): list of sharded tensors to convert
rank (int, optional): current process rank passed to PyT ShardedTensor.
If None, assumes rank in the default pg.
load_legacy_1d_flatten_tensors (bool, optional): flag indicating if 1-D flattened tensors
should be loaded in a legacy way. Defaults to False.
Returns (TorchShardedTensor): PyT ShardedTensor containing all passed shards.
"""
if
rank
is
None
:
rank
=
torch
.
distributed
.
get_rank
()
some_sh_ten
=
sh_tens
[
0
]
has_flattened_range
=
some_sh_ten
.
flattened_range
is
not
None
for
sh_ten
in
sh_tens
:
assert
(
sh_ten
.
flattened_range
is
not
None
)
==
has_flattened_range
,
sh_tens
if
not
sh_ten
.
data
.
is_contiguous
():
sh_ten
.
data
=
sh_ten
.
data
.
contiguous
()
if
load_legacy_1d_flatten_tensors
and
len
(
some_sh_ten
.
global_shape
)
==
1
:
# Legacy 1-D flattened tensors are loaded as non-flat regular ShardedTensors
has_flattened_range
=
False
local_global_offsets
=
{}
prepend_axis_num
=
sh_tens
[
0
].
prepend_axis_num
# Determine local shards according to tensor type (see docs)
if
has_flattened_range
:
# Type (3) case: N-D flattened ShardedTensors
for
sh_ten
in
sh_tens
:
local_global_offsets
.
setdefault
(
sh_ten
.
local_chunk_offset_in_global
(),
[]).
append
(
sh_ten
)
assert
sh_ten
.
data
.
ndim
==
1
,
sh_ten
sh_ten
.
data
=
sh_ten
.
data
.
view
((
1
,)
*
len
(
sh_ten
.
global_shape
)
+
(
-
1
,))
# Global shape reformulation:
global_shape
=
nd_flattened_tensor_reformulated_global_shape
(
some_sh_ten
)
offsets_shape
=
(
1
,)
*
len
(
some_sh_ten
.
global_shape
)
# reformulated global shape has shape equal ti number of local chunks
local_shards
=
[
Shard
.
from_tensor_and_offsets
(
sh_ten
.
data
,
list
(
sh_ten
.
local_chunk_offset_in_global
()
+
(
sh_ten
.
flattened_range
.
start
,)
),
# additional flattened offset
rank
,
)
for
sh_ten
in
sh_tens
]
else
:
# Type (1) case: non-flat regular ShardedTensors
for
sh_ten
in
sh_tens
:
local_global_offsets
.
setdefault
(
sh_ten
.
global_offset
,
[]).
append
(
sh_ten
)
sh_ten
.
data
=
sh_ten
.
data
.
view
(
(
1
,)
*
prepend_axis_num
+
sh_ten
.
local_shape
)
# adjust to prepended_axis_num
global_shape
=
some_sh_ten
.
global_shape
offsets_shape
=
some_sh_ten
.
data
.
shape
# includes prepended axes
local_shards
=
[
Shard
.
from_tensor_and_offsets
(
sh_ten
.
data
,
list
(
sh_ten
.
global_offset
),
rank
# simple case
)
for
sh_ten
in
sh_tens
]
# Create a ShardedTensor without invoking communication. Determine global shards
world_size
=
torch
.
distributed
.
get_world_size
()
shard_metadata
=
[]
# NOTE: here we assume a regular grid of shards
for
fragment_offsets
in
product
(
*
map
(
range
,
some_sh_ten
.
axis_fragmentations
)):
offset
=
tuple
(
map
(
lambda
x
:
x
[
0
]
*
x
[
1
],
zip
(
fragment_offsets
,
offsets_shape
)))
if
offset
in
local_global_offsets
:
# local shard
placement
=
f
"rank:
{
rank
}
/cuda"
for
sh_ten
in
local_global_offsets
[
offset
]:
if
has_flattened_range
:
assert
offset
==
sh_ten
.
local_chunk_offset_in_global
()
# This is not an actual offset, but an offset of the whole shard
# This is needed for a PyT Dist internal integrity check
offset
=
sh_ten
.
local_chunk_offset_in_global
()
+
(
0
,)
size
=
(
1
,)
*
len
(
offsets_shape
)
+
global_shape
[
-
1
:]
else
:
size
=
sh_ten
.
data
.
shape
shard_metadata
.
append
(
ShardMetadata
(
offset
,
size
,
placement
))
else
:
# pylint: disable=line-too-long
# for shards from other ranks we provide simplistic data - this information will be discarded
# during TorchShardedTensor._init_from_local_shards_and_global_metadata call.
# Due to a bug in PyT 24.05 container we must specify some concrete rank within a world size.
# The exact rank doesn't matter as long as it's different than my rank - hence (rank + 1) % WS.
placement
=
f
"rank:
{
(
rank
+
1
)
%
world_size
}
/cuda"
if
has_flattened_range
:
offset
=
offset
+
(
0
,)
size
=
(
1
,)
*
len
(
offsets_shape
)
+
global_shape
[
-
1
:]
else
:
size
=
offsets_shape
shard_metadata
.
append
(
ShardMetadata
(
offset
,
size
,
placement
))
tensor
=
some_sh_ten
.
data
sharded_tensor_metadata
=
ShardedTensorMetadata
(
shards_metadata
=
shard_metadata
,
size
=
torch
.
Size
(
global_shape
),
tensor_properties
=
TensorProperties
(
dtype
=
tensor
.
dtype
,
layout
=
tensor
.
layout
,
requires_grad
=
tensor
.
requires_grad
,
memory_format
=
torch
.
contiguous_format
,
pin_memory
=
tensor
.
is_pinned
(),
),
)
pyt_sh_ten
=
TorchShardedTensor
.
_init_from_local_shards_and_global_metadata
(
local_shards
,
sharded_tensor_metadata
=
sharded_tensor_metadata
,
process_group
=
None
)
# Store MCore related data as PyTShardedTensor attribute.
# This won't be stored in the checkpoint, only for runtime purposes
pyt_sh_ten
.
mcore_sh_ten
=
sh_ten
.
without_data
()
pyt_sh_ten
.
mcore_metadata
=
{}
if
has_flattened_range
:
pyt_sh_ten
.
mcore_metadata
[
'nd_reformulated_orig_global_shape'
]
=
sh_ten
.
global_shape
return
pyt_sh_ten
def
mcore_to_pyt_state_dict
(
state_dict
:
Dict
[
str
,
List
[
ShardedBase
]],
is_loading
:
bool
=
False
,
init_device
:
torch
.
device
=
torch
.
device
(
"cpu"
),
load_legacy_1d_flatten_tensors
:
bool
=
False
,
)
->
Dict
[
str
,
Union
[
TorchShardedTensor
,
io
.
BytesIO
]]:
"""Convert state dict with ShardedTensors and ShardedObjects
to state dict compatible with PyT Dist format.
Operates in-place and returns the original state dict.
Args:
state_dict (Dict[str, List[ShardedBase]]): flattened state dict, where values
are lists of either ShardedTensor or ShardedObjects.
is_loading (bool, optional): flag indicating if loading or saving. Defaults to False.
init_device (torch.device, optional): device to initialize potentially missing tensors
during loading. Defaults to 'cpu'.
Returns (Dict[str, Union[TorchShardedTensor, io.BytesIO]]): original dictionary with values
converted either into PyT ShardedTensors or io.BytesIO.
"""
rank
=
torch
.
distributed
.
get_rank
()
pyt_state_dict
=
{}
def
_mcore_to_torch_sharded_tensor
(
sh_tens
:
List
[
ShardedTensor
])
->
TorchShardedTensor
:
"""Build a PyT ShardedTensor from given shards.
During loading:
- if data is None, initialize it with an empty tensor (will be used to copy the data into)
- if `allow_shape_mismatch` is True, the data is initialized with zeros
prior to loading (not all parts of the tensor will be read from the checkpoint)
"""
assert
all
(
isinstance
(
sh_ten
,
ShardedTensor
)
for
sh_ten
in
sh_tens
),
sh_tens
for
sh_ten
in
sh_tens
:
if
sh_ten
.
data
is
None
:
if
is_loading
:
sh_ten
.
init_data
(
init_device
,
init_fn
=
torch
.
zeros
if
sh_ten
.
allow_shape_mismatch
else
torch
.
empty
,
)
else
:
raise
CheckpointingException
(
f
'`data` attr is None for
{
sh_ten
}
'
)
else
:
sh_ten
.
data
=
sh_ten
.
data
.
detach
()
if
sh_ten
.
allow_shape_mismatch
and
is_loading
:
sh_ten
.
data
.
zero_
()
torch_sh_ten
=
sharded_tensor_to_torch_sharded_tensor
(
sh_tens
,
rank
,
load_legacy_1d_flatten_tensors
)
torch_sh_ten
.
key
=
sh_tens
[
0
].
key
return
torch_sh_ten
def
_mcore_to_torch_sharded_object
(
sh_objs
:
List
[
ShardedObject
])
->
io
.
BytesIO
:
"""Build io.BytesIO from given sharded objects data."""
assert
all
(
isinstance
(
sh_obj
,
ShardedObject
)
for
sh_obj
in
sh_objs
),
sh_objs
serialized_data
=
io
.
BytesIO
()
torch
.
save
([
sh_obj
.
data
for
sh_obj
in
sh_objs
],
serialized_data
)
return
serialized_data
for
k
,
v
in
state_dict
.
items
():
if
isinstance
(
v
[
0
],
ShardedTensor
):
v
=
cast
(
List
[
ShardedTensor
],
v
)
pyt_state_dict
[
k
]
=
_mcore_to_torch_sharded_tensor
(
v
)
else
:
v
=
cast
(
List
[
ShardedObject
],
v
)
pyt_state_dict
[
k
]
=
_mcore_to_torch_sharded_object
(
v
)
return
pyt_state_dict
def
_unwrap_pyt_sharded_tensor
(
sh_ten
:
TorchShardedTensor
)
->
List
[
torch
.
Tensor
]:
"""Unwrap tensor from PyT ShardedTensor instance.
If `prepend_axis_num` was non-zero (which is specific to MCore ShardedTensor)
then the tensor has additional singleton dimensions which should be squeezed.
"""
mcore_sh_ten
=
sh_ten
.
mcore_sh_ten
ret_tensors
=
[]
for
sh
in
sh_ten
.
local_shards
():
ten
=
sh
.
tensor
if
mcore_sh_ten
.
flattened_range
is
not
None
:
assert
ten
.
shape
[:
-
1
]
==
(
1
,)
*
(
len
(
ten
.
shape
)
-
1
),
ten
.
shape
ten
=
ten
.
view
(
-
1
)
else
:
for
_
in
range
(
mcore_sh_ten
.
prepend_axis_num
):
ten
=
ten
.
squeeze
(
0
)
ret_tensors
.
append
(
ten
)
return
ret_tensors
def
_replace_state_dict_keys_with_sharded_keys
(
sharded_state_dict
:
ShardedStateDict
,
keep_only_main_replica
:
bool
=
False
)
->
Tuple
[
Dict
[
str
,
List
[
ShardedBase
]],
FLATTEN_MAPPING
,
Dict
[
str
,
List
[
str
]]]:
"""Group ShardedBase objects by keys and
return mappings required for recreating the original dict."""
flat_sd
,
flat_mapping
=
flatten_state_dict
(
sharded_state_dict
)
rename_mapping
=
defaultdict
(
list
)
new_flat_sd
=
defaultdict
(
list
)
for
k
,
sh_base
in
flat_sd
.
items
():
assert
isinstance
(
sh_base
,
ShardedBase
),
type
(
sh_base
)
key
=
sh_base
.
unique_key
if
isinstance
(
sh_base
,
ShardedObject
)
else
sh_base
.
key
if
is_main_replica
(
sh_base
.
replica_id
)
or
not
keep_only_main_replica
:
rename_mapping
[
key
].
append
(
k
)
new_flat_sd
[
key
].
append
(
sh_base
)
return
new_flat_sd
,
flat_mapping
,
rename_mapping
def
_replace_sharded_keys_with_state_dict_keys
(
state_dict
:
Dict
[
str
,
List
[
Union
[
torch
.
Tensor
,
io
.
BytesIO
]]],
flat_mapping
:
FLATTEN_MAPPING
,
rename_mapping
:
Dict
[
str
,
List
[
str
]],
):
"""Inverse of _replace_state_dict_keys_with_sharded_keys."""
recovered_sd
=
{}
for
k
,
tensors
in
state_dict
.
items
():
assert
len
(
tensors
)
==
len
(
rename_mapping
[
k
])
for
ten
,
recovered_k
in
zip
(
tensors
,
rename_mapping
[
k
]):
recovered_sd
[
recovered_k
]
=
ten
return
unflatten_state_dict
(
recovered_sd
,
flat_mapping
)
def
_restore_dict_types
(
x
:
Union
[
dict
,
list
,
Any
],
keys_template
:
Union
[
dict
,
list
,
Any
]):
"""Recursively update `x` keys, based on `keys_template`."""
if
isinstance
(
keys_template
,
dict
):
assert
isinstance
(
x
,
dict
),
type
(
x
)
for
k
,
v
in
keys_template
.
items
():
if
not
isinstance
(
k
,
str
):
assert
str
(
k
)
in
x
,
(
k
,
x
.
keys
)
x
[
k
]
=
x
.
pop
(
str
(
k
))
_restore_dict_types
(
x
[
k
],
v
)
elif
isinstance
(
keys_template
,
list
):
assert
isinstance
(
x
,
list
),
type
(
x
)
for
x_val
,
templ_val
in
zip
(
x
,
keys_template
):
_restore_dict_types
(
x_val
,
templ_val
)
@
dataclass
(
frozen
=
True
)
class
MCoreSavePlan
(
SavePlan
):
"""SavePlan with MCore specific data."""
mcore_data
:
Optional
[
Dict
[
str
,
Dict
[
str
,
Any
]]]
=
None
# Mcore related data about each tensor
class
MCoreSavePlanner
(
DefaultSavePlanner
):
"""Differs with the default planner by saving BytesIO objects on all ranks.
In the integration of MCore with PyT Distributed format, BytesIO objects
come from ShardedObjects, which should be treated as separate objects on each rank
(not common on all ranks).
Also, the objects are already packed in io.BytesIO, so no need to redo it
in transform_object.
"""
def
__init__
(
self
,
*
args
,
dedup_replicated_tensors
:
Optional
[
bool
]
=
None
,
nd_flattened_global_shapes
:
Optional
[
Dict
[
str
,
Tuple
[
int
,
...]]]
=
None
,
can_run_decentralized_global_plan
:
bool
=
True
,
**
kwargs
,
)
->
None
:
# `dedup_replicated_tensors` was deprecated in 2.3; this check avoids warnings
# during saving.
if
get_torch_version
()
<=
PkgVersion
(
"2.2"
):
kwargs
[
'dedup_replicated_tensors'
]
=
dedup_replicated_tensors
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
nd_flattened_global_shapes
=
nd_flattened_global_shapes
or
{}
self
.
can_run_decentralized_global_plan
=
can_run_decentralized_global_plan
if
can_run_decentralized_global_plan
:
assert
(
not
dedup_replicated_tensors
),
'Cannot run decentralized plan with dedup_replicated_tensors=True'
assert
(
not
self
.
flatten_state_dict
),
'Cannot run decentralized plan with flatten_state_dict=True'
def
create_local_plan
(
self
)
->
SavePlan
:
"""Adds IOBytes write request on non-coordinator ranks."""
# NOTE: for PyT 2.4.0a0 we can't rely on `create_default_local_save_plan` because
# some alpha versions (specifically 2.4.0a0+f70bd71a48 in 24.06 NGC PyTorch container)
# add iobytes request only on coordinator ranks and some alpha versions
# (specifically 2.4.0a0+3bcc3cddb5 in 24.07 NGC PyTorch container)
# add those requests on all ranks. We inline a simplified version of this method below.
write_items
=
[]
for
fqn
,
obj
in
self
.
state_dict
.
items
():
assert
not
HAVE_DTENSOR
or
not
isinstance
(
obj
,
DTensor
)
# translation from MCore ShardedTensors shouldn't result in DTensors
# Create write requests for tensor and bytes values.
# For MCore, these should be already non-duplicates.
write_items
+=
_create_write_items
(
fqn
,
obj
)
self
.
plan
=
MCoreSavePlan
(
items
=
write_items
,
planner_data
=
self
.
mappings
,
mcore_data
=
{
k
:
sh_ten
.
mcore_metadata
for
k
,
sh_ten
in
self
.
state_dict
.
items
()
if
isinstance
(
sh_ten
,
TorchShardedTensor
)
},
)
return
self
.
plan
def
create_global_plan
(
self
,
all_plans
:
List
[
MCoreSavePlan
])
->
Tuple
[
List
[
SavePlan
],
Metadata
]:
"""Merges MCore data for all plans."""
global_plan
,
metadata
=
super
().
create_global_plan
(
all_plans
)
metadata
.
mcore_data
=
dict
(
ChainMap
(
*
(
plan
.
mcore_data
for
plan
in
all_plans
))
# type: ignore[arg-type]
)
return
global_plan
,
metadata
def
create_decentralized_global_plan
(
self
,
local_plan
:
SavePlan
)
->
SavePlan
:
"""Nothing to do, just some checks.
Args:
local_plan (SavePlan): local plan to turn to a global plan
(without interactions with other ranks)
Returns:
SavePlan - locally transformed plan equivalent to the plan that would be
created by the coordinator
"""
assert
(
not
self
.
flatten_state_dict
),
'Cannot run decentralized plan with flatten_state_dict=True'
assert
not
local_plan
.
planner_data
,
'Planner data should be empty with decentralized plan'
return
local_plan
def
transform_object
(
self
,
write_item
:
WriteItem
,
object
:
Any
):
"""Make no transformations - bytes objects are already serialized."""
return
object
class
MCoreLoadPlanner
(
DefaultLoadPlanner
):
"""Adds global shape validation to the default planner.
If global shape validation can be ignored (shouldn't!), the default
load planner can be used.
"""
def
__init__
(
self
,
*
args
,
shapes_validation_sharded_tensors
:
Iterable
[
ShardedTensor
]
=
(),
allow_shape_mismatch_sharded_tensors
:
Optional
[
Dict
[
str
,
ShardedTensor
]]
=
None
,
**
kwargs
,
)
->
None
:
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
shapes_validation_sharded_tensors
=
shapes_validation_sharded_tensors
self
.
allow_shape_mismatch_sharded_tensors
=
allow_shape_mismatch_sharded_tensors
self
.
_intermediate_read_item_and_target
:
Optional
[
Tuple
[
ReadItem
,
torch
.
Tensor
]]
=
None
@
staticmethod
def
_expected_shape
(
sh_ten
):
return
(
nd_flattened_tensor_reformulated_global_shape
(
sh_ten
)
if
is_nd_flattened_tensor
(
sh_ten
)
else
sh_ten
.
global_shape
)
def
_validate_global_shapes
(
self
,
metadata
,
sharded_tensors
):
for
sh_ten
in
sharded_tensors
:
if
sh_ten
.
key
not
in
metadata
.
state_dict_metadata
:
raise
KeyError
(
f
"
{
sh_ten
.
key
}
from model not in state dict:"
f
"
{
sorted
(
metadata
.
state_dict_metadata
.
keys
())
}
"
)
loaded_shape
=
metadata
.
state_dict_metadata
[
sh_ten
.
key
].
size
expected_shape
=
self
.
_expected_shape
(
sh_ten
)
if
loaded_shape
!=
expected_shape
:
if
is_nd_flattened_tensor
(
sh_ten
)
and
len
(
sh_ten
.
global_shape
)
==
1
:
# Handle legacy 1-D flattened tensors checkpoint format
# where the global shape is not stored in the metadata
expected_shape
=
sh_ten
.
global_shape
if
loaded_shape
==
expected_shape
:
continue
_msg
=
(
f
'Global shape mismatch for loaded (
{
loaded_shape
}
)'
f
' and expected (
{
expected_shape
}
) tensor'
f
' for key
{
sh_ten
.
key
}
'
)
raise
CheckpointingException
(
_msg
)
@
contextmanager
def
_temporarily_bypass_shape_validation
(
self
):
"""
Temporarily set the size of tensors to their expected shapes to bypass DCP shape validation.
This is used when validating the shapes during local plan creation.
"""
if
not
self
.
allow_shape_mismatch_sharded_tensors
:
yield
return
tensor_metadata
=
self
.
metadata
.
state_dict_metadata
metadata_with_sizes
=
[
(
tensor_metadata
[
key
],
tensor_metadata
[
key
].
size
,
sharded_tensor
)
for
key
,
sharded_tensor
in
self
.
allow_shape_mismatch_sharded_tensors
.
items
()
]
try
:
# Temporarily set sizes to expected shapes
for
md
,
_
,
sharded_tensor
in
metadata_with_sizes
:
md
.
size
=
self
.
_expected_shape
(
sharded_tensor
)
yield
finally
:
# Restore original sizes after yield
for
md
,
size
,
_
in
metadata_with_sizes
:
md
.
size
=
size
def
create_local_plan
(
self
)
->
LoadPlan
:
"""Runs additional shapes validation."""
self
.
_validate_global_shapes
(
self
.
metadata
,
self
.
shapes_validation_sharded_tensors
)
with
self
.
_temporarily_bypass_shape_validation
():
local_plan
=
super
().
create_local_plan
()
return
local_plan
def
resolve_tensor
(
self
,
read_item
:
ReadItem
):
"""Override to add FP8 support.
Narrowing the Float8Tensor can create incontiguous tensors and there are
no `copy` kernels for such cases. This method creates a contiguous FP8
tensors so that the subsequent `copy_` in FileSystemReader succeeds.
Note that this requires tracking the original tensor
(as `self._intermediate_read_item_and_target` attribute)
and restoring it in `commit_tensor` method.
"""
target_tensor
=
super
().
resolve_tensor
(
read_item
)
if
(
not
target_tensor
.
is_contiguous
()
and
HAVE_TE
and
isinstance
(
target_tensor
,
Float8Tensor
)
):
self
.
_intermediate_read_item_and_target
=
(
read_item
,
target_tensor
)
target_tensor
=
Float8Tensor
.
make_like
(
target_tensor
,
data
=
target_tensor
.
_data
.
contiguous
()
)
return
target_tensor
def
commit_tensor
(
self
,
read_item
:
ReadItem
,
tensor
:
torch
.
Tensor
)
->
None
:
"""Restores the original FP8 tensor saved in `resolve_tensor`."""
if
self
.
_intermediate_read_item_and_target
is
not
None
:
interm_read_item
,
target_tensor
=
self
.
_intermediate_read_item_and_target
assert
(
interm_read_item
is
read_item
),
'`commit_tensor` method should be called right after `resolve_tensor`'
target_tensor
.
copy_
(
tensor
)
tensor
=
target_tensor
self
.
_intermediate_read_item_and_target
=
None
return
super
().
commit_tensor
(
read_item
,
tensor
)
class
TorchDistSaveShardedStrategy
(
AsyncSaveShardedStrategy
):
"""Async save strategy for the PyT Distributed format.
The idea is to translate MCore ShardedTensors into PyT ShardedTensors
and use the async-adjusted torch.distributed.checkpoint saving mechanism
provided by the FileSystemWriterAsync writer.
"""
def
__init__
(
self
,
backend
:
str
,
version
:
int
,
keep_only_main_replica
:
bool
=
True
,
thread_count
:
int
=
2
,
cached_metadata
:
bool
=
False
,
separation_hint
:
Optional
[
str
]
=
None
,
):
"""Adds parameters specific to PyT Distributed format
Args:
backend (str): format backend string
version (int): format version
keep_only_main_replica (bool, optional): PyT Distributed has a mechanism
for deduplication, but replica_id aware deduplication is more coherent.
Default is True (recommended to keep it).
thread_count (int, optional): threads to use during saving.
Affects the number of files in the checkpoint (saving ranks * num_threads).
cached_metadata (bool, optional): Enables using cached global metadata to avoid
gathering local metadata every checkpointing invocation
separation_hint(str, optional): If provided, all tensors whose keys have this
prefix will be saved to a separate file.
"""
super
().
__init__
(
backend
,
version
)
self
.
keep_only_main_replica
=
keep_only_main_replica
self
.
thread_count
=
thread_count
# Cached SavePlans to skip plan in `save_state_dict_async_plan`
# cached outcome of `SavePlan.prepare_global_plan`,
# which aggregates local plans from all ranks
self
.
cached_central_plan
:
SavePlan
=
None
# cached outcome of `SavePlan.prepare_local_plan` describes how local state_dict is written
self
.
cached_local_plan
:
SavePlan
=
None
# Cached global metadata, only `coordinator` for dist-ckpt holds
# if central plans are consistent over iters
self
.
cached_global_metadata
:
Metadata
=
None
# This variable records if the ckpt structures are consistent
# so the following checkpoint savings reuse `cached_global_metadata`
self
.
validated_cache_reuse
:
bool
=
False
# The knob to enable cached metadata communication in saving
self
.
use_cached_ckpt_structure
:
bool
=
cached_metadata
self
.
separation_hint
=
separation_hint
self
.
validated_loaded_metadata_reuse
=
False
def
async_save
(
self
,
sharded_state_dict
:
ShardedStateDict
,
checkpoint_dir
:
Path
)
->
AsyncRequest
:
"""Translates MCore ShardedTensors to PyT ShardedTensors & saves in PyT Distributed format.
Args:
sharded_state_dict (ShardedStateDict): sharded state dict to save
checkpoint_dir (Path): checkpoint directory
Returns: None
"""
# Translate the state dict
(
sharded_state_dict
,
flat_mapping
,
rename_mapping
)
=
(
_replace_state_dict_keys_with_sharded_keys
(
sharded_state_dict
,
self
.
keep_only_main_replica
)
)
pyt_state_dict
=
mcore_to_pyt_state_dict
(
sharded_state_dict
,
False
)
# Use PyT saving mechanism
writer
=
FileSystemWriterAsync
(
checkpoint_dir
,
separation_hint
=
self
.
separation_hint
,
thread_count
=
self
.
thread_count
,
use_msc
=
MultiStorageClientFeature
.
is_enabled
(),
)
# This should be set differently if we run in a smaller process group than the default
coordinator
=
0
# Try twice to validate the generated `central_plan` is the same across iterations
# If so, reuse `cached_central_plan` and `cached_global_metadata`
# From the 3rd iteration, `save_state_dict_async_plan` will not generate `global_metadata`
# (return None) so `self.cached_global_metadata` is reused
args_cached_plans
=
None
loaded_all_plans
=
None
if
self
.
use_cached_ckpt_structure
:
loaded_all_plans
=
getattr
(
self
.
cached_global_metadata
,
"all_local_plans"
,
None
)
if
loaded_all_plans
is
None
:
logger
.
debug
(
"no all_local_plans in metadata - can't verify global metadata reuse..."
)
args_cached_plans
=
(
self
.
cached_central_plan
,
self
.
cached_local_plan
,
self
.
validated_cache_reuse
,
)
(
save_state_dict_ret
,
self
.
cached_central_plan
,
self
.
cached_local_plan
,
self
.
validated_cache_reuse
,
self
.
validated_loaded_metadata_reuse
,
)
=
save_state_dict_async_plan
(
pyt_state_dict
,
writer
,
None
,
coordinator
,
planner
=
MCoreSavePlanner
(
dedup_replicated_tensors
=
not
self
.
keep_only_main_replica
,
flatten_state_dict
=
False
),
cached_ckpt_structure
=
args_cached_plans
,
loaded_all_plans
=
loaded_all_plans
,
)
rank
=
torch
.
distributed
.
get_rank
()
if
self
.
use_cached_ckpt_structure
:
if
(
loaded_all_plans
and
self
.
cached_global_metadata
and
self
.
validated_loaded_metadata_reuse
):
if
coordinator
==
rank
:
logger
.
debug
(
f
"rank:
{
rank
}
, reuse global metadata from loaded"
f
" .metadata,
{
save_state_dict_ret
[
1
]
}
"
)
save_state_dict_ret
=
list
(
save_state_dict_ret
)
save_state_dict_ret
[
1
]
=
self
.
cached_global_metadata
elif
self
.
validated_cache_reuse
:
logger
.
debug
(
f
"rank:
{
rank
}
, cache validated"
)
if
save_state_dict_ret
[
1
]:
# when global_metadata is not cached
self
.
cached_global_metadata
=
save_state_dict_ret
[
1
]
# Cache Metadata
# Only Coordinator rank holds cached global_metadata
# (None is returned for global_metadata)
elif
coordinator
==
rank
:
logger
.
debug
(
f
"rank:
{
rank
}
, reuse global metadata cached from previous"
f
" save iteration,
{
save_state_dict_ret
[
1
]
}
"
)
save_state_dict_ret
=
list
(
save_state_dict_ret
)
save_state_dict_ret
[
1
]
=
self
.
cached_global_metadata
return
self
.
_get_save_and_finalize_callbacks
(
writer
,
save_state_dict_ret
)
def
_get_save_and_finalize_callbacks
(
self
,
writer
,
save_state_dict_ret
)
->
AsyncRequest
:
save_fn_args
=
writer
.
get_save_function_and_args
()
save_fn
,
preload_fn
,
save_args
=
save_fn_args
def
finalize_fn
():
save_state_dict_async_finalize
(
*
save_state_dict_ret
)
torch
.
distributed
.
barrier
()
return
AsyncRequest
(
save_fn
,
save_args
,
[
finalize_fn
],
preload_fn
=
preload_fn
)
def
can_handle_sharded_objects
(
self
):
return
True
def
_get_filesystem_reader
(
checkpoint_dir
:
Union
[
str
,
Path
],
cache_metadata
:
bool
=
False
)
->
FileSystemReader
:
if
MultiStorageClientFeature
.
is_enabled
():
msc
=
MultiStorageClientFeature
.
import_package
()
return
msc
.
torch
.
MultiStorageFileSystemReader
(
checkpoint_dir
,
thread_count
=
2
)
if
cache_metadata
:
return
CachedMetadataFileSystemReader
(
checkpoint_dir
)
return
FileSystemReader
(
checkpoint_dir
)
def
get_reformulation_metadata
(
sharded_state_dict
:
ShardedStateDict
,
checkpoint_dir
:
Path
)
->
Dict
[
str
,
TensorReformulationMetadata
]:
"""Reads MCore data for N-D flattened tensors from checkpoint metadata during ckpt load.
Args:
sharded_state_dict (ShardedStateDict): sharded state dict to load
checkpoint_dir (Path): checkpoint directory
Returns:
Dict[str, TensorReformulationMetadata] - dictionary that maps keys of every
N-D flattened tensor from the sharded_state_dict to its original global shape
as stored in `mcore_data` in the checkpoint.
"""
fs_reader
=
_get_filesystem_reader
(
checkpoint_dir
)
ckpt_metadata
=
fs_reader
.
read_metadata
()
reformulation_metadata
=
{}
for
sh_ten
in
nested_values
(
sharded_state_dict
):
if
not
is_nd_flattened_tensor
(
sh_ten
):
continue
try
:
ckpt_global_shape
=
ckpt_metadata
.
mcore_data
[
sh_ten
.
key
][
'nd_reformulated_orig_global_shape'
]
except
KeyError
as
e
:
if
len
(
sh_ten
.
global_shape
)
==
1
:
warnings
.
warn
(
f
'Legacy checkpoint format detected for 1-D flattened tensor
{
sh_ten
}
. '
'Skip metadata reformulation.'
)
continue
raise
CheckpointingException
(
f
'Cannot find global shape metadata for N-D flattened tensor
{
sh_ten
}
'
f
'in checkpoint metadata:
{
ckpt_metadata
.
mcore_data
}
'
)
from
e
reformulation_metadata
[
sh_ten
.
key
]
=
TensorReformulationMetadata
(
ckpt_global_shape
,
ckpt_metadata
.
state_dict_metadata
[
sh_ten
.
key
].
size
)
return
reformulation_metadata
class
TorchDistLoadShardedStrategy
(
LoadShardedStrategy
):
"""Basic load strategy for the PyT Distributed format."""
def
__init__
(
self
):
self
.
cached_global_metadata
:
Optional
[
Metadata
]
=
None
super
().
__init__
()
def
load
(
self
,
sharded_state_dict
:
ShardedStateDict
,
checkpoint_dir
:
Path
)
->
StateDict
:
"""Translates MCore ShardedTensors to PyT ShardedTensors & loads from PyT Distributed fmt.
Args:
sharded_state_dict (ShardedStateDict): sharded state dict with mapping
information to instruct loading
checkpoint_dir (Path): checkpoint directory
Returns: loaded state dict
"""
# Apply N-D tensors resharding
reformulation_metadata
=
get_reformulation_metadata
(
sharded_state_dict
,
checkpoint_dir
)
sharded_state_dict
,
formulation_restore_data
=
apply_nd_flattened_tensors_reformulation
(
sharded_state_dict
,
reformulation_metadata
)
# Check if there are legacy 1-D flattened tensors in the checkpoint
has_legacy_1d_flattened_tensors
=
False
for
sh_ten
in
nested_values
(
sharded_state_dict
):
if
is_nd_flattened_tensor
(
sh_ten
)
and
sh_ten
.
key
not
in
reformulation_metadata
:
has_legacy_1d_flattened_tensors
=
True
break
flexible_shape_sharded_tensors
=
[
sh_ten
for
sh_ten
in
nested_values
(
sharded_state_dict
)
if
isinstance
(
sh_ten
,
ShardedTensor
)
and
not
sh_ten
.
allow_shape_mismatch
]
allow_shape_mismatch_sharded_tensors
=
{
sh_ten
.
key
:
sh_ten
for
sh_ten
in
nested_values
(
sharded_state_dict
)
if
isinstance
(
sh_ten
,
ShardedTensor
)
and
sh_ten
.
allow_shape_mismatch
}
orig_sharded_state_dict
=
sharded_state_dict
# MCore state dict to PyT Distributed compatible
(
sharded_state_dict
,
flat_mapping
,
rename_mapping
)
=
(
_replace_state_dict_keys_with_sharded_keys
(
sharded_state_dict
)
)
pyt_state_dict
=
mcore_to_pyt_state_dict
(
sharded_state_dict
,
True
,
load_legacy_1d_flatten_tensors
=
has_legacy_1d_flattened_tensors
)
# Load PyT Distributed format
fsr
=
_get_filesystem_reader
(
checkpoint_dir
,
cache_metadata
=
True
)
checkpoint
.
load_state_dict
(
pyt_state_dict
,
fsr
,
planner
=
MCoreLoadPlanner
(
shapes_validation_sharded_tensors
=
flexible_shape_sharded_tensors
,
allow_shape_mismatch_sharded_tensors
=
allow_shape_mismatch_sharded_tensors
,
),
)
self
.
cached_global_metadata
=
(
fsr
.
read_metadata
()
)
# no storage interaction thanks to caching
pyt_state_dict
=
cast
(
Dict
[
str
,
Union
[
TorchShardedTensor
,
List
[
io
.
BytesIO
]]],
pyt_state_dict
)
# Unwrap ShardedTensors and return to original state dict
mcore_state_dict
=
{
k
:
v
if
not
isinstance
(
v
,
TorchShardedTensor
)
else
_unwrap_pyt_sharded_tensor
(
v
)
for
k
,
v
in
pyt_state_dict
.
items
()
}
mcore_state_dict
=
_replace_sharded_keys_with_state_dict_keys
(
mcore_state_dict
,
flat_mapping
,
rename_mapping
# type: ignore[arg-type]
)
_restore_dict_types
(
mcore_state_dict
,
orig_sharded_state_dict
)
# Apply N-D tensors resharding postprocessing
mcore_state_dict
=
restore_nd_flattened_tensors_formulation
(
mcore_state_dict
,
formulation_restore_data
)
return
mcore_state_dict
def
load_tensors_metadata
(
self
,
checkpoint_dir
:
Path
,
metadata
:
Metadata
=
None
):
"""Uses tensors metadata stored in the metadata file."""
if
metadata
is
None
:
fs_reader
=
_get_filesystem_reader
(
checkpoint_dir
)
metadata
=
fs_reader
.
read_metadata
()
mcore_data
=
getattr
(
metadata
,
'mcore_data'
,
{})
sharded_metadata
=
{}
for
k
,
tp
in
metadata
.
state_dict_metadata
.
items
():
if
not
isinstance
(
tp
,
TensorStorageMetadata
):
continue
# load only tensors
nd_orig_global_shape
=
mcore_data
.
get
(
k
,
{}).
get
(
'nd_reformulated_orig_global_shape'
)
if
nd_orig_global_shape
is
None
:
# Regular tensor
sharded_metadata
[
k
]
=
ShardedTensor
.
from_rank_offsets
(
k
,
torch
.
empty
(
tp
.
size
,
**
tp
.
properties
.
__dict__
,
device
=
'meta'
)
).
without_data
()
else
:
# N-D flattened tensor
unflat_ten
=
torch
.
empty
(
nd_orig_global_shape
,
**
tp
.
properties
.
__dict__
,
device
=
'meta'
)
flat_ten
=
unflat_ten
.
flatten
()
sharded_metadata
[
k
]
=
ShardedTensor
.
from_rank_offsets_flat
(
k
,
flat_ten
,
unflat_ten
.
shape
,
flattened_range
=
slice
(
0
,
unflat_ten
.
numel
()),
# whole slice
).
without_data
()
return
sharded_metadata
def
load_sharded_metadata
(
self
,
checkpoint_dir
:
Path
)
->
ShardedStateDict
:
"""Uses tensors and objects metadata stored in the metadata file."""
fs_reader
=
_get_filesystem_reader
(
checkpoint_dir
)
metadata
=
fs_reader
.
read_metadata
()
sharded_metadata
=
{}
for
metadata_key
,
storage_metadata
in
metadata
.
state_dict_metadata
.
items
():
if
not
isinstance
(
storage_metadata
,
BytesStorageMetadata
):
continue
sh_obj
=
ShardedObject
.
empty_from_unique_key
(
metadata_key
)
sharded_metadata
[
sh_obj
.
unique_key
]
=
sh_obj
sharded_metadata
.
update
(
self
.
load_tensors_metadata
(
checkpoint_dir
,
metadata
))
return
sharded_metadata
def
remove_sharded_tensors
(
self
,
checkpoint_dir
:
str
,
key_prefix
:
str
):
"""Removes checkpoint files whose keys have the given prefix.
Performs the following steps:
1. checks whether there are files that start with the key_prefix
2. loads metadata
3. removes all entries from the metadata that start with the key_prefix
4. resaves the new metadata and removes the old metadata
5. removes the relevant files
"""
assert
is_torch_min_version
(
"2.3.0"
),
f
'torch >= 2.3.0 is required for remove_sharded_tensors'
distckpt_files
=
[
f
for
f
in
os
.
listdir
(
checkpoint_dir
)
if
f
.
endswith
(
"distcp"
)]
files_to_remove
=
[
f
for
f
in
distckpt_files
if
f
.
startswith
(
key_prefix
)]
if
not
files_to_remove
:
warnings
.
warn
(
f
'There are no files in
{
checkpoint_dir
}
that begin with "
{
key_prefix
}
".'
f
' Skipping removal.'
)
return
fs_reader
=
FileSystemReader
(
checkpoint_dir
)
original_metadata
=
fs_reader
.
read_metadata
()
new_state_dict_metadata
=
{}
new_planner_data
=
{}
new_storage_data
=
{}
for
k
in
original_metadata
.
state_dict_metadata
.
keys
():
if
k
.
startswith
(
key_prefix
):
continue
new_state_dict_metadata
[
k
]
=
original_metadata
.
state_dict_metadata
[
k
]
original_planner_data
=
original_metadata
.
planner_data
if
original_planner_data
is
not
None
:
for
k
in
original_planner_data
.
keys
():
if
k
.
startswith
(
key_prefix
):
continue
new_planner_data
[
k
]
=
original_metadata
.
planner_data
[
k
]
original_storage_data
=
original_metadata
.
storage_data
if
original_storage_data
is
not
None
:
for
k
in
original_storage_data
.
keys
():
if
k
.
fqn
.
startswith
(
key_prefix
):
continue
new_storage_data
[
k
]
=
original_metadata
.
storage_data
[
k
]
metadata
=
Metadata
(
state_dict_metadata
=
new_state_dict_metadata
,
planner_data
=
new_planner_data
,
storage_data
=
new_storage_data
,
)
fs_writer
=
FileSystemWriter
(
checkpoint_dir
)
metadata_filename
=
cast
(
Path
,
fs_writer
.
fs
.
concat_path
(
fs_writer
.
path
,
_metadata_fn
))
tmp_path
=
cast
(
metadata_filename
,
# type: ignore[valid-type]
fs_writer
.
fs
.
concat_path
(
fs_writer
.
path
,
f
"
{
_metadata_fn
}
.tmp"
),
)
old_path
=
cast
(
metadata_filename
,
# type: ignore[valid-type]
fs_writer
.
fs
.
concat_path
(
fs_writer
.
path
,
f
"
{
_metadata_fn
}
.bck"
),
)
## save the new metadata
with
fs_writer
.
fs
.
create_stream
(
tmp_path
,
"wb"
)
as
metadata_file
:
pickle
.
dump
(
metadata
,
metadata_file
)
try
:
os
.
fsync
(
metadata_file
.
fileno
())
except
AttributeError
:
os
.
sync
()
## move the old metadata
fs_writer
.
fs
.
rename
(
fs_writer
.
metadata_path
,
old_path
)
try
:
## rename the new metadata
fs_writer
.
fs
.
rename
(
tmp_path
,
fs_writer
.
metadata_path
)
## finally, remove the files we want to drop
for
f
in
files_to_remove
:
fs_writer
.
fs
.
rm_file
(
checkpoint_dir
/
f
)
except
Exception
as
e
:
fs_writer
.
fs
.
rename
(
old_path
,
fs_writer
.
metadata_path
)
raise
e
else
:
fs_writer
.
fs
.
rm_file
(
old_path
)
def
can_handle_sharded_objects
(
self
):
return
True
def
check_backend_compatibility
(
self
,
loaded_version
):
pass
# TODO
def
check_version_compatibility
(
self
,
loaded_version
):
pass
# TODO
Megatron-LM/megatron/core/dist_checkpointing/strategies/two_stage.py
0 → 100644
View file @
1106877d
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" 2-stage checkpoint loading. """
import
time
from
collections
import
defaultdict
from
dataclasses
import
dataclass
from
functools
import
partial
,
wraps
from
itertools
import
chain
from
logging
import
getLogger
from
operator
import
attrgetter
,
itemgetter
from
pathlib
import
Path
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
torch
from
..dict_utils
import
dict_list_map_inplace
,
map_reduce
,
nested_values
from
..mapping
import
ShardedStateDict
,
ShardedTensor
from
.base
import
LoadShardedStrategy
from
.tensorstore
import
_load_from_array
,
open_ts_array
from
.zarr
import
flatten_range
,
load_zarr_based_sharded_metadata
_import_trigger
=
None
timers
=
defaultdict
(
list
)
logger
=
getLogger
(
__name__
)
logger
.
warning
(
'megatron.core.dist_checkpointing.two_stage module is deprecated'
' and will be removed in Megatron-Core v0.12. Please use'
' FullyParallelLoadStrategyWrapper to accomplish a parallelized checkpoint load.'
)
def
timed
(
verbose
=
True
):
"""Timing decorator."""
def
timed_dec
(
fn
):
name
=
fn
.
__name__
@
wraps
(
fn
)
def
wrapped
(
*
args
,
**
kwargs
):
if
verbose
:
logger
.
debug
(
f
'
{
name
}
init'
)
start
=
time
.
time
()
ret
=
fn
(
*
args
,
**
kwargs
)
took
=
time
.
time
()
-
start
if
verbose
:
logger
.
debug
(
f
'
{
name
}
took
{
took
}
s'
)
timers
[
name
].
append
(
took
)
return
ret
return
wrapped
return
timed_dec
@
dataclass
class
_ShardedTensorMetadata
:
global_rank
:
int
sharded_tensor_no_data
:
ShardedTensor
dist_group_rank
:
Tuple
[
int
]
# id of distributed group
dist_group_ranks
:
Tuple
[
int
]
# id of distributed group
data_size
:
Optional
[
int
]
=
None
# bytes
def
sharded_tensor_chunk_id
(
sharded_tensor
:
ShardedTensor
):
"""Id of a sharded tensor."""
return
(
sharded_tensor
.
key
,
sharded_tensor
.
global_offset
)
class
TwoStageDataParallelLoadShardedStrategy
(
LoadShardedStrategy
):
"""Loads one checkpoint replica from storage and broadcasts to other nodes.
This strategy loads checkpoint from storage on minimal set of nodes
and distributes the checkpoint to other nodes with torch.distributed.
Loading is performed with tensorstore.
Steps:
0. (optional) create Gloo distributed groups
1. Exchange ShardedTensors metadata between all nodes
2. Align needed tensors within DP groups
3. For each globally unique tensor:
3.a) on one of the ranks load it from storage to CPU and move to CUDA
3.b) allocate CUDA tensor on other ranks
3.c) broadcast within DP group
3.d) copy tensor content to the model param location
3.e) free tensor buffers from a) and b)
Notes:
1. Loading and broadcasting is done sequentially to avoid both host and device OOMs
2. There is a lot of overlap potential between all three steps done for each tensor:
2.a) loading from storage to numpy
2.b) moving CPU tensors to CUDA
2.c) broadcast
"""
def
__init__
(
self
,
data_parallel_group
,
cpu_transfer
=
True
):
super
().
__init__
()
self
.
cpu_transfer
=
cpu_transfer
self
.
data_parallel_group_orig
=
data_parallel_group
self
.
data_parallel_group
=
None
if
cpu_transfer
else
data_parallel_group
self
.
dp_group_ranks
=
tuple
(
sorted
(
torch
.
distributed
.
get_process_group_ranks
(
data_parallel_group
))
)
self
.
dp_group_rank
=
self
.
data_parallel_group_orig
.
rank
()
self
.
global_rank
=
torch
.
distributed
.
get_rank
()
def
load
(
self
,
sharded_state_dict
:
ShardedStateDict
,
checkpoint_dir
:
Path
):
"""Main load method."""
self
.
maybe_init_gloo_group
()
all_tensors_sorted
=
self
.
_build_load_plan
(
sharded_state_dict
)
self
.
_exchange_loaded_tensors
(
all_tensors_sorted
,
sharded_state_dict
,
checkpoint_dir
)
# TODO: fix hang in summarize_load_times
# self.summarize_load_times()
return
sharded_state_dict
def
summarize_load_times
(
self
):
"""Summarize load times."""
torch
.
distributed
.
barrier
()
logger
.
info
(
'Checkpoint loading finished. Summary:'
)
# TODO: `timers` keys are not guaranteed to be the same across ranks which causes hangs
for
key
,
times
in
sorted
(
timers
.
items
()):
times_sum
=
sum
(
times
)
max_times
=
torch
.
tensor
([
times_sum
],
device
=
'cuda'
)
avg_times
=
torch
.
tensor
([
times_sum
],
device
=
'cuda'
)
torch
.
distributed
.
all_reduce
(
max_times
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
)
torch
.
distributed
.
all_reduce
(
avg_times
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
)
avg_times
/=
torch
.
distributed
.
get_world_size
()
if
torch
.
distributed
.
get_rank
()
==
0
:
logger
.
info
(
f
'
{
key
}
: max
{
max_times
[
0
]
}
, avg
{
avg_times
[
0
]
}
'
)
@
timed
(
verbose
=
False
)
def
load_tensor_from_storage
(
self
,
checkpoint_dir
,
ten_meta
:
_ShardedTensorMetadata
):
"""Load tensor from storage."""
logger
.
debug
(
f
'_load_from_array(
{
ten_meta
.
sharded_tensor_no_data
.
key
}
) init'
)
ret
=
_load_from_array
(
ten_meta
.
sharded_tensor_no_data
,
checkpoint_dir
,
load_directly_on_device
=
False
,
apply_flattened_range
=
False
,
)
logger
.
debug
(
f
'_load_from_array(
{
ten_meta
.
sharded_tensor_no_data
.
key
}
) DONE'
)
return
ret
@
timed
()
def
maybe_init_gloo_group
(
self
):
"""Create Gloo groups."""
if
not
self
.
cpu_transfer
:
return
all_groups
=
[
None
]
*
torch
.
distributed
.
get_world_size
()
torch
.
distributed
.
all_gather_object
(
all_groups
,
self
.
dp_group_ranks
)
all_groups
=
set
(
tuple
(
sorted
(
gr
))
for
gr
in
all_groups
)
for
group_ranks
in
sorted
(
all_groups
):
# "two_stage" module will be deprecated, so not replace new_group()
# with ...parallel_state.create_group() func setting group_desc here.
gloo_pg
=
torch
.
distributed
.
new_group
(
ranks
=
group_ranks
,
backend
=
'gloo'
)
if
self
.
global_rank
in
group_ranks
:
self
.
data_parallel_group
=
gloo_pg
assert
self
.
dp_group_rank
==
self
.
data_parallel_group
.
rank
()
def
check_backend_compatibility
(
self
,
loaded_version
):
pass
# TODO
def
check_version_compatibility
(
self
,
loaded_version
):
pass
# TODO
@
timed
()
def
_build_load_plan
(
self
,
sharded_state_dict
:
ShardedStateDict
)
->
List
[
_ShardedTensorMetadata
]:
local_meta
=
[
_ShardedTensorMetadata
(
self
.
global_rank
,
sharded_ten
.
without_data
(),
self
.
dp_group_rank
,
self
.
dp_group_ranks
,
)
for
sharded_ten
in
nested_values
(
sharded_state_dict
)
]
all_meta
=
[
None
]
*
self
.
data_parallel_group
.
size
()
torch
.
distributed
.
all_gather_object
(
all_meta
,
local_meta
,
group
=
self
.
data_parallel_group
)
all_meta
=
list
(
chain
.
from_iterable
(
all_meta
))
all_tensors_sorted
=
self
.
deduplicate_chunks
(
all_meta
)
return
all_tensors_sorted
@
timed
()
def
deduplicate_chunks
(
self
,
ten_metas
:
List
[
_ShardedTensorMetadata
]):
"""Group tensors by chunk and then pick the tensor with the lowest rank.
NOTE: with proper loading overlap, loading from randomized ranks
(instead of the smallest one) could be beneficial here.
"""
ten_metas
=
map_reduce
(
ten_metas
,
key_fn
=
lambda
meta
:
sharded_tensor_chunk_id
(
meta
.
sharded_tensor_no_data
),
reduce_fn
=
partial
(
min
,
key
=
attrgetter
(
'dist_group_rank'
)),
)
all_metas_sorted
=
list
(
map
(
itemgetter
(
1
),
sorted
(
ten_metas
.
items
())))
return
all_metas_sorted
@
timed
()
def
_exchange_loaded_tensors
(
self
,
ten_metas
:
List
[
_ShardedTensorMetadata
],
sharded_state_dict
,
checkpoint_dir
):
logger
.
debug
(
f
'_exchange_loaded_tensors, num ten_metas:
{
len
(
ten_metas
)
}
'
)
for
ten_meta
in
ten_metas
:
src_rank
=
torch
.
distributed
.
get_global_rank
(
self
.
data_parallel_group
,
ten_meta
.
dist_group_rank
)
if
self
.
dp_group_rank
==
ten_meta
.
dist_group_rank
:
exchange_tensor
=
self
.
load_tensor_from_storage
(
checkpoint_dir
,
ten_meta
)
if
not
self
.
cpu_transfer
:
exchange_tensor
=
exchange_tensor
.
cuda
()
else
:
# TODO: for non-flattened ranges we could reuse the buffer from the start here
exchange_tensor
=
torch
.
empty
(
ten_meta
.
sharded_tensor_no_data
.
local_shape
,
device
=
'cpu'
if
self
.
cpu_transfer
else
'cuda'
,
dtype
=
ten_meta
.
sharded_tensor_no_data
.
dtype
,
)
logger
.
debug
(
f
'exchange
{
ten_meta
.
sharded_tensor_no_data
.
key
}
,
{
exchange_tensor
.
shape
}
\
(
{
exchange_tensor
.
numel
()
}
), broadcast(
{
src_rank
}
->
{
self
.
dp_group_ranks
}
)'
)
torch
.
distributed
.
broadcast
(
exchange_tensor
,
group
=
self
.
data_parallel_group
,
src
=
src_rank
)
self
.
_distribute_data_to_state_dict
(
ten_meta
,
exchange_tensor
,
sharded_state_dict
)
logger
.
debug
(
f
'exchange
{
ten_meta
.
sharded_tensor_no_data
.
key
}
done'
)
# free buffer memory
exchange_tensor
=
None
@
timed
(
verbose
=
False
)
def
_distribute_data_to_state_dict
(
self
,
ten_meta
:
_ShardedTensorMetadata
,
loaded_ten
:
torch
.
Tensor
,
sharded_state_dict
:
ShardedStateDict
,
):
tensor_key
=
sharded_tensor_chunk_id
(
ten_meta
.
sharded_tensor_no_data
)
def
_fill_in_data
(
t
:
Union
[
ShardedTensor
,
torch
.
Tensor
]):
if
not
isinstance
(
t
,
ShardedTensor
)
or
sharded_tensor_chunk_id
(
t
)
!=
tensor_key
:
# already filled-in or key not matching
return
t
sharded_tensor
:
ShardedTensor
=
t
x
=
loaded_ten
if
sharded_tensor
.
flattened_range
is
not
None
:
x
=
flatten_range
(
sharded_tensor
,
x
)
# Reuse existing buffer
sharded_tensor
.
data
.
data
.
copy_
(
x
)
return
sharded_tensor
.
data
dict_list_map_inplace
(
_fill_in_data
,
sharded_state_dict
)
def
load_tensors_metadata
(
self
,
checkpoint_dir
:
Path
):
def
get_ts_shape_dtype
(
path
):
arr
=
open_ts_array
(
path
)
return
arr
.
shape
,
arr
.
dtype
.
numpy_dtype
return
load_zarr_based_sharded_metadata
(
checkpoint_dir
,
get_ts_shape_dtype
)
Megatron-LM/megatron/core/dist_checkpointing/strategies/zarr.py
0 → 100644
View file @
1106877d
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
"""Strategies using Zarr as an underlying format."""
import
logging
import
os
from
functools
import
partial
from
logging
import
getLogger
from
pathlib
import
Path
from
typing
import
Callable
,
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
torch
from
..core
import
CheckpointingException
from
..dict_utils
import
dict_list_map_inplace
,
nested_values
from
..mapping
import
ShardedStateDict
,
ShardedTensor
,
is_main_replica
from
.base
import
(
LoadShardedStrategy
,
SaveShardedStrategy
,
StrategyAction
,
register_default_strategy
,
)
logger
=
logging
.
getLogger
(
__name__
)
try
:
import
zarr
HAVE_ZARR
=
True
except
ImportError
:
from
unittest.mock
import
MagicMock
zarr
=
MagicMock
()
HAVE_ZARR
=
False
numpy_to_torch_dtype_dict
=
{
np
.
dtype
(
"bool"
):
torch
.
bool
,
np
.
dtype
(
"uint8"
):
torch
.
uint8
,
np
.
dtype
(
"int8"
):
torch
.
int8
,
np
.
dtype
(
"int16"
):
torch
.
int16
,
np
.
dtype
(
"int32"
):
torch
.
int32
,
np
.
dtype
(
"int64"
):
torch
.
int64
,
np
.
dtype
(
"float16"
):
torch
.
float16
,
np
.
dtype
(
"float32"
):
torch
.
float32
,
np
.
dtype
(
"float64"
):
torch
.
float64
,
np
.
dtype
(
"complex64"
):
torch
.
complex64
,
np
.
dtype
(
"complex128"
):
torch
.
complex128
,
}
torch_to_numpy_dtype_dict
=
{
v
:
k
for
k
,
v
in
numpy_to_torch_dtype_dict
.
items
()}
try
:
# Register a bfloat16 type with this import
import
tensorstore
# pylint: disable=unused-import
HAS_BFLOAT16
=
True
numpy_to_torch_dtype_dict
[
np
.
dtype
(
"bfloat16"
)]
=
torch
.
bfloat16
torch_to_numpy_dtype_dict
[
torch
.
bfloat16
]
=
np
.
dtype
(
"bfloat16"
)
except
ImportError
:
HAS_BFLOAT16
=
False
logger
=
getLogger
(
__name__
)
def
register_default_zarr_strategies
():
"""Register default strategies related to Zarr backend."""
register_default_strategy
(
StrategyAction
.
SAVE_SHARDED
,
"zarr"
,
1
,
ZarrSaveShardedStrategy
(
"zarr"
,
1
)
)
class
ZarrSaveShardedStrategy
(
SaveShardedStrategy
):
"""Save strategy for Zarr backend."""
def
__init__
(
self
,
backend
:
str
,
version
:
int
):
super
().
__init__
(
backend
,
version
)
logger
.
warning
(
f
"`zarr` distributed checkpoint backend is deprecated."
" Please switch to PyTorch Distributed format (`torch_dist`)."
)
def
save
(
self
,
sharded_state_dict
:
ShardedStateDict
,
checkpoint_dir
:
Union
[
str
,
Path
]):
if
isinstance
(
checkpoint_dir
,
str
):
checkpoint_dir
=
Path
(
checkpoint_dir
)
sharded_tensors
=
list
(
nested_values
(
sharded_state_dict
))
arrays
=
_create_or_open_zarr_arrays
(
sharded_tensors
,
checkpoint_dir
)
for
ten
,
arr
in
zip
(
sharded_tensors
,
arrays
):
_save_to_existing_array
(
ten
,
arr
)
torch
.
distributed
.
barrier
()
def
_create_or_open_zarr_arrays
(
sharded_tensors
:
List
[
ShardedTensor
],
checkpoint_dir
:
Path
)
->
List
[
Optional
[
zarr
.
Array
]]:
"""Returns list of zarr arrays corresponding to given tensors.
For a sharded tensors that:
a) is main replica and represents the first chunk (all offsets 0), creates the Zarr array
b) is main replica but not the first chunk,
opens the arrays created in (a) (possibly by other process)
c) otherwise, sets the corresponding array to None since it won't be used
Args:
sharded_tensors (List[ShardedTensor]): sharded tensors from a given rank
that will be saved to checkpoint
checkpoint_dir (Path): checkpoint in which the arrays will be created
"""
if
not
HAVE_ZARR
:
raise
RuntimeError
(
"zarr is required, please install it with `pip install zarr`"
)
arrays
=
[]
for
ten
in
sharded_tensors
:
arr
=
_create_zarr_array
(
ten
,
checkpoint_dir
)
if
_should_create_array
(
ten
)
else
None
arrays
.
append
(
arr
)
torch
.
distributed
.
barrier
()
# Open arrays created above by other processes
for
arr_idx
,
ten
in
enumerate
(
sharded_tensors
):
if
arrays
[
arr_idx
]
is
not
None
:
# array created by this process
assert
_should_create_array
(
ten
),
ten
continue
if
not
is_main_replica
(
ten
.
replica_id
):
# this array won't be needed for saving and can stay None
continue
open_kwargs
=
{}
if
ten
.
flattened_range
is
not
None
:
open_kwargs
[
"synchronizer"
]
=
zarr
.
ProcessSynchronizer
(
str
(
checkpoint_dir
/
f
"
{
ten
.
key
}
.sync"
)
)
arrays
[
arr_idx
]
=
_open_zarr_array_verbose
(
checkpoint_dir
/
ten
.
key
,
"r+"
,
**
open_kwargs
)
return
arrays
def
_should_create_array
(
ten
:
ShardedTensor
):
return
(
is_main_replica
(
ten
.
replica_id
)
and
set
(
ten
.
global_offset
)
==
{
0
}
and
(
ten
.
flattened_range
is
None
or
ten
.
flattened_range
.
start
==
0
)
)
def
_save_to_existing_array
(
sharded_tensor
:
ShardedTensor
,
arr
:
Optional
[
zarr
.
Array
]):
if
not
is_main_replica
(
sharded_tensor
.
replica_id
):
return
assert
arr
is
not
None
x
=
sharded_tensor
.
data
x
=
x
.
detach
().
cpu
()
torch
.
cuda
.
synchronize
()
if
x
.
dtype
==
torch
.
bfloat16
:
x
=
x
.
float
()
x
=
x
.
numpy
()
x
=
x
.
astype
(
"bfloat16"
)
else
:
x
=
x
.
numpy
()
if
sharded_tensor
.
flattened_range
is
None
:
arr
[
sharded_tensor
.
global_slice
()]
=
x
else
:
arr
.
set_coordinate_selection
(
sharded_tensor
.
global_coordinates
(),
x
)
def
_create_zarr_array
(
sharded_tensor
:
ShardedTensor
,
checkpoint_dir
:
Path
):
np_dtype
=
torch_to_numpy_dtype_dict
[
sharded_tensor
.
dtype
]
try
:
arr
=
zarr
.
create
(
sharded_tensor
.
global_shape
,
dtype
=
np_dtype
,
store
=
checkpoint_dir
/
sharded_tensor
.
key
,
chunks
=
sharded_tensor
.
max_allowed_chunks
(),
compressor
=
None
,
fill_value
=
None
,
write_empty_chunks
=
True
,
)
logger
.
debug
(
f
"Created a new Zarr array at
{
checkpoint_dir
/
sharded_tensor
.
key
}
"
)
except
zarr
.
errors
.
ContainsArrayError
as
e
:
raise
CheckpointingException
(
f
"Array
{
checkpoint_dir
/
sharded_tensor
.
key
}
already exists"
)
from
e
if
HAS_BFLOAT16
and
np_dtype
==
np
.
dtype
(
"bfloat16"
):
arr
.
_dtype
=
np_dtype
zarray
=
arr
.
store
[
".zarray"
]
arr
.
store
[
".zarray"
]
=
zarray
.
replace
(
b
"<V2"
,
b
"bfloat16"
)
return
arr
class
ZarrLoadShardedStrategy
(
LoadShardedStrategy
):
"""Load strategy for the Zarr backend."""
def
load
(
self
,
sharded_state_dict
:
ShardedStateDict
,
checkpoint_dir
:
Union
[
str
,
Path
]):
if
isinstance
(
checkpoint_dir
,
str
):
checkpoint_dir
=
Path
(
checkpoint_dir
)
dict_list_map_inplace
(
partial
(
_load_from_array
,
checkpoint_dir
=
checkpoint_dir
),
sharded_state_dict
)
return
sharded_state_dict
def
load_tensors_metadata
(
self
,
checkpoint_dir
:
Union
[
str
,
Path
]):
def
get_zarr_shape_dtype
(
path
):
arr
=
zarr
.
open
(
path
,
"r"
)
return
arr
.
shape
,
arr
.
dtype
if
isinstance
(
checkpoint_dir
,
str
):
checkpoint_dir
=
Path
(
checkpoint_dir
)
return
load_zarr_based_sharded_metadata
(
checkpoint_dir
,
get_zarr_shape_dtype
)
def
check_backend_compatibility
(
self
,
loaded_version
):
pass
# TODO
def
check_version_compatibility
(
self
,
loaded_version
):
pass
# TODO
def
_load_from_array
(
sharded_tensor
:
ShardedTensor
,
checkpoint_dir
:
Path
):
assert
isinstance
(
sharded_tensor
,
ShardedTensor
),
type
(
sharded_tensor
)
arr
=
_open_zarr_array_verbose
(
checkpoint_dir
/
sharded_tensor
.
key
,
"r"
)
if
not
sharded_tensor
.
allow_shape_mismatch
and
sharded_tensor
.
global_shape
!=
arr
.
shape
:
_msg
=
(
f
"Global shape mismatch for loaded (
{
arr
.
shape
}
)"
f
" and expected (
{
sharded_tensor
.
global_shape
}
) tensor"
f
" for key
{
sharded_tensor
.
key
}
"
)
raise
CheckpointingException
(
_msg
)
x
=
arr
[
sharded_tensor
.
global_slice
()]
# flattened tensors loading is delayed
return
postprocess_numpy_array
(
x
,
sharded_tensor
)
def
_open_zarr_array_verbose
(
path
:
Path
,
mode
:
str
,
**
open_kwargs
):
try
:
return
zarr
.
open
(
str
(
path
),
mode
,
**
open_kwargs
)
except
zarr
.
errors
.
PathNotFoundError
as
e
:
ckpt_dir
=
path
.
parent
err_msg
=
f
"Array
{
path
}
not found"
if
ckpt_dir
.
exists
():
ckpt_files
=
[
f
.
name
for
f
in
ckpt_dir
.
iterdir
()]
logger
.
debug
(
f
"
{
err_msg
}
. Checkpoint directory
{
ckpt_dir
}
content:
{
ckpt_files
}
"
)
else
:
err_msg
+=
f
". Checkpoint directory
{
ckpt_dir
}
does not exist."
raise
CheckpointingException
(
err_msg
)
from
e
def
postprocess_numpy_array
(
loaded_array
,
sharded_tensor
,
apply_flattened_range
=
True
):
"""Turn numpy array to torch tensor."""
x
=
loaded_array
if
HAS_BFLOAT16
and
x
.
dtype
==
np
.
dtype
(
"bfloat16"
):
x
=
x
.
astype
(
np
.
dtype
(
"float32"
))
x
=
torch
.
from_numpy
(
x
)
x
=
x
.
bfloat16
()
else
:
x
=
torch
.
from_numpy
(
x
)
# TODO: consider some other consistency checks
if
x
.
shape
!=
sharded_tensor
.
local_shape
:
if
sharded_tensor
.
allow_shape_mismatch
:
x
=
pad_to_expected_shape
(
x
,
sharded_tensor
)
else
:
_msg
=
(
f
"Local shape mismatch for loaded (
{
x
.
shape
}
)"
f
" and expected (
{
sharded_tensor
.
local_shape
}
) tensor"
f
" for key
{
sharded_tensor
.
key
}
"
)
raise
CheckpointingException
(
_msg
)
if
apply_flattened_range
and
sharded_tensor
.
flattened_range
is
not
None
:
x
=
flatten_range
(
sharded_tensor
,
x
)
# TODO: consider cuda() tensors support
return
x
def
flatten_range
(
sharded_tensor
,
x
):
"""Apply flattened range to a tensor."""
return
x
.
flatten
()[
sharded_tensor
.
flattened_range
]
def
pad_to_expected_shape
(
x
:
torch
.
Tensor
,
expected_sharded_ten
:
ShardedTensor
):
"""Pad tensor to the expected shape."""
pad_args
=
[]
assert
len
(
x
.
shape
)
==
len
(
expected_sharded_ten
.
local_shape
)
# Reversed iteration order because F.pad expects so
for
x_sh
,
exp_sh
,
axis_fragm
in
reversed
(
list
(
zip
(
x
.
shape
,
expected_sharded_ten
.
local_shape
,
expected_sharded_ten
.
axis_fragmentations
)
)
):
if
x_sh
==
exp_sh
:
pad_args
.
extend
((
0
,
0
))
elif
x_sh
>
exp_sh
:
assert
False
,
(
f
"Expected shape (
{
exp_sh
}
) smaller than actual (
{
x_sh
}
)"
f
" for
{
repr
(
expected_sharded_ten
)
}
"
)
else
:
pad_args
.
extend
((
0
,
exp_sh
-
x_sh
))
# TODO: behavior control with envvar is for testing purposes only, remove it
if
not
int
(
os
.
environ
.
get
(
"DIST_CKPT_PAD_REPLICATE"
,
0
)):
return
torch
.
nn
.
functional
.
pad
(
x
,
pad_args
)
# unsqueeze and squeeze to get shapes supported by cudnn
print
(
f
"Replicating last row for
{
expected_sharded_ten
.
key
}
"
)
if
x
.
dtype
==
torch
.
bfloat16
:
return
(
torch
.
nn
.
functional
.
pad
(
x
.
float
().
unsqueeze
(
0
),
pad_args
,
mode
=
"replicate"
)
.
squeeze
(
0
)
.
bfloat16
()
)
return
torch
.
nn
.
functional
.
pad
(
x
.
unsqueeze
(
0
),
pad_args
,
mode
=
"replicate"
).
squeeze
(
0
)
def
load_zarr_based_sharded_metadata
(
checkpoint_dir
:
Path
,
get_shape_dtype_fn
:
Callable
[[
str
],
Tuple
[
Tuple
[
int
],
np
.
dtype
]]
)
->
ShardedStateDict
:
"""Load metadata of Zarr arrays.
Args:
checkpoint_dir (str): checkpoint root directory
get_shape_dtype_fn (str -> ((int, ...), np.dtype)): a function returning
an array shape and dtype for a given Zarr array path
"""
sharded_state_dict
=
{}
for
subdir
in
checkpoint_dir
.
iterdir
():
if
not
subdir
.
is_dir
()
or
not
(
subdir
/
".zarray"
).
exists
():
continue
key
=
subdir
.
name
arr_shape
,
arr_dtype
=
get_shape_dtype_fn
(
str
(
subdir
))
sharded_state_dict
[
key
]
=
ShardedTensor
(
key
,
None
,
numpy_to_torch_dtype_dict
[
arr_dtype
],
arr_shape
,
arr_shape
,
tuple
(
0
for
_
in
arr_shape
),
tuple
(
1
for
_
in
arr_shape
),
)
return
sharded_state_dict
Megatron-LM/megatron/core/dist_checkpointing/tensor_aware_state_dict.py
0 → 100644
View file @
1106877d
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
"""Utilities for transforming state_dict, including a tensor-aware implementation."""
import
logging
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
Iterable
,
Iterator
,
List
,
Optional
,
Tuple
import
torch
from
.dict_utils
import
dict_list_map_inplace
,
dict_list_map_outplace
,
merge
,
nested_values
from
.exchange_utils
import
(
ShardDistribution
,
determine_main_replica_uniform_distribution
,
exchange_by_distribution
,
)
from
.mapping
import
ShardedObject
,
ShardedStateDict
,
ShardedTensor
,
StateDict
,
apply_factory_merges
from
.state_dict_utils
import
load_preprocess
,
save_preprocess
from
.utils
import
(
_sharded_object_id
,
_sharded_tensor_shard_id
,
debug_time
,
extract_sharded_base
,
zip_strict
,
)
from
.validation
import
(
StrictHandling
,
determine_global_metadata
,
parse_strict_flag
,
validate_integrity_and_strict_load
,
)
logger
=
logging
.
getLogger
(
__name__
)
try
:
from
nvidia_resiliency_ext.checkpointing.local.base_state_dict
import
TensorAwareStateDict
HAVE_NVRX
=
True
except
ImportError
:
import
types
# Create a dummy class that mimics the real one
TensorAwareStateDict
=
types
.
new_class
(
"TensorAwareStateDict"
,
())
HAVE_NVRX
=
False
@
dataclass
class
MCoreTensorAwareStateDict
(
TensorAwareStateDict
):
"""
MCore-specific class defining the interface between the MCore state dict and checkpoint manager.
This class distinguishes between raw objects, the common state dict, and sharded state dicts
(tensor parts). It also handles optional metadata needed for fully parallel save/load.
"""
common
:
StateDict
sharded_state_dict
:
ShardedStateDict
_is_hollow
:
bool
=
False
@
staticmethod
def
_validate_params
(
algo
):
if
algo
!=
"atomic"
and
algo
!=
"fully_parallel"
:
raise
NotImplementedError
(
'Only "atomic" and "fully_parallel" sharding algorithms are supported.'
)
@
staticmethod
def
_get_distribution
(
fully_parallel
,
sharded_part
,
parallelization_group
,
cached_distribution
=
None
):
if
fully_parallel
:
if
cached_distribution
is
None
:
distribution
=
determine_main_replica_uniform_distribution
(
sharded_part
,
parallelization_group
,
True
)
logger
.
debug
(
f
"MCore_TASD._get_distribution calculated distribution"
)
else
:
distribution
=
cached_distribution
logger
.
debug
(
f
"MCore_TASD._get_distribution used cache"
)
else
:
distribution
=
(
None
,
None
,
None
,
None
)
logger
.
debug
(
f
"MCore_TASD._get_distribution returned empty distribution"
)
return
distribution
@
staticmethod
def
_remove_redundant_data
(
fully_parallel
,
sharded_part
,
shard_to_saving_rank
,
parallelization_group
):
if
parallelization_group
is
None
:
parallelization_group
=
torch
.
distributed
.
group
.
WORLD
if
fully_parallel
:
for
sh_base
in
nested_values
(
sharded_part
):
# TODO remove redundant objects as well
if
isinstance
(
sh_base
,
ShardedTensor
):
shard_id
=
_sharded_tensor_shard_id
(
sh_base
)
if
shard_to_saving_rank
[
shard_id
]
!=
parallelization_group
.
rank
():
sh_base
.
data
=
None
@
classmethod
@
debug_time
(
"from_state_dict"
,
logger
)
def
from_state_dict
(
cls
,
sharded_state_dict
:
ShardedStateDict
,
algo
:
str
=
"fully_parallel"
,
parallelization_group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
cached_metadata
:
ShardDistribution
=
None
,
)
->
Tuple
[
TensorAwareStateDict
,
ShardDistribution
]:
"""
Constructs a TensorAwareStateDict from a sharded state dictionary.
This method preprocesses the input `sharded_state_dict`, validates parameters,
and extracts the necessary data to create an instance of `MCoreTensorAwareStateDict`.
Args:
sharded_state_dict: The input sharded state dictionary to be converted.
algo (str, optional): Initialization algorithm. Defaults to 'fully_parallel'.
- 'fully_parallel' enables fully parallel initialization.
parallelization_group (Optional): A distributed process group for parallelization.
cached_metadata (Optional): Precomputed metadata from previous saves.
- Reuses data that doesn't need recalculation, optimizing the creation process.
Returns:
TensorAwareStateDict: An instance initialized with the provided sharded state dictionary
and optional cached metadata.
- The metadata is stored in memory to speed up future saves.
"""
if
not
HAVE_NVRX
:
raise
ImportError
(
"nvidia_resiliency_ext is not installed. "
"Please install it with "
"`pip install nvidia-resiliency-ext`"
)
with
debug_time
(
"_get_distribution"
,
logger
):
cls
.
_validate_params
(
algo
)
fully_parallel
=
algo
==
"fully_parallel"
sharded_part
,
common_state_dict
=
save_preprocess
(
sharded_state_dict
,
cached_metadata
is
None
)
cacheable_distribution
=
cls
.
_get_distribution
(
fully_parallel
,
sharded_part
,
parallelization_group
,
cached_metadata
)
if
cacheable_distribution
is
not
None
:
shard_to_saving_rank
,
_
,
_
,
_
=
cacheable_distribution
cls
.
_remove_redundant_data
(
fully_parallel
,
sharded_part
,
shard_to_saving_rank
,
parallelization_group
)
return
(
MCoreTensorAwareStateDict
(
common
=
common_state_dict
,
sharded_state_dict
=
sharded_part
),
cacheable_distribution
,
)
@
property
def
is_hollow
(
self
):
"""
True iff tensors had been extracted and have not been inserted back yet.
"""
return
self
.
_is_hollow
@
property
def
_sharded_tensors
(
self
):
# Three possible states for sharded_tensor:
# 1. sharded_tensor with data (.data = tensor)
# 2. sharded_tensor hollow (.data = None, .orig_device = orig_device)
# 3. removed sharded_tensor (.data = None, no device information)
# TODO: Consider simplifying by removing the entire sharded_tensor instead of just the data
if
self
.
is_hollow
:
for
sh_base
in
nested_values
(
self
.
sharded_state_dict
):
# FIXME: Hacky way to store the original device of the popped tensor
if
isinstance
(
sh_base
,
ShardedTensor
)
and
hasattr
(
sh_base
,
"orig_device"
):
yield
sh_base
else
:
for
sh_base
in
nested_values
(
self
.
sharded_state_dict
):
if
isinstance
(
sh_base
,
ShardedTensor
)
and
sh_base
.
data
is
not
None
:
yield
sh_base
@
property
def
tensors
(
self
)
->
Iterator
[
torch
.
Tensor
]:
"""
Get the tensor data from the state dict.
"""
assert
not
self
.
is_hollow
# TODO raise exception
return
map
(
lambda
sh_ten
:
sh_ten
.
data
,
self
.
_sharded_tensors
)
@
property
def
common_state_dict
(
self
)
->
Dict
:
"""
Get the common state dict from the state dict.
"""
return
self
.
common
def
pop_tensors
(
self
)
->
List
[
torch
.
Tensor
]:
"""
Extracts the tensor data from the wrapped state dict, preserving metadata.
Replaces the tensor data in sharded_tensors with device type of extracted tensors.
After this operation, the state dictionary is "hollow", containing no tensor data.
Further calls to `pop_tensor` will raise an error.
@return List of extracted tensors
"""
assert
not
self
.
is_hollow
# TODO raise exception
result
=
[]
for
sh_ten
in
self
.
_sharded_tensors
:
result
.
append
(
sh_ten
.
data
)
# FIXME: Hacky way to store the original device, which is not included in the metadata
setattr
(
sh_ten
,
"orig_device"
,
sh_ten
.
data
.
device
.
type
)
sh_ten
.
data
=
None
self
.
_is_hollow
=
True
return
result
def
insert_tensors
(
self
,
tensor_data
:
Iterable
[
torch
.
Tensor
]):
"""
Reverse of `pop_tensors`. Replaces device type in sharded_tensors with actual values
Value of `self` is considered to be the same after:
```
self.insert_tensors(self.pop_tensors())
```
"""
assert
self
.
is_hollow
# TODO raise exception
for
sh_ten
,
ten
in
zip_strict
(
self
.
_sharded_tensors
,
tensor_data
):
# FIXME: Hacky way to store the original device
if
sh_ten
.
orig_device
==
ten
.
device
.
type
:
delattr
(
sh_ten
,
"orig_device"
)
# Tensor might be on non-original device
sh_ten
.
data
=
ten
self
.
_is_hollow
=
False
def
init_tensors
(
self
):
"""
Initializes empty tensors with the same properties as the original tensors.
This function should only be called after the original tensors have been popped.
It ensures that the newly created empty tensors match the shape,
dtype, and device of the originals, but contain no data.
"""
assert
self
.
is_hollow
# TODO raise exception
for
sh_ten
in
self
.
_sharded_tensors
:
# Hacky way to retrieve the original device
sh_ten
.
init_data
(
sh_ten
.
orig_device
)
delattr
(
sh_ten
,
"orig_device"
)
self
.
_is_hollow
=
False
def
copy_tensors_to_cpu
(
self
,
non_blocking
=
False
):
"""
Stores CPU copies of tensors in the state_dict, replacing the originals,
but without destroying them.
The original devices are remembered for restoration with restore_tensor_device().
Using non_blocking=True allows for asynchronous copying.
"""
assert
not
self
.
is_hollow
# TODO raise exception
for
sh_ten
in
self
.
_sharded_tensors
:
if
sh_ten
.
data
.
device
.
type
==
"cpu"
:
# Skip cloning if it's already confirmed to be a copy
if
not
hasattr
(
sh_ten
,
"orig_device"
):
sh_ten
.
data
=
sh_ten
.
data
.
clone
()
else
:
# FIXME: Hacky way to store the original device
if
not
hasattr
(
sh_ten
,
"orig_device"
):
setattr
(
sh_ten
,
"orig_device"
,
sh_ten
.
data
.
device
.
type
)
sh_ten
.
data
=
sh_ten
.
data
.
detach
().
to
(
"cpu"
,
non_blocking
=
non_blocking
)
def
restore_tensor_device
(
self
,
non_blocking
=
True
):
"""
Restores all tensors to their original devices, if a move is required.
Using non_blocking=True allows for asynchronous copying.
"""
assert
not
self
.
is_hollow
# TODO raise exception
for
sh_ten
in
self
.
_sharded_tensors
:
# FIXME: Hacky way to store the original device
if
hasattr
(
sh_ten
,
"orig_device"
):
sh_ten
.
data
=
sh_ten
.
data
.
to
(
sh_ten
.
orig_device
,
non_blocking
=
non_blocking
)
delattr
(
sh_ten
,
"orig_device"
)
def
_insert_sharded_data
(
self
,
fully_parallel
,
sharded_part
,
parallelization_group
,
exchange_algo
):
loaded_tensors
=
{}
for
sh_ten
in
self
.
_sharded_tensors
:
loaded_tensors
[
_sharded_tensor_shard_id
(
sh_ten
)]
=
sh_ten
.
data
if
fully_parallel
:
with
debug_time
(
"_get_distribution"
,
logger
):
distribution
=
self
.
_get_distribution
(
fully_parallel
,
sharded_part
,
parallelization_group
)
if
distribution
is
not
None
:
unloaded_shards
=
{}
for
sh_base
in
nested_values
(
sharded_part
):
# TODO retrieve redundant ShardedObjects once removed in _remove_redundant_data
if
isinstance
(
sh_base
,
ShardedTensor
):
shard_id
=
_sharded_tensor_shard_id
(
sh_base
)
if
shard_id
not
in
loaded_tensors
:
unloaded_shards
[
shard_id
]
=
sh_base
with
debug_time
(
"exchange_by_distribution"
,
logger
):
loaded_tensors
=
exchange_by_distribution
(
loaded_tensors
,
unloaded_shards
,
distribution
,
parallelization_group
,
exchange_algo
,
)
torch
.
cuda
.
synchronize
()
loaded_objects
=
{}
for
sh_base
in
nested_values
(
self
.
sharded_state_dict
):
if
not
isinstance
(
sh_base
,
ShardedTensor
):
assert
isinstance
(
sh_base
,
ShardedObject
)
loaded_objects
[
_sharded_object_id
(
sh_base
)]
=
sh_base
.
data
def
load_sharded_base
(
x
:
Any
):
if
isinstance
(
x
,
ShardedTensor
):
shard_id
=
_sharded_tensor_shard_id
(
x
)
assert
shard_id
in
loaded_tensors
,
(
x
,
shard_id
,
loaded_tensors
.
keys
())
x
=
loaded_tensors
[
shard_id
]
if
isinstance
(
x
,
ShardedObject
):
object_id
=
_sharded_object_id
(
x
)
assert
object_id
in
loaded_objects
,
(
x
,
object_id
,
loaded_objects
.
keys
())
x
=
loaded_objects
[
object_id
]
return
x
dict_list_map_inplace
(
load_sharded_base
,
sharded_part
)
@
debug_time
(
"to_state_dict"
,
logger
)
def
to_state_dict
(
self
,
sharded_state_dict
:
ShardedStateDict
,
algo
:
str
=
"atomic"
,
exchange_algo
:
str
=
"broadcast"
,
validate_access_integrity
:
bool
=
True
,
parallelization_group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
strict
:
StrictHandling
=
StrictHandling
.
ASSUME_OK_UNEXPECTED
,
return_mismatch_keys
:
bool
=
False
,
):
"""
Convert tensor-aware dict back to the original state_dict
"""
with
debug_time
(
"load_preprocess_and_state_dict_manipulations"
,
logger
):
assert
not
self
.
is_hollow
# TODO raise exception
self
.
_validate_params
(
algo
)
fully_parallel
=
algo
==
"fully_parallel"
# __adding__ common part
recreated_state_dict
=
dict_list_map_outplace
(
lambda
x
:
x
,
self
.
common
)
if
not
sharded_state_dict
:
return
recreated_state_dict
# TODO validate self.sharded_state_dict"] and sharded_state_dict are compatible
sharded_state_dict
,
nonpersistent_state_dict
,
sh_ten_factories
=
load_preprocess
(
sharded_state_dict
)
# __adding__ nonpersistent part
merge
(
recreated_state_dict
,
nonpersistent_state_dict
)
sharded_part
,
_
=
extract_sharded_base
(
sharded_state_dict
)
# Strictness
ckpt_sharded_metadata
=
None
local_metadata
,
global_metadata
=
None
,
None
strict
=
parse_strict_flag
(
strict
)
if
StrictHandling
.
requires_explicit_ckpt_mismatch_check
(
strict
):
ckpt_sharded_metadata
=
{
sh_base
.
key
:
sh_base
.
without_data
()
for
sh_base
in
nested_values
(
self
.
sharded_state_dict
)
}
if
validate_access_integrity
or
StrictHandling
.
requires_global_app_metadata
(
strict
):
local_metadata
,
global_metadata
=
determine_global_metadata
(
sharded_part
)
sharded_state_dict
,
missing_keys
,
unexpected_keys
=
validate_integrity_and_strict_load
(
sharded_part
,
strict
,
validate_access_integrity
,
local_metadata
,
global_metadata
,
ckpt_sharded_metadata
,
)
# load sharded tensors and sharded objects to sharded_part
with
debug_time
(
"_insert_sharded_data"
,
logger
):
self
.
_insert_sharded_data
(
fully_parallel
,
sharded_part
,
parallelization_group
,
exchange_algo
)
with
debug_time
(
"apply_factory_merges"
,
logger
):
sharded_part
=
apply_factory_merges
(
sharded_part
,
sh_ten_factories
)
# __adding__ sharded_part
merge
(
recreated_state_dict
,
sharded_part
)
if
return_mismatch_keys
:
return
recreated_state_dict
,
missing_keys
,
unexpected_keys
else
:
return
recreated_state_dict
Megatron-LM/megatron/core/dist_checkpointing/utils.py
0 → 100644
View file @
1106877d
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Helpers for manipulating sharded tensors and sharded state dicts. """
import
logging
from
contextlib
import
contextmanager
from
time
import
time
from
typing
import
Dict
,
Optional
,
Tuple
from
.dict_utils
import
dict_list_map_inplace
,
extract_matching_values
from
.mapping
import
(
LocalNonpersistentObject
,
ShardedBase
,
ShardedObject
,
ShardedStateDict
,
ShardedTensor
,
ShardedTensorFactory
,
StateDict
,
)
# _ShardId uniquely identifies a ShardedTensor. This is a subset of ShardedTensor
# attributes: key (str), global_offset (tuple) and flattened_range (optional tuple)
_ShardId
=
Tuple
[
str
,
tuple
,
Optional
[
tuple
]]
def
zip_strict
(
*
args
):
"""
Alternative to Python's builtin zip(..., strict=True) (available in 3.10+).
Apart from providing functionality in earlier versions of Python is also more verbose.
(Python's zip does not print lengths, only which iterable has finished earlier)
"""
args
=
[
list
(
a
)
for
a
in
args
]
lens
=
[
len
(
a
)
for
a
in
args
]
assert
len
(
set
(
lens
))
<=
1
,
f
"Tried to zip iterables of unequal lengths:
{
lens
}
!"
return
zip
(
*
args
)
def
_sharded_tensor_shard_id
(
sharded_tensor
:
ShardedTensor
)
->
_ShardId
:
"""Unique id of the sharded tensor data.
Should yield the same value for same data replicated on different ranks.
Args:
sharded_tensor (ShardedTensor): sharded tensor representing the data shard
Returns (tuple): unique id of a data shard
"""
f_range
=
sharded_tensor
.
flattened_range
return
(
sharded_tensor
.
key
,
sharded_tensor
.
global_offset
,
None
if
f_range
is
None
else
(
f_range
.
start
,
f_range
.
stop
),
)
def
_sharded_object_id
(
sharded_object
:
ShardedObject
)
->
_ShardId
:
"""Unique id of the sharded object data.
Should yield the same value for same data replicated on different ranks.
Args:
sharded_object (ShardedObject): sharded object representing the data shard
Returns (tuple): unique id of a data shard
"""
return
(
sharded_object
.
key
,
sharded_object
.
global_offset
,
sharded_object
.
global_shape
)
def
extract_sharded_tensors
(
sharded_state_dict
:
ShardedStateDict
,
)
->
Tuple
[
ShardedStateDict
,
StateDict
]:
"""Extract a dict consisting of only ShardedTensor objects
from a given state dict with any objects.
Args:
sharded_state_dict: state dict possibly containing ShardedTensor objects
Returns:
Tuple[ShardedStateDict, StateDict]: tuple of:
- state dict with all ShardedTensor (keeping the original state dict structure)
- state dict with all objects other than ShardedTensor
(keeping the original state dict structure)
"""
return
extract_matching_values
(
sharded_state_dict
,
lambda
v
:
isinstance
(
v
,
ShardedTensor
))
def
extract_sharded_tensors_and_factories
(
sharded_state_dict
:
ShardedStateDict
,
)
->
Tuple
[
ShardedStateDict
,
StateDict
]:
"""Extract a dict consisting of only ShardedTensor and ShardedTensorFactory objects
from a given state dict with any objects.
Args:
sharded_state_dict:
state dict possibly containing ShardedTensor and ShardedTensorFactory objects
Returns:
Tuple[ShardedStateDict, StateDict]: tuple of:
- state dict with all ShardedTensor and ShardedTensorFactory
(keeping the original state dict structure)
- state dict with all other objects (keeping the original state dict structure)
"""
return
extract_matching_values
(
sharded_state_dict
,
lambda
v
:
isinstance
(
v
,
(
ShardedTensor
,
ShardedTensorFactory
))
)
def
extract_sharded_tensors_or_nonpersistent
(
sharded_state_dict
:
ShardedStateDict
,
)
->
Tuple
[
ShardedStateDict
,
StateDict
]:
"""Extract a dict consisting of only ShardedTensor, ShardedTensorFactory
and LocalNonpersistentObject objects from a given state dict with any objects.
Args:
sharded_state_dict: state dict possibly containing ShardedTensor, ShardedTensorFactory
and LocalNonpersistentObject objects
Returns:
Tuple[ShardedStateDict, StateDict]: tuple of:
- state dict with all ShardedTensor, ShardedTensorFactory and LocalNonpersistentObject
(keeping the original state dict structure)
- state dict with all other objects (keeping the original state dict structure)
"""
return
extract_matching_values
(
sharded_state_dict
,
lambda
v
:
isinstance
(
v
,
(
ShardedTensor
,
LocalNonpersistentObject
,
ShardedTensorFactory
)),
)
def
extract_sharded_base
(
sharded_state_dict
:
ShardedStateDict
,
)
->
Tuple
[
ShardedStateDict
,
StateDict
]:
"""Extract a dict consisting of only ShardedBase from a given state dict with any objects.
Args:
sharded_state_dict: state dict possibly containing ShardedBase objects
Returns:
Tuple[ShardedStateDict, StateDict]: tuple of:
- state dict with all ShardedBase objects (keeping the original state dict structure)
- state dict with all other objects (keeping the original state dict structure)
"""
return
extract_matching_values
(
sharded_state_dict
,
lambda
v
:
isinstance
(
v
,
ShardedBase
))
def
extract_nonpersistent
(
sharded_state_dict
:
ShardedStateDict
,
)
->
Tuple
[
ShardedStateDict
,
StateDict
]:
"""Extract a dict consisting of only LocalNonpersistentObjects from a given state dict.
Args:
sharded_state_dict: state dict possibly containing LocalNonpersistentObjects
Returns:
Tuple[ShardedStateDict, StateDict]: tuple of:
- state dict with all LocalNonpersistentObjects
(keeping the original state dict structure)
- state dict with all other objects (keeping the original state dict structure)
"""
return
extract_matching_values
(
sharded_state_dict
,
lambda
v
:
isinstance
(
v
,
LocalNonpersistentObject
)
)
def
add_prefix_for_sharding
(
sharded_state_dict
:
ShardedStateDict
,
prefix
:
str
):
"""Prepend a given prefix to all ShardedBase objects in a given state dict *in-place*.
Args:
sharded_state_dict (ShardedStateDict): sharded state dict
prefix (str): prefix to be prepended
Returns:
None: state dict is modified in-place
"""
def
add_prefix
(
t
):
if
isinstance
(
t
,
ShardedBase
):
t
.
key
=
f
'
{
prefix
}{
t
.
key
}
'
return
t
dict_list_map_inplace
(
add_prefix
,
sharded_state_dict
)
def
replace_prefix_for_sharding
(
sharded_state_dict
:
ShardedStateDict
,
old_prefix
:
str
,
new_prefix
:
str
):
"""Replaces the given prefix in *all* sharded keys in a given state dict.
Errors out if some key does not begin with a given prefix.
Args:
sharded_state_dict (ShardedStateDict): sharded state dict to replace keys in
old_prefix (str): prefix to be replaced in each key
new_prefix (str): new prefix
Returns:
None: state dict is modified in place
"""
def
_replace_prefix
(
x
):
if
isinstance
(
x
,
(
ShardedTensor
,
ShardedTensorFactory
,
ShardedObject
)):
if
not
x
.
key
.
startswith
(
old_prefix
):
raise
ValueError
(
f
'Expected
{
x
.
key
}
to begin with prefix
{
old_prefix
}
'
)
x
.
key
=
f
'
{
new_prefix
}{
x
.
key
[
len
(
old_prefix
):]
}
'
# str.removeprefix in Python >= 3.9
return
x
dict_list_map_inplace
(
_replace_prefix
,
sharded_state_dict
)
def
apply_prefix_mapping
(
sharded_state_dict
:
ShardedStateDict
,
prefix_map
:
Dict
[
str
,
str
]):
"""Replaces prefixes *only in keys matching* with one of prefixes in the map.
Args:
sharded_state_dict (ShardedStateDict): sharded state dict to replace keys in
prefix_map (Dict[str, str]):
map of old->new prefixes. The first matching prefix for each key is used
Returns:
None: state dict is modified in place
"""
def
_replace_prefixes
(
x
):
if
not
isinstance
(
x
,
(
ShardedTensor
,
ShardedTensorFactory
,
ShardedObject
)):
return
x
for
old_prefix
,
new_prefix
in
prefix_map
.
items
():
if
x
.
key
.
startswith
(
old_prefix
):
x
.
key
=
(
f
'
{
new_prefix
}{
x
.
key
[
len
(
old_prefix
):]
}
'
# str.removeprefix in Python >= 3.9
)
break
return
x
dict_list_map_inplace
(
_replace_prefixes
,
sharded_state_dict
)
fallback_logger
=
logging
.
getLogger
(
__name__
)
__LOGGER_NAME_STACK
=
[]
__LOGGER_STACK
=
[]
@
contextmanager
def
logger_stack
(
name
:
Optional
[
str
]
=
None
,
current_logger
:
Optional
[
logging
.
Logger
]
=
None
):
"""Context manager for managing logger and name stack.
Temporarily pushes a logger and/or name onto their respective stacks, allowing hierarchical
logging and contextual logger usage. Ensures the logger stack is restored afterward.
Args:
name (str, optional): Name to add to the logger stack. Defaults to None.
current_logger (logging.Logger, optional): Logger to use. Defaults to the last logger in
the stack or a fallback if none exist.
Yields:
Tuple[str, logging.Logger]: A tuple with the concatenated logger name stack and
the current logger for the block.
Example:
with logger_stack("scope", logger):
logger.info("Log within 'scope'")
"""
if
name
:
__LOGGER_NAME_STACK
.
append
(
name
)
if
current_logger
:
__LOGGER_STACK
.
append
(
current_logger
)
last_logger
=
current_logger
elif
__LOGGER_STACK
:
last_logger
=
__LOGGER_STACK
[
-
1
]
else
:
last_logger
=
fallback_logger
try
:
yield
"."
.
join
(
__LOGGER_NAME_STACK
),
last_logger
finally
:
if
name
and
__LOGGER_NAME_STACK
:
__LOGGER_NAME_STACK
.
pop
(
-
1
)
if
current_logger
and
__LOGGER_STACK
:
__LOGGER_STACK
.
pop
(
-
1
)
@
contextmanager
def
debug_time
(
name
:
str
,
logger
:
Optional
[
logging
.
Logger
]
=
None
,
threshold
:
float
=
float
(
"-inf"
),
level
=
None
):
"""Simple context manager for timing functions/code blocks.
Args:
name (str): Label describing the code being measured.
logger (logging.Logger, optional): Logger for output. Defaults to the lowest logger.
threshold (float, optional): Minimum time (seconds) to log. Skips logging if faster.
level (int, optional): Logging level. Defaults to DEBUG if `threshold` is unset;
WARNING otherwise.
"""
with
logger_stack
(
name
,
logger
)
as
(
stacked_name
,
last_logger
):
start
=
time
()
try
:
yield
finally
:
result
=
time
()
-
start
if
result
<
threshold
:
return
if
level
is
None
:
level
=
logging
.
DEBUG
if
threshold
==
float
(
"-inf"
)
else
logging
.
WARNING
last_logger
.
log
(
level
,
f
"
{
stacked_name
}
took
{
result
:.
4
f
}
s"
)
def
debug_msg
(
msg
:
str
):
"""Logs a debug message using the current logger stack.
This function formats and logs a debug message with the current logger
and name stack, preserving context from the logger_stack context manager.
Args:
msg (str): The message to be logged at the debug level.
Example:
debug_msg("Checkpoint initialized")
# Logs: "scope_name Checkpoint initialized" if called within logger_stack("scope_name")
"""
with
logger_stack
(
None
,
None
)
as
(
stacked_name
,
last_logger
):
last_logger
.
debug
(
f
"
{
stacked_name
}
{
msg
}
"
)
Megatron-LM/megatron/core/dist_checkpointing/validation.py
0 → 100644
View file @
1106877d
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import
logging
import
os
from
collections
import
Counter
,
defaultdict
from
enum
import
Enum
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Set
,
Tuple
,
Union
import
numpy
as
np
import
torch
from
megatron.core.dist_checkpointing
import
ShardedTensor
from
megatron.core.dist_checkpointing.core
import
CheckpointingException
,
maybe_load_config
from
megatron.core.dist_checkpointing.dict_utils
import
(
diff
,
extract_matching_values
,
map_reduce
,
nested_values
,
)
from
megatron.core.dist_checkpointing.mapping
import
(
CommonStateDict
,
ShardedBase
,
ShardedObject
,
ShardedStateDict
,
is_main_replica
,
)
from
megatron.core.dist_checkpointing.strategies.base
import
(
LoadCommonStrategy
,
LoadShardedStrategy
,
SaveCommonStrategy
,
SaveShardedStrategy
,
StrategyAction
,
get_default_strategy
,
)
from
megatron.core.msc_utils
import
MultiStorageClientFeature
if
TYPE_CHECKING
:
from
megatron.core.dist_checkpointing.serialization
import
CkptShardedMetadata
logger
=
logging
.
getLogger
(
__name__
)
# pylint: disable=line-too-long
# list of local saved/loaded ShardedBase objects
_LocalMetadata
=
List
[
Union
[
ShardedTensor
,
ShardedObject
]]
# list of lists of global saved/loaded ShardedBase objects (each element corresponds to global rank)
_GlobalMetadata
=
List
[
_LocalMetadata
]
class
StrictHandling
(
Enum
):
"""Determines handling of load mismatch (non-empty "unexpected" or "missing" keys).
Different flags carry different implications on performance and behaviour and
are divided into two groups:
- *_UNEXPECTED
- *_ALL
The first group ignores missing keys (present in the checkpoint but missing
in the sharded state dict) which is created in order to avoid inter-rank
metadata exchange. Note that the metadata exchange will happen anyway
with `load(..., validate_access_integrity=True)` flag in which case using the
`*_ALL` option is recommended as it provides a more thorough check with no
performance penalty wrt. `*_UNEXPECTED` group.
All options except for the first one (`ASSUME_OK_UNEXPECTED`) require
extra disk access before the load in order to remove unexpected keys
from the sharded state dict requested to load.
"""
# Relies on the underlying strategy to raise error on unexpected keys
ASSUME_OK_UNEXPECTED
=
"assume_ok_unexpected"
# Logs (with WARNING level) "unexpected" keys. Missing keys are ignored.
# This is treated as a reasonable default for a "non-strict" load
LOG_UNEXPECTED
=
"log_unexpected"
# Logs (with WARNING level) all mismatched keys.
LOG_ALL
=
"log_all"
# Raise error on unexpected keys before load attempt.
# Gives cleaner error message than `ASSUME_OK_UNEXPECTED` but requires
# extra disk access.
RAISE_UNEXPECTED
=
"raise_unexpected"
# Raise error on any mismatch. Similar to `RAISE_UNEXPECTED` but requires
# metadata exchange.
RAISE_ALL
=
"raise_all"
# "Unexpected" mismatches are not reported, but returned by the `load`
# function along with the loaded state dict. Missing keys are ignored.
RETURN_UNEXPECTED
=
"return_unexpected"
# All mismatches are returned along with the loaded state dict.
RETURN_ALL
=
"return_all"
# Simply ignores mismatches (not recommended)
IGNORE_ALL
=
"ignore_all"
@
staticmethod
def
requires_explicit_ckpt_mismatch_check
(
val
:
"StrictHandling"
)
->
bool
:
"""Whether a given strict flag involves mismatch check against the checkpoint."""
return
val
!=
StrictHandling
.
ASSUME_OK_UNEXPECTED
@
staticmethod
def
requires_global_app_metadata
(
val
:
"StrictHandling"
)
->
bool
:
"""Whether a given strict option requires global metadata for validation."""
return
val
in
(
StrictHandling
.
IGNORE_ALL
,
StrictHandling
.
RAISE_ALL
,
StrictHandling
.
RETURN_ALL
,
StrictHandling
.
LOG_ALL
,
)
@
staticmethod
def
requires_returning_mismatch_keys
(
val
:
"StrictHandling"
)
->
bool
:
"""Whether a given strict option results in extra return value from the `load` function."""
return
val
in
(
StrictHandling
.
RETURN_UNEXPECTED
,
StrictHandling
.
RETURN_ALL
)
def
parse_strict_flag
(
strict
:
Union
[
str
,
StrictHandling
])
->
StrictHandling
:
"""Parse user passed strict flag from a string to StrictHandling instance.
Args:
strict (str, StrictHandling): strict flag to parse. If already an instance
of StrictHandling, this function is a noop.
Returns:
StrictHandling: enum instance
"""
if
isinstance
(
strict
,
StrictHandling
):
return
strict
try
:
return
StrictHandling
(
strict
)
except
(
ValueError
,
TypeError
)
as
e
:
raise
ValueError
(
f
"Invalid strict flag:
{
e
}
"
)
from
e
def
validate_integrity_and_strict_load
(
sharded_state_dict
:
ShardedStateDict
,
strict
:
StrictHandling
,
validate_access_integrity
:
bool
,
local_metadata
:
Optional
[
_LocalMetadata
]
=
None
,
global_metadata
:
Optional
[
_GlobalMetadata
]
=
None
,
ckpt_sharded_metadata
:
Optional
[
"CkptShardedMetadata"
]
=
None
,
)
->
Tuple
[
ShardedStateDict
,
Set
[
str
],
Set
[
str
]]:
"""Validates sharding integrity and potential mismatches with the checkpoint.
`validate_access_integrity` controls sharding integrity check (orthogonal
to strictness checking) which verifies `sharded_state_dict` runtime completeness
(in isolation from the actual checkpoint).
`strict` flag controls handling of mismatches between the requested
sharded state dict to load and the actual checkpoint. See `StrictHandling`
docs for details regarding flag behavior and performance implications
(disk interactions or inter-rank communication).
Args:
sharded_state_dict (ShardedStateDict): sharded state dict to verify.
strict (StrictHandling): flag determining how to handle sharded keys mismatch.
validate_access_integrity (bool): whether to perform sharding validation.
local_metadata (_LocalMetadata, optional): local sharded state dict metadata.
Defaults to None, in which case it's determined based on `sharded_state_dict`.
global_metadata (_GlobalMetadata, optional): global sharded state dict metadata
(exchanged between ranks). Defaults to None, in which case "missing"
keys are not determined.
ckpt_sharded_metadata (CkptShardedMetadata, optional): sharded metadata
from the checkpoint. Defaults to None, which only makes sense
for the `StrictHandling.ASSUME_OK_UNEXPECTED` strict value.
Returns:
Tuple[ShardedStateDict, Set[str], Set[str]]: tuple of: sharded state dict
without unexpected keys, missing and unexpected keys. Missing keys are equal
on all ranks, unexpected keys might differ across ranks. Additionally,
missing keys might be erroneously empty (depending on `strict` value).
"""
missing_keys
,
unexpected_keys
=
set
(),
set
()
if
StrictHandling
.
requires_explicit_ckpt_mismatch_check
(
strict
):
if
ckpt_sharded_metadata
is
None
:
raise
CheckpointingException
(
"Cannot verify checkpoint mismatch with ckpt_sharded_metadata=None."
)
if
local_metadata
is
None
:
local_metadata
=
[
sh_base
.
without_data
()
for
sh_base
in
nested_values
(
sharded_state_dict
)
]
# We don't want to check for missing keys even if we could
_skip_missing_keys
=
strict
in
(
StrictHandling
.
ASSUME_OK_UNEXPECTED
,
StrictHandling
.
LOG_UNEXPECTED
,
StrictHandling
.
RAISE_UNEXPECTED
,
StrictHandling
.
RETURN_UNEXPECTED
,
)
missing_keys
,
unexpected_keys
=
_determine_missing_and_unexpected_keys
(
ckpt_sharded_metadata
,
local_metadata
,
None
if
_skip_missing_keys
else
global_metadata
)
sharded_state_dict
=
adjust_non_strict_load
(
sharded_state_dict
,
unexpected_keys
)
if
strict
==
StrictHandling
.
IGNORE_ALL
:
missing_keys
,
unexpected_keys
=
set
(),
set
()
elif
strict
in
(
StrictHandling
.
RAISE_UNEXPECTED
,
StrictHandling
.
RAISE_ALL
):
maybe_report_missing_and_unexpected_keys
(
missing_keys
,
unexpected_keys
,
True
)
elif
strict
in
(
StrictHandling
.
LOG_UNEXPECTED
,
StrictHandling
.
LOG_ALL
):
maybe_report_missing_and_unexpected_keys
(
missing_keys
,
unexpected_keys
,
False
)
if
validate_access_integrity
:
if
global_metadata
is
None
:
raise
CheckpointingException
(
"Cannot check sharding intergrity without global_metadata (None)."
)
validate_sharding_integrity
(
global_metadata
)
return
sharded_state_dict
,
missing_keys
,
unexpected_keys
def
verify_checkpoint_and_load_strategy
(
checkpoint_dir
:
str
,
sharded_strategy
:
Union
[
LoadShardedStrategy
,
Tuple
[
str
,
int
],
None
]
=
None
,
common_strategy
:
Union
[
LoadCommonStrategy
,
Tuple
[
str
,
int
],
None
]
=
None
,
)
->
Tuple
[
LoadShardedStrategy
,
LoadCommonStrategy
]:
"""Verifies if checkpoint metadata exists and matches given strategies.
If no strategies are passed, they are determined based on the checkpoint metadata.
Args:
checkpoint_dir (str): checkpoint directory
sharded_strategy (LoadShardedStrategy, Tuple[str, int], optional): sharded load strategy to be verified
if compatible with the checkpoint content. If None, the default sharded load strategy
for the checkpoint backend will be returned.
common_strategy (LoadCommonStrategy, Tuple[str, int], optional): common load strategy to be verified
if compatible with the checkpoint content. If None, the default common load strategy
for the checkpoint backend will be returned.
"""
isdir
=
True
if
MultiStorageClientFeature
.
is_enabled
():
msc
=
MultiStorageClientFeature
.
import_package
()
isdir
=
msc
.
os
.
path
.
isdir
(
str
(
checkpoint_dir
),
strict
=
False
)
else
:
isdir
=
os
.
path
.
isdir
(
checkpoint_dir
)
if
not
isdir
:
raise
CheckpointingException
(
f
"Checkpoint directory
{
checkpoint_dir
}
does not exist"
)
saved_config
=
maybe_load_config
(
checkpoint_dir
)
if
saved_config
is
None
:
raise
CheckpointingException
(
f
"
{
checkpoint_dir
}
is not a distributed checkpoint"
)
if
sharded_strategy
is
None
:
sharded_strategy
=
get_default_strategy
(
StrategyAction
.
LOAD_SHARDED
,
saved_config
.
sharded_backend
,
saved_config
.
sharded_backend_version
,
)
elif
isinstance
(
sharded_strategy
,
tuple
):
sharded_strategy
=
get_default_strategy
(
StrategyAction
.
LOAD_SHARDED
,
*
sharded_strategy
)
if
common_strategy
is
None
:
common_strategy
=
get_default_strategy
(
StrategyAction
.
LOAD_COMMON
,
saved_config
.
common_backend
,
saved_config
.
common_backend_version
,
)
elif
isinstance
(
common_strategy
,
tuple
):
sharded_strategy
=
get_default_strategy
(
StrategyAction
.
LOAD_COMMON
,
*
common_strategy
)
sharded_strategy
.
check_backend_compatibility
(
saved_config
.
sharded_backend
)
sharded_strategy
.
check_version_compatibility
(
saved_config
.
sharded_backend_version
)
common_strategy
.
check_backend_compatibility
(
saved_config
.
common_backend
)
common_strategy
.
check_version_compatibility
(
saved_config
.
common_backend_version
)
return
sharded_strategy
,
common_strategy
def
adjust_non_strict_load
(
sharded_state_dict
:
ShardedStateDict
,
sharded_keys_to_remove
:
Set
[
str
]
)
->
ShardedStateDict
:
"""Adjusts sharded state dict removing keys not existing in the checkpoint.
Args:
sharded_state_dict (ShardedStateDict): sharded state dict to modify
sharded_keys_to_remove (Set[str]): keys to remove from the state dict
Returns:
ShardedStateDict: state dict without ShardedBase objects with specified keys
"""
def
is_unexpected_key
(
x
:
ShardedBase
):
assert
isinstance
(
x
,
ShardedBase
),
f
"Unexpected type
{
type
(
x
)
}
"
return
x
.
key
in
sharded_keys_to_remove
_
,
sharded_state_dict
=
extract_matching_values
(
sharded_state_dict
,
is_unexpected_key
)
return
sharded_state_dict
def
_determine_missing_and_unexpected_keys
(
ckpt_sharded_metadata
:
"CkptShardedMetadata"
,
local_metadata
:
_LocalMetadata
,
global_metadata
:
Optional
[
_GlobalMetadata
]
=
None
,
)
->
Tuple
[
Set
[
str
],
Set
[
str
]]:
"""Determines load mismatches based on metadata.
There is an asymmetry between "unexpected" and "missing" keys.
Unexpected keys can be determined based only on local metadata.
Missing keys must be based on global metadata, since other ranks might access
different keys than the current rank.
In consequence, the return value of this function is different on each rank:
"missing_keys" are equal, but "unexpected_keys" might differ across ranks.
Args:
ckpt_sharded_metadata (CkptShardedMetadata): sharded state dict (without data)
constructed based on the checkpoint content
local_metadata (_LocalMetadata): list of local ShardedBase objects
requested to be loaded by this rank
global_metadata (_GlobalMetadata, optional): list of global ShardedBase objects
requested to be loaded by all ranks. Defaults to None, in which case
returned "missing" keys are empty.
Returns:
Tuple[Set[str], Set[str]]: missing and unexpected keys. Missing keys are equal
on all ranks, unexpected keys might differ across ranks. If passed
`global_metadata` is empty, returned missing keys are empty as well.
"""
local_accessed_keys
=
set
(
sh_base
.
key
for
sh_base
in
local_metadata
)
ckpt_keys
=
set
(
sh_base
.
key
for
sh_base
in
ckpt_sharded_metadata
.
values
())
unexpected_keys
=
local_accessed_keys
-
ckpt_keys
if
global_metadata
is
not
None
:
global_accessed_keys
=
set
(
sh_base
.
key
for
rank_metadata
in
global_metadata
for
sh_base
in
rank_metadata
)
missing_keys
=
ckpt_keys
-
global_accessed_keys
else
:
missing_keys
=
set
()
if
missing_keys
:
logger
.
debug
(
f
"Dist ckpt load missing keys:
{
missing_keys
}
"
)
if
unexpected_keys
:
logger
.
debug
(
f
"Dist ckpt load unexpected keys:
{
unexpected_keys
}
"
)
return
missing_keys
,
unexpected_keys
def
maybe_report_missing_and_unexpected_keys
(
missing_keys
:
Set
[
str
],
unexpected_keys
:
Set
[
str
],
raise_error
:
bool
=
True
)
->
None
:
"""Raises or logs an error in case missing or unexpected keys are non-empty.
Args:
missing_keys (Set[str]): missing keys in the state dict
unexpected_keys (Set[str]): unexpected keys in the state dict
raise_error: If True, raises error on mismatch. Otherwise, logs mismatch
with WARNING level.
Returns:
None
Raises:
CheckpointingException: if `raise_error` is True and at least one of
`missing_keys` or `unexpected_keys` are non-empty.
"""
if
not
missing_keys
and
not
unexpected_keys
:
return
missing_title_msg
=
(
f
"Some keys found in the checkpoint are missing in the provided sharded state dict. "
)
missing_body_msg
=
f
"Missing keys (for all ranks):
{
missing_keys
}
. "
unexpected_title_msg
=
f
"Unexpected keys (not found in the checkpoint) encountered in the provided sharded state dict. "
unexpected_body_msg
=
f
"Unexpected keys (for this rank):
{
unexpected_keys
}
. "
error_msg
=
""
if
missing_keys
:
error_msg
+=
missing_title_msg
if
unexpected_keys
:
error_msg
+=
unexpected_title_msg
error_msg
+=
"
\n
"
if
missing_keys
:
error_msg
+=
missing_body_msg
if
unexpected_keys
:
error_msg
+=
unexpected_body_msg
if
raise_error
:
raise
CheckpointingException
(
error_msg
)
else
:
logger
.
warning
(
error_msg
)
def
_validate_common_state_dict
(
common_state_dict
:
CommonStateDict
)
->
None
:
"""Validate consistancy across ranks for the common state dict
We save the common state dict only on rank 0. We validate to make sure that the common dict is consistant across ranks before saving.
Args:
common_state_dict: The common state dict present in all ransk
"""
# Gather the common state dict across ranks onto rank 0 for comparison
rank
=
torch
.
distributed
.
get_rank
()
other_rank_state_dicts
=
[
None
]
*
torch
.
distributed
.
get_world_size
()
if
rank
==
0
else
None
torch
.
distributed
.
gather_object
(
common_state_dict
,
other_rank_state_dicts
)
common_state_dict_diff
=
{}
if
rank
==
0
:
assert
other_rank_state_dicts
main_rank_state_dict
=
common_state_dict
for
rank
,
rank_state_dict
in
enumerate
(
other_rank_state_dicts
[
1
:],
1
):
only_left
,
only_right
,
mismatch
=
diff
(
main_rank_state_dict
,
rank_state_dict
)
if
only_left
or
only_right
or
mismatch
:
common_state_dict_diff
[
rank
]
=
(
only_left
,
only_right
,
mismatch
)
if
len
(
common_state_dict_diff
)
!=
0
:
logger
.
warning
(
f
"There is difference in the common state dict in different ranks. The differences are
{
common_state_dict_diff
}
"
)
def
validate_sharding_integrity
(
global_metadata
:
_GlobalMetadata
,
common_state_dict
:
CommonStateDict
=
None
)
->
None
:
"""Validate if the ShardedTensors and ShardedObjects from multiple processes define correct sharding.
Local ShardedTensors and ShardedObject metadata is exchanged with `torch.distributed.all_gather_object`
and then process with global rank 0 checks if main replicas of the shards:
- cover the whole global tensors
- don't overlap
Args:
global_metadata (_GlobalMetadata): ShardedTensor and ShardedObject objects from all ranks.
common_state_dict (CommonStateDict): The common state dict stored by rank 0
Returns:
None
Raises:
CheckpointingException for invalid access pattern
"""
if
common_state_dict
is
not
None
:
_validate_common_state_dict
(
common_state_dict
)
if
torch
.
distributed
.
get_rank
()
!=
0
:
return
key_shardings
=
defaultdict
(
list
)
for
rank
,
rank_shardings
in
enumerate
(
global_metadata
):
for
sharding
in
rank_shardings
:
key_shardings
[
sharding
.
key
].
append
((
rank
,
sharding
))
for
key
,
shardings
in
key_shardings
.
items
():
if
isinstance
(
shardings
[
0
][
1
],
ShardedObject
):
_validate_objects_for_key
(
shardings
)
else
:
_validate_sharding_for_key
(
shardings
)
def
_validate_sharding_for_key
(
rank_sharding
:
List
[
Tuple
[
int
,
ShardedTensor
]]):
some_rank_shard
=
rank_sharding
[
0
][
1
]
global_shape
=
some_rank_shard
.
global_shape
local_shape
=
some_rank_shard
.
local_shape
dtype
=
some_rank_shard
.
dtype
has_flattened_range
=
some_rank_shard
.
flattened_range
is
not
None
for
rank
,
sharding
in
rank_sharding
:
assert
sharding
.
dtype
==
dtype
,
(
sharding
.
dtype
,
dtype
,
some_rank_shard
)
assert
sharding
.
global_shape
==
global_shape
,
(
sharding
.
global_shape
,
global_shape
,
some_rank_shard
,
)
assert
sharding
.
local_shape
==
local_shape
,
(
sharding
.
local_shape
,
local_shape
,
some_rank_shard
,
)
assert
(
sharding
.
flattened_range
is
not
None
)
==
has_flattened_range
,
(
(
sharding
.
flattened_range
is
not
None
),
has_flattened_range
,
some_rank_shard
,
)
shard_access_cnt
=
_compute_shards_access
(
rank_sharding
)
if
has_flattened_range
:
map_reduce
(
rank_sharding
,
lambda
x
:
x
[
1
].
global_offset
,
lambda
x
:
x
[
1
],
_validate_sharding_for_key_flattened
,
)
# For each shard with at least 1 flattened tensor in it, the above
# `_validate_sharding_for_key_flattened` ensure a correct consistent pattern
# The only thing that can go wrong at this point is that some shard don't have
# *any* representatives which will be checked later by comparing `shard_access_cnt == 1`
shard_access_cnt
=
torch
.
minimum
(
shard_access_cnt
,
torch
.
tensor
([
1
]))
if
not
torch
.
all
(
shard_access_cnt
==
1
):
raise
CheckpointingException
(
f
"Invalid access pattern for
{
rank_sharding
[
0
][
1
]
}
:
{
shard_access_cnt
}
"
)
def
_compute_shards_access
(
rank_sharding
):
shard_access_cnt
=
torch
.
zeros
(
rank_sharding
[
0
][
1
].
axis_fragmentations
,
dtype
=
torch
.
int
,
device
=
"cpu"
)
for
rank
,
sharding
in
rank_sharding
:
if
is_main_replica
(
sharding
.
replica_id
):
shard_access_cnt
[
sharding
.
local_chunk_offset_in_global
()]
+=
1
return
shard_access_cnt
def
_validate_sharding_for_key_flattened
(
tensors_by_shard
):
all_slices
=
[]
local_shape
=
tensors_by_shard
[
0
].
local_shape
for
sharding
in
tensors_by_shard
:
assert
sharding
.
local_shape
==
local_shape
sharding
:
ShardedTensor
if
not
is_main_replica
(
sharding
.
replica_id
):
continue
all_slices
.
append
((
sharding
.
flattened_range
.
start
,
sharding
.
flattened_range
.
stop
))
starts
,
stops
=
map
(
np
.
asarray
,
zip
(
*
sorted
(
all_slices
)))
expected_size
=
np
.
product
(
local_shape
)
if
starts
[
0
]
!=
0
or
stops
[
-
1
]
!=
expected_size
or
not
np
.
all
(
starts
[
1
:]
==
stops
[:
-
1
]):
raise
CheckpointingException
(
f
"Flattened ranges dont cover the whole shard
{
tensors_by_shard
[
0
]
}
of size
{
expected_size
}
. Ranges:
{
(
starts
,
stops
)
}
"
)
def
_validate_objects_for_key
(
sharded_objects
:
List
[
ShardedObject
]):
"""Ensure uniqueness of saved objects."""
unique_keys
=
[
sh_obj
.
unique_key
for
_
,
sh_obj
in
sharded_objects
if
is_main_replica
(
sh_obj
.
replica_id
)
]
if
len
(
unique_keys
)
!=
len
(
set
(
unique_keys
)):
duplicates
=
{
k
:
cnt
for
k
,
cnt
in
Counter
(
unique_keys
).
items
()
if
cnt
>
1
}
logger
.
error
(
f
"Duplicate ShardedObject keys and counts:
{
duplicates
}
"
)
raise
CheckpointingException
(
f
"Duplicate ShardedObject keys:
{
list
(
duplicates
.
keys
())
}
"
)
expected_shard_num
=
np
.
prod
(
sharded_objects
[
0
][
1
].
global_shape
)
if
len
(
unique_keys
)
!=
expected_shard_num
:
err_msg
=
f
"Invalid access pattern:
{
expected_shard_num
-
len
(
unique_keys
)
}
ShardedObject are missing."
logger
.
error
(
f
"
{
err_msg
}
Existing shards:
{
unique_keys
}
"
)
raise
CheckpointingException
(
err_msg
)
def
determine_global_metadata
(
sharded_state_dict
:
ShardedStateDict
,
)
->
Tuple
[
_LocalMetadata
,
_GlobalMetadata
]:
"""Exchanges local metadata with `all_gather_object` to determine global metadata.
Args:
sharded_state_dict (ShardedStateDict): local sharded state dict
Returns:
Tuple[_LocalMetadata, _GlobalMetadata]: local and global ShardedBase objects with stripped data
"""
local_metadata
=
[
ten
.
without_data
()
for
ten
in
nested_values
(
sharded_state_dict
)]
global_metadata
=
[
None
]
*
torch
.
distributed
.
get_world_size
()
torch
.
distributed
.
all_gather_object
(
global_metadata
,
local_metadata
)
return
local_metadata
,
global_metadata
# type: ignore[return-value]
def
validate_sharded_objects_handling
(
sharded_strategy
:
Union
[
SaveShardedStrategy
,
LoadShardedStrategy
],
common_strategy
:
Union
[
SaveCommonStrategy
,
LoadCommonStrategy
],
)
->
None
:
"""Checks if either of the passed strategies can handle sharded objects.
Args:
sharded_strategy (Union[SaveShardedStrategy, LoadShardedStrategy]): sharded strategy used for saving/loading
common_strategy (Union[SaveCommonStrategy, LoadCommonStrategy]): common strategy used for saving/loading
Returns:
None
Raises:
CheckpointingException: if both strategies can't handle ShardedObjects
"""
if
(
not
sharded_strategy
.
can_handle_sharded_objects
and
not
common_strategy
.
can_handle_sharded_objects
):
raise
CheckpointingException
(
f
"Either sharded strategy or common strategy must implement ShardedObjects handling."
f
" Both
{
sharded_strategy
}
and
{
common_strategy
}
specify can_handle_sharded_objects=False"
)
Prev
1
…
12
13
14
15
16
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment