Commit 7c19b3a8 authored by wangsen's avatar wangsen
Browse files

Initial commit

parents
Pipeline #1721 failed with stages
in 0 seconds
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Module for managing distributed checkpoints metadata. """
import json
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Optional
CONFIG_FNAME = 'metadata.json'
class CheckpointingException(Exception):
""" Base checkpointing related exception """
pass
@dataclass
class CheckpointingConfig:
""" Documents backends used in the checkpoint.
Checkpoint config keeps track of formats used for storing the sharded tensors
(sharded_backend) and other objects (common_backend).
Note that versioning is not for the checkpoint content (which is application specific),
but for the checkpoint format itself.
"""
sharded_backend: str
sharded_backend_version: int = 1
common_backend: str = 'torch'
common_backend_version: int = 1
def check_is_distributed_checkpoint(checkpoint_dir):
""" Checks if `metadata.json` exists in the checkpoint and is a valid config.
Args:
checkpoint_dir: checkpoint directory
Returns:
bool: True if `metadata.json` exists in the checkpoint and is a valid config.
"""
return maybe_load_config(checkpoint_dir) is not None
def maybe_load_config(checkpoint_dir: str) -> Optional[CheckpointingConfig]:
""" Returns checkpoint config if `checkpoint_dir` is a distributed checkpoint and None otherwise
Args:
checkpoint_dir: checkpoint directory
Returns:
CheckpointingConfig (optional): None if checkpoint is not a valid distributed checkpoint
"""
config_path = Path(checkpoint_dir, CONFIG_FNAME)
if not config_path.exists():
return None
with config_path.open() as f:
config_dict = json.load(f)
return CheckpointingConfig(**config_dict)
def save_config(config: CheckpointingConfig, checkpoint_dir: str):
""" Save given config to checkpoint directory.
Args:
config: checkpoint config
checkpoint_dir: checkpoint directory
Returns:
None
"""
config_path = Path(checkpoint_dir, CONFIG_FNAME)
with config_path.open('w') as f:
json.dump(asdict(config), f)
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Utilities for operating with dicts and lists.
All functions in this module handle nesting of dicts and lists.
Other objects (e.g. tuples) are treated as atomic leaf types that cannot be traversed.
"""
from collections import defaultdict
from typing import Any, Callable, Iterable, Optional, Tuple, Union
import torch
def extract_matching_values(
x: Union[dict, list], predicate: Callable[[Any], bool], return_lists_as_dicts: bool = False
) -> Tuple[Union[dict, list], Union[dict, list]]:
""" Return matching and nonmatching values. Keeps hierarchy.
Args:
x (Union[dict, list]) : state dict to process. Top-level argument must be a dict or list
predicate (object -> bool): determines matching values
return_lists_as_dicts (bool): if True, matching lists will be turned
into dicts, with keys indicating the indices of original elements.
Useful for reconstructing the original hierarchy.
"""
def _set_elem(target, k, v):
if return_lists_as_dicts:
target[k] = v
else:
target.append(v)
if isinstance(x, dict):
matching_vals = {}
nonmatching_vals = {}
for k, v in x.items():
if isinstance(v, (list, dict)):
match, nonmatch = extract_matching_values(v, predicate, return_lists_as_dicts)
if match:
matching_vals[k] = match
if nonmatch or not v:
nonmatching_vals[k] = nonmatch
elif predicate(v):
matching_vals[k] = v
else:
nonmatching_vals[k] = v
elif isinstance(x, list):
matching_vals = {} if return_lists_as_dicts else []
nonmatching_vals = {} if return_lists_as_dicts else []
for ind, v in enumerate(x):
if isinstance(v, (list, dict)) and v:
match, nonmatch = extract_matching_values(v, predicate, return_lists_as_dicts)
if match:
_set_elem(matching_vals, ind, match)
if nonmatch or not v:
_set_elem(nonmatching_vals, ind, nonmatch)
else:
target = matching_vals if predicate(v) else nonmatching_vals
_set_elem(target, ind, v)
else:
raise ValueError(f'Unexpected top-level object type: {type(x)}')
return matching_vals, nonmatching_vals
def diff(x1: Any, x2: Any, prefix: Tuple = ()) -> Tuple[list, list, list]:
""" Recursive diff of dicts.
Args:
x1 (object): left dict
x2 (object): right dict
prefix (tuple): tracks recursive calls. Used for reporting differing keys.
Returns:
Tuple[list, list, list]: tuple of:
- only_left: Prefixes present only in left dict
- only_right: Prefixes present only in right dict
- mismatch: values present in both dicts but not equal across dicts.
For tensors equality of all elems is checked.
Each element is a tuple (prefix, type of left value, type of right value).
"""
mismatch = []
if isinstance(x1, dict) and isinstance(x2, dict):
only_left = [prefix + (k,) for k in x1.keys() - x2.keys()]
only_right = [prefix + (k,) for k in x2.keys() - x1.keys()]
for k in x2.keys() & x1.keys():
_left, _right, _mismatch = diff(x1[k], x2[k], prefix + (k,))
only_left.extend(_left)
only_right.extend(_right)
mismatch.extend(_mismatch)
elif isinstance(x1, list) and isinstance(x2, list):
only_left = list(range(len(x1) - 1, len(x2) - 1, -1))
only_right = list(range(len(x1) - 1, len(x2) - 1, -1))
for i, (v1, v2) in enumerate(zip(x1, x2)):
_left, _right, _mismatch = diff(v1, v2, prefix + (i,))
only_left.extend(_left)
only_right.extend(_right)
mismatch.extend(_mismatch)
else:
only_left = []
only_right = []
if isinstance(x1, torch.Tensor) and isinstance(x2, torch.Tensor):
_is_mismatch = not torch.all(x1 == x2)
else:
try:
_is_mismatch = bool(x1 != x2)
except RuntimeError:
_is_mismatch = True
if _is_mismatch:
mismatch.append((prefix, type(x1), type(x2)))
return only_left, only_right, mismatch
def inspect_types(x: Any, prefix: Tuple = (), indent: int = 4):
""" Helper to print types of (nested) dict values. """
print_indent = lambda: print(' ' * indent * len(prefix), end='')
if isinstance(x, dict):
print()
for k, v in x.items():
print_indent()
print(f'> {k}: ', end='')
inspect_types(v, prefix + (k,), indent)
elif isinstance(x, list):
print()
for i, v in enumerate(x):
print_indent()
print(f'- {i}: ', end='')
inspect_types(v, prefix + (i,), indent)
else:
if isinstance(x, torch.Tensor):
print(f'Tensor of shape {x.shape}')
else:
try:
x_str = str(x)
except:
x_str = '<no string repr>'
if len(x_str) > 30:
x_str = x_str[:30] + '... (truncated)'
print(f'[{type(x)}]: {x_str}')
def nested_values(x: Union[dict, list]):
""" Returns iterator over (nested) values of a given dict or list. """
x_iter = x.values() if isinstance(x, dict) else x
for v in x_iter:
if isinstance(v, (dict, list)):
yield from nested_values(v)
else:
yield v
def nested_items_iter(x: Union[dict, list]):
""" Returns iterator over (nested) tuples (container, key, value) of a given dict or list. """
x_iter = x.items() if isinstance(x, dict) else enumerate(x)
for k, v in x_iter:
if isinstance(v, (dict, list)):
yield from nested_items_iter(v)
else:
yield x, k, v
def dict_map(f: Callable, d: dict):
""" `map` equivalent for dicts. """
for sub_d, k, v in nested_items_iter(d):
sub_d[k] = f(v)
def dict_map_with_key(f: Callable, d: dict):
""" `map` equivalent for dicts with a function that accepts tuple (key, value). """
for sub_d, k, v in nested_items_iter(d):
sub_d[k] = f(k, v)
def dict_list_map_inplace(f: Callable, x: Union[dict, list]):
""" Maps dicts and lists *in-place* with a given function. """
if isinstance(x, dict):
for k, v in x.items():
x[k] = dict_list_map_inplace(f, v)
elif isinstance(x, list):
x[:] = (dict_list_map_inplace(f, v) for v in x)
else:
return f(x)
return x
def dict_list_map_outplace(f: Callable, x: Union[dict, list]):
""" Maps dicts and lists *out-of-place* with a given function. """
if isinstance(x, dict):
return {k: dict_list_map_outplace(f, v) for k, v in x.items()}
elif isinstance(x, list):
return [dict_list_map_outplace(f, v) for v in x]
else:
return f(x)
def merge(x1: dict, x2: dict, key: Tuple[str, ...] = ()):
""" Merges dicts and lists recursively. """
if isinstance(x1, dict) and isinstance(x2, dict):
for k, v2 in x2.items():
if k not in x1:
x1[k] = v2
else:
x1[k] = merge(x1[k], v2, key=key + (k,))
elif isinstance(x1, list) and isinstance(x2, list):
if len(x1) != len(x2):
raise ValueError(
f'Cannot merge two lists with different lengths ({len(x1)} and {len(x2)}, encountered at level {key})'
)
for i, v2 in enumerate(x2):
x1[i] = merge(x1[i], v2, key=key + (i,))
else:
raise ValueError(
f'Duplicate non-dict and non-list values encountered: `{x1}` and `{x2}` (at level {key})'
)
return x1
def map_reduce(
xs: Iterable,
key_fn: Callable = lambda x: x,
value_fn: Callable = lambda x: x,
reduce_fn: Callable = lambda x: x,
) -> dict:
""" Simple map-reduce implementation following `more_itertools.map_reduce` interface. """
res = defaultdict(list)
for x in xs:
res[key_fn(x)].append(value_fn(x))
for k in res:
res[k] = reduce_fn(res[k])
return dict(res)
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Core library classes for representing sharding of tensors and objects.
The main expected usage is wrapping torch.Tensors in state dicts with
ShardedTensor class (mostly with the ShardedTensor.from_rank_offsets classmethod).
"""
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass, replace
from itertools import chain
from typing import Any, Callable, Dict, Optional, Tuple, Union
import numpy as np
import torch
from .core import CheckpointingException
from .dict_utils import dict_list_map_inplace, dict_list_map_outplace
logger = logging.getLogger(__name__)
# These type definitions are just hints to differentiate a plain model state
# dict (StateDict) from a state dict with tensors replaced with ShardedTensors
# (ShardedStateDict).
StateDict = Dict[str, Any]
ShardedStateDict = Dict[str, Any]
ReplicaId = Union[int, Tuple[int, ...]]
class ShardedBase(ABC):
key: str
data: object
replica_id: ReplicaId
@abstractmethod
def validate_metadata_integrity(self):
"""Codifies the constraints on metadata attributes."""
@dataclass
class ShardedTensor(ShardedBase):
"""Represents a mapping between a local tensor and a global tensor.
Global tensor is assumed to consist of many local tensors distributed
between different processes.
Args:
key: unique identifier of a global tensor
data: local tensor data. Can be None only for consistency validation
dtype: tensor dtype
local_shape: local tensor shape
global_shape: global tensor shape
global_offset: offset of a local tensor in a global tensor, specified in number of tensor elements
axis_fragmentations: global tensor fragmentation of each axis
replica_id: indicates given local tensor's replication wrt. local tensors in different processes
prepend_axis_num: number of axes prepended to the local tensor to reflect global tensor shape. The behavior is similar to unsqueezing the local tensor.
allow_shape_mismatch: if True, during loading, the global shape of a stored tensor does not have to match the expected global shape. Useful for representing tensors with flexible shape, e.g. padded.
flattened_range: specifies a slice that should be applied to a flattened tensor with `local_shape` in order to get the tensor stored as `data`
"""
key: str
data: Optional[torch.Tensor]
dtype: torch.dtype
local_shape: Tuple[int, ...]
global_shape: Tuple[int, ...]
global_offset: Tuple[int, ...]
axis_fragmentations: Optional[Tuple[int, ...]]
replica_id: ReplicaId = 0
prepend_axis_num: int = 0
allow_shape_mismatch: bool = False
flattened_range: Optional[slice] = None
def __post_init__(self):
self.validate_metadata_integrity()
def validate_metadata_integrity(self) -> None:
"""Codifies the constraints on metadata attributes.
Meeting those constraints is guaranteed when instantiating a ShardedTensor
class with `from_rank_offsets` or `from_rank_offsets_flat` constructors.
Returns:
None
"""
has_flattened_range = self.flattened_range is not None
if self.data is not None:
if self.data.dtype != self.dtype:
raise CheckpointingException(
f'Data dtype should match `dtype` attribute for {self}'
)
if not has_flattened_range and self.data.shape != self.local_shape:
raise CheckpointingException(
f'Data shape should match `local_shape` attribute for {self}'
)
if has_flattened_range:
if self.data.ndim != 1:
raise CheckpointingException(f'Data should be 1D for a flattened {self}')
real_data = self.data
try:
self.data = None
self.init_data(device='meta')
if self.data.shape != real_data.shape:
raise CheckpointingException(
f'Data shape doesnt match expected {self.data.shape} for {self}'
)
finally:
self.data = real_data
if len(self.global_shape) != len(self.global_offset):
raise CheckpointingException(
f'Global offset dimensions should be equal to global shape dimensions for {self}'
)
if len(self.local_shape) + self.prepend_axis_num != len(self.global_shape):
raise CheckpointingException(
f'Local shape together with `prepend_axis_num` dimensions should be equal to global shape dimensions for {self}'
)
for off, sh in zip(self.global_offset[self.prepend_axis_num :], self.local_shape):
if off % sh != 0:
raise CheckpointingException(
f'Global offset ({off}) must be divisible by local shape ({sh}) for {self}.'
)
if has_flattened_range and self.flattened_range.step is not None:
raise CheckpointingException(
f'`step` argument in the flattened range of a ShardedTensor is not supported.'
)
def global_slice(self) -> Tuple[Union[int, slice], ...]:
assert len(self.global_offset) == len(self.local_shape) + self.prepend_axis_num
return tuple(
chain(
(off for off in self.global_offset[: self.prepend_axis_num]),
(
slice(off, off + sh)
for off, sh in zip(
self.global_offset[self.prepend_axis_num :], self.local_shape
)
),
)
)
def global_coordinates(self) -> Tuple[np.ndarray, ...]:
if self.flattened_range is None:
raise CheckpointingException(
f'`global_coordinates` is undefined for'
f' {self.__class__.__name__} without `flattened_range`'
)
local_coords = self.local_coordinates()
assert len(local_coords) + self.prepend_axis_num == len(self.global_offset), (
len(local_coords),
self,
)
global_coords = tuple(
c + off
for c, off in zip((0,) * self.prepend_axis_num + local_coords, self.global_offset)
)
return global_coords
def local_coordinates(self) -> Tuple[np.ndarray, ...]:
if self.flattened_range is None:
raise CheckpointingException(
f'`local_coordinates` is undefined for'
f' {self.__class__.__name__} without `flattened_range`'
)
# TODO: np.unravel_index?
mask = np.zeros(np.product(self.local_shape), dtype=bool)
mask[self.flattened_range] = True
return np.nonzero(mask.reshape(self.local_shape))
def local_chunk_offset_in_global(self) -> Tuple[int, ...]:
"""Offset of a local chunk in a global array of chunks.
Returns:
Tuple[int, ...]: the offset of the whole local chunk in a global array of chunks.
"""
assert len(self.global_offset) == len(self.local_shape) + self.prepend_axis_num
chunk_offset = list(self.global_offset[: self.prepend_axis_num])
for off, sh in zip(self.global_offset[self.prepend_axis_num :], self.local_shape):
assert off % sh == 0, str(self)
chunk_offset.append(off // sh)
return tuple(chunk_offset)
def max_allowed_chunks(self) -> Tuple[int, ...]:
chunks = []
for axis_sh, axis_fragm in zip(self.global_shape, self.axis_fragmentations):
if not self.allow_shape_mismatch and axis_sh % axis_fragm != 0:
raise CheckpointingException(
f'Axis shape ({axis_sh}) not divisible by axis fragmentation ({axis_fragm}'
)
axis_chunk_size = axis_sh // axis_fragm
chunks.append(axis_chunk_size)
return tuple(chunks)
def without_data(self):
return replace(self, data=None)
@classmethod
def from_rank_offsets(
cls,
key: str,
data: torch.Tensor,
*rank_offsets: Tuple[int, int, int],
replica_id: ReplicaId = 0,
prepend_axis_num: int = 0,
flattened_range: None = None,
**init_kwargs,
):
"""Allows to construct the ShardedTensor given offset specified in process ranks.
Args:
key (str): unique key
data (torch.Tensor): local tensor data
rank_offsets (Tuple[int, int, int]): each tuple (axis, axis_rank_offset, axis_fragm) says that if global tensor is divided into `axis_fragm` fragment along `axis` axis, then local tensor data corresponds to the `axis_rank_offset` chunk.
replica_id (ReplicaId): see ShardedTensor
prepend_axis_num (int): see ShardedTensor
flattened_range (None): must be None when using this constructor
init_kwargs: passed to ShardedTensor.__init__
"""
if flattened_range is not None:
raise ValueError(
'Cannot instantiate a flat ShardedTensor with `from_rank_offsets` method.'
' Use `from_rank_offsets_flat` instead'
)
global_offset = [0] * (data.ndim + prepend_axis_num)
global_shape = ([1] * prepend_axis_num) + list(data.shape)
axis_fragmentations = [1] * (data.ndim + prepend_axis_num)
_seen_axis = set()
for axis, axis_rank_offset, axis_fragm in rank_offsets:
assert axis >= 0 and axis_rank_offset >= 0 and axis_fragm >= 0, (
axis,
axis_rank_offset,
axis_fragm,
)
assert (
axis_rank_offset < axis_fragm
), 'Rank offset must be lower than axis fragmentation'
if axis in _seen_axis:
raise CheckpointingException('Duplicated axis specified')
_seen_axis.add(axis)
local_axis_shape = 1 if axis < prepend_axis_num else data.shape[axis - prepend_axis_num]
global_shape[axis] = axis_fragm * local_axis_shape
global_offset[axis] = axis_rank_offset * local_axis_shape
axis_fragmentations[axis] = axis_fragm
return cls(
key,
data,
data.dtype,
tuple(data.shape),
tuple(global_shape),
tuple(global_offset),
tuple(axis_fragmentations),
replica_id,
prepend_axis_num,
flattened_range=flattened_range,
**init_kwargs,
)
@classmethod
def from_rank_offsets_flat(
cls,
key: str,
data: torch.Tensor,
non_flat_local_shape: Tuple[int, ...],
*args,
flattened_range: Optional[slice] = None,
**kwargs,
):
"""Allows to construct a *flattened* ShardedTensor given offset specified in process ranks.
Args:
key (str):
data (torch.Tensor): this should be a flattened data tensor
non_flat_local_shape (Tuple[int, ...]): expected local shape of a non-flat chunk
*args: passed unchanged to the `from_rank_offsets` constructor
flattened_range (slice): see ShardedTensor. Defaults to None, but must be set to
a non-None slice.
**kwargs:
Returns:
ShardedTensor: constructed ShardedTensor instance
"""
if flattened_range is None:
raise CheckpointingException(
'Cannot instantiate a non-flat ShardedTensor with `from_rank_offsets_flat` method.'
' Use `from_rank_offsets` instead'
)
if data.ndim != 1:
raise CheckpointingException(
f'Flattened ShardedTensor requires 1D data, got shape: {data.shape}'
)
if flattened_range.stop - flattened_range.start != data.numel():
raise CheckpointingException(
f'Flattened ShardedTensor data length ({data.numel()}) must meet the slice length: {flattened_range.stop - flattened_range.start}'
)
non_flat_data_meta = torch.empty(*non_flat_local_shape, dtype=data.dtype, device='meta')
sh_ten = cls.from_rank_offsets(key, non_flat_data_meta, *args, **kwargs)
instance = replace(sh_ten, data=data, flattened_range=flattened_range)
instance.validate_metadata_integrity()
return instance
def init_data(self, device: Union[str, torch.device], init_fn=torch.empty):
if self.data is not None:
return
self.data = init_fn(self.local_shape, dtype=self.dtype, device=device)
if self.flattened_range is not None:
self.data = self.data.flatten()[self.flattened_range.start : self.flattened_range.stop]
def __str__(self):
return f'{self.__class__.__name__}(key=\'{self.key}\')'
def is_main_replica(replica_id: ReplicaId):
""" Checks if given `replica_id` is considered as main.
"Main" replica is:
- integer 0
- or an iterable with all 0 elements
It is the application responsibility to set correct replicas for sharded tensors.
Args:
replica_id (Union[int, Tuple[int, ...]]): replica id
Returns:
(bool): True for a "main" replica
"""
if isinstance(replica_id, int):
return replica_id == 0
return all(r == 0 for r in replica_id)
class LocalNonpersitentObject:
"""Object that should not be stored in a checkpoint, but restored locally.
Wrapping any object inside the state dict with LocalNonpersitentObject
will result in:
- during saving, this object will *not* be stored in the checkpoint
- during loading, a local version of this object will be placed in a state dict
"""
def __init__(self, obj):
self.obj = obj
def unwrap(self):
return self.obj
@dataclass
class ShardedObject(ShardedBase):
"""Represents a mapping between a local object and a global object.
Global object is assumed to consist of many local objects distributed
between different processes.
NOTE: Contrary to ShardedTensor, it's impossible to change global object
sharding. Conceptually, ShardedObject is a fully-sharded ShardedTensor
with atomic arbitrary typed elements.
Args:
key: unique identifier of a global tensor
data: local object data. Can be None only for consistency validation
global_shape: global object shape
global_offset: offset of a local object in a global object, specified in number of shards
replica_id: indicates local object replication wrt. local objects in different processes
"""
key: str
data: object
global_shape: Tuple[int, ...]
global_offset: Tuple[int, ...]
replica_id: ReplicaId = 0
def __post_init__(self):
self.validate_metadata_integrity()
def validate_metadata_integrity(self):
if len(self.global_shape) != len(self.global_offset):
raise CheckpointingException(
f'Global offset dimensions should be equal to global shape dimensions for {self}'
)
def without_data(self):
return replace(self, data=None)
@property
def unique_key(self):
return f'{self.key}/shard_{".".join(map(str, self.global_offset))}_{".".join(map(str, self.global_shape))}'
def __str__(self):
return f'{self.__class__.__name__}(key=\'{self.key}\')'
@dataclass
class ShardedTensorFactory(ShardedBase):
""" Allows to apply transformations to tensors before/after serialization.
The essence of those transformations is that they can be applied to
optimizer states the same way they are applied to the model params.
The ultimate state dict with sharded tensors must depend functionally on
`build_fn` arguments (key, data, replica_id, flattened_range),
which will be provided by the optimizer.
Builder creates a sub-state-dict out of a tensor before saving, and merger
merges the corresponding state dict after loading.
Args:
key (str): unique identifier of the factory
data (torch.Tensor): original model parameter that will be further transformed by this factory
build_fn (callable): function that transforms the original tensor to a sharded state dict
merge_fn (callable): function that transforms loaded subtree back into a single tensor (inverse of `build_fn`)
replica_id (ReplicaId): indicates factory replication wrt. factories in different processes
flattened_range (slice, optional): indicates additional flattening applied to the ShardedTensors produced by the factory
"""
key: str
data: torch.Tensor
build_fn: Callable[[str, torch.Tensor, ReplicaId, Optional[slice]], ShardedStateDict]
merge_fn: Callable[[StateDict], torch.Tensor]
replica_id: ReplicaId = 0
flattened_range: Optional[slice] = None
def build(self):
return self.build_fn(self.key, self.data, self.replica_id, self.flattened_range)
def validate_metadata_integrity(self):
"""No reasonable checks can be applied"""
pass
def apply_factories(sharded_state_dict: ShardedStateDict):
""" Turn ShardedTensorFactories into ShardedTensors *in-place*.
Args:
sharded_state_dict (ShardedStateDict): state dict possibly containing ShardedTensorFactory objects
Returns:
None: state dict is modified in place
"""
def apply(x):
if isinstance(x, ShardedTensorFactory):
x = x.build()
return x
dict_list_map_inplace(apply, sharded_state_dict)
def apply_factory_merges(
x1: StateDict, x2: ShardedStateDict, key: Tuple[str, ...] = ()
) -> StateDict:
""" Apply merges defined by ShardedTensorFactories *in-place*.
Args:
x1 (StateDict): state dict loaded from the checkpoint
x2 (ShardedStateDict): subset of `x1` (in terms of dict keys) with ShardedTensorFactory
as (possibly nested) values that define how to merge objects from the `x1` state dict
key (Tuple[str, ...]): current key in a recursive call. Used only for reporting meaningful errors
Returns:
StateDict: `x1` modified in-place
"""
if isinstance(x2, ShardedTensorFactory):
return x2.merge_fn(x1)
# There rest is almost the same as the `merge` function from `dict_utils`
if isinstance(x1, dict) and isinstance(x2, dict):
for k, v2 in x2.items():
if k not in x1:
raise ValueError(
f'Different dict keys encountered in `apply_factory_merges` ({x1.keys()} vs {x2.keys()})'
)
else:
x1[k] = apply_factory_merges(x1[k], v2, key=key + (k,))
elif isinstance(x1, list) and isinstance(x2, list):
if len(x1) != len(x2):
err_msg = f'Cannot merge two lists with different lengths ({len(x1)} and {len(x2)}, encountered at key {key})'
logger.error(err_msg + f'\nx1: {x1}\nx2: {x2}')
raise ValueError(err_msg)
for i, v2 in enumerate(x2):
x1[i] = apply_factory_merges(x1[i], v2, key=key + (i,))
elif isinstance(x1, list) and isinstance(x2, dict):
for k, v2 in x2.items():
if not isinstance(k, int):
raise ValueError(
f'Invalid dict key {k} non-integer type encountered in a list-dict merge at level {key}'
)
if k >= len(x1):
raise ValueError(
f'Dict key {k} out of bound for list of length {len(x1)} (encountered at level {key})'
)
x1[k] = apply_factory_merges(x1[k], v2, key=key + (k,))
else:
raise ValueError(
f'Duplicate non-dict and non-list values encountered: `{x1}` and `{x2} (at key {key})`'
)
return x1
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Helpers for defining sharding for optimizer states based on existing sharding for model parameters. """
import logging
from copy import deepcopy
from dataclasses import replace
from itertools import chain
from typing import Dict, Iterable, List, Tuple, Union
logger = logging.getLogger(__name__)
import torch
from .dict_utils import nested_values
from .mapping import (
LocalNonpersitentObject,
ShardedStateDict,
ShardedTensor,
ShardedTensorFactory,
StateDict,
)
from .utils import extract_sharded_tensors_and_factories
def get_optim_param_to_id_map(optim_params_iter: Iterable[torch.nn.Parameter]) -> Dict[int, int]:
param_mappings = {}
for i, param in enumerate(optim_params_iter):
if id(param) not in param_mappings:
param_mappings[id(param)] = i
return param_mappings
def get_param_id_to_sharded_param_map(
model_sharded_state_dict: ShardedStateDict, optim_params_iter: Iterable[torch.nn.Parameter]
) -> Dict[int, Union[ShardedTensor, ShardedTensorFactory]]:
""" Generate mapping from optimizer state ids to model sharded parameters.
Args:
model_sharded_state_dict: sharded state dict with all model sharded tensors (can have any structure)
optim_params_iter: iterable which iterates over model parameters tracked by the optimizer.
The iteration must be in the same order as in the optimizer parameters.
Returns:
Dict[int, Union[ShardedTensor, ShardedTensorFactory]]: mapping from optimizer state ids
to model sharded parameters.
"""
model_sharded_state_dict, _ = extract_sharded_tensors_and_factories(model_sharded_state_dict)
id_to_sharded_param_map = {}
param_to_id_map = get_optim_param_to_id_map(optim_params_iter)
for ten in nested_values(model_sharded_state_dict):
if id(ten.data) in param_to_id_map:
id_to_sharded_param_map[param_to_id_map[id(ten.data)]] = ten
else:
logger.debug(f'{ten} is not tracked by the optimizer')
if not id_to_sharded_param_map:
logger.warning(
"Sharded parameters mapping is empty. It means tensors in model state dict"
" do not correspond to tensors in optimizer parameters map."
" Make sure to call state_dict with `keep_vars=True`."
)
return id_to_sharded_param_map
def make_sharded_optimizer_tensor(
model_param: Union[ShardedTensor, ShardedTensorFactory], optim_param: torch.Tensor, prefix: str
) -> Union[ShardedTensor, ShardedTensorFactory]:
""" Build a ShardedTensor or ShardedTensorFactory for optimizer param based on model param
Args:
model_param (Union[ShardedTensor, ShardedTensorFactory]): model param
optim_param (torch.Tensor): corresponding optimizer param
prefix (str): optimizer prefix for the ShardedTensor or ShardedTensorFactory
Returns:
Union[ShardedTensor, ShardedTensorFactory]: wrapped optimizer parameter
"""
if isinstance(model_param, ShardedTensorFactory):
return replace(model_param, key=f'{prefix}.{model_param.key}', data=optim_param)
assert (
tuple(optim_param.shape) == model_param.local_shape
), f'Optimizer shape ({tuple(optim_param.shape)} does not match model shape ({model_param.local_shape})'
sh_ten = replace(
model_param, key=f'{prefix}.{model_param.key}', data=optim_param, dtype=optim_param.dtype
)
sh_ten.validate_metadata_integrity()
return sh_ten
def optim_state_to_sharding_state(
optim_state_dict: StateDict,
id_to_sharded_param_map: Dict[int, ShardedTensor],
exclude_keys: Tuple[str] = (),
):
""" Turn optimizer state dict to sharded state dict based on model state dict *in-place*.
Can be used to add sharding information to most common optimizer state dict.
Creates separate ShardedTensors for each key in `optim_state_dict['state']`
(e.g. for torch.optim.Adam there will be separate tensors for `exp_avg` and `exp_avg_sq`)
Args:
optim_state_dict (StateDict): optimizer state dict with
state parameters under `state` key and group hyperparameters under `param_groups` -> `params` key.
id_to_sharded_param_map (Dict[int, ShardedTensor]): mapping from optimizer param ids to model sharded tensors.
Can be generated with `get_param_id_to_sharded_param_map` function
exclude_keys (Tuple[str]): optimizer state keys to exclude from the final state dict.
Returns:
None: state dict is modified in place
"""
sharded_state = {}
for param_id, param_state in optim_state_dict['state'].items():
sharded_state[param_id] = {}
for state_key, param in param_state.items():
if state_key in exclude_keys:
continue
if param_id in id_to_sharded_param_map:
sharded_state[param_id][state_key] = make_sharded_optimizer_tensor(
id_to_sharded_param_map[param_id], param, prefix=f'optimizer.state.{state_key}'
)
else:
raise ValueError(f'Param id {param_id} does not match any model sharded param')
optim_state_dict['param_groups'] = deepcopy(optim_state_dict['param_groups'])
for group in optim_state_dict['param_groups']:
group['params'] = LocalNonpersitentObject(group['params'])
optim_state_dict['state'] = sharded_state
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Entrypoints for saving and loading the distributed checkpoints.
Functions `load` and `save` are equivalents of `torch.load` and `torch.save`
but expect torch.Tensors to be wrapped with classes from the `mapping module`.
Additionally, `load` expects the sharded state dict argument as a guidance for loading the sharded tensors.
"""
import logging
import os
from collections import Counter, defaultdict
from itertools import chain
from pathlib import Path
from typing import Iterable, List, Optional, Tuple, Union
import numpy as np
import torch
from .core import CheckpointingConfig, maybe_load_config, save_config
from .dict_utils import (
dict_list_map_inplace,
diff,
extract_matching_values,
map_reduce,
merge,
nested_values,
)
from .mapping import (
CheckpointingException,
ShardedObject,
ShardedStateDict,
ShardedTensor,
ShardedTensorFactory,
StateDict,
apply_factories,
apply_factory_merges,
is_main_replica,
)
from .strategies.async_utils import AsyncRequest
from .strategies.base import (
AsyncSaveShardedStrategy,
LoadCommonStrategy,
LoadShardedStrategy,
SaveCommonStrategy,
SaveShardedStrategy,
StrategyAction,
get_default_strategy,
)
from .utils import (
extract_nonpersistent,
extract_sharded_base,
extract_sharded_tensors,
extract_sharded_tensors_or_nonpersistent,
)
COMMON_STATE_FNAME = 'common.pt'
logger = logging.getLogger(__name__)
def load(
sharded_state_dict: ShardedStateDict,
checkpoint_dir: str,
sharded_strategy: Union[LoadShardedStrategy, Tuple[str, int], None] = None,
common_strategy: Union[LoadCommonStrategy, Tuple[str, int], None] = None,
validate_access_integrity: bool = True,
) -> StateDict:
"""Loading entrypoint.
In the steps below, the following verbs refer to corresponding objects:
- load = load from checkpoint
- extract = extract from sharded_state_dict
- add = add to the final state dict
Steps:
1. Load common state dict and form the base of the result state dict
2. Apply factories to sharded_state_dict
3. Extract LocalNonPersistentObject and add
4. (optional) Extract ShardedObjects, load and add
5. Extract ShardedBase, load, apply factory merges and add
Args:
sharded_state_dict (ShardedStateDict): state dict of the existing model
populated with ShardedTensors. Used as a mapping to determine which
parts of global tensors stored in the checkpoint should be loaded.
checkpoint_dir (str): directory with the checkpoint
sharded_strategy (LoadShardedStrategy, Tuple[str, int], optional): configures loading behavior for sharded tensors
common_strategy (LoadCommonStrategy, Tuple[str, int], optional): configures loading behavior for common data
validate_access_integrity (bool default = True): checks if each tensor shard is accessed
exactly once (as main replica) by some process
"""
if common_strategy is not None:
raise NotImplementedError('The only supported common strategy is torch')
sharded_strategy = _verify_checkpoint_and_load_strategy(checkpoint_dir, sharded_strategy)
checkpoint_dir = Path(checkpoint_dir)
common_state_dict = load_common_state_dict(checkpoint_dir)
if not sharded_state_dict:
return common_state_dict
# Create a copy of sharded_state_dict as the passed in state dict may have
# references that prevent tensors from being deallocated
sharded_state_dict, _ = extract_matching_values(sharded_state_dict, lambda x: True)
sh_ten_factories, _ = extract_matching_values(
sharded_state_dict,
lambda x: isinstance(x, ShardedTensorFactory),
return_lists_as_dicts=True,
)
apply_factories(sharded_state_dict)
# Data inside sh_ten_factories no longer needed so delete them to reduce memory usage
def unlink_data(x):
x.data = None
return x
dict_list_map_inplace(unlink_data, sh_ten_factories)
# Non-persistent objects
nonpersistent_state_dict, sharded_state_dict = extract_nonpersistent(sharded_state_dict)
dict_list_map_inplace(lambda o: o.unwrap(), nonpersistent_state_dict)
merge(common_state_dict, nonpersistent_state_dict)
# Sharded base
if not sharded_strategy.can_handle_sharded_objects:
# TODO: implement is a part of common strategy
sharded_objects, sharded_state_dict = load_sharded_objects(
sharded_state_dict, checkpoint_dir
)
merge(common_state_dict, sharded_objects)
sharded_state_dict, _ = extract_sharded_base(sharded_state_dict)
if validate_access_integrity:
validate_sharding_integrity(nested_values(sharded_state_dict))
loaded_state_dict = sharded_strategy.load(sharded_state_dict, checkpoint_dir)
loaded_state_dict = apply_factory_merges(loaded_state_dict, sh_ten_factories)
merge(common_state_dict, loaded_state_dict)
return common_state_dict
def _verify_checkpoint_and_load_strategy(
checkpoint_dir: str, sharded_strategy: Union[LoadShardedStrategy, Tuple[str, int], None] = None,
) -> LoadShardedStrategy:
""" Verifies if checkpoint metadata exists and matches given strategy.
Args:
checkpoint_dir (str): checkpoint directory
sharded_strategy (LoadShardedStrategy, Tuple[str, int], optional): load strategy to be verified
if compatible with the checkpoint content. If None, the default load strategy
for the checkpoint backend will be returned.
"""
if not Path(checkpoint_dir).exists():
raise CheckpointingException(f'Checkpoint directory {checkpoint_dir} does not exist')
saved_config = maybe_load_config(checkpoint_dir)
if saved_config is None:
raise CheckpointingException(f'{checkpoint_dir} is not a distributed checkpoint')
if sharded_strategy is None:
sharded_strategy = get_default_strategy(
StrategyAction.LOAD_SHARDED,
saved_config.sharded_backend,
saved_config.sharded_backend_version,
)
elif isinstance(sharded_strategy, tuple):
sharded_strategy = get_default_strategy(StrategyAction.LOAD_SHARDED, *sharded_strategy)
# TODO: implement consistency checks here
return sharded_strategy
# TODO: implement it as common torch strategy
def load_common_state_dict(checkpoint_dir: Path) -> StateDict:
""" Load common (non-sharded) objects state dict from the checkpoint.
Args:
checkpoint_dir (Path): checkpoint directory
Returns:
StateDict: state dict with non-sharded objects from the checkpoint
"""
load_path = Path(checkpoint_dir) / COMMON_STATE_FNAME
try:
return torch.load(load_path, map_location='cpu')
except FileNotFoundError as e:
err_msg = f'Common file {load_path} does not exist'
ckpt_files = [f.name for f in checkpoint_dir.iterdir()]
logger.debug(f'{err_msg}. Checkpoint directory content: {ckpt_files}')
raise CheckpointingException(err_msg) from e
def load_sharded_objects(sharded_state_dict: ShardedStateDict, checkpoint_dir: Path):
""" Replaces all ShardedObject from a given state dict with values loaded from the checkpoint.
Args:
sharded_state_dict (ShardedStateDict): sharded state dict defining what objects should be loaded.
checkpoint_dir (Path): checkpoint directory
Returns:
None: state dict is modified in place
"""
sharded_objects, sharded_state_dict = extract_matching_values(
sharded_state_dict, lambda v: isinstance(v, ShardedObject)
)
def load_sharded_object(sh_obj: ShardedObject):
sh_obj.data = None
load_path = (checkpoint_dir / sh_obj.unique_key).with_suffix('.pt')
try:
loaded_obj = torch.load(load_path)
except FileNotFoundError as e:
err_msg = f'Object shard {load_path} not found'
obj_subdir = checkpoint_dir / sh_obj.key
if obj_subdir.exists():
obj_files = [f.name for f in obj_subdir.iterdir()]
logger.debug(f'{err_msg}. Object {sh_obj.key} directory content: {obj_files}')
else:
ckpt_files = [f.name for f in checkpoint_dir.iterdir()]
logger.debug(
f'{err_msg}. Object {sh_obj.key} directory does not exist. Checkpoint directory content: {ckpt_files}'
)
raise CheckpointingException(err_msg) from e
return loaded_obj
return dict_list_map_inplace(load_sharded_object, sharded_objects), sharded_state_dict
def load_tensors_metadata(
checkpoint_dir: str, sharded_strategy: Union[LoadShardedStrategy, None] = None
) -> ShardedStateDict:
"""Load tensors metadata from the checkpoint.
Returns a dictionary similar to a sharded state dict, but note that
the dictionary keys are simply ShardedTensor keys (contrary to the
actual sharded state dicts where keys correspond to state dict keys).
Dict values are ShardedTensors without any sharding (so, the only useful
information is tensors global shape and dtype).
Concrete implementation depends on the loading strategy. If no strategy is
given, a default for a given backend is used.
"""
sharded_strategy = _verify_checkpoint_and_load_strategy(checkpoint_dir, sharded_strategy)
return sharded_strategy.load_tensors_metadata(Path(checkpoint_dir))
def load_plain_tensors(checkpoint_dir: str):
"""Load checkpoint tensors without any sharding.
NOTE: common state dict is NOT included."""
sharded_state_dict = load_tensors_metadata(checkpoint_dir)
# Don't validate integrity because shards will be overlapped
# if world_size > 1 (all processes load whole tensors)
return load(sharded_state_dict, checkpoint_dir, validate_access_integrity=False)
def save(
sharded_state_dict: ShardedStateDict,
checkpoint_dir: str,
sharded_strategy: Union[SaveShardedStrategy, Tuple[str, int], None] = None,
common_strategy: Union[SaveCommonStrategy, Tuple[str, int], None] = None,
validate_access_integrity: bool = True,
async_sharded_save: bool = False,
) -> Optional[AsyncRequest]:
"""Saving entrypoint.
Extracts ShardedTensors from the given state dict. Rank 0 saves the
"regular" part of the checkpoint to common torch file.
The ShardedTensors are saved according to a strategy specified by the
config.
Steps:
1. Apply factories
2. Extract and discard LocalNonPersistentObject
3. Extract all ShardedBase object
4. Save all other objects to common.pt
5. (optional) Extract and save ShardedObjects
6. Save all ShardedBase objects
7. Write metadata.json file with backend and version metadata.
Step (6) can be performed asynchronously (see `async_sharded_save`), in this
case the actual save is embodied in the returned async request and can be
scheduled by the external caller. For async request, step (7) is added as
one of the finalization functions, so that metadata.json is written only
if the checkpoint is complete.
Args:
sharded_state_dict (ShardedStateDict): state dict of the populated with
ShardedTensors. Used as a mapping to determine how local tensors
should be saved as global tensors in the checkpoint.
checkpoint_dir (str): directory to save the checkpoint to
sharded_strategy (SaveShardedStrategy, Tuple[str, int], optional): configures sharded tensors saving behavior and backend
common_strategy (SaveCommonStrategy, Tuple[str, int], optional): configures common data saving behavior and backend
validate_access_integrity (bool default = True): checks if each tensor shard is accessed
exactly once (as main replica) by some process
async_sharded_save (bool, optional): if True, for the sharded state dict part
an async save implementation will be called, with the AsyncRequest
being returned to the caller. Note that it is the caller responsibility to
actually schedule the async save. Defaults to False.
Returns:
AsyncRequest (optional): if `async_sharded_save` is True, returns
async request that should be scheduled by the caller of this function.
None otherwise.
"""
checkpoint_dir = Path(checkpoint_dir)
if torch.distributed.get_rank() == 0:
if not checkpoint_dir.exists():
raise CheckpointingException(
f'Checkpoint destination directory does not exist: {checkpoint_dir}'
)
if next(checkpoint_dir.iterdir(), None) is not None:
raise CheckpointingException(
f'Checkpoint destination directory ({checkpoint_dir}) is not empty'
)
if common_strategy is not None:
raise NotImplementedError('The only supported common strategy is torch')
if sharded_strategy is None:
sharded_strategy = get_default_save_sharded_strategy()
if not isinstance(sharded_strategy, SaveShardedStrategy):
assert isinstance(sharded_strategy, tuple), type(sharded_strategy)
sharded_strategy = get_default_strategy(StrategyAction.SAVE_SHARDED, *sharded_strategy)
apply_factories(sharded_state_dict)
_, sharded_state_dict = extract_nonpersistent(sharded_state_dict)
sharded_state_dict, state_dict = extract_sharded_base(sharded_state_dict)
_save_common_dict(state_dict, checkpoint_dir, True)
if validate_access_integrity:
validate_sharding_integrity(list(nested_values(sharded_state_dict)))
if not sharded_strategy.can_handle_sharded_objects:
# TODO: implement is a part of common strategy
sharded_state_dict = _extract_and_save_sharded_objects(
sharded_state_dict, checkpoint_dir, validate_access_integrity
)
def metadata_finalize_fn():
if torch.distributed.get_rank() == 0:
save_config(
CheckpointingConfig(sharded_strategy.backend, sharded_strategy.version),
checkpoint_dir,
)
torch.distributed.barrier()
if not async_sharded_save:
sharded_strategy.save(sharded_state_dict, checkpoint_dir)
metadata_finalize_fn()
return
if not isinstance(sharded_strategy, AsyncSaveShardedStrategy):
raise CheckpointingException(
f'Cannot apply async_save to non-async strategy {sharded_strategy}'
)
async_request = sharded_strategy.async_save(sharded_state_dict, checkpoint_dir)
async_request.finalize_fns.append(metadata_finalize_fn)
return async_request
def get_default_save_sharded_strategy(
backend: str = 'torch_dist', version: int = 1
) -> SaveShardedStrategy:
return get_default_strategy(StrategyAction.SAVE_SHARDED, backend, version)
def get_default_load_sharded_strategy(checkpoint_dir: str) -> LoadShardedStrategy:
return _verify_checkpoint_and_load_strategy(checkpoint_dir)
# TODO: implement it as common torch strategy
def _save_common_dict(
state_dict: StateDict, checkpoint_dir: Path, validate_consistency: bool = False
):
if torch.distributed.get_rank() == 0:
torch.save(state_dict, checkpoint_dir / COMMON_STATE_FNAME)
if validate_consistency:
# TODO: implement checking consistency with rank 0 common dict on other ranks
pass
# torch.distributed.barrier()
# if not torch.distributed.get_rank() == 0:
# rank_0_state_dict = torch.load(checkpoint_dir / COMMON_STATE_FNAME)
# print(diff(common_state_dict, rank_0_state_dict))
def _extract_and_save_sharded_objects(
state_dict: StateDict, checkpoint_dir: Path, validate_consistency: bool = False
):
sharded_objects, state_dict = extract_matching_values(
state_dict, lambda v: isinstance(v, ShardedObject)
)
sharded_objects = list(nested_values(sharded_objects))
for sh_obj in sharded_objects:
if is_main_replica(sh_obj.replica_id):
save_path = (checkpoint_dir / sh_obj.unique_key).with_suffix('.pt')
os.makedirs(save_path.parent, exist_ok=True)
torch.save(sh_obj.data, save_path)
return state_dict
def validate_sharding_integrity(sharded_tensors: Iterable[ShardedTensor]):
""" Validate if the ShardedTensors from multiple processes define correct sharding of a global tensor.
Local ShardedTensors metadata is exchanged with `torch.distributed.all_gather_object`
and then process with global rank 0 checks if main replicas of the shards:
- cover the whole global tensors
- don't overlap
Args:
sharded_tensors (Iterable[ShardedTensor]): sharded tensors local to this process
Returns:
None
Raises:
CheckpointingException for invalid access pattern
"""
sharding = [ten.without_data() for ten in sharded_tensors]
all_sharding = [None] * torch.distributed.get_world_size()
torch.distributed.all_gather_object(all_sharding, sharding)
if torch.distributed.get_rank() != 0:
return
key_shardings = defaultdict(list)
for rank, rank_shardings in enumerate(all_sharding):
for sharding in rank_shardings:
key_shardings[sharding.key].append((rank, sharding))
for key, shardings in key_shardings.items():
if isinstance(shardings[0][1], ShardedObject):
_validate_objects_for_key(shardings)
else:
_validate_sharding_for_key(shardings)
def _validate_sharding_for_key(rank_sharding: List[Tuple[int, ShardedTensor]]):
some_rank_shard = rank_sharding[0][1]
global_shape = some_rank_shard.global_shape
local_shape = some_rank_shard.local_shape
dtype = some_rank_shard.dtype
has_flattened_range = some_rank_shard.flattened_range is not None
for rank, sharding in rank_sharding:
assert sharding.dtype == dtype, (sharding.dtype, dtype, some_rank_shard)
assert sharding.global_shape == global_shape, (
sharding.global_shape,
global_shape,
some_rank_shard,
)
assert sharding.local_shape == local_shape, (
sharding.local_shape,
local_shape,
some_rank_shard,
)
assert (sharding.flattened_range is not None) == has_flattened_range, (
(sharding.flattened_range is not None),
has_flattened_range,
some_rank_shard,
)
shard_access_cnt = _compute_shards_access(rank_sharding)
if has_flattened_range:
map_reduce(
rank_sharding,
lambda x: x[1].global_offset,
lambda x: x[1],
_validate_sharding_for_key_flattened,
)
else:
if not torch.all(shard_access_cnt == 1):
logger.error(f'Invalid access pattern for {rank_sharding[0][1]}: {shard_access_cnt}')
raise CheckpointingException(f'Invalid access pattern for {rank_sharding[0][1]}')
def _compute_shards_access(rank_sharding):
shard_access_cnt = torch.zeros(
rank_sharding[0][1].axis_fragmentations, dtype=torch.int, device='cpu'
)
for rank, sharding in rank_sharding:
if is_main_replica(sharding.replica_id):
shard_access_cnt[sharding.local_chunk_offset_in_global()] += 1
# TODO: consider validating different replicas too
return shard_access_cnt
def _validate_sharding_for_key_flattened(tensors_by_shard):
all_slices = []
local_shape = tensors_by_shard[0].local_shape
for sharding in tensors_by_shard:
assert sharding.local_shape == local_shape
sharding: ShardedTensor
if not is_main_replica(sharding.replica_id):
# TODO: this checks only saving (and loading replica_id=0) consistency
continue
all_slices.append((sharding.flattened_range.start, sharding.flattened_range.stop))
starts, stops = map(np.asarray, zip(*sorted(all_slices)))
if (
starts[0] != 0
or stops[-1] != np.product(local_shape)
or not np.all(starts[1:] == stops[:-1])
):
logger.error(
f'Flattened ranges dont cover the whole shard {tensors_by_shard[0]}. Ranges: {(starts, stops)}'
)
raise CheckpointingException(
f'Flattened ranges dont cover the whole shard {tensors_by_shard[0]}'
)
def _validate_objects_for_key(sharded_objects: List[ShardedObject]):
""" Ensure uniqueness of saved objects. """
unique_keys = [
sh_obj.unique_key for _, sh_obj in sharded_objects if is_main_replica(sh_obj.replica_id)
]
if len(unique_keys) != len(set(unique_keys)):
duplicates = {k: cnt for k, cnt in Counter(unique_keys).items() if cnt > 1}
logger.error(f'Duplicate ShardedObject keys and counts: {duplicates}')
raise CheckpointingException(f'Duplicate ShardedObject keys: {list(duplicates.keys())}')
expected_shard_num = np.prod(sharded_objects[0][1].global_shape)
if len(unique_keys) != expected_shard_num:
err_msg = f'Invalid access pattern: {expected_shard_num - len(unique_keys)} ShardedObject are missing.'
logger.error(f'{err_msg} Existing shards: {unique_keys}')
raise CheckpointingException(err_msg)
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Various loading and saving strategies """
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""
This module provides an async utilities which allow to start
a checkpoint save process in the background.
"""
import logging
from collections import deque
from time import time
from typing import Callable, List, NamedTuple, Optional, Tuple
import torch
from torch import multiprocessing as mp
logger = logging.getLogger(__name__)
class AsyncRequest(NamedTuple):
""" Represents an async request that needs to be scheduled for execution.
Args:
async_fn (Callable, optional): async function to call. None represents noop.
async_fn_args (Tuple): args to pass to `async_fn`.
finalize_fns (List[Callable]): list of functions to call to finalize the request.
These functions will be called synchronously after `async_fn` is done
*on all ranks*.
"""
async_fn: Optional[Callable]
async_fn_args: Tuple
finalize_fns: List[Callable]
is_frozen: bool = False
def add_finalize_fn(self, fn: Callable) -> None:
""" Adds a new finalize function to the request.
Args:
fn (Callable): function to add to the async request. This function
will be called *after* existing finalization functions.
Returns:
None
"""
if self.is_frozen:
raise RuntimeError('Cannot add finalization functions to a frozen AsyncRequest')
self.finalize_fns.append(fn)
def execute_sync(self) -> None:
""" Helper to synchronously execute the request.
This logic is equivalent to what should happen in case of the async call.
"""
if self.async_fn is not None:
self.async_fn(*self.async_fn_args)
torch.distributed.barrier()
for finalize_fn in self.finalize_fns:
finalize_fn()
def freeze(self) -> 'AsyncRequest':
""" Freezes the async request, disallowing adding new finalization functions.
Returns:
AsyncRequest: new async request with all same fields except for the
`is_frozen` flag.
"""
return self._replace(is_frozen=True)
class DistributedAsyncCaller:
""" Wrapper around mp.Process that ensures correct semantic of distributed finalization.
Starts process asynchronously and allows checking if all processes on all ranks are done.
"""
def __init__(self):
self.process: Optional[mp.Process] = None
self.start_time: Optional[float] = None
def schedule_async_call(self, async_fn: Optional[Callable], save_args: Tuple,) -> None:
""" Spawn a process with `async_fn` as the target.
This method must be called on all ranks.
Args:
async_fn (Callable, optional): async function to call. If None,
no process will be started.
save_args (Tuple): async function args.
"""
if async_fn is None:
return # nothing to do
torch.cuda.synchronize()
ctx = mp.get_context('fork')
self.start_time = time()
self.process = ctx.Process(target=async_fn, args=save_args,)
self.process.start()
def is_current_async_call_done(self, blocking=False) -> bool:
""" Check if async save is finished on all ranks.
For semantic correctness, requires rank synchronization in each check.
This method must be called on all ranks.
Args:
blocking (bool, optional): if True, will wait until the call is done
on all ranks. Otherwise, returns immediately if at least one rank
is still active. Defaults to False.
Returns:
bool: True if all ranks are done (immediately of after active wait
if `blocking` is True), False if at least one rank is still active.
"""
# The following takes the same overhead as torch.distributed.barrier (single integer all-reduce)
is_alive = int(self.process.is_alive()) if self.process is not None else 0
ten = torch.tensor([is_alive], dtype=torch.int, device=torch.cuda.current_device())
logger.debug(
f"rank: {torch.distributed.get_rank()}, DistributedAsyncCaller is_alive: {is_alive}"
)
torch.distributed.all_reduce(ten)
if ten[0] > 0 and not blocking:
return False
else:
if self.process is not None:
logger.debug(f"rank: {torch.distributed.get_rank()}, joining self.process")
self.process.join()
self.process = None
logger.debug(
f"DistributedAsyncCaller: Async process join finished after {time() - self.start_time:.2f}s from forking"
)
self.start_time = None
return True
class _ActiveAsyncRequest(NamedTuple):
""" Helper to represent an active async call.
Args:
idx (int): index of the call (starting from 0)
async_caller (DistributedAsyncCaller): async caller instance that represents
the async process handling the async request
async_request (AsyncRequest): async request that is being called
"""
idx: int
async_caller: DistributedAsyncCaller
async_request: AsyncRequest
class AsyncCallsQueue:
""" Manages a queue of async calls.
Allows adding a new async call with `schedule_async_request` and finalizing
active calls with `maybe_finalize_async_calls`.
"""
def __init__(self):
self.async_calls: deque[_ActiveAsyncRequest] = deque([])
self.call_idx: int = -1
def schedule_async_request(self, async_request: AsyncRequest) -> int:
""" Start a new async call and add it to a queue of active async calls.
This method must be called on all ranks.
Args:
async_request (AsyncRequest): async request to start.
Returns:
int: index of the async call that was started.
This can help the user keep track of the async calls.
"""
self.call_idx += 1
async_caller = DistributedAsyncCaller()
async_request = async_request.freeze()
async_caller.schedule_async_call(async_request.async_fn, async_request.async_fn_args)
self.async_calls.append(_ActiveAsyncRequest(self.call_idx, async_caller, async_request))
return self.call_idx
def maybe_finalize_async_calls(self, blocking=False) -> List[int]:
""" Finalizes all available calls.
This method must be called on all ranks.
Args:
blocking (bool, optional): if True, will wait until all active requests
are done. Otherwise, finalizes only the async request that already
finished. Defaults to False.
Returns:
List[int]: list of indices (as returned by `schedule_async_request`)
of async calls that have been successfully finalized.
"""
call_idx_finalized = []
while self.async_calls:
next_async_done = self.async_calls[0].async_caller.is_current_async_call_done(blocking)
if not next_async_done:
break
call_idx, _, async_request = self.async_calls.popleft()
for finalize_fn in async_request.finalize_fns:
finalize_fn()
ten = torch.tensor([call_idx], dtype=torch.int, device=torch.cuda.current_device())
torch.distributed.all_reduce(ten, op=torch.distributed.ReduceOp.MAX)
assert (
ten.item() == call_idx
), 'Unmatched async calls. That probably means not all ranks are participating in async finalization'
call_idx_finalized.append(call_idx)
return call_idx_finalized
def get_num_unfinalized_calls(self):
""" Get the number of active async calls. """
return len(self.async_calls)
def close(self):
""" Finalize all calls upon closing. """
self.maybe_finalize_async_calls(blocking=True)
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Strategies base interfaces. """
from abc import ABC, abstractmethod
from collections import defaultdict
from enum import Enum
from pathlib import Path
from ..mapping import CheckpointingException, ShardedStateDict, StateDict
from .async_utils import AsyncRequest
class StrategyAction(Enum):
LOAD_COMMON = 'load_common'
LOAD_SHARDED = 'load_sharded'
SAVE_COMMON = 'save_common'
SAVE_SHARDED = 'save_sharded'
default_strategies = defaultdict(dict)
def get_default_strategy(action: StrategyAction, backend: str, version: int):
""" Retrieves a default strategy for a given action, backend and version. """
try:
if backend == 'zarr':
error_hint = ' Please install `zarr` and `tensorstore<=0.1.45` packages'
from .tensorstore import _import_trigger
from .zarr import _import_trigger
elif backend == 'torch_dist':
error_hint = ' Please use PyTorch version >=2.1'
from .torch import _import_trigger
except ImportError as e:
raise CheckpointingException(
f'Cannot import a default strategy for: {(action.value, backend, version)}. Error: {e}. Hint: {error_hint}'
) from e
try:
return default_strategies[action.value][(backend, version)]
except KeyError as e:
raise CheckpointingException(
f'Cannot find a default strategy for: {(action.value, backend, version)}'
) from e
class LoadStrategyBase(ABC):
""" Base class for a load strategy. Requires implementing checks for compatibility with a given checkpoint version. """
@abstractmethod
def check_backend_compatibility(self, loaded_version):
raise NotImplementedError
@abstractmethod
def check_version_compatibility(self, loaded_version):
raise NotImplementedError
@property
def can_handle_sharded_objects(self):
""" Returns whether or not this strategy can handle loading ShardedObjects. """
return False
class SaveStrategyBase(ABC):
""" Base class for a save strategy. Requires defining a backend type and version of the saved format. """
def __init__(self, backend: str, version: int):
self.backend = backend
self.version = version
@property
def can_handle_sharded_objects(self):
""" Returns whether or not this strategy can handle saving ShardedObjects. """
return False
def __str__(self):
return f'{self.__class__.__name__}({self.backend}, {self.version})'
class LoadCommonStrategy(LoadStrategyBase):
""" Load strategy for common (non-sharded) objects """
@abstractmethod
def load(self, checkpoint_dir: Path):
raise NotImplementedError
class LoadShardedStrategy(LoadStrategyBase):
""" Load strategy for sharded tensors """
@abstractmethod
def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path):
raise NotImplementedError
@abstractmethod
def load_tensors_metadata(self, checkpoint_dir: Path):
"""Load tensors metadata from the checkpoint.
Returns a dictionary similar to a sharded state dict, but note that
the dictionary keys are simply ShardedTensor keys (contrary to the
actual sharded state dicts where keys correspond to state dict keys).
Dict values are ShardedTensors without any sharding (so, the only useful
information is tensors global shape and dtype).
"""
raise NotImplementedError(
f'{self.__class__.__name__} doesnt allow loading only sharded metadata'
)
class SaveCommonStrategy(SaveStrategyBase):
""" Save strategy for common (non-sharded) objects """
@abstractmethod
def save(self, common_state_dict: StateDict, checkpoint_dir: Path):
raise NotImplementedError
class SaveShardedStrategy(SaveStrategyBase):
""" Save strategy for sharded tensors """
@abstractmethod
def save(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path):
raise NotImplementedError
class AsyncSaveShardedStrategy(SaveShardedStrategy):
""" Save strategy suitable for async save. """
@abstractmethod
def async_save(
self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path
) -> AsyncRequest:
""" Perform preparation and return an AsyncRequest to the external caller.
Args:
sharded_state_dict (ShardedStateDict): sharded state dict to save
checkpoint_dir (Path): checkpoint target directory
Returns:
AsyncRequest: represents the async save function and finalization function.
It is the caller responsibility to actually schedule the async save.
"""
raise NotImplementedError
def save(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path):
""" Each async strategy can be trivially used as a sync strategy. """
async_request = self.async_save(sharded_state_dict, checkpoint_dir)
async_request.execute_sync()
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
""" Storage writer for PyT Distributed format allowing asynchronous save. """
import logging
import os
from itertools import chain
from pathlib import Path
from time import time
from typing import Callable, Dict, List, Optional, Tuple
import psutil
import torch
from torch import multiprocessing as mp
from torch.distributed.checkpoint import FileSystemWriter
from torch.distributed.checkpoint.filesystem import DEFAULT_SUFFIX, _StoragePrefix, _write_item
from torch.distributed.checkpoint.planner import SavePlan, SavePlanner, WriteItem, WriteItemType
from torch.distributed.checkpoint.storage import WriteResult
from torch.futures import Future
logger = logging.getLogger(__name__)
WriteBucket = Tuple[Path, str, Tuple[list, list]] # represents writes to a single file
class FileSystemWriterAsync(FileSystemWriter):
"""
Async-enabled implementation of FileSystemWriter using file IO.
This class doesn't spawn the async process itself, relies on the external async mechanism.
Flow:
1. Call `write_data`
2. Externally start async process with `get_save_function_and_args` function and args
3. The async function to call is `writer_proxy_func` which calls
`write_preloaded_data` in multiple processes
After saving is finalized on all ranks:
4. Call `super().finish` with the results gathered in `self.writer_result`
Note that step (3) above can also be called synchronously.
Currently, it's assumed that a separate writer is created for each ckpt save
(intermediate state is stored as writer attributes).
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if not self.single_file_per_rank:
raise NotImplementedError(
'single_file_per_rank flag not supported for FileSystemWriterAsync'
)
# Intermediate state between preparation and finalization
self.write_buckets: Optional[List[WriteBucket]] = None
self.write_results: Optional[Dict[int, List[WriteResult]]] = None
def prepare_write_data(self, plan: SavePlan, planner: SavePlanner) -> None:
"""
First stage of async saving. Copy data to CPU and plan the local saving.
Args:
plan (SavePlan): save plan generated by the PyT Distributed compatible planner
planner (SavePlanner): save planner used to resolve the bytes and tensor data
Returns: None, but stores the save plan in `self.write_buckets`
"""
storage_plan: _StoragePrefix = plan.storage_data
start = time()
logger.debug(f"thread_count: {self.thread_count}, time: {start}")
item_buckets = _split_by_size_and_type(self.thread_count, plan.items)
logger.debug(f"bucket_prep, time: {time() - start}")
start = time()
# move tensors from GPU to CPU before starting async writing
# We do D2H synchronously for now
file_count = 0
def gen_file():
nonlocal file_count
file_name = f"{storage_plan.prefix}{file_count}{DEFAULT_SUFFIX}"
file_count += 1
return file_name
# Prepare bytes / tensor data in each bucket, which will be assigned to each writer process
self.write_buckets = []
for bucket in item_buckets:
bytes_data = [
(item, planner.resolve_data(item))
for item in bucket
if item.type == WriteItemType.BYTE_IO
]
tensor_data = [
(item, planner.resolve_data(item).detach().to("cpu", non_blocking=True))
for item in bucket
if item.type != WriteItemType.BYTE_IO
]
if len(bytes_data) > 0 or len(tensor_data) > 0:
file_name = gen_file()
self.write_buckets.append(
(self.path / file_name, file_name, (bytes_data, tensor_data))
)
# Check if there is anything to write on this rank
if len(self.write_buckets) > 0:
assert len(self.write_buckets) <= self.thread_count, (
len(self.write_buckets),
self.thread_count,
)
ctx = mp.get_context('fork')
self.write_results = ctx.Manager().dict()
else:
self.write_results = {}
logger.debug(f"D2H and push, time: {time() - start}")
def get_save_function_and_args(self) -> Tuple[Optional[Callable], Tuple]:
"""
Get function that saves the data to storage along with its arguments.
Allows the external caller to apply the save function synchronously or asynchronously.
Returns: None (if there is nothing to write on this rank) or a tuple of:
- the function that saves the data
- arguments to that function
"""
if not self.write_buckets:
return None, ()
return (self.write_preloaded_data_multiproc, (self.write_buckets, self.write_results))
@staticmethod
def write_preloaded_data_multiproc(
write_buckets: List[WriteBucket], write_results: Dict[int, List[WriteResult]]
) -> None:
"""
Performs saving data to storage with multiple processes.
Args:
write_buckets (List[WriteBucket]): write plan
write_results: (Dict[int, List[WriteResult]]): dict to store the write results to.
Assumes multiprocessing save, so keys are local process indices
Returns: None
"""
w_start = time()
ctx = mp.get_context('fork')
p_list = [
ctx.Process(
target=FileSystemWriterAsync.write_preloaded_data,
args=(i, write_bucket, write_results, True),
)
for i, write_bucket in enumerate(write_buckets)
]
for p in p_list:
p.start()
for p in p_list:
p.join()
w_end = time()
logger.debug(
f"{w_end}, rank: {torch.distributed.get_rank()}, write(sync,parallel): {w_end - w_start}"
)
@staticmethod
def write_preloaded_data(
local_proc_idx: int,
write_bucket: WriteBucket,
write_results: Dict[int, List[WriteResult]],
use_fsync: bool,
) -> None:
"""
Performs actual data saving to storage.
Args:
local_proc_idx (int): index of a local process that performs writing
write_bucket (WriteBucket): data to write to storage
write_results (Dict[int, List[WriteResult]]): dict to store the write results to.
Assumes multiprocessing save, so keys are local process indices
use_fsync (bool): if True, calls os.fsync at the end of saving
Returns: None, the write result are written to the `write_results` dict
"""
mem_before = _process_memory()
local_results = []
file_name, storage_key, (bytes_data, tensor_data) = write_bucket
with open(file_name, "wb") as stream:
for write_item, data in bytes_data:
local_results.append(_write_item(stream, data, write_item, storage_key))
for write_item, tensor in tensor_data:
assert tensor.is_cpu
local_results.append(_write_item(stream, tensor, write_item, storage_key))
if use_fsync:
os.fsync(stream.fileno())
write_results[local_proc_idx] = local_results
mem_after = _process_memory()
logger.debug(
f"{local_proc_idx} consumed: {mem_after - mem_before}, before: {mem_before}, after: {mem_after}"
)
def write_data(self, plan: SavePlan, planner: SavePlanner,) -> Future[List[WriteResult]]:
raise NotImplementedError('write_data not implemented for FileSystemWriterAsync')
def retrieve_write_results(self) -> List[WriteResult]:
"""
Turn self.write_results into a single results lists. Includes error check.
Returns (List[WriteResult]): the list of write results from all local processes performing the save.
"""
assert self.write_results is not None
assert self.write_buckets is not None
if len(self.write_results) != len(self.write_buckets):
raise RuntimeError(
f'Incomplete worker results (expected {len(self.write_buckets)}, got {len(self.write_results)}.'
f' This probably indicates a worker failure.'
)
return list(chain.from_iterable(self.write_results.values()))
def _split_by_size_and_type(bins: int, items: List[WriteItem]) -> List[List[WriteItem]]:
"""
Splits write items according to item size into close to uniform bins.
Same as torch.distributed.checkpoint.filesystem._split_by_size_and_type,
but with a fixed _item_size function.
Args:
bins (int): numbers of bins to split to
items (List[WriteItem]): list of write items
Returns (List[List[WriteItem]]): write items split to bins
"""
if bins == 1:
return [items]
bytes_items = [wi for wi in items if wi.type == WriteItemType.BYTE_IO]
tensor_items = [wi for wi in items if wi.type != WriteItemType.BYTE_IO]
buckets: List[List[WriteItem]] = [[] for _ in range(bins)]
bucket_sizes = [0 for _ in range(bins)]
tensor_items.sort(key=_item_size, reverse=True)
# Assign bytes with a simple round-robin
for i, item in enumerate(bytes_items):
buckets[i % bins].append(item)
# Then, assign tensors according to their sizes
for item in tensor_items:
# TODO replace with headq
idx = min(enumerate(bucket_sizes), key=lambda x: x[1])[0]
buckets[idx].append(item)
bucket_sizes[idx] += _item_size(item)
return buckets
def _item_size(item: WriteItem) -> int:
"""
Calculates size (in bytes) of a single write item.
Same as torch.distributed.checkpoint.filesystem._item_size,
but fixes computing chunk size (with item.tensor_data.chunk.sizes)
Args:
item (WriteItem): write item to compute the size of
Returns (int): size of an item in bytes
"""
size = 1
assert item.tensor_data is not None
# can't use math.prod as PT needs to support older python
for s in item.tensor_data.chunk.sizes:
size *= s
dtype = item.tensor_data.properties.dtype
return size * torch._utils._element_size(dtype)
def _process_memory() -> int:
"""
Get memory used by current process.
Returns (int): memory used by current process
"""
process = psutil.Process(os.getpid())
mem_info = process.memory_info()
return mem_info.rss
import logging
from collections import defaultdict
from functools import reduce
from itertools import zip_longest
from pathlib import Path
from time import time
from typing import Dict, List, NamedTuple, Optional, Set, Tuple, TypeVar, cast
import numpy as np
import torch
import torch.distributed as dist
from megatron.core.dist_checkpointing import ShardedTensor
from megatron.core.dist_checkpointing.core import CheckpointingException
from megatron.core.dist_checkpointing.dict_utils import (
dict_list_map_inplace,
extract_matching_values,
merge,
nested_values,
)
from megatron.core.dist_checkpointing.mapping import ShardedStateDict, StateDict, is_main_replica
from megatron.core.dist_checkpointing.serialization import validate_sharding_integrity
from megatron.core.dist_checkpointing.strategies.base import (
AsyncSaveShardedStrategy,
LoadShardedStrategy,
SaveShardedStrategy,
)
logger = logging.getLogger(__name__)
# _ShardId uniquely identifies a ShardedTensor. This is a subset of ShardedTensor
# attributes: key (str), global_offset (tuple) and flattened_range (optional tuple)
_ShardId = Tuple[str, tuple, Optional[tuple]]
class SaveLoadDistribution(NamedTuple):
""" Represents a save or load distribution of ShardedTensors.
Given distribution is valid only for a specific parallelization group,
which is implicit here (not referenced by this class).
Args:
main_rank_for_shard (Dict[_ShardId, int]): specifies which rank should hold
the main replica for a given shard
shards_in_this_group (Set[_ShardId]): which shards have a main replica
in this parallelization group
shard_to_metadata (Dict[_ShardId, ShardedTensor]): maps ShardedTensor
identifier to the original ShardedTensor
"""
main_rank_for_shard: Dict[_ShardId, int]
shards_in_this_group: Set[_ShardId]
shard_to_metadata: Dict[_ShardId, ShardedTensor]
class FullyParallelSaveStrategyWrapper(AsyncSaveShardedStrategy):
""" Wraps arbitrary strategy and distributes the save during `save`.
The save distribution happens without any *data* communication.
Only the *metadata* is exchanged and based on data replication on different
ranks, we try to distribute the save as uniformly as possible.
This wrapper assumes, that setting `replica_id` to 0 will make the
underlying strategy do the saving on current rank. All the other `replica_id`s
are set to 1.
Currently, the save distribution is realized with a greedy algorithm
described in `distribute_shards_to_ranks`.
Args:
strategy (SaveShardedStrategy): base strategy to wrap
parallelization_group (ProcessGroup, optional): process group to use for save
distribution. Note that this doesn't have to match exactly the
data distribution, but should cover the replication pattern
to maximize performance. Defaults to the whole world.
do_cache_distribution (bool, optional): whether to cache the save distribution
from previous calls. Should be set to True only if the state dict
structure between the calls is always the same. Defaults to True.
"""
def __init__(
self,
strategy: SaveShardedStrategy,
parallelization_group: Optional[torch.distributed.ProcessGroup] = None,
do_cache_distribution: bool = False,
):
super().__init__(strategy.backend, strategy.version)
self.base_strategy = strategy
self.parallelization_group = parallelization_group
self.do_cache_distribution = do_cache_distribution
self.cached_distribution: Optional[SaveLoadDistribution] = None
def async_save(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path):
if not isinstance(self.base_strategy, AsyncSaveShardedStrategy):
raise CheckpointingException(
f'Cannot apply async_save to non-async base strategy {self.base_strategy}'
)
self.apply_saving_parallelization(sharded_state_dict)
return self.base_strategy.async_save(sharded_state_dict, checkpoint_dir)
def save(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path):
self.apply_saving_parallelization(sharded_state_dict)
return self.base_strategy.save(sharded_state_dict, checkpoint_dir)
def apply_saving_parallelization(self, sharded_state_dict: ShardedStateDict) -> None:
""" Distributes the save across ranks by exchanging metadata.
Exchanges metadata from the state dict and computes the uniform
(as close as possible) distribution of saves among the ranks.
If `self.do_cache_distribution` is True, caches the distribution between
the calls and subsequent distributions happen without any inter-rank
communication.
Args:
sharded_state_dict (ShardedStateDict): state dict to distribute the saving
Returns: None
"""
if self.do_cache_distribution and self.cached_distribution is not None:
logger.debug(f'Apply *cached* save parallelization')
precomputed_distribution = self.cached_distribution
else:
logger.debug(f'Apply save parallelization')
precomputed_distribution = determine_main_replica_uniform_distribution(
sharded_state_dict, self.parallelization_group
)
distribute_main_replicas_with_precomputed_distribution(
sharded_state_dict, self.parallelization_group, precomputed_distribution
)
if self.cached_distribution is None:
# First time applying the parallelization
validate_sharding_integrity(nested_values(sharded_state_dict))
if self.do_cache_distribution:
self.cached_distribution = precomputed_distribution
@property
def can_handle_sharded_objects(self):
return self.base_strategy.can_handle_sharded_objects
class FullyParallelLoadStrategyWrapper(LoadShardedStrategy):
""" Wraps arbitrary load strategy and distributes the load during `load`.
See `load` method docs for details.
Args:
strategy (LoadShardedStrategy): base strategy to wrap
parallelization_group (ProcessGroup, optional): process group to use for load
distribution. Note that this doesn't have to match exactly the
data distribution, but should cover the replication pattern
to maximize performance. Defaults to the whole world.
In most cases, it's recommended to set it to the DP group.
do_cache_distribution (bool, optional): whether to cache the load distribution
from previous calls. Should be set to True only if the state dict
structure between the calls is always the same. Defaults to False,
since the loading in general happens only once during training.
Note that the load distribution *cannot* be reused as a save distribution,
because save/load is not fully symmetrical.
exchange_algo (str): algorithm to use for exchanging the data.
Options:
- broadcast - each rank broadcasts individual tensors to others
- gather_object (default) - ranks all_gather_object the whole loaded state dicts
- gather_rounds (default) - ranks all gather individual tensors in rounds
See method docs for more details.
"""
def __init__(
self,
strategy: LoadShardedStrategy,
parallelization_group: Optional[torch.distributed.ProcessGroup] = None,
do_cache_distribution: bool = False,
exchange_algo: str = 'gather_rounds',
):
super().__init__()
self.base_strategy = strategy
self.parallelization_group = parallelization_group
self.do_cache_distribution = do_cache_distribution
self.exchange_algo = exchange_algo
self.cached_distribution: Optional[SaveLoadDistribution] = None
def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path) -> StateDict:
""" Distributes the load and calls underlying strategy only for parts of the state dict.
Steps:
1. Load metadata is exchanged between the ranks in the parallelization group.
2. Each rank deterministically plans the load for the whole workload
so that the loads are as uniform as possible.
3. Each ranks loads its planned shard of the checkpoint.
4. All ranks exchange the loaded shards.
Internode communication is involved in steps (1) (with metadata)
and (4) (with actual data). Storage interaction is involved in step (3).
Currently, the load distribution (step 2) is realized with a greedy algorithm
described in `distribute_shards_to_ranks` (same as for saving distribution).
Currently, the shards are all gathered between all ranks in the parallelization
group. This might not be optimal (some ranks do not need all tensors),
but it's a reasonable approximation for an optimal exchange in most scenarios.
Args:
sharded_state_dict (ShardedStateDict): sharded state dict to load
checkpoint_dir (Path): checkpoint directory to load from
Returns:
StateDict: loaded state dict. The state dict should be equivalent to
a state dict that would be loaded with the underlying strategy
without this wrapper.
"""
if torch.distributed.get_world_size(self.parallelization_group) <= 1:
return self.base_strategy.load(sharded_state_dict, checkpoint_dir)
# Step 1 and 2: exchange load metadata and distribute the load
start = time()
precomputed_distribution = self.apply_loading_parallelization(sharded_state_dict)
assert (
precomputed_distribution is not None
), 'Expecting non-trivial distribution for non-trivial parallelization group'
end = time()
logger.debug(f'self.apply_loading_parallelization took {end - start}s')
start = end
# Step 3: load part of the checkpoint.
# Load only sharded objects first. ShardedTensors will be loaded separately
# so that we can keep track of sharded tensors loaded by this rank
(
sharded_tensors,
sharded_state_dict,
to_load_shards,
unloaded_shards,
) = self._defer_loading_sharded_tensors(sharded_state_dict)
loaded_state_dict = self.base_strategy.load(sharded_state_dict, checkpoint_dir)
end = time()
logger.debug(f'Base load of ShardedObjects took {end - start}s')
start = end
# Load sharded tensors separately
loaded_tensors = self.base_strategy.load(to_load_shards, checkpoint_dir)
end = time()
logger.debug(f'Base load of ShardedTensors took {end - start}s')
start = end
# Step 4: exchange data between ranks
logger.debug(f'Applying parallel load with algo {self.exchange_algo}')
if self.exchange_algo == 'gather_object':
exchange_fn = self.exchange_loaded_tensors_gather_object
elif self.exchange_algo == 'gather_rounds':
exchange_fn = self.exchange_loaded_tensors_gather_rounds
elif self.exchange_algo == 'broadcast':
exchange_fn = self.exchange_loaded_tensors_broadcast
else:
raise NotImplementedError(f'Unrecognized gather algorithm: {self.exchange_algo}')
all_loaded_tensors = exchange_fn(
loaded_tensors, unloaded_shards, precomputed_distribution, self.parallelization_group,
)
if not set(unloaded_shards.keys()).issubset(all_loaded_tensors.keys()):
missing_shards = set(unloaded_shards.keys()) - all_loaded_tensors.keys()
raise CheckpointingException(
f'Missing shards after fully parallel loading: {missing_shards}'
)
sync_start = time()
torch.cuda.synchronize()
end = time()
logger.debug(f'torch.cuda.synchronize took {end - sync_start}s')
logger.debug(f'self.exchange_loaded_tensors took {end - start}s')
self.fill_in_deferred_sharded_tensors(sharded_tensors, all_loaded_tensors)
merge(loaded_state_dict, sharded_tensors)
return loaded_state_dict
def _defer_loading_sharded_tensors(
self, sharded_state_dict: ShardedStateDict
) -> Tuple[
ShardedStateDict,
ShardedStateDict,
Dict[_ShardId, ShardedTensor],
Dict[_ShardId, ShardedTensor],
]:
""" Divides state dict into parts loaded by this vs other ranks.
ShardedTensors with main replica_id will be loaded by this rank,
others will be received by other ranks (after loading from storage).
Args:
sharded_state_dict (ShardedStateDict): state dict with ShardedTensor
that will be divided.
Returns: a tuple of:
- ShardedStateDict: sub-state dict only with ShardedTensors
- ShardedStateDict: sub-state dict with non-ShardedTensors
- Dict[_ShardId, ShardedTensor]: ShardedTensor are uniquely identified
by shard ids. This is a mapping from shard id to a corresponding
ShardedTensor for tensors loaded by *this* rank
- Dict[_ShardId, ShardedTensor]: mapping from shard id to a corresponding
ShardedTensor for tensors loaded by *other* ranks
"""
to_load_shards = {}
unloaded_shards = {}
sharded_tensors, sharded_state_dict = extract_matching_values(
sharded_state_dict, lambda v: isinstance(v, ShardedTensor)
)
def wrap_non_main_replicas(x):
if isinstance(x, ShardedTensor):
# Assign shard to be loaded or not
if is_main_replica(x.replica_id):
to_load_shards[_sharded_tensor_shard_id(x)] = x
else:
unloaded_shards[_sharded_tensor_shard_id(x)] = x
return x
dict_list_map_inplace(wrap_non_main_replicas, sharded_tensors)
return sharded_tensors, sharded_state_dict, to_load_shards, unloaded_shards
def apply_loading_parallelization(
self, sharded_state_dict: ShardedStateDict
) -> Optional[SaveLoadDistribution]:
""" Distributes the load across ranks by exchanging metadata.
Exchanges metadata from the state dict and computes the uniform
(as close as possible) distribution of loads among the ranks.
Marks ShardedTensors to be loaded by the current rank with replica_id 0
(and others with non 0 values).
If `self.do_cache_distribution` is True, caches the distribution between
the calls and subsequent distributions happen without any inter-rank
communication.
Args:
sharded_state_dict (ShardedStateDict): state dict to distribute the loading
Returns:
SaveLoadDistribution (optional): the computed loading distribution
"""
if self.do_cache_distribution and self.cached_distribution is not None:
logger.debug(f'Apply *cached* load parallelization')
precomputed_distribution = self.cached_distribution
else:
logger.debug(f'Apply load parallelization')
precomputed_distribution = determine_main_replica_uniform_distribution(
sharded_state_dict, self.parallelization_group, True
)
distribute_main_replicas_with_precomputed_distribution(
sharded_state_dict, self.parallelization_group, precomputed_distribution
)
if self.do_cache_distribution:
self.cached_distribution = precomputed_distribution
return precomputed_distribution
def exchange_loaded_tensors_gather_object(
self,
loaded_tensors: Dict[_ShardId, torch.Tensor],
unloaded_shards: Dict[_ShardId, ShardedTensor],
precomputed_distribution: SaveLoadDistribution,
parallelization_group: Optional[torch.distributed.ProcessGroup] = None,
) -> Dict[_ShardId, torch.Tensor]:
""" Exchange the tensors loaded by different ranks with a simple all_gather_object call.
This version can be used for debugging purposes do to its simplistic
implementation. Shouldn't be used if performance is important.
Args:
loaded_tensors (Dict[_ShardId, torch.Tensor]): mapping from ShardedTensor
shard ids to tensors already loaded by this rank.
unloaded_shards (Dict[_ShardId, torch.Tensor]): mapping from ShardedTensor
shard ids to ShardedTensors that aren't loaded yet.
precomputed_distribution (SaveLoadDistribution): uniform load distribution
parallelization_group (ProcessGroup, optional): process group used for load
distribution. Tensors will be exchanged within this group
Returns:
Dict[_ShardId, torch.Tensor]: dictionary mapping shard ids to tensors
needed by this rank to load a given state dict. Includes
previously loaded tensors (from `loaded_tensors` input)
"""
all_loaded_tensors_list = [None] * torch.distributed.get_world_size(
group=parallelization_group
)
torch.distributed.all_gather_object(
all_loaded_tensors_list, loaded_tensors, group=parallelization_group
)
all_loaded_tensors_list = cast(List[Dict[_ShardId, torch.Tensor]], all_loaded_tensors_list)
all_loaded_tensors = reduce(lambda x, y: {**x, **y}, all_loaded_tensors_list)
# Error checks
if len(all_loaded_tensors) != sum(map(len, all_loaded_tensors_list)):
err_msg = 'Duplicate shard ids loaded by different ranks'
if torch.distributed.get_rank() == 0:
logger.error(
f'{err_msg}. Shards ids by rank: {[lt.keys() for lt in all_loaded_tensors_list]}'
)
raise CheckpointingException(err_msg)
return all_loaded_tensors
@torch.no_grad()
def exchange_loaded_tensors_gather_rounds(
self,
loaded_tensors: Dict[_ShardId, torch.Tensor],
unloaded_shards: Dict[_ShardId, ShardedTensor],
precomputed_distribution: SaveLoadDistribution = None,
parallelization_group: Optional[torch.distributed.ProcessGroup] = None,
) -> Dict[_ShardId, torch.Tensor]:
""" Exchange the tensors loaded by different ranks with several all_gather calls.
Groups tensors by dtype, divide tensors that will be exchanged into rounds
and execute all_gather for tensors from each round.
Note: the loading is distributed across ranks based on total loaded size
in bytes, so there is no guarantee that number of rounds needed for each
rank will be similar, which might result in a lot of almost empty
all_gathers. The solution would be to group all tensors into a one
bytes tensor and do a single all_gather (with similarly sized messages).
Args:
loaded_tensors (Dict[_ShardId, torch.Tensor]): mapping from ShardedTensor
shard ids to tensors already loaded by this rank.
unloaded_shards (Dict[_ShardId, torch.Tensor]): mapping from ShardedTensor
shard ids to ShardedTensors that aren't loaded yet.
precomputed_distribution (SaveLoadDistribution): uniform load distribution
parallelization_group (ProcessGroup, optional): process group used for load
distribution. Tensors will be exchanged within this group
Returns:
Dict[_ShardId, torch.Tensor]: dictionary mapping shard ids to tensors
needed by this rank to load a given state dict. Includes
previously loaded tensors (from `loaded_tensors` input)
"""
shard_to_saving_rank, _, shard_to_metadata = precomputed_distribution
local_rank = torch.distributed.get_rank(group=self.parallelization_group)
all_loaded_tensors = dict(loaded_tensors)
# Group by dtype so that we all_gather tensors of the same dtype
for dtype in sorted(
set(map(lambda sh_ten: sh_ten.dtype, shard_to_metadata.values())), key=str
):
start = time()
# shards_by_rank maps rank to tensors loaded by this rank
shards_by_rank: List[List[torch.Tensor]] = [
[] for _ in range(torch.distributed.get_world_size(group=parallelization_group))
]
for shard_id, rank in shard_to_saving_rank.items():
if shard_to_metadata[shard_id].dtype == dtype:
shards_by_rank[rank].append(shard_id)
# Transpose `shards_by_rank` to form exchange rounds
shards_by_round = zip_longest(*shards_by_rank, fillvalue=None)
for round_idx, round_shard_ids in enumerate(shards_by_round):
round_tensors = []
for rank, shard_id in enumerate(round_shard_ids):
if shard_id is None:
# if no more useful data, the given rank will exchange empty tensor
local_ten = torch.empty(0, dtype=dtype, device='cuda')
else:
assert isinstance(shard_id, tuple), type(shard_id)
if rank == local_rank:
assert shard_id in all_loaded_tensors, (
shard_id,
all_loaded_tensors.keys(),
)
all_loaded_tensors[shard_id] = all_loaded_tensors[shard_id].cuda()
local_ten = all_loaded_tensors[shard_id]
else:
local_ten = self._get_empty_tensor_for_exchange(
shard_id, shard_to_metadata, unloaded_shards, all_loaded_tensors
)
round_tensors.append(local_ten)
torch.distributed.all_gather(
list(round_tensors),
round_tensors[local_rank],
group=self.parallelization_group,
async_op=True,
)
del round_tensors # remove tensor references
end = time()
if torch.distributed.get_rank() == 0:
logger.debug(f'{dtype} exchange rounds all_gather schedule took {end - start}s')
return all_loaded_tensors
@torch.no_grad()
def exchange_loaded_tensors_broadcast(
self,
loaded_tensors: Dict[_ShardId, torch.Tensor],
unloaded_shards: Dict[_ShardId, ShardedTensor],
precomputed_distribution: SaveLoadDistribution = None,
parallelization_group: Optional[torch.distributed.ProcessGroup] = None,
) -> Dict[_ShardId, torch.Tensor]:
""" Exchange the tensors loaded by different ranks by a series of broadcasts.
For each rank for each loaded tensor do a broadcast to the whole group.
A reasonable tradeoff in terms of performance and simplicity.
Args:
loaded_tensors (Dict[_ShardId, torch.Tensor]): mapping from ShardedTensor
shard ids to tensors already loaded by this rank.
unloaded_shards (Dict[_ShardId, torch.Tensor]): mapping from ShardedTensor
shard ids to ShardedTensors that aren't loaded yet.
precomputed_distribution (SaveLoadDistribution): uniform load distribution
parallelization_group (ProcessGroup, optional): process group used for load
distribution. Tensors will be exchanged within this group
Returns:
Dict[_ShardId, torch.Tensor]: dictionary mapping shard ids to tensors
needed by this rank to load a given state dict. Includes
previously loaded tensors (from `loaded_tensors` input)
"""
shard_to_saving_rank, _, shard_to_metadata = precomputed_distribution
local_rank = torch.distributed.get_rank(group=self.parallelization_group)
all_loaded_tensors = dict(loaded_tensors)
start = time()
for shard_id, rank in shard_to_saving_rank.items():
if rank == local_rank:
assert shard_id in all_loaded_tensors, (shard_id, all_loaded_tensors.keys())
all_loaded_tensors[shard_id] = all_loaded_tensors[shard_id].cuda()
local_ten = all_loaded_tensors[shard_id]
else:
local_ten = self._get_empty_tensor_for_exchange(
shard_id, shard_to_metadata, unloaded_shards, all_loaded_tensors
)
global_src_rank = torch.distributed.get_global_rank(parallelization_group, rank)
torch.distributed.broadcast(
local_ten, src=global_src_rank, group=parallelization_group, async_op=True
)
end = time()
if torch.distributed.get_rank() == 0:
logger.debug(f'exchange broadcast schedule took {end - start}s')
return all_loaded_tensors
def _get_empty_tensor_for_exchange(
self,
shard_id: _ShardId,
needed_shards: Dict[_ShardId, ShardedTensor],
unneeded_shards: Dict[_ShardId, ShardedTensor],
loaded_tensors: Dict[_ShardId, torch.Tensor],
) -> torch.Tensor:
""" Determines the empty tensor to use for exchange.
If shard_id is needed by this rank, it will be in the `unloaded_shards`.
Otherwise, the metadata for this tensor can be found in `shard_to_metadata`
Args:
shard_id (_ShardId): shard_id that will be exchanged
needed_shards (Dict[_ShardId, ShardedTensor]): mapping from shard ids
to metadata for shards needed by this rank
unneeded_shards (Dict[_ShardId, ShardedTensor]): mapping from shard ids
to metadata for shards that can be discarded after exchange
loaded_tensors (Dict[_ShardId, torch.Tensor]): mapping where useful tensors
are placed in
Returns:
torch.Tensor: empty tensor to be exchanged
"""
local_unloaded_sh_ten = needed_shards.get(shard_id)
if local_unloaded_sh_ten is None:
sh_ten = unneeded_shards[shard_id]
sh_ten.init_data('cuda')
tensor = sh_ten.data
sh_ten.data = None # won't be used. free memory
else:
local_unloaded_sh_ten.init_data('cuda')
tensor = local_unloaded_sh_ten.data
loaded_tensors[shard_id] = tensor
return tensor
def fill_in_deferred_sharded_tensors(
self, sharded_state_dict: ShardedStateDict, loaded_tensors: Dict[_ShardId, torch.Tensor]
) -> None:
""" Fill in tensors not loaded by current rank with tensors from `loaded_tensors` map.
Args:
sharded_state_dict (ShardedStateDict): sharded state dict to fill in.
ShardedTensors are completely replaced with corresponding torch.Tensors.
loaded_tensors (Dict[_ShardId, torch.Tensor]): dict allowing to map
ShardedTensor from the sharded_state_dict to loaded tensors.
Returns:
"""
def fill_in_sharded_tensor(x):
if isinstance(x, ShardedTensor):
try:
x = loaded_tensors[_sharded_tensor_shard_id(x)]
except KeyError as e:
raise CheckpointingException(
f'Missing loaded tensor shard: {_sharded_tensor_shard_id(x)}'
) from e
return x
dict_list_map_inplace(fill_in_sharded_tensor, sharded_state_dict)
@property
def can_handle_sharded_objects(self):
return self.base_strategy.can_handle_sharded_objects
def load_tensors_metadata(self, checkpoint_dir: Path):
self.base_strategy.load_tensors_metadata(checkpoint_dir)
def check_backend_compatibility(self, loaded_version):
self.base_strategy.check_backend_compatibility(loaded_version)
def check_version_compatibility(self, loaded_version):
self.base_strategy.check_version_compatibility(loaded_version)
def _sharded_tensor_shard_id(sharded_tensor: ShardedTensor) -> _ShardId:
""" Unique id of the sharded tensor data.
Should yield the same value for same data replicated on different ranks.
Args:
sharded_tensor (ShardedTensor): sharded tensor representing the data shard
Returns (tuple): unique id of a data shard
"""
f_range = sharded_tensor.flattened_range
return (
sharded_tensor.key,
sharded_tensor.global_offset,
None if f_range is None else (f_range.start, f_range.stop),
)
def _shard_size(sh_ten: ShardedTensor):
""" Returns size in bytes of a given sharded tensor. """
if sh_ten.flattened_range is None:
numel = np.product(sh_ten.local_shape)
else:
numel = sh_ten.flattened_range.stop - sh_ten.flattened_range.start
return numel * torch._utils._element_size(sh_ten.dtype)
def determine_main_replica_uniform_distribution(
sharded_state_dict: ShardedStateDict,
parallelization_group: torch.distributed.ProcessGroup,
is_loading: bool = False,
) -> Optional[SaveLoadDistribution]:
""" Computes the save distribution.
Should be used in conjunction with `distribute_main_replicas_with_precomputed_distribution`
which applies the computed save distribution.
We rely on the fact that the assignment algorithm is deterministic on all ranks,
so there is no extra communication needed after metadata exchange.
Args:
sharded_state_dict (ShardedStateDict): state dict to compute the distribution of
parallelization_group (ProcessGroup): distribution will be computed
within this process group
is_loading (bool, optional): whether the distribution is for loading or saving.
For loading, even non-main replicas must be loaded by this parallelization
group. Defaults to False.
Returns (SaveLoadDistribution, optional): distribution that can be used to apply the
parallelization. Returns None if the process_group is trivial (1 rank)
"""
group_size = torch.distributed.get_world_size(group=parallelization_group)
if group_size <= 1:
return
local_shards = list(
sh_base
for sh_base in nested_values(sharded_state_dict)
if isinstance(sh_base, ShardedTensor)
)
local_shards_no_data = [ten.without_data() for ten in local_shards]
all_shards = [None] * torch.distributed.get_world_size(group=parallelization_group)
torch.distributed.all_gather_object(
all_shards, local_shards_no_data, group=parallelization_group
)
shard_to_ranks = defaultdict(list)
shard_to_size = {}
shard_to_metadata = {}
shards_saved_by_this_parallelization_group: Set[_ShardId] = set()
for rank, rank_shards in enumerate(all_shards):
for sh_ten in rank_shards:
shard_id = _sharded_tensor_shard_id(sh_ten)
shard_to_ranks[shard_id].append(rank)
if shard_id not in shard_to_size:
shard_to_size[shard_id] = _shard_size(sh_ten)
shard_to_metadata[shard_id] = sh_ten
if is_main_replica(sh_ten.replica_id) or is_loading:
shards_saved_by_this_parallelization_group.add(shard_id)
shard_to_ranks = {
k: v for k, v in shard_to_ranks.items() if k in shards_saved_by_this_parallelization_group
}
shard_to_saving_rank = distribute_shards_to_ranks(
shard_to_ranks, shard_to_size, len(all_shards)
)
return SaveLoadDistribution(
shard_to_saving_rank, shards_saved_by_this_parallelization_group, shard_to_metadata
)
def distribute_main_replicas_with_precomputed_distribution(
sharded_state_dict: ShardedStateDict,
parallelization_group: torch.distributed.ProcessGroup,
precomputed_distribution: Optional[SaveLoadDistribution],
):
""" Applies the save distribution computed with `determine_main_replica_uniform_distribution`.
Based on rank assignment, sets replica ids of the shards saved by current rank to 0
and all the other replica ids to 1.
Args:
sharded_state_dict (ShardedStateDict): state dict to apply the save distribution to
parallelization_group (ProcessGroup): distribution will be applied within this
process group. Must match with the process group passed to
`determine_main_replica_uniform_distribution`.
precomputed_distribution (SaveLoadDistribution): distribution computed with
`determine_main_replica_uniform_distribution`
Returns: None
Example replica ids of tensors A, B, C before distribution:
rank0: A: (0, 0, 0), B: (0, 0, 0), C: (0, 0, 0)
rank1: A: (0, 0, 1), B: (0, 0, 1), C: (0, 0, 1)
rank2: A: (0, 0, 2), B: (0, 0, 2), C: (0, 0, 2)
Replicas after distribution for the example above:
rank0: A: 0, B: 1, C: 1
rank1: A: 1, B: 0, C: 1
rank2: A: 1, B: 1, C: 0
"""
if torch.distributed.get_world_size(group=parallelization_group) <= 1:
return
if precomputed_distribution is None:
raise ValueError(
'precomputed_distribution must be not None for non-trivial parallelization group'
)
local_shards = list(
sh_base
for sh_base in nested_values(sharded_state_dict)
if isinstance(sh_base, ShardedTensor)
)
rank_within_dp_group = torch.distributed.get_rank(parallelization_group)
for sh_ten in local_shards:
shard_id = _sharded_tensor_shard_id(sh_ten)
if (
shard_id in precomputed_distribution.shards_in_this_group
and rank_within_dp_group == precomputed_distribution.main_rank_for_shard[shard_id]
):
sh_ten.replica_id = 0
else:
sh_ten.replica_id = 1
T = TypeVar('T')
def distribute_shards_to_ranks(
shard_to_ranks: Dict[T, List[int]], shard_to_size: Dict[T, int], num_ranks: int
) -> Dict[T, int]:
""" Computes uniform distribution of workload across ranks, based on sizes.
Currently, the assignment is greedy, based on:
1. Firstly, the coverage of each shard
(how many ranks the shard is available on; lower coverage is assigned first)
2. Secondly, the size of each shard (larger size is assigned first)
3. Finally, shard id for differentiation.
Third step is added because we rely on the fact that the assignment is deterministic on all ranks.
Args:
shard_to_ranks (Dict[T, List[int]]): mapping which tells which rank have access to which shards
shard_to_size (Dict[T, int]): sizes of each shard
num_ranks (int): number of ranks in the parallelization group
Returns (Dict[T, int]): assignment of shard to rank (which rank should do the work
to achieve maximal uniformity)
"""
shard_to_ranks = {k: tuple(v) for k, v in shard_to_ranks.items()}
shard_to_saving_rank = {}
rank_sizes = [(0, rank) for rank in range(num_ranks)]
# start from tensors with lowest coverage, then go by tensor size from largest (hence minus size)
for shard_id, shard_ranks in sorted(
shard_to_ranks.items(),
key=lambda sh_id_ranks: (
len(sh_id_ranks[1]),
-shard_to_size[sh_id_ranks[0]],
sh_id_ranks[0],
),
):
# assign greedily to the least occupied rank
size, rank = min((size, rank) for size, rank in rank_sizes if rank in shard_ranks)
shard_to_saving_rank[shard_id] = rank
rank_sizes[rank] = (size + shard_to_size[shard_id], rank)
logger.debug(f'distribute_shards_to_ranks distribution: {rank_sizes}')
return shard_to_saving_rank
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
""" State dict saver for PyT Distributed format allowing asynchronous save. """
from logging import getLogger
from time import time
from typing import TYPE_CHECKING, Optional, Tuple, cast
import torch
import torch.distributed as dist
from torch.distributed.checkpoint import CheckpointException
from torch.distributed.checkpoint.default_planner import DefaultSavePlanner
from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE, Metadata
from torch.distributed.checkpoint.planner import SavePlanner
from torch.distributed.checkpoint.utils import _DistWrapper, _get_failure_dict
if TYPE_CHECKING:
from .filesystem_async import FileSystemWriterAsync
logger = getLogger(__name__)
def save_state_dict_async_plan(
state_dict: STATE_DICT_TYPE,
storage_writer: 'FileSystemWriterAsync',
process_group: Optional[dist.ProcessGroup] = None,
coordinator_rank: int = 0,
planner: Optional[SavePlanner] = None,
) -> Tuple['FileSystemWriterAsync', Metadata, _DistWrapper]:
"""
First stage of saving a state dict to storage.
This is an async adjustment of torch.distributed.checkpoint.state_dict_saver.
In order to support async save, saving should be split into three parts:
1. Planning
2. Actual saving
3. Finalization
Out of these, step (2) *must* happen asynchronously.
The first step is realized with this function.
The planning part consists of several steps, described here:
https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.SavePlanner
Args:
state_dict (STATE_DICT_TYPE): state dict to save
storage_writer (FileSystemWriterAsync): in current version only an instance of
FileSystemWriterAsync
process_group (dist.ProcessGroup, optional): process group used for save planning
coordinator_rank (int, optional): coordinator rank for planning. Defaults to 0.
planner (SavePlanner, optional): save planner for torch.distributed.checkpoint format
Returns: Tuple of:
- storage writer (the one passed as input)
- metadata from planning
- distributed wrapper used for planning
The return value of this function should be passed as an input to
`save_state_dict_async_finalize`.
"""
rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
dist_wrapper = _DistWrapper(process_group, True, coordinator_rank)
if planner is None:
planner = DefaultSavePlanner()
assert planner is not None
global_metadata = None
def local_step():
assert planner is not None
planner.set_up_planner(state_dict, dist_wrapper.is_coordinator)
storage_writer.set_up_storage_writer(dist_wrapper.is_coordinator)
local_plan = planner.create_local_plan()
local_plan = storage_writer.prepare_local_plan(local_plan)
return local_plan
def global_step(all_local_plans):
nonlocal global_metadata
assert planner is not None
all_local_plans, global_metadata = planner.create_global_plan(all_local_plans)
all_local_plans = storage_writer.prepare_global_plan(all_local_plans)
return all_local_plans
# Execute local and global planning
start_plan = time()
central_plan = dist_wrapper.reduce_scatter("plan", local_step, global_step)
logger.debug(f"rank: {rank}, plan time: {time() - start_plan}")
# Prepare async writing of tensors.
# The `storage_writer` will store the information about tensors it needs to save
start = time()
final_local_plan = planner.finish_plan(central_plan)
storage_writer.prepare_write_data(final_local_plan, planner)
end = time()
logger.debug(f"{time()} rank: {rank}, write(async) time: {end - start}")
return storage_writer, cast(Metadata, global_metadata), dist_wrapper
def save_state_dict_async_finalize(
storage_writer: 'FileSystemWriterAsync', global_metadata: Metadata, dist_wrapper: _DistWrapper,
) -> None:
"""
Finalization of save_state_dict_async_plan.
The input arguments are the same as the save_state_dict_async_plan output,
the `write_results` are retrieved from the storage_writer.
Args:
storage_writer (FileSystemWriterAsync): storage writer used for planning
global_metadata (Metadata): metadata created during planning
dist_wrapper (_DistWrapper): distributed wrapper created during planning
Returns: None
"""
write_results = storage_writer.retrieve_write_results()
# Gather the write results that will be saved to the metadata file.
gather_start = time()
all_results = dist_wrapper.gather_object(write_results)
gather_end = time()
logger.debug(f"{gather_end}, {torch.distributed.get_rank()}, gather: {gather_end-gather_start}")
# Store the metadata on coordinator rank
if dist_wrapper.is_coordinator:
node_failures = _get_failure_dict(all_results)
if len(node_failures) == 0:
assert global_metadata is not None
write_start = time()
storage_writer.finish(global_metadata, all_results)
write_end = time()
logger.debug(f"{write_end}, metadata_write: {write_end - write_start}")
else:
raise CheckpointException("write", node_failures)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment