Commit bc5c7fa7 authored by wxj's avatar wxj
Browse files

第一次测试提交

parent 70fddd0f
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
from .core import check_is_distributed_checkpoint
from .mapping import LocalNonpersitentObject, ShardedTensor
from .serialization import (
load,
load_common_state_dict,
load_plain_tensors,
load_tensors_metadata,
save,
)
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Module for managing distributed checkpoints metadata. """
import json
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Optional
CONFIG_FNAME = 'metadata.json'
class CheckpointingException(Exception):
""" Base checkpointing related exception """
pass
@dataclass
class CheckpointingConfig:
""" Documents backends used in the checkpoint.
Checkpoint config keeps track of formats used for storing the sharded tensors
(sharded_backend) and other objects (common_backend).
Note that versioning is not for the checkpoint content (which is application specific),
but for the checkpoint format itself.
"""
sharded_backend: str
sharded_backend_version: int = 1
common_backend: str = 'torch'
common_backend_version: int = 1
def check_is_distributed_checkpoint(checkpoint_dir):
""" Checks if `metadata.json` exists in the checkpoint and is a valid config.
Args:
checkpoint_dir: checkpoint directory
Returns:
bool: True if `metadata.json` exists in the checkpoint and is a valid config.
"""
return maybe_load_config(checkpoint_dir) is not None
def maybe_load_config(checkpoint_dir: str) -> Optional[CheckpointingConfig]:
""" Returns checkpoint config if `checkpoint_dir` is a distributed checkpoint and None otherwise
Args:
checkpoint_dir: checkpoint directory
Returns:
CheckpointingConfig (optional): None if checkpoint is not a valid distributed checkpoint
"""
config_path = Path(checkpoint_dir, CONFIG_FNAME)
if not config_path.exists():
return None
with config_path.open() as f:
config_dict = json.load(f)
return CheckpointingConfig(**config_dict)
def save_config(config: CheckpointingConfig, checkpoint_dir: str):
""" Save given config to checkpoint directory.
Args:
config: checkpoint config
checkpoint_dir: checkpoint directory
Returns:
None
"""
config_path = Path(checkpoint_dir, CONFIG_FNAME)
with config_path.open('w') as f:
json.dump(asdict(config), f)
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Utilities for operating with dicts and lists.
All functions in this module handle nesting of dicts and lists.
Other objects (e.g. tuples) are treated as atomic leaf types that cannot be traversed.
"""
from collections import defaultdict
from typing import Any, Callable, Iterable, Optional, Tuple, Union
import torch
def extract_matching_values(
x: Union[dict, list], predicate: Callable[[Any], bool], return_lists_as_dicts: bool = False
) -> Tuple[Union[dict, list], Union[dict, list]]:
""" Return matching and nonmatching values. Keeps hierarchy.
Args:
x (Union[dict, list]) : state dict to process. Top-level argument must be a dict or list
predicate (object -> bool): determines matching values
return_lists_as_dicts (bool): if True, matching lists will be turned
into dicts, with keys indicating the indices of original elements.
Useful for reconstructing the original hierarchy.
"""
def _set_elem(target, k, v):
if return_lists_as_dicts:
target[k] = v
else:
target.append(v)
if isinstance(x, dict):
matching_vals = {}
nonmatching_vals = {}
for k, v in x.items():
if isinstance(v, (list, dict)):
match, nonmatch = extract_matching_values(v, predicate, return_lists_as_dicts)
if match:
matching_vals[k] = match
if nonmatch or not v:
nonmatching_vals[k] = nonmatch
elif predicate(v):
matching_vals[k] = v
else:
nonmatching_vals[k] = v
elif isinstance(x, list):
matching_vals = {} if return_lists_as_dicts else []
nonmatching_vals = {} if return_lists_as_dicts else []
for ind, v in enumerate(x):
if isinstance(v, (list, dict)) and v:
match, nonmatch = extract_matching_values(v, predicate, return_lists_as_dicts)
if match:
_set_elem(matching_vals, ind, match)
if nonmatch or not v:
_set_elem(nonmatching_vals, ind, nonmatch)
else:
target = matching_vals if predicate(v) else nonmatching_vals
_set_elem(target, ind, v)
else:
raise ValueError(f'Unexpected top-level object type: {type(x)}')
return matching_vals, nonmatching_vals
def diff(x1: Any, x2: Any, prefix: Tuple = ()) -> Tuple[list, list, list]:
""" Recursive diff of dicts.
Args:
x1 (object): left dict
x2 (object): right dict
prefix (tuple): tracks recursive calls. Used for reporting differing keys.
Returns:
Tuple[list, list, list]: tuple of:
- only_left: Prefixes present only in left dict
- only_right: Prefixes present only in right dict
- mismatch: values present in both dicts but not equal across dicts.
For tensors equality of all elems is checked.
Each element is a tuple (prefix, type of left value, type of right value).
"""
mismatch = []
if isinstance(x1, dict) and isinstance(x2, dict):
only_left = [prefix + (k,) for k in x1.keys() - x2.keys()]
only_right = [prefix + (k,) for k in x2.keys() - x1.keys()]
for k in x2.keys() & x1.keys():
_left, _right, _mismatch = diff(x1[k], x2[k], prefix + (k,))
only_left.extend(_left)
only_right.extend(_right)
mismatch.extend(_mismatch)
elif isinstance(x1, list) and isinstance(x2, list):
only_left = list(range(len(x1) - 1, len(x2) - 1, -1))
only_right = list(range(len(x1) - 1, len(x2) - 1, -1))
for i, (v1, v2) in enumerate(zip(x1, x2)):
_left, _right, _mismatch = diff(v1, v2, prefix + (i,))
only_left.extend(_left)
only_right.extend(_right)
mismatch.extend(_mismatch)
else:
only_left = []
only_right = []
if isinstance(x1, torch.Tensor) and isinstance(x2, torch.Tensor):
_is_mismatch = not torch.all(x1 == x2)
else:
try:
_is_mismatch = bool(x1 != x2)
except RuntimeError:
_is_mismatch = True
if _is_mismatch:
mismatch.append((prefix, type(x1), type(x2)))
return only_left, only_right, mismatch
def inspect_types(x: Any, prefix: Tuple = (), indent: int = 4):
""" Helper to print types of (nested) dict values. """
print_indent = lambda: print(' ' * indent * len(prefix), end='')
if isinstance(x, dict):
print()
for k, v in x.items():
print_indent()
print(f'> {k}: ', end='')
inspect_types(v, prefix + (k,), indent)
elif isinstance(x, list):
print()
for i, v in enumerate(x):
print_indent()
print(f'- {i}: ', end='')
inspect_types(v, prefix + (i,), indent)
else:
if isinstance(x, torch.Tensor):
print(f'Tensor of shape {x.shape}')
else:
try:
x_str = str(x)
except:
x_str = '<no string repr>'
if len(x_str) > 30:
x_str = x_str[:30] + '... (truncated)'
print(f'[{type(x)}]: {x_str}')
def nested_values(x: Union[dict, list]):
""" Returns iterator over (nested) values of a given dict or list. """
x_iter = x.values() if isinstance(x, dict) else x
for v in x_iter:
if isinstance(v, (dict, list)):
yield from nested_values(v)
else:
yield v
def nested_items_iter(x: Union[dict, list]):
""" Returns iterator over (nested) tuples (container, key, value) of a given dict or list. """
x_iter = x.items() if isinstance(x, dict) else enumerate(x)
for k, v in x_iter:
if isinstance(v, (dict, list)):
yield from nested_items_iter(v)
else:
yield x, k, v
def dict_map(f: Callable, d: dict):
""" `map` equivalent for dicts. """
for sub_d, k, v in nested_items_iter(d):
sub_d[k] = f(v)
def dict_map_with_key(f: Callable, d: dict):
""" `map` equivalent for dicts with a function that accepts tuple (key, value). """
for sub_d, k, v in nested_items_iter(d):
sub_d[k] = f(k, v)
def dict_list_map_inplace(f: Callable, x: Union[dict, list]):
""" Maps dicts and lists *in-place* with a given function. """
if isinstance(x, dict):
for k, v in x.items():
x[k] = dict_list_map_inplace(f, v)
elif isinstance(x, list):
x[:] = (dict_list_map_inplace(f, v) for v in x)
else:
return f(x)
return x
def dict_list_map_outplace(f: Callable, x: Union[dict, list]):
""" Maps dicts and lists *out-of-place* with a given function. """
if isinstance(x, dict):
return {k: dict_list_map_outplace(f, v) for k, v in x.items()}
elif isinstance(x, list):
return [dict_list_map_outplace(f, v) for v in x]
else:
return f(x)
def merge(x1: dict, x2: dict, key: Tuple[str, ...] = ()):
""" Merges dicts and lists recursively. """
if isinstance(x1, dict) and isinstance(x2, dict):
for k, v2 in x2.items():
if k not in x1:
x1[k] = v2
else:
x1[k] = merge(x1[k], v2, key=key + (k,))
elif isinstance(x1, list) and isinstance(x2, list):
if len(x1) != len(x2):
raise ValueError(
f'Cannot merge two lists with different lengths ({len(x1)} and {len(x2)}, encountered at level {key})'
)
for i, v2 in enumerate(x2):
x1[i] = merge(x1[i], v2, key=key + (i,))
else:
raise ValueError(
f'Duplicate non-dict and non-list values encountered: `{x1}` and `{x2}` (at level {key})'
)
return x1
def map_reduce(
xs: Iterable,
key_fn: Callable = lambda x: x,
value_fn: Callable = lambda x: x,
reduce_fn: Callable = lambda x: x,
) -> dict:
""" Simple map-reduce implementation following `more_itertools.map_reduce` interface. """
res = defaultdict(list)
for x in xs:
res[key_fn(x)].append(value_fn(x))
for k in res:
res[k] = reduce_fn(res[k])
return dict(res)
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Core library classes for representing sharding of tensors and objects.
The main expected usage is wrapping torch.Tensors in state dicts with
ShardedTensor class (mostly with the ShardedTensor.from_rank_offsets classmethod).
"""
import logging
from abc import ABC
from dataclasses import dataclass, replace
from itertools import chain
from typing import Any, Callable, Dict, Optional, Tuple, Union
import numpy as np
import torch
from .core import CheckpointingException
from .dict_utils import dict_list_map_inplace, dict_list_map_outplace
logger = logging.getLogger(__name__)
# These type definitions are just hints to differentiate a plain model state
# dict (StateDict) from a state dict with tensors replaced with ShardedTensors
# (ShardedStateDict).
StateDict = Dict[str, Any]
ShardedStateDict = Dict[str, Any]
ReplicaId = Union[int, Tuple[int, ...]]
class ShardedBase(ABC):
key: str
data: object
replica_id: ReplicaId
@dataclass
class ShardedTensor(ShardedBase):
"""Represents a mapping between a local tensor and a global tensor.
Global tensor is assumed to consist of many local tensors distributed
between different processes.
Args:
key: unique identifier of a global tensor
data: local tensor data. Can be None only for consistency validation
dtype: tensor dtype
local_shape: local tensor shape
global_shape: global tensor shape
global_offset: offset of a local tensor in a global tensor, specified in number of tensor elements
axis_fragmentations: global tensor fragmentation of each axis
replica_id: indicates given local tensor's replication wrt. local tensors in different processes
prepend_axis_num: number of axes prepended to the local tensor to reflect global tensor shape. The behavior is similar to unsqueezing the local tensor.
allow_shape_mismatch: if True, during loading, the global shape of a stored tensor does not have to match the expected global shape. Useful for representing tensors with flexible shape, e.g. padded.
flattened_range: specifies a slice that should be applied to a flattened tensor with `local_shape` in order to get the tensor stored as `data`
"""
key: str
data: Optional[torch.Tensor]
dtype: torch.dtype
local_shape: Tuple[int, ...]
global_shape: Tuple[int, ...]
global_offset: Tuple[int, ...]
axis_fragmentations: Optional[Tuple[int, ...]]
replica_id: ReplicaId = 0
prepend_axis_num: int = 0
allow_shape_mismatch: bool = False
flattened_range: Optional[slice] = None
def global_slice(self) -> Tuple[Union[int, slice], ...]:
assert len(self.global_offset) == len(self.local_shape) + self.prepend_axis_num
return tuple(
chain(
(off for off in self.global_offset[: self.prepend_axis_num]),
(
slice(off, off + sh)
for off, sh in zip(
self.global_offset[self.prepend_axis_num :], self.local_shape
)
),
)
)
def global_coordinates(self) -> Tuple[np.ndarray, ...]:
if self.flattened_range is None:
raise CheckpointingException(
f'`global_coordinates` is undefined for'
f' {self.__class__.__name__} without `flattened_range`'
)
local_coords = self.local_coordinates()
assert len(local_coords) + self.prepend_axis_num == len(self.global_offset), (
len(local_coords),
self,
)
global_coords = tuple(
c + off
for c, off in zip((0,) * self.prepend_axis_num + local_coords, self.global_offset)
)
return global_coords
def local_coordinates(self) -> Tuple[np.ndarray, ...]:
if self.flattened_range is None:
raise CheckpointingException(
f'`local_coordinates` is undefined for'
f' {self.__class__.__name__} without `flattened_range`'
)
# TODO: np.unravel_index?
mask = np.zeros(np.product(self.local_shape), dtype=bool)
mask[self.flattened_range] = True
return np.nonzero(mask.reshape(self.local_shape))
def max_allowed_chunks(self) -> Tuple[int, ...]:
chunks = []
for axis_sh, axis_fragm in zip(self.global_shape, self.axis_fragmentations):
if not self.allow_shape_mismatch and axis_sh % axis_fragm != 0:
raise CheckpointingException(
f'Axis shape ({axis_sh}) not divisible' f' by axis fragmentation ({axis_fragm}'
)
axis_chunk_size = axis_sh // axis_fragm
chunks.append(axis_chunk_size)
return tuple(chunks)
def without_data(self):
return replace(self, data=None)
@classmethod
def from_rank_offsets(
cls,
key: str,
data: torch.Tensor,
*rank_offsets: Tuple[int, int, int],
replica_id: ReplicaId = 0,
prepend_axis_num: int = 0,
**init_kwargs,
):
"""Allows to construct the ShardedTensor given offset specified in process ranks.
Args:
key: unique key
data: local tensor data
rank_offsets: each tuple (axis, axis_rank_offset, axis_fragm) says that if global tensor is divided into `axis_fragm` fragment along `axis` axis, then local tensor data corresponds to the `axis_rank_offset` chunk.
replica_id: see ShardedTensor
prepend_axis_num: see ShardedTensor
init_kwargs: passed to ShardedTensor.__init__
"""
global_offset = [0] * (data.ndim + prepend_axis_num)
global_shape = ([1] * prepend_axis_num) + list(data.shape)
axis_fragmentations = [1] * (data.ndim + prepend_axis_num)
_seen_axis = set()
for axis, axis_rank_offset, axis_fragm in rank_offsets:
assert axis >= 0 and axis_rank_offset >= 0 and axis_fragm >= 0, (
axis,
axis_rank_offset,
axis_fragm,
)
assert (
axis_rank_offset < axis_fragm
), 'Rank offset must be lower than axis fragmentation'
if axis in _seen_axis:
raise CheckpointingException('Duplicated axis specified')
_seen_axis.add(axis)
local_axis_shape = 1 if axis < prepend_axis_num else data.shape[axis - prepend_axis_num]
global_shape[axis] = axis_fragm * local_axis_shape
global_offset[axis] = axis_rank_offset * local_axis_shape
axis_fragmentations[axis] = axis_fragm
return cls(
key,
data,
data.dtype,
tuple(data.shape),
tuple(global_shape),
tuple(global_offset),
tuple(axis_fragmentations),
replica_id,
prepend_axis_num,
**init_kwargs,
)
def init_data(self, device: torch.device, init_fn=torch.empty):
if self.data is not None:
return
self.data = init_fn(self.local_shape, dtype=self.dtype, device=device)
def __str__(self):
return f'{self.__class__.__name__}(key=\'{self.key}\')'
def is_main_replica(replica_id: ReplicaId):
""" Checks if given `replica_id` is considered as main.
"Main" replica is:
- integer 0
- or an iterable with all 0 elements
It is the application responsibility to set correct replicas for sharded tensors.
Args:
replica_id (Union[int, Tuple[int, ...]]): replica id
Returns:
(bool): True for a "main" replica
"""
if isinstance(replica_id, int):
return replica_id == 0
return all(r == 0 for r in replica_id)
class LocalNonpersitentObject:
"""Object that should not be stored in a checkpoint, but restored locally.
Wrapping any object inside the state dict with LocalNonpersitentObject
will result in:
- during saving, this object will *not* be stored in the checkpoint
- during loading, a local version of this object will be placed in a state dict
"""
def __init__(self, obj):
self.obj = obj
def unwrap(self):
return self.obj
@dataclass
class ShardedObject(ShardedBase):
"""Represents a mapping between a local object and a global object.
Global object is assumed to consist of many local objects distributed
between different processes.
NOTE: Contrary to ShardedTensor, it's impossible to change global object
sharding. Conceptually, ShardedObject is a fully-sharded ShardedTensor
with atomic arbitrary typed elements.
Args:
key: unique identifier of a global tensor
data: local object data. Can be None only for consistency validation
global_shape: global object shape
global_offset: offset of a local object in a global object, specified in number of shards
replica_id: indicates local object replication wrt. local objects in different processes
"""
key: str
data: object
global_shape: Tuple[int, ...]
global_offset: Tuple[int, ...]
replica_id: ReplicaId = 0
def without_data(self):
return replace(self, data=None)
@property
def unique_key(self):
return f'{self.key}/shard_{".".join(map(str, self.global_offset))}_{".".join(map(str, self.global_shape))}'
def __str__(self):
return f'{self.__class__.__name__}(key=\'{self.key}\')'
@dataclass
class ShardedTensorFactory(ShardedBase):
""" Allows to apply transformations to tensors before/after serialization.
The essence of those transformations is that they can be applied to
optimizer states the same way they are applied to the model params.
Builder creates a sub-state-dict out of a tensor before saving, and merger
merges the corresponding state dict after loading.
Args:
key (str): unique identifier of the factory
data (torch.Tensor): original model parameter that will be further transformed by this factory
build_fn (callable): function that transforms the original tensor to a sharded state dict
merge_fn (callable): function that transforms loaded subtree back into a single tensor (inverse of `build_fn`)
replica_id (ReplicaId): indicates factory replication wrt. factories in different processes
"""
key: str
data: torch.Tensor
build_fn: Callable[[str, torch.Tensor, ReplicaId], ShardedStateDict]
merge_fn: Callable[[StateDict], torch.Tensor]
replica_id: ReplicaId = 0
def build(self):
return self.build_fn(self.key, self.data, self.replica_id)
def apply_factories(sharded_state_dict: ShardedStateDict):
""" Turn ShardedTensorFactories into ShardedTensors *in-place*.
Args:
sharded_state_dict (ShardedStateDict): state dict possibly containing ShardedTensorFactory objects
Returns:
None: state dict is modified in place
"""
def apply(x):
if isinstance(x, ShardedTensorFactory):
x = x.build()
return x
dict_list_map_inplace(apply, sharded_state_dict)
def apply_factory_merges(
x1: StateDict, x2: ShardedStateDict, key: Tuple[str, ...] = ()
) -> StateDict:
""" Apply merges defined by ShardedTensorFactories *in-place*.
Args:
x1 (StateDict): state dict loaded from the checkpoint
x2 (ShardedStateDict): subset of `x1` (in terms of dict keys) with ShardedTensorFactory
as (possibly nested) values that define how to merge objects from the `x1` state dict
key (Tuple[str, ...]): current key in a recursive call. Used only for reporting meaningful errors
Returns:
StateDict: `x1` modified in-place
"""
if isinstance(x2, ShardedTensorFactory):
return x2.merge_fn(x1)
# There rest is almost the same as the `merge` function from `dict_utils`
if isinstance(x1, dict) and isinstance(x2, dict):
for k, v2 in x2.items():
if k not in x1:
raise ValueError(
f'Different dict keys encountered in `apply_factory_merges` ({x1.keys()} vs {x2.keys()})'
)
else:
x1[k] = apply_factory_merges(x1[k], v2, key=key + (k,))
elif isinstance(x1, list) and isinstance(x2, list):
if len(x1) != len(x2):
err_msg = f'Cannot merge two lists with different lengths ({len(x1)} and {len(x2)}, encountered at key {key})'
logger.error(err_msg + f'\nx1: {x1}\nx2: {x2}')
raise ValueError(err_msg)
for i, v2 in enumerate(x2):
x1[i] = apply_factory_merges(x1[i], v2, key=key + (i,))
elif isinstance(x1, list) and isinstance(x2, dict):
for k, v2 in x2.items():
if not isinstance(k, int):
raise ValueError(
f'Invalid dict key {k} non-integer type encountered in a list-dict merge at level {key}'
)
if k >= len(x1):
raise ValueError(
f'Dict key {k} out of bound for list of length {len(x1)} (encountered at level {key})'
)
x1[k] = apply_factory_merges(x1[k], v2, key=key + (k,))
else:
raise ValueError(
f'Duplicate non-dict and non-list values encountered: `{x1}` and `{x2} (at key {key})`'
)
return x1
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Helpers for defining sharding for optimizer states based on existing sharding for model parameters. """
import logging
from copy import deepcopy
from dataclasses import replace
from itertools import chain
from typing import Dict, Iterable, List, Tuple, Union
logger = logging.getLogger(__name__)
import torch
from .dict_utils import nested_values
from .mapping import (
LocalNonpersitentObject,
ShardedStateDict,
ShardedTensor,
ShardedTensorFactory,
StateDict,
)
from .utils import extract_sharded_tensors_and_factories
def get_optim_param_to_id_map(optim_params_iter: Iterable[torch.nn.Parameter]) -> Dict[int, int]:
param_mappings = {}
for i, param in enumerate(optim_params_iter):
if id(param) not in param_mappings:
param_mappings[id(param)] = i
return param_mappings
def get_param_id_to_sharded_param_map(
model_sharded_state_dict: ShardedStateDict, optim_params_iter: Iterable[torch.nn.Parameter]
) -> Dict[int, Union[ShardedTensor, ShardedTensorFactory]]:
""" Generate mapping from optimizer state ids to model sharded parameters.
Args:
model_sharded_state_dict: sharded state dict with all model sharded tensors (can have any structure)
optim_params_iter: iterable which iterates over model parameters tracked by the optimizer.
The iteration must be in the same order as in the optimizer parameters.
Returns:
Dict[int, Union[ShardedTensor, ShardedTensorFactory]]: mapping from optimizer state ids
to model sharded parameters.
"""
model_sharded_state_dict, _ = extract_sharded_tensors_and_factories(model_sharded_state_dict)
id_to_sharded_param_map = {}
param_to_id_map = get_optim_param_to_id_map(optim_params_iter)
for ten in nested_values(model_sharded_state_dict):
if id(ten.data) in param_to_id_map:
id_to_sharded_param_map[param_to_id_map[id(ten.data)]] = ten
else:
logger.debug(f'{ten} is not tracked by the optimizer')
if not id_to_sharded_param_map:
logger.warning(
"Sharded parameters mapping is empty. It means tensors in model state dict"
" do not correspond to tensors in optimizer parameters map."
" Make sure to call state_dict with `keep_vars=True`."
)
return id_to_sharded_param_map
def make_sharded_optimizer_tensor(
model_param: Union[ShardedTensor, ShardedTensorFactory], optim_param: torch.Tensor, prefix: str
) -> Union[ShardedTensor, ShardedTensorFactory]:
""" Build a ShardedTensor or ShardedTensorFactory for optimizer param based on model param
Args:
model_param (Union[ShardedTensor, ShardedTensorFactory]): model param
optim_param (torch.Tensor): corresponding optimizer param
prefix (str): optimizer prefix for the ShardedTensor or ShardedTensorFactory
Returns:
Union[ShardedTensor, ShardedTensorFactory]: wrapped optimizer parameter
"""
if isinstance(model_param, ShardedTensorFactory):
return replace(model_param, key=f'{prefix}.{model_param.key}', data=optim_param)
assert (
tuple(optim_param.shape) == model_param.local_shape
), f'Optimizer shape ({tuple(optim_param.shape)} does not match model shape ({model_param.local_shape})'
return replace(
model_param, key=f'{prefix}.{model_param.key}', data=optim_param, dtype=optim_param.dtype
)
def optim_state_to_sharding_state(
optim_state_dict: StateDict,
id_to_sharded_param_map: Dict[int, ShardedTensor],
exclude_keys: Tuple[str] = (),
):
""" Turn optimizer state dict to sharded state dict based on model state dict *in-place*.
Can be used to add sharding information to most common optimizer state dict.
Creates separate ShardedTensors for each key in `optim_state_dict['state']`
(e.g. for torch.optim.Adam there will be separate tensors for `exp_avg` and `exp_avg_sq`)
Args:
optim_state_dict (StateDict): optimizer state dict with
state parameters under `state` key and group hyperparameters under `param_groups` -> `params` key.
id_to_sharded_param_map (Dict[int, ShardedTensor]): mapping from optimizer param ids to model sharded tensors.
Can be generated with `get_param_id_to_sharded_param_map` function
exclude_keys (Tuple[str]): optimizer state keys to exclude from the final state dict.
Returns:
None: state dict is modified in place
"""
sharded_state = {}
for param_id, param_state in optim_state_dict['state'].items():
sharded_state[param_id] = {}
for state_key, param in param_state.items():
if state_key in exclude_keys:
continue
if param_id in id_to_sharded_param_map:
sharded_state[param_id][state_key] = make_sharded_optimizer_tensor(
id_to_sharded_param_map[param_id], param, prefix=f'optimizer.state.{state_key}'
)
else:
raise ValueError(f'Param id {param_id} does not match any model sharded param')
optim_state_dict['param_groups'] = deepcopy(optim_state_dict['param_groups'])
for group in optim_state_dict['param_groups']:
group['params'] = LocalNonpersitentObject(group['params'])
optim_state_dict['state'] = sharded_state
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Entrypoints for saving and loading the distributed checkpoints.
Functions `load` and `save` are equivalents of `torch.load` and `torch.save`
but expect torch.Tensors to be wrapped with classes from the `mapping module`.
Additionally, `load` expects the sharded state dict argument as a guidance for loading the sharded tensors.
"""
import logging
import os
from collections import Counter, defaultdict
from itertools import chain
from pathlib import Path
from typing import Iterable, List, Optional, Tuple, Union
import numpy as np
import torch
from .core import CheckpointingConfig, maybe_load_config, save_config
from .dict_utils import (
dict_list_map_inplace,
diff,
extract_matching_values,
map_reduce,
merge,
nested_values,
)
from .mapping import (
CheckpointingException,
ShardedObject,
ShardedStateDict,
ShardedTensor,
ShardedTensorFactory,
StateDict,
apply_factories,
apply_factory_merges,
is_main_replica,
)
from .strategies.base import (
LoadCommonStrategy,
LoadShardedStrategy,
SaveCommonStrategy,
SaveShardedStrategy,
StrategyAction,
get_default_strategy,
)
from .utils import (
extract_nonpersistent,
extract_sharded_base,
extract_sharded_tensors,
extract_sharded_tensors_or_nonpersistent,
)
COMMON_STATE_FNAME = 'common.pt'
logger = logging.getLogger(__name__)
def load(
sharded_state_dict: ShardedStateDict,
checkpoint_dir: str,
sharded_strategy: Union[LoadShardedStrategy, Tuple[str, int], None] = None,
common_strategy: Union[LoadCommonStrategy, Tuple[str, int], None] = None,
validate_access_integrity: bool = True,
) -> StateDict:
"""Loading entrypoint.
In the steps below, the following verbs refer to corresponding objects:
- load = load from checkpoint
- extract = extract from sharded_state_dict
- add = add to the final state dict
Steps:
1. Load common state dict and form the base of the result state dict
2. Apply factories to sharded_state_dict
3. Extract LocalNonPersistentObject and add
4. (optional) Extract ShardedObjects, load and add
5. Extract ShardedBase, load, apply factory merges and add
Args:
sharded_state_dict (ShardedStateDict): state dict of the existing model
populated with ShardedTensors. Used as a mapping to determine which
parts of global tensors stored in the checkpoint should be loaded.
checkpoint_dir (str): directory with the checkpoint
sharded_strategy (LoadShardedStrategy, Tuple[str, int], optional): configures loading behavior for sharded tensors
common_strategy (LoadCommonStrategy, Tuple[str, int], optional): configures loading behavior for common data
validate_access_integrity (bool default = True): checks if each tensor shard is accessed
exactly once (as main replica) by some process
"""
if common_strategy is not None:
raise NotImplementedError('The only supported common strategy is torch')
sharded_strategy = _verify_checkpoint_and_load_strategy(checkpoint_dir, sharded_strategy)
checkpoint_dir = Path(checkpoint_dir)
common_state_dict = load_common_state_dict(checkpoint_dir)
if not sharded_state_dict:
return common_state_dict
# Create a copy of sharded_state_dict as the passed in state dict may have
# references that prevent tensors from being deallocated
sharded_state_dict, _ = extract_matching_values(sharded_state_dict, lambda x: True)
sh_ten_factories, _ = extract_matching_values(
sharded_state_dict,
lambda x: isinstance(x, ShardedTensorFactory),
return_lists_as_dicts=True,
)
apply_factories(sharded_state_dict)
# Data inside sh_ten_factories no longer needed so delete them to reduce memory usage
def unlink_data(x):
x.data = None
return x
dict_list_map_inplace(unlink_data, sh_ten_factories)
# Non-persistent objects
nonpersistent_state_dict, sharded_state_dict = extract_nonpersistent(sharded_state_dict)
dict_list_map_inplace(lambda o: o.unwrap(), nonpersistent_state_dict)
merge(common_state_dict, nonpersistent_state_dict)
# Sharded base
if not sharded_strategy.can_handle_sharded_objects:
# TODO: implement is a part of common strategy
sharded_objects, sharded_state_dict = load_sharded_objects(
sharded_state_dict, checkpoint_dir
)
merge(common_state_dict, sharded_objects)
sharded_state_dict, _ = extract_sharded_base(sharded_state_dict)
if validate_access_integrity:
validate_sharding_integrity(nested_values(sharded_state_dict))
loaded_state_dict = sharded_strategy.load(sharded_state_dict, checkpoint_dir)
loaded_state_dict = apply_factory_merges(loaded_state_dict, sh_ten_factories)
merge(common_state_dict, loaded_state_dict)
return common_state_dict
def _verify_checkpoint_and_load_strategy(
checkpoint_dir: str, sharded_strategy: Union[LoadShardedStrategy, Tuple[str, int], None] = None,
) -> LoadShardedStrategy:
""" Verifies if checkpoint metadata exists and matches given strategy.
Args:
checkpoint_dir (str): checkpoint directory
sharded_strategy (LoadShardedStrategy, Tuple[str, int], optional): load strategy to be verified
if compatible with the checkpoint content. If None, the default load strategy
for the checkpoint backend will be returned.
"""
if not Path(checkpoint_dir).exists():
raise CheckpointingException(f'Checkpoint directory {checkpoint_dir} does not exist')
saved_config = maybe_load_config(checkpoint_dir)
if saved_config is None:
raise CheckpointingException(f'{checkpoint_dir} is not a distributed checkpoint')
if sharded_strategy is None:
sharded_strategy = get_default_strategy(
StrategyAction.LOAD_SHARDED,
saved_config.sharded_backend,
saved_config.sharded_backend_version,
)
elif isinstance(sharded_strategy, tuple):
sharded_strategy = get_default_strategy(StrategyAction.LOAD_SHARDED, *sharded_strategy)
# TODO: implement consistency checks here
return sharded_strategy
# TODO: implement it as common torch strategy
def load_common_state_dict(checkpoint_dir: Path) -> StateDict:
""" Load common (non-sharded) objects state dict from the checkpoint.
Args:
checkpoint_dir (Path): checkpoint directory
Returns:
StateDict: state dict with non-sharded objects from the checkpoint
"""
load_path = Path(checkpoint_dir) / COMMON_STATE_FNAME
try:
return torch.load(load_path, map_location='cpu')
except FileNotFoundError as e:
err_msg = f'Common file {load_path} does not exist'
ckpt_files = [f.name for f in checkpoint_dir.iterdir()]
logger.debug(f'{err_msg}. Checkpoint directory content: {ckpt_files}')
raise CheckpointingException(err_msg) from e
def load_sharded_objects(sharded_state_dict: ShardedStateDict, checkpoint_dir: Path):
""" Replaces all ShardedObject from a given state dict with values loaded from the checkpoint.
Args:
sharded_state_dict (ShardedStateDict): sharded state dict defining what objects should be loaded.
checkpoint_dir (Path): checkpoint directory
Returns:
None: state dict is modified in place
"""
sharded_objects, sharded_state_dict = extract_matching_values(
sharded_state_dict, lambda v: isinstance(v, ShardedObject)
)
def load_sharded_object(sh_obj: ShardedObject):
sh_obj.data = None
load_path = (checkpoint_dir / sh_obj.unique_key).with_suffix('.pt')
try:
loaded_obj = torch.load(load_path)
except FileNotFoundError as e:
err_msg = f'Object shard {load_path} not found'
obj_subdir = checkpoint_dir / sh_obj.key
if obj_subdir.exists():
obj_files = [f.name for f in obj_subdir.iterdir()]
logger.debug(f'{err_msg}. Object {sh_obj.key} directory content: {obj_files}')
else:
ckpt_files = [f.name for f in checkpoint_dir.iterdir()]
logger.debug(
f'{err_msg}. Object {sh_obj.key} directory does not exist. Checkpoint directory content: {ckpt_files}'
)
raise CheckpointingException(err_msg) from e
return loaded_obj
return dict_list_map_inplace(load_sharded_object, sharded_objects), sharded_state_dict
def load_tensors_metadata(
checkpoint_dir: str, sharded_strategy: Union[LoadShardedStrategy, None] = None
) -> ShardedStateDict:
"""Load tensors metadata from the checkpoint.
Returns a dictionary similar to a sharded state dict, but note that
the dictionary keys are simply ShardedTensor keys (contrary to the
actual sharded state dicts where keys correspond to state dict keys).
Dict values are ShardedTensors without any sharding (so, the only useful
information is tensors global shape and dtype).
Concrete implementation depends on the loading strategy. If no strategy is
given, a default for a given backend is used.
"""
sharded_strategy = _verify_checkpoint_and_load_strategy(checkpoint_dir, sharded_strategy)
return sharded_strategy.load_tensors_metadata(Path(checkpoint_dir))
def load_plain_tensors(checkpoint_dir: str):
"""Load checkpoint tensors without any sharding.
NOTE: common state dict is NOT included."""
sharded_state_dict = load_tensors_metadata(checkpoint_dir)
# Don't validate integrity because shards will be overlapped
# if world_size > 1 (all processes load whole tensors)
return load(sharded_state_dict, checkpoint_dir, validate_access_integrity=False)
def save(
sharded_state_dict: ShardedStateDict,
checkpoint_dir: str,
sharded_strategy: Union[SaveShardedStrategy, Tuple[str, int], None] = None,
common_strategy: Union[SaveCommonStrategy, Tuple[str, int], None] = None,
validate_access_integrity: bool = True,
) -> None:
"""Saving entrypoint.
Extracts ShardedTensors from the given state dict. Rank 0 saves the
"regular" part of the checkpoint to common torch file.
The ShardedTensors are saved according to a strategy specified by the
config.
Steps:
1. Apply factories
2. Extract and discard LocalNonPersistentObject
3. Extract all ShardedBase object
4. Save all other objects to common.pt
5. (optional) Extract and save ShardedObjects
6. Save all ShardedBase objects
Args:
sharded_state_dict (ShardedStateDict): state dict of the populated with
ShardedTensors. Used as a mapping to determine how local tensors
should be saved as global tensors in the checkpoint.
checkpoint_dir (str): directory to save the checkpoint to
sharded_strategy (SaveShardedStrategy, Tuple[str, int], optional): configures sharded tensors saving behavior and backend
common_strategy (SaveCommonStrategy, Tuple[str, int], optional): configures common data saving behavior and backend
validate_access_integrity (bool default = True): checks if each tensor shard is accessed
exactly once (as main replica) by some process
"""
checkpoint_dir = Path(checkpoint_dir)
if torch.distributed.get_rank() == 0:
if not checkpoint_dir.exists():
raise CheckpointingException(
f'Checkpoint destination directory does not exist: {checkpoint_dir}'
)
if next(checkpoint_dir.iterdir(), None) is not None:
raise CheckpointingException(
f'Checkpoint destination directory ({checkpoint_dir}) is not empty'
)
if common_strategy is not None:
raise NotImplementedError('The only supported common strategy is torch')
if sharded_strategy is None:
sharded_strategy = ('zarr', 1)
if not isinstance(sharded_strategy, SaveShardedStrategy):
assert isinstance(sharded_strategy, tuple), type(sharded_strategy)
sharded_strategy = get_default_strategy(StrategyAction.SAVE_SHARDED, *sharded_strategy)
apply_factories(sharded_state_dict)
_, sharded_state_dict = extract_nonpersistent(sharded_state_dict)
sharded_state_dict, state_dict = extract_sharded_base(sharded_state_dict)
_save_common_dict(state_dict, checkpoint_dir, True)
if validate_access_integrity:
validate_sharding_integrity(list(nested_values(sharded_state_dict)))
if not sharded_strategy.can_handle_sharded_objects:
# TODO: implement is a part of common strategy
sharded_state_dict = _extract_and_save_sharded_objects(
sharded_state_dict, checkpoint_dir, validate_access_integrity
)
sharded_strategy.save(sharded_state_dict, checkpoint_dir)
if torch.distributed.get_rank() == 0:
save_config(
CheckpointingConfig(sharded_strategy.backend, sharded_strategy.version), checkpoint_dir
)
torch.distributed.barrier()
# TODO: implement it as common torch strategy
def _save_common_dict(
state_dict: StateDict, checkpoint_dir: Path, validate_consistency: bool = False
):
if torch.distributed.get_rank() == 0:
torch.save(state_dict, checkpoint_dir / COMMON_STATE_FNAME)
if validate_consistency:
# TODO: implement checking consistency with rank 0 common dict on other ranks
pass
# torch.distributed.barrier()
# if not torch.distributed.get_rank() == 0:
# rank_0_state_dict = torch.load(checkpoint_dir / COMMON_STATE_FNAME)
# print(diff(common_state_dict, rank_0_state_dict))
def _extract_and_save_sharded_objects(
state_dict: StateDict, checkpoint_dir: Path, validate_consistency: bool = False
):
sharded_objects, state_dict = extract_matching_values(
state_dict, lambda v: isinstance(v, ShardedObject)
)
sharded_objects = list(nested_values(sharded_objects))
for sh_obj in sharded_objects:
if is_main_replica(sh_obj.replica_id):
save_path = (checkpoint_dir / sh_obj.unique_key).with_suffix('.pt')
os.makedirs(save_path.parent, exist_ok=True)
torch.save(sh_obj.data, save_path)
return state_dict
def validate_sharding_integrity(sharded_tensors: Iterable[ShardedTensor]):
""" Validate if the ShardedTensors from multiple processes define correct sharding of a global tensor.
Local ShardedTensors metadata is exchanged with `torch.distributed.all_gather_object`
and then process with global rank 0 checks if main replicas of the shards:
- cover the whole global tensors
- don't overlap
Args:
sharded_tensors (Iterable[ShardedTensor]): sharded tensors local to this process
Returns:
None
Raises:
CheckpointingException for invalid access pattern
"""
sharding = [ten.without_data() for ten in sharded_tensors]
all_sharding = [None] * torch.distributed.get_world_size()
torch.distributed.all_gather_object(all_sharding, sharding)
if torch.distributed.get_rank() != 0:
return
key_shardings = defaultdict(list)
for rank, rank_shardings in enumerate(all_sharding):
for sharding in rank_shardings:
key_shardings[sharding.key].append((rank, sharding))
for key, shardings in key_shardings.items():
if isinstance(shardings[0][1], ShardedObject):
_validate_objects_for_key(shardings)
else:
_validate_sharding_for_key(shardings)
def _validate_sharding_for_key(rank_sharding: List[Tuple[int, ShardedTensor]]):
some_rank_shard = rank_sharding[0][1]
global_shape = some_rank_shard.global_shape
local_shape = some_rank_shard.local_shape
dtype = some_rank_shard.dtype
has_flattened_range = some_rank_shard.flattened_range is not None
for rank, sharding in rank_sharding:
assert sharding.dtype == dtype, (sharding.dtype, dtype, some_rank_shard)
assert sharding.global_shape == global_shape, (
sharding.global_shape,
global_shape,
some_rank_shard,
)
assert sharding.local_shape == local_shape, (
sharding.local_shape,
local_shape,
some_rank_shard,
)
assert (sharding.flattened_range is not None) == has_flattened_range, (
(sharding.flattened_range is not None),
has_flattened_range,
some_rank_shard,
)
shard_access_cnt = _compute_shards_access(rank_sharding)
if has_flattened_range:
map_reduce(
rank_sharding,
lambda x: x[1].global_offset,
lambda x: x[1],
_validate_sharding_for_key_flattened,
)
else:
if not torch.all(shard_access_cnt == 1):
logger.error(f'Invalid access pattern for {rank_sharding[0][1]}: {shard_access_cnt}')
raise CheckpointingException(f'Invalid access pattern for {rank_sharding[0][1]}')
def _compute_shards_access(rank_sharding):
def chunk_offset(sharding):
assert len(sharding.global_offset) == len(sharding.local_shape) + sharding.prepend_axis_num
return tuple(
chain(
(off for off in sharding.global_offset[: sharding.prepend_axis_num]),
(
off // sh
for off, sh in zip(
sharding.global_offset[sharding.prepend_axis_num :], sharding.local_shape
)
),
)
)
shard_access_cnt = torch.zeros(
rank_sharding[0][1].axis_fragmentations, dtype=torch.int, device='cpu'
)
for rank, sharding in rank_sharding:
if is_main_replica(sharding.replica_id):
shard_access_cnt[chunk_offset(sharding)] += 1
# TODO: consider validating different replicas too
return shard_access_cnt
def _validate_sharding_for_key_flattened(tensors_by_shard):
all_slices = []
local_shape = tensors_by_shard[0].local_shape
for sharding in tensors_by_shard:
assert sharding.local_shape == local_shape
sharding: ShardedTensor
if not is_main_replica(sharding.replica_id):
# TODO: this checks only saving (and loading replica_id=0) consistency
continue
all_slices.append((sharding.flattened_range.start, sharding.flattened_range.stop))
starts, stops = map(np.asarray, zip(*sorted(all_slices)))
if (
starts[0] != 0
or stops[-1] != np.product(local_shape)
or not np.all(starts[1:] == stops[:-1])
):
logger.error(
f'Flattened ranges dont cover the whole shard {tensors_by_shard[0]}. Ranges: {(starts, stops)}'
)
raise CheckpointingException(
f'Flattened ranges dont cover the whole shard {tensors_by_shard[0]}'
)
def _validate_objects_for_key(sharded_objects: List[ShardedObject]):
""" Ensure uniqueness of saved objects. """
unique_keys = [
sh_obj.unique_key for _, sh_obj in sharded_objects if is_main_replica(sh_obj.replica_id)
]
if len(unique_keys) != len(set(unique_keys)):
duplicates = {k: cnt for k, cnt in Counter(unique_keys).items() if cnt > 1}
logger.error(f'Duplicate ShardedObject keys and counts: {duplicates}')
raise CheckpointingException(f'Duplicate ShardedObject keys: {list(duplicates.keys())}')
expected_shard_num = np.prod(sharded_objects[0][1].global_shape)
if len(unique_keys) != expected_shard_num:
err_msg = f'Invalid access pattern: {expected_shard_num - len(unique_keys)} ShardedObject are missing.'
logger.error(f'{err_msg} Existing shards: {unique_keys}')
raise CheckpointingException(err_msg)
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Various loading and saving strategies """
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Strategies base interfaces. """
from abc import ABC, abstractmethod
from collections import defaultdict
from enum import Enum
from pathlib import Path
from typing import Dict, List, Optional
from ..mapping import CheckpointingException, ShardedStateDict, ShardedTensor, StateDict
class StrategyAction(Enum):
LOAD_COMMON = 'load_common'
LOAD_SHARDED = 'load_sharded'
SAVE_COMMON = 'save_common'
SAVE_SHARDED = 'save_sharded'
default_strategies = defaultdict(dict)
def get_default_strategy(action: StrategyAction, backend: str, version: int):
""" Retrieves a default strategy for a given action, backend and version. """
try:
if backend == 'zarr':
error_hint = ' Please install `zarr` and `tensorstore<=0.1.45` packages'
from .tensorstore import _import_trigger
from .zarr import _import_trigger
elif backend == 'torch_dist':
error_hint = ' Please use PyTorch version >=2.1'
from .torch import _import_trigger
except ImportError as e:
raise CheckpointingException(
f'Cannot import a default strategy for: {(action.value, backend, version)}. Error: {e}. Hint: {error_hint}'
) from e
try:
return default_strategies[action.value][(backend, version)]
except KeyError as e:
raise CheckpointingException(
f'Cannot find a default strategy for: {(action.value, backend, version)}'
) from e
class LoadStrategyBase(ABC):
""" Base class for a load strategy. Requires implementing checks for compatibility with a given checkpoint version. """
@abstractmethod
def check_backend_compatibility(self, loaded_version):
raise NotImplementedError
@abstractmethod
def check_version_compatibility(self, loaded_version):
raise NotImplementedError
@property
def can_handle_sharded_objects(self):
""" Returns whether or not this strategy can handle loading ShardedObjects. """
return False
class SaveStrategyBase(ABC):
""" Base class for a save strategy. Requires defining a backend type and version of the saved format. """
def __init__(self, backend: str, version: int):
self.backend = backend
self.version = version
@property
def can_handle_sharded_objects(self):
""" Returns whether or not this strategy can handle saving ShardedObjects. """
return False
class LoadCommonStrategy(LoadStrategyBase):
""" Load strategy for common (non-sharded) objects """
@abstractmethod
def load(self, checkpoint_dir: Path):
raise NotImplementedError
class LoadShardedStrategy(LoadStrategyBase):
""" Load strategy for sharded tensors """
@abstractmethod
def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path):
raise NotImplementedError
@abstractmethod
def load_tensors_metadata(self, checkpoint_dir: Path):
"""Load tensors metadata from the checkpoint.
Returns a dictionary similar to a sharded state dict, but note that
the dictionary keys are simply ShardedTensor keys (contrary to the
actual sharded state dicts where keys correspond to state dict keys).
Dict values are ShardedTensors without any sharding (so, the only useful
information is tensors global shape and dtype).
"""
raise NotImplementedError(
f'{self.__class__.__name__} doesnt allow loading only sharded metadata'
)
class SaveCommonStrategy(SaveStrategyBase):
""" Save strategy for common (non-sharded) objects """
@abstractmethod
def save(self, common_state_dict: StateDict, checkpoint_dir: Path):
raise NotImplementedError
class SaveShardedStrategy(SaveStrategyBase):
""" Save strategy for sharded tensors """
@abstractmethod
def save(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path):
raise NotImplementedError
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
""" Storage writer for PyT Distributed format allowing asynchronous save. """
import logging
import os
from itertools import chain
from pathlib import Path
from time import time
from typing import Callable, Dict, List, Optional, Tuple
import psutil
import torch
from torch import multiprocessing as mp
from torch.distributed.checkpoint import FileSystemWriter
from torch.distributed.checkpoint.filesystem import DEFAULT_SUFFIX, _StoragePrefix, _write_item
from torch.distributed.checkpoint.planner import SavePlan, SavePlanner, WriteItem, WriteItemType
from torch.distributed.checkpoint.storage import WriteResult
from torch.futures import Future
logger = logging.getLogger(__name__)
WriteBucket = Tuple[Path, str, Tuple[list, list]] # represents writes to a single file
class FileSystemWriterAsync(FileSystemWriter):
"""
Async-enabled implementation of FileSystemWriter using file IO.
This class doesn't spawn the async process itself, relies on the external async mechanism.
Flow:
1. Call `write_data`
2. Externally start async process with `get_save_function_and_args` function and args
3. The async function to call is `writer_proxy_func` which calls
`write_preloaded_data` in multiple processes
After saving is finalized on all ranks:
4. Call `super().finish` with the results gathered in `self.writer_result`
Note that step (3) above can also be called synchronously.
Currently, it's assumed that a separate writer is created for each ckpt save
(intermediate state is stored as writer attributes).
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if not self.single_file_per_rank:
raise NotImplementedError(
'single_file_per_rank flag not supported for FileSystemWriterAsync'
)
# Intermediate state between preparation and finalization
self.write_buckets: Optional[List[WriteBucket]] = None
self.write_results: Optional[Dict[int, List[WriteResult]]] = None
def prepare_write_data(self, plan: SavePlan, planner: SavePlanner) -> None:
"""
First stage of async saving. Copy data to CPU and plan the local saving.
Args:
plan (SavePlan): save plan generated by the PyT Distributed compatible planner
planner (SavePlanner): save planner used to resolve the bytes and tensor data
Returns: None, but stores the save plan in `self.write_buckets`
"""
storage_plan: _StoragePrefix = plan.storage_data
start = time()
logger.debug(f"thread_count: {self.thread_count}, time: {start}")
item_buckets = _split_by_size_and_type(self.thread_count, plan.items)
logger.debug(f"bucket_prep, time: {time() - start}")
start = time()
# move tensors from GPU to CPU before starting async writing
# We do D2H synchronously for now
file_count = 0
def gen_file():
nonlocal file_count
file_name = f"{storage_plan.prefix}{file_count}{DEFAULT_SUFFIX}"
file_count += 1
return file_name
# Prepare bytes / tensor data in each bucket, which will be assigned to each writer process
self.write_buckets = []
for bucket in item_buckets:
bytes_data = [
(item, planner.resolve_data(item))
for item in bucket
if item.type == WriteItemType.BYTE_IO
]
tensor_data = [
(item, planner.resolve_data(item).detach().to("cpu", non_blocking=True))
for item in bucket
if item.type != WriteItemType.BYTE_IO
]
if len(bytes_data) > 0 or len(tensor_data) > 0:
file_name = gen_file()
self.write_buckets.append(
(self.path / file_name, file_name, (bytes_data, tensor_data))
)
# Check if there is anything to write on this rank
if len(self.write_buckets) > 0:
assert len(self.write_buckets) <= self.thread_count, (
len(self.write_buckets),
self.thread_count,
)
ctx = mp.get_context('fork')
self.write_results = ctx.Manager().dict()
else:
self.write_results = {}
logger.debug(f"D2H and push, time: {time() - start}")
def get_save_function_and_args(self) -> Optional[Tuple[Callable, Tuple]]:
"""
Get function that saves the data to storage along with its arguments.
Allows the external caller to apply the save function synchronously or asynchronously.
Returns: None (if there is nothing to write on this rank) or a tuple of:
- the function that saves the data
- arguments to that function
"""
if not self.write_buckets:
return None
return (self.write_preloaded_data_multiproc, (self.write_buckets, self.write_results))
@staticmethod
def write_preloaded_data_multiproc(
write_buckets: List[WriteBucket], write_results: Dict[int, List[WriteResult]]
) -> None:
"""
Performs saving data to storage with multiple processes.
Args:
write_buckets (List[WriteBucket]): write plan
write_results: (Dict[int, List[WriteResult]]): dict to store the write results to.
Assumes multiprocessing save, so keys are local process indices
Returns: None
"""
w_start = time()
ctx = mp.get_context('fork')
p_list = [
ctx.Process(
target=FileSystemWriterAsync.write_preloaded_data,
args=(i, write_bucket, write_results, True),
)
for i, write_bucket in enumerate(write_buckets)
]
for p in p_list:
p.start()
for p in p_list:
p.join()
w_end = time()
logger.debug(
f"{w_end}, rank: {torch.distributed.get_rank()}, write(sync,parallel): {w_end - w_start}"
)
@staticmethod
def write_preloaded_data(
local_proc_idx: int,
write_bucket: WriteBucket,
write_results: Dict[int, List[WriteResult]],
use_fsync: bool,
) -> None:
"""
Performs actual data saving to storage.
Args:
local_proc_idx (int): index of a local process that performs writing
write_bucket (WriteBucket): data to write to storage
write_results (Dict[int, List[WriteResult]]): dict to store the write results to.
Assumes multiprocessing save, so keys are local process indices
use_fsync (bool): if True, calls os.fsync at the end of saving
Returns: None, the write result are written to the `write_results` dict
"""
mem_before = _process_memory()
local_results = []
file_name, storage_key, (bytes_data, tensor_data) = write_bucket
with open(file_name, "wb") as stream:
for write_item, data in bytes_data:
local_results.append(_write_item(stream, data, write_item, storage_key))
for write_item, tensor in tensor_data:
assert tensor.is_cpu
local_results.append(_write_item(stream, tensor, write_item, storage_key))
if use_fsync:
os.fsync(stream.fileno())
write_results[local_proc_idx] = local_results
mem_after = _process_memory()
logger.debug(
f"{local_proc_idx} consumed: {mem_after - mem_before}, before: {mem_before}, after: {mem_after}"
)
def write_data(self, plan: SavePlan, planner: SavePlanner,) -> Future[List[WriteResult]]:
raise NotImplementedError('write_data not implemented for FileSystemWriterAsync')
def retrieve_write_results(self) -> List[WriteResult]:
"""
Turn self.write_results into a single results lists. Includes error check.
Returns (List[WriteResult]): the list of write results from all local processes performing the save.
"""
assert self.write_results is not None
assert self.write_buckets is not None
if len(self.write_results) != len(self.write_buckets):
raise RuntimeError(
f'Incomplete worker results (expected {len(self.write_buckets)}, got {len(self.write_results)}.'
f' This probably indicates a worker failure.'
)
return list(chain.from_iterable(self.write_results.values()))
def _split_by_size_and_type(bins: int, items: List[WriteItem]) -> List[List[WriteItem]]:
"""
Splits write items according to item size into close to uniform bins.
Same as torch.distributed.checkpoint.filesystem._split_by_size_and_type,
but with a fixed _item_size function.
Args:
bins (int): numbers of bins to split to
items (List[WriteItem]): list of write items
Returns (List[List[WriteItem]]): write items split to bins
"""
if bins == 1:
return [items]
bytes_items = [wi for wi in items if wi.type == WriteItemType.BYTE_IO]
tensor_items = [wi for wi in items if wi.type != WriteItemType.BYTE_IO]
buckets: List[List[WriteItem]] = [[] for _ in range(bins)]
bucket_sizes = [0 for _ in range(bins)]
tensor_items.sort(key=_item_size, reverse=True)
# Assign bytes with a simple round-robin
for i, item in enumerate(bytes_items):
buckets[i % bins].append(item)
# Then, assign tensors according to their sizes
for item in tensor_items:
# TODO replace with headq
idx = min(enumerate(bucket_sizes), key=lambda x: x[1])[0]
buckets[idx].append(item)
bucket_sizes[idx] += _item_size(item)
return buckets
def _item_size(item: WriteItem) -> int:
"""
Calculates size (in bytes) of a single write item.
Same as torch.distributed.checkpoint.filesystem._item_size,
but fixes computing chunk size (with item.tensor_data.chunk.sizes)
Args:
item (WriteItem): write item to compute the size of
Returns (int): size of an item in bytes
"""
size = 1
assert item.tensor_data is not None
# can't use math.prod as PT needs to support older python
for s in item.tensor_data.chunk.sizes:
size *= s
dtype = item.tensor_data.properties.dtype
return size * torch._utils._element_size(dtype)
def _process_memory() -> int:
"""
Get memory used by current process.
Returns (int): memory used by current process
"""
process = psutil.Process(os.getpid())
mem_info = process.memory_info()
return mem_info.rss
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
""" State dict saver for PyT Distributed format allowing asynchronous save. """
from logging import getLogger
from time import time
from typing import TYPE_CHECKING, Optional, Tuple, cast
import torch
import torch.distributed as dist
from torch.distributed.checkpoint import CheckpointException
from torch.distributed.checkpoint.default_planner import DefaultSavePlanner
from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE, Metadata
from torch.distributed.checkpoint.planner import SavePlanner
from torch.distributed.checkpoint.utils import _DistWrapper, _get_failure_dict
if TYPE_CHECKING:
from .filesystem_async import FileSystemWriterAsync
logger = getLogger(__name__)
def save_state_dict_async_plan(
state_dict: STATE_DICT_TYPE,
storage_writer: 'FileSystemWriterAsync',
process_group: Optional[dist.ProcessGroup] = None,
coordinator_rank: int = 0,
planner: Optional[SavePlanner] = None,
) -> Tuple['FileSystemWriterAsync', Metadata, _DistWrapper]:
"""
First stage of saving a state dict to storage.
This is an async adjustment of torch.distributed.checkpoint.state_dict_saver.
In order to support async save, saving should be split into three parts:
1. Planning
2. Actual saving
3. Finalization
Out of these, step (2) *must* happen asynchronously.
The first step is realized with this function.
The planning part consists of several steps, described here:
https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.SavePlanner
Args:
state_dict (STATE_DICT_TYPE): state dict to save
storage_writer (FileSystemWriterAsync): in current version only an instance of
FileSystemWriterAsync
process_group (dist.ProcessGroup, optional): process group used for save planning
coordinator_rank (int, optional): coordinator rank for planning. Defaults to 0.
planner (SavePlanner, optional): save planner for torch.distributed.checkpoint format
Returns: Tuple of:
- storage writer (the one passed as input)
- metadata from planning
- distributed wrapper used for planning
The return value of this function should be passed as an input to
`save_state_dict_async_finalize`.
"""
rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
dist_wrapper = _DistWrapper(process_group, True, coordinator_rank)
if planner is None:
planner = DefaultSavePlanner()
assert planner is not None
global_metadata = None
def local_step():
assert planner is not None
planner.set_up_planner(state_dict, dist_wrapper.is_coordinator)
storage_writer.set_up_storage_writer(dist_wrapper.is_coordinator)
local_plan = planner.create_local_plan()
local_plan = storage_writer.prepare_local_plan(local_plan)
return local_plan
def global_step(all_local_plans):
nonlocal global_metadata
assert planner is not None
all_local_plans, global_metadata = planner.create_global_plan(all_local_plans)
all_local_plans = storage_writer.prepare_global_plan(all_local_plans)
return all_local_plans
# Execute local and global planning
start_plan = time()
central_plan = dist_wrapper.reduce_scatter("plan", local_step, global_step)
logger.debug(f"rank: {rank}, plan time: {time() - start_plan}")
# Prepare async writing of tensors.
# The `storage_writer` will store the information about tensors it needs to save
start = time()
final_local_plan = planner.finish_plan(central_plan)
storage_writer.prepare_write_data(final_local_plan, planner)
end = time()
logger.debug(f"{time()} rank: {rank}, write(async) time: {end - start}")
return storage_writer, cast(Metadata, global_metadata), dist_wrapper
def save_state_dict_async_finalize(
storage_writer: 'FileSystemWriterAsync', global_metadata: Metadata, dist_wrapper: _DistWrapper,
) -> None:
"""
Finalization of save_state_dict_async_plan.
The input arguments are the same as the save_state_dict_async_plan output,
the `write_results` are retrieved from the storage_writer.
Args:
storage_writer (FileSystemWriterAsync): storage writer used for planning
global_metadata (Metadata): metadata created during planning
dist_wrapper (_DistWrapper): distributed wrapper created during planning
Returns: None
"""
write_results = storage_writer.retrieve_write_results()
# Gather the write results that will be saved to the metadata file.
gather_start = time()
all_results = dist_wrapper.gather_object(write_results)
gather_end = time()
logger.debug(f"{gather_end}, {torch.distributed.get_rank()}, gather: {gather_end-gather_start}")
# Store the metadata on coordinator rank
if dist_wrapper.is_coordinator:
node_failures = _get_failure_dict(all_results)
if len(node_failures) == 0:
assert global_metadata is not None
write_start = time()
storage_writer.finish(global_metadata, all_results)
write_end = time()
logger.debug(f"{write_end}, metadata_write: {write_end - write_start}")
else:
raise CheckpointException("write", node_failures)
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Strategies using TensorStore to load and save Zarr arrays. """
from functools import partial
from itertools import starmap
from logging import getLogger
from pathlib import Path
import tensorstore as ts
import torch
from ..core import CheckpointingException
from ..dict_utils import dict_list_map_inplace
from ..mapping import ShardedStateDict, ShardedTensor
from .base import LoadShardedStrategy, StrategyAction, default_strategies
from .zarr import (
load_zarr_based_sharded_metadata,
numpy_to_torch_dtype_dict,
postprocess_numpy_array,
)
_import_trigger = None
logger = getLogger(__name__)
class TensorStoreLoadShardedStrategy(LoadShardedStrategy):
def __init__(self, load_directly_on_device: bool = False):
super().__init__()
self.load_directly_on_device = load_directly_on_device
def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path):
if torch.distributed.get_rank() == 0:
print(f'Loading distributed checkpoint with {self.__class__.__name__}')
if self.load_directly_on_device:
print(f'Loading distributed checkpoint directly on the GPU')
load_fn = partial(
_load_from_array,
checkpoint_dir=checkpoint_dir,
load_directly_on_device=self.load_directly_on_device,
)
dict_list_map_inplace(load_fn, sharded_state_dict)
return sharded_state_dict
def load_tensors_metadata(self, checkpoint_dir: Path):
def get_ts_shape_dtype(path):
arr = open_ts_array(path)
return arr.shape, arr.dtype.numpy_dtype
return load_zarr_based_sharded_metadata(checkpoint_dir, get_ts_shape_dtype)
def check_backend_compatibility(self, loaded_version):
pass # TODO
def check_version_compatibility(self, loaded_version):
pass # TODO
def merge_global_slice_with_shape(global_slice, actual_shape, key):
def _merge_slice(dim_slice, dim_size):
if isinstance(dim_slice, slice):
assert (
dim_slice.start < dim_size
), f'Got empty slice for ShardedTensor {key} ({dim_slice}, {dim_size})'
if dim_slice.stop > dim_size:
dim_slice = slice(dim_slice.start, dim_size, dim_slice.step)
return dim_slice
assert len(global_slice) == len(actual_shape), (global_slice, actual_shape, key)
return tuple(starmap(_merge_slice, zip(global_slice, actual_shape)))
def _load_from_array(
sharded_tensor: ShardedTensor,
checkpoint_dir: Path,
load_directly_on_device: bool = False,
apply_flattened_range: bool = True,
):
x = _load_regular_chunk(sharded_tensor, checkpoint_dir)
ten = postprocess_numpy_array(x, sharded_tensor, apply_flattened_range)
if load_directly_on_device:
sharded_tensor.data.data.copy_(ten)
return sharded_tensor.data
else:
return ten
def _load_regular_chunk(sharded_tensor: ShardedTensor, checkpoint_dir: Path):
assert isinstance(sharded_tensor, ShardedTensor), type(sharded_tensor)
arr = open_ts_array(checkpoint_dir / sharded_tensor.key)
if sharded_tensor.global_shape == arr.shape:
x = (
arr[sharded_tensor.global_slice()].read().result()
) # flattened tensors loading is delayed
elif sharded_tensor.allow_shape_mismatch:
global_slice = merge_global_slice_with_shape(
sharded_tensor.global_slice(), arr.shape, sharded_tensor.key
)
x = arr[global_slice].read().result() # flattened tensors loading is delayed
else:
_msg = (
f'Global shape mismatch for loaded ({arr.shape})'
f' and expected ({sharded_tensor.global_shape}) tensor'
f' for key {sharded_tensor.key}'
)
raise CheckpointingException(_msg)
return x
def open_ts_array(arr_path: Path):
"""Opens a Zarr file array with Tensorstore with basic setting.
Args:
arr_path (Path): path to a Zarr (Tensorstore) array
"""
spec = {'driver': 'zarr', 'metadata_key': '.zarray', 'kvstore': {}}
spec['kvstore'] = {
'driver': 'file',
'path': str(arr_path),
}
try:
arr = ts.open(ts.Spec(spec), open=True).result()
except Exception as e:
raise CheckpointingException(f'Array {arr_path} could not be loaded. Error: {e}') from e
return arr
default_strategies[StrategyAction.LOAD_SHARDED.value][
('zarr', 1)
] = TensorStoreLoadShardedStrategy()
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Strategies using PyTorch distributed.checkpoint as an underlying format. """
import dataclasses
import io
import itertools
from collections import defaultdict
from logging import getLogger
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, cast
import torch
from torch.distributed import checkpoint
from torch.distributed._shard.metadata import ShardMetadata
from torch.distributed._shard.sharded_tensor import Shard, ShardedTensorMetadata, TensorProperties
from torch.distributed._sharded_tensor import ShardedTensor as TorchShardedTensor
from torch.distributed.checkpoint import (
DefaultLoadPlanner,
DefaultSavePlanner,
FileSystemReader,
LoadPlan,
SavePlan,
TensorStorageMetadata,
WriteItem,
)
from torch.distributed.checkpoint._nested_dict import FLATTEN_MAPPING, unflatten_state_dict
from torch.distributed.checkpoint._traverse import OBJ_PATH, traverse_state_dict
from torch.distributed.checkpoint.default_planner import create_default_local_save_plan
from torch.distributed.checkpoint.planner_helpers import _create_write_items
from ..core import CheckpointingException
from ..dict_utils import nested_values
from ..mapping import (
ShardedBase,
ShardedObject,
ShardedStateDict,
ShardedTensor,
StateDict,
is_main_replica,
)
from .base import LoadShardedStrategy, SaveShardedStrategy, StrategyAction, default_strategies
from .filesystem_async import FileSystemWriterAsync
from .state_dict_saver import save_state_dict_async_finalize, save_state_dict_async_plan
_import_trigger = None
logger = getLogger(__name__)
def flatten_state_dict(
state_dict: ShardedStateDict,
) -> Tuple[ShardedStateDict, Dict[str, OBJ_PATH]]:
""" Flattens state dict into a single level dict.
It's a copy of torch.distributed.checkpoint._nested_dict.flatten_state_dict
which also accepts ShardedBase tensors as terminal objects
Args:
state_dict (ShardedStateDict): state dict to be flattened
Returns (tuple): flattened state dict and a mapping allowing to recreate the original one
"""
flattened = {}
mappings = {}
def flat_copy(path: OBJ_PATH, value: Any) -> None:
new_fqn = ".".join(map(str, path))
if new_fqn in flattened:
raise ValueError(f"duplicated flatten key {new_fqn}")
flattened[new_fqn] = value
mappings[new_fqn] = path
traverse_state_dict(state_dict, flat_copy, lambda x: isinstance(x, (torch.Tensor, ShardedBase)))
return flattened, mappings
def sharded_tensor_to_torch_sharded_tensor(
sh_tens: List[ShardedTensor], rank: Optional[int] = None
) -> TorchShardedTensor:
"""Convert MCore ShardedTensor to PyT ShardedTensor. PyT requires information about all chunks.
NOTE: this function assumes regular (grid) sharding of the MCore ShardedTensor.
The only local irregularities could be introduced with a `flattened_range` attribute.
NOTE: `flattened_range` is currently supported only for 1D tensors.
This function follows the logic of torch.distributed.fsdp._shard_utils._create_chunk_sharded_tensor.
Additionally, it saves `prepend_axis_num` (specific to MCore) as an attribute
for further restoration in `_unwrap_pyt_sharded_tensor`.
Args:
sh_tens (List[ShardedTensor]): list of sharded tensors to convert
rank (int, optional): current process rank passed to PyT ShardedTensor.
If None, assumes rank in the default pg.
Returns (TorchShardedTensor): PyT ShardedTensor containing all passed shards.
"""
if rank is None:
rank = torch.distributed.get_rank()
some_sh_ten = sh_tens[0]
has_flattened_range = some_sh_ten.flattened_range is not None
prepend_axis_num = sh_tens[0].prepend_axis_num
# Determine local shards
if has_flattened_range:
if prepend_axis_num:
raise NotImplementedError(
'`prepend_axis_num` attribute of ShardedTensor not supported'
'together with `flattened_range` for PyT Distributed format'
)
for sh_ten in sh_tens:
assert sh_ten.flattened_range is not None
assert len(sh_ten.global_offset) == 1, sh_ten
local_shards = [
Shard.from_tensor_and_offsets(
sh_ten.data, [sh_ten.global_offset[0] + sh_ten.flattened_range.start], rank
)
for sh_ten in sh_tens
]
offsets_shape = some_sh_ten.local_shape # used to determine local offsets
else:
# Apply extra axes `prepend_axis_num` with a view
for sh_ten in sh_tens:
assert sh_ten.flattened_range is None, sh_ten.flattened_range
if prepend_axis_num:
sh_ten.data = sh_ten.data.view((1,) * prepend_axis_num + sh_ten.local_shape)
local_shards = [
Shard.from_tensor_and_offsets(sh_ten.data, list(sh_ten.global_offset), rank)
for sh_ten in sh_tens
]
offsets_shape = some_sh_ten.data.shape # includes prepended axes
local_global_offsets = {}
for sh_ten in sh_tens:
local_global_offsets.setdefault(sh_ten.global_offset, []).append(sh_ten)
# Create a ShardedTensor without invoking communication. Determine global shards
shard_metadata = []
# NOTE: here we assume a regular grid of shards
for fragment_offsets in itertools.product(*map(range, some_sh_ten.axis_fragmentations)):
offset = tuple(map(lambda x: x[0] * x[1], zip(fragment_offsets, offsets_shape)))
if offset in local_global_offsets:
# local shard
placement = f"rank:{rank}/cuda"
for sh_ten in local_global_offsets[offset]:
if has_flattened_range:
offset = (sh_ten.global_offset[0] + sh_ten.flattened_range.start,)
size = sh_ten.data.shape
shard_metadata.append(ShardMetadata(offset, size, placement))
else:
# for shards from other ranks we provide simplistic data - this information will be discarded
# during TorchShardedTensor._init_from_local_shards_and_global_metadata call
shard_metadata.append(ShardMetadata(offset, offsets_shape, "cuda"))
tensor = some_sh_ten.data
sharded_tensor_metadata = ShardedTensorMetadata(
shards_metadata=shard_metadata,
size=torch.Size(some_sh_ten.global_shape),
tensor_properties=TensorProperties(
dtype=tensor.dtype,
layout=tensor.layout,
requires_grad=tensor.requires_grad,
memory_format=torch.contiguous_format,
pin_memory=tensor.is_pinned(),
),
)
pyt_sh_ten = TorchShardedTensor._init_from_local_shards_and_global_metadata(
local_shards, sharded_tensor_metadata=sharded_tensor_metadata, process_group=None
)
pyt_sh_ten.prepend_axis_num = prepend_axis_num
return pyt_sh_ten
def mcore_to_pyt_state_dict(
state_dict: Dict[str, List[ShardedBase]],
is_loading: bool = False,
init_device: torch.device = torch.device("cpu"),
) -> Dict[str, Union[TorchShardedTensor, io.BytesIO]]:
"""Turn state dict with ShardedTensors and ShardedObjects to state dict compatible with PyT Dist format.
Operates in-place and returns the original state dict.
Args:
state_dict (Dict[str, List[ShardedBase]]): flattened state dict, where values
are lists of either ShardedTensor or ShardedObjects.
is_loading (bool, optional): flag indicating if loading or saving. Defaults to False.
init_device (torch.device, optional): device to initialize potentially missing tensors
during loading. Defaults to 'cpu'.
Returns (Dict[str, Union[TorchShardedTensor, io.BytesIO]]): original dictionary with values
converted either into PyT ShardedTensors or io.BytesIO.
"""
rank = torch.distributed.get_rank()
pyt_state_dict = {}
def _mcore_to_torch_sharded_tensor(sh_tens: List[ShardedTensor]) -> TorchShardedTensor:
"""Build a PyT ShardedTensor from given shards.
During loading:
- if data is None, initialize it with an empty tensor (will be used to copy the data into)
- if `allow_shape_mismatch` is True, the data is initialized with zeros
prior to loading (not all parts of the tensor will be read from the checkpoint)
"""
assert all(isinstance(sh_ten, ShardedTensor) for sh_ten in sh_tens), sh_tens
for sh_ten in sh_tens:
if sh_ten.data is None:
if is_loading:
sh_ten.init_data(
init_device,
init_fn=torch.zeros if sh_ten.allow_shape_mismatch else torch.empty,
)
else:
raise CheckpointingException(f'`data` attr is None for {sh_ten}')
else:
sh_ten.data = sh_ten.data.detach()
if sh_ten.allow_shape_mismatch and is_loading:
sh_ten.data.zero_()
torch_sh_ten = sharded_tensor_to_torch_sharded_tensor(sh_tens, rank)
torch_sh_ten.key = sh_tens[0].key
return torch_sh_ten
def _mcore_to_torch_sharded_object(sh_objs: List[ShardedObject]) -> io.BytesIO:
"""Build io.BytesIO from given sharded objects data."""
assert all(isinstance(sh_obj, ShardedObject) for sh_obj in sh_objs), sh_objs
serialized_data = io.BytesIO()
torch.save([sh_obj.data for sh_obj in sh_objs], serialized_data)
return serialized_data
for k, v in state_dict.items():
if isinstance(v[0], ShardedTensor):
v = cast(List[ShardedTensor], v)
pyt_state_dict[k] = _mcore_to_torch_sharded_tensor(v)
else:
v = cast(List[ShardedObject], v)
pyt_state_dict[k] = _mcore_to_torch_sharded_object(v)
return pyt_state_dict
def _unwrap_pyt_sharded_tensor(sh_ten: TorchShardedTensor) -> List[torch.Tensor]:
""" Unwrap tensor from PyT ShardedTensor instance.
If `prepend_axis_num` was non-zero (which is specific to MCore ShardedTensor)
then the tensor has additional singleton dimensions which should be squeezed.
"""
prepend_axis_num = getattr(sh_ten, 'prepend_axis_num', 0)
if prepend_axis_num == 0:
return [sh.tensor for sh in sh_ten.local_shards()]
ret_tensors = []
for sh in sh_ten.local_shards():
ten = sh.tensor
for _ in range(prepend_axis_num):
ten = ten.squeeze(0)
ret_tensors.append(ten)
return ret_tensors
def _replace_state_dict_keys_with_sharded_keys(
sharded_state_dict: ShardedStateDict, keep_only_main_replica: bool = False
) -> Tuple[Dict[str, List[ShardedBase]], FLATTEN_MAPPING, Dict[str, List[str]]]:
"""Group ShardedBase objects by keys and return mappings required for recreating the original dict. """
flat_sd, flat_mapping = flatten_state_dict(sharded_state_dict)
rename_mapping = defaultdict(list)
new_flat_sd = defaultdict(list)
for k, sh_base in flat_sd.items():
assert isinstance(sh_base, ShardedBase), type(sh_base)
key = sh_base.unique_key if isinstance(sh_base, ShardedObject) else sh_base.key
if is_main_replica(sh_base.replica_id) or not keep_only_main_replica:
rename_mapping[key].append(k)
new_flat_sd[key].append(sh_base)
return new_flat_sd, flat_mapping, rename_mapping
def _replace_sharded_keys_with_state_dict_keys(
state_dict: Dict[str, List[Union[torch.Tensor, io.BytesIO]]],
flat_mapping: FLATTEN_MAPPING,
rename_mapping: Dict[str, List[str]],
):
""" Inverse of _replace_state_dict_keys_with_sharded_keys. """
recovered_sd = {}
for k, tensors in state_dict.items():
assert len(tensors) == len(rename_mapping[k])
for ten, recovered_k in zip(tensors, rename_mapping[k]):
recovered_sd[recovered_k] = ten
return unflatten_state_dict(recovered_sd, flat_mapping)
def _restore_dict_types(x: Union[dict, list, Any], keys_template: Union[dict, list, Any]):
""" Recursively update `x` keys, based on `keys_template`. """
if isinstance(keys_template, dict):
assert isinstance(x, dict), type(x)
for k, v in keys_template.items():
if not isinstance(k, str):
assert str(k) in x, (k, x.keys)
x[k] = x.pop(str(k))
_restore_dict_types(x[k], v)
elif isinstance(keys_template, list):
assert isinstance(x, list), type(x)
for x_val, templ_val in zip(x, keys_template):
_restore_dict_types(x_val, templ_val)
class MCoreSavePlanner(DefaultSavePlanner):
"""Differs with the default planner by saving BytesIO objects on all ranks.
In the integration of MCore with PyT Distributed format, BytesIO objects
come from ShardedObjects, which should be treated as separate objects on each rank
(not common on all ranks).
Also, the objects are already packed in io.BytesIO, so no need to redo it
in transform_object.
"""
def create_local_plan(self) -> SavePlan:
plan = create_default_local_save_plan(self.state_dict, self.is_coordinator)
self._add_non_coordinator_iobytes_request(plan)
if self.flatten_state_dict:
plan = dataclasses.replace(plan, planner_data=self.mappings)
self.plan = plan
return self.plan
def _add_non_coordinator_iobytes_request(self, plan):
if self.is_coordinator:
return
for fqn, obj in self.state_dict.items():
if isinstance(obj, io.BytesIO):
plan.items.extend(_create_write_items(fqn, obj))
def transform_object(self, write_item: WriteItem, object: Any):
return object
class MCoreLoadPlanner(DefaultLoadPlanner):
"""Adds global shape validation to the default planner.
If global shape validation can be ignored (shouldn't!), the default
load planner can be used.
"""
def __init__(
self, *args, shapes_validation_sharded_tensors: Iterable[ShardedTensor] = (), **kwargs
) -> None:
super().__init__(*args, **kwargs)
self.shapes_validation_sharded_tensors = shapes_validation_sharded_tensors
def _validate_global_shapes(self, metadata, sharded_tensors):
for sh_ten in sharded_tensors:
loaded_shape = metadata.state_dict_metadata[sh_ten.key].size
if loaded_shape != sh_ten.global_shape:
_msg = (
f'Global shape mismatch for loaded ({loaded_shape})'
f' and expected ({sh_ten.global_shape}) tensor'
f' for key {sh_ten.key}'
)
raise CheckpointingException(_msg)
def create_local_plan(self) -> LoadPlan:
self._validate_global_shapes(self.metadata, self.shapes_validation_sharded_tensors)
return super().create_local_plan()
class TorchDistSaveShardedStrategy(SaveShardedStrategy):
"""Basic save strategy for the PyT Distributed format.
The idea is to translate MCore ShardedTensors into PyT ShardedTensors
and reuse the default torch.distributed.checkpoint saving mechanism.
"""
def __init__(
self, backend: str, version: int, keep_only_main_replica: bool = True, thread_count: int = 2
):
"""Adds parameters specific to PyT Distributed format
Args:
backend (str): format backend string
version (int): format version
keep_only_main_replica (bool, optional): PyT Distributed has a mechanism
for deduplication, but replica_id aware deduplication is more coherent.
Default is True (recommended to keep it).
thread_count (int, optional): threads to use during saving.
Affects the number of files in the checkpoint (saving ranks * num_threads).
"""
super().__init__(backend, version)
self.keep_only_main_replica = keep_only_main_replica
self.thread_count = thread_count
# Intermediate state
self.save_state_dict_ret: Optional[Tuple[Any, ...]] = None
def save(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path):
""" Translates MCore ShardedTensors to PyT ShardedTensors and saves in PyT Distributed format.
Args:
sharded_state_dict (ShardedStateDict): sharded state dict to save
checkpoint_dir (Path): checkpoint directory
Returns: None
"""
# Translate the state dict
(
sharded_state_dict,
flat_mapping,
rename_mapping,
) = _replace_state_dict_keys_with_sharded_keys(
sharded_state_dict, self.keep_only_main_replica
)
pyt_state_dict = mcore_to_pyt_state_dict(sharded_state_dict, False)
# Using async infrastructure for sync save
writer = FileSystemWriterAsync(checkpoint_dir, thread_count=self.thread_count)
self.save_state_dict_ret = save_state_dict_async_plan(
pyt_state_dict,
writer,
None,
planner=MCoreSavePlanner(dedup_replicated_tensors=not self.keep_only_main_replica),
)
fun_args = writer.get_save_function_and_args()
if fun_args is not None:
fun, args = fun_args
fun(*args)
self._finalize_save()
def _finalize_save(self) -> None:
""" Perform save finalization.
Breakdown into `save` and `save_finalize` cn be useful for async saving.
"""
if self.save_state_dict_ret is None:
raise CheckpointingException('finalize_save called, but no ckpt save in progress')
save_state_dict_async_finalize(*self.save_state_dict_ret)
self.save_state_dict_ret = None
torch.distributed.barrier()
def can_handle_sharded_objects(self):
return True
class TorchDistLoadShardedStrategy(LoadShardedStrategy):
"""Basic load strategy for the PyT Distributed format. """
def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path) -> StateDict:
"""Translates MCore ShardedTensors to PyT ShardedTensors and loads from PyT Distributed format.
Args:
sharded_state_dict (ShardedStateDict): sharded state dict with mapping
information to instruct loading
checkpoint_dir (Path): checkpoint directory
Returns: loaded state dict
"""
flexible_shape_sharded_tensors = [
sh_ten
for sh_ten in nested_values(sharded_state_dict)
if isinstance(sh_ten, ShardedTensor) and not sh_ten.allow_shape_mismatch
]
orig_sharded_state_dict = sharded_state_dict
# MCore state dict to PyT Distributed compatible
(
sharded_state_dict,
flat_mapping,
rename_mapping,
) = _replace_state_dict_keys_with_sharded_keys(sharded_state_dict)
pyt_state_dict = mcore_to_pyt_state_dict(sharded_state_dict, True)
# Load PyT Distributed format
checkpoint.load_state_dict(
pyt_state_dict,
FileSystemReader(checkpoint_dir),
planner=MCoreLoadPlanner(
shapes_validation_sharded_tensors=flexible_shape_sharded_tensors
),
)
pyt_state_dict = cast(
Dict[str, Union[TorchShardedTensor, List[io.BytesIO]]], pyt_state_dict
)
# Unwrap ShardedTensors and return to original state dict
mcore_state_dict = {
k: v if not isinstance(v, TorchShardedTensor) else _unwrap_pyt_sharded_tensor(v)
for k, v in pyt_state_dict.items()
}
mcore_state_dict = _replace_sharded_keys_with_state_dict_keys(
mcore_state_dict, flat_mapping, rename_mapping
)
_restore_dict_types(mcore_state_dict, orig_sharded_state_dict)
return mcore_state_dict
def load_tensors_metadata(self, checkpoint_dir: Path):
"""Uses tensors metadata stored in the metadata file."""
fs_reader = FileSystemReader(checkpoint_dir)
metadata = fs_reader.read_metadata()
return {
k: ShardedTensor.from_rank_offsets(
k, torch.empty(tp.size, **tp.properties.__dict__, device='meta')
).without_data()
for k, tp in metadata.state_dict_metadata.items()
if isinstance(tp, TensorStorageMetadata)
}
def can_handle_sharded_objects(self):
return True
def check_backend_compatibility(self, loaded_version):
pass # TODO
def check_version_compatibility(self, loaded_version):
pass # TODO
default_strategies[StrategyAction.LOAD_SHARDED.value][
('torch_dist', 1)
] = TorchDistLoadShardedStrategy()
default_strategies[StrategyAction.SAVE_SHARDED.value][
('torch_dist', 1)
] = TorchDistSaveShardedStrategy('torch_dist', 1)
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" 2-stage checkpoint loading. """
import os
import time
from collections import defaultdict
from dataclasses import dataclass
from functools import partial, wraps
from itertools import chain
from logging import DEBUG, INFO, StreamHandler, getLogger
from operator import attrgetter, itemgetter
from pathlib import Path
from typing import Iterable, List, NamedTuple, Optional, Tuple, Union
import torch
from ..dict_utils import dict_list_map_inplace, map_reduce, nested_values
from ..mapping import ShardedStateDict, ShardedTensor, StateDict
from .base import LoadShardedStrategy
from .tensorstore import TensorStoreLoadShardedStrategy, _load_from_array, open_ts_array
from .zarr import flatten_range, load_zarr_based_sharded_metadata
_import_trigger = None
timers = defaultdict(list)
logger = getLogger(__name__)
def timed(verbose=True):
def timed_dec(fn):
name = fn.__name__
@wraps(fn)
def wrapped(*args, **kwargs):
if verbose:
logger.debug(f'{name} init')
start = time.time()
ret = fn(*args, **kwargs)
took = time.time() - start
if verbose:
logger.debug(f'{name} took {took}s')
timers[name].append(took)
return ret
return wrapped
return timed_dec
@dataclass
class _ShardedTensorMetadata:
global_rank: int
sharded_tensor_no_data: ShardedTensor
dist_group_rank: Tuple[int] # id of distributed group
dist_group_ranks: Tuple[int] # id of distributed group
data_size: Optional[int] = None # bytes
def sharded_tensor_chunk_id(sharded_tensor: ShardedTensor):
return (
sharded_tensor.key,
sharded_tensor.global_offset,
)
class TwoStageDataParallelLoadShardedStrategy(LoadShardedStrategy):
"""Loads one checkpoint replica from storage and broadcasts to other nodes.
This strategy loads checkpoint from storage on minimal set of nodes
and distributes the checkpoint to other nodes with torch.distributed.
Loading is performed with tensorstore.
Steps:
0. (optional) create Gloo distributed groups
1. Exchange ShardedTensors metadata between all nodes
2. Align needed tensors within DP groups
3. For each globally unique tensor:
3.a) on one of the ranks load it from storage to CPU and move to CUDA
3.b) allocate CUDA tensor on other ranks
3.c) broadcast within DP group
3.d) copy tensor content to the model param location
3.e) free tensor buffers from a) and b)
Notes:
1. Loading and broadcasting is done sequentially to avoid both host and device OOMs
2. There is a lot of overlap potential between all three steps done for each tensor:
2.a) loading from storage to numpy
2.b) moving CPU tensors to CUDA
2.c) broadcast
"""
def __init__(self, data_parallel_group, cpu_transfer=True):
super().__init__()
self.cpu_transfer = cpu_transfer
self.data_parallel_group_orig = data_parallel_group
self.data_parallel_group = None if cpu_transfer else data_parallel_group
self.dp_group_ranks = tuple(
sorted(torch.distributed.get_process_group_ranks(data_parallel_group))
)
self.dp_group_rank = torch.distributed.get_rank(self.data_parallel_group_orig)
self.global_rank = torch.distributed.get_rank()
def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path):
self.maybe_init_gloo_group()
all_tensors_sorted = self._build_load_plan(sharded_state_dict)
self._exchange_loaded_tensors(all_tensors_sorted, sharded_state_dict, checkpoint_dir)
# TODO: fix hang in summarize_load_times
# self.summarize_load_times()
return sharded_state_dict
def summarize_load_times(self):
torch.distributed.barrier()
logger.info('Checkpoint loading finished. Summary:')
# TODO: `timers` keys are not guaranteed to be the same across ranks which causes hangs
for key, times in sorted(timers.items()):
times_sum = sum(times)
max_times = torch.tensor([times_sum], device='cuda')
avg_times = torch.tensor([times_sum], device='cuda')
torch.distributed.all_reduce(max_times, op=torch.distributed.ReduceOp.MAX)
torch.distributed.all_reduce(avg_times, op=torch.distributed.ReduceOp.SUM)
avg_times /= torch.distributed.get_world_size()
if torch.distributed.get_rank() == 0:
logger.info(f'{key}: max {max_times[0]}, avg {avg_times[0]}')
@timed(verbose=False)
def load_tensor_from_storage(self, checkpoint_dir, ten_meta: _ShardedTensorMetadata):
logger.debug(f'_load_from_array({ten_meta.sharded_tensor_no_data.key}) init')
ret = _load_from_array(
ten_meta.sharded_tensor_no_data,
checkpoint_dir,
load_directly_on_device=False,
apply_flattened_range=False,
)
logger.debug(f'_load_from_array({ten_meta.sharded_tensor_no_data.key}) DONE')
return ret
@timed()
def maybe_init_gloo_group(self):
if not self.cpu_transfer:
return
all_groups = [None] * torch.distributed.get_world_size()
torch.distributed.all_gather_object(all_groups, self.dp_group_ranks)
all_groups = set(tuple(sorted(gr)) for gr in all_groups)
for group_ranks in sorted(all_groups):
gloo_pg = torch.distributed.new_group(ranks=group_ranks, backend='gloo')
if self.global_rank in group_ranks:
self.data_parallel_group = gloo_pg
assert self.dp_group_rank == torch.distributed.get_rank(self.data_parallel_group)
def check_backend_compatibility(self, loaded_version):
pass # TODO
def check_version_compatibility(self, loaded_version):
pass # TODO
@timed()
def _build_load_plan(
self, sharded_state_dict: ShardedStateDict
) -> List[_ShardedTensorMetadata]:
local_meta = [
_ShardedTensorMetadata(
self.global_rank,
sharded_ten.without_data(),
self.dp_group_rank,
self.dp_group_ranks,
)
for sharded_ten in nested_values(sharded_state_dict)
]
all_meta = [None] * torch.distributed.get_world_size(group=self.data_parallel_group)
torch.distributed.all_gather_object(all_meta, local_meta, group=self.data_parallel_group)
all_meta = list(chain.from_iterable(all_meta))
all_tensors_sorted = self.deduplicate_chunks(all_meta)
return all_tensors_sorted
@timed()
def deduplicate_chunks(self, ten_metas: List[_ShardedTensorMetadata]):
""" Group tensors by chunk and then pick the tensor with the lowest rank.
NOTE: with proper loading overlap, loading from randomized ranks
(instead of the smallest one) could be beneficial here.
"""
ten_metas = map_reduce(
ten_metas,
key_fn=lambda meta: sharded_tensor_chunk_id(meta.sharded_tensor_no_data),
reduce_fn=partial(min, key=attrgetter('dist_group_rank')),
)
all_metas_sorted = list(map(itemgetter(1), sorted(ten_metas.items())))
return all_metas_sorted
@timed()
def _exchange_loaded_tensors(
self, ten_metas: List[_ShardedTensorMetadata], sharded_state_dict, checkpoint_dir
):
logger.debug(f'_exchange_loaded_tensors, num ten_metas: {len(ten_metas)}')
for ten_meta in ten_metas:
src_rank = torch.distributed.get_global_rank(
self.data_parallel_group, ten_meta.dist_group_rank
)
if self.dp_group_rank == ten_meta.dist_group_rank:
exchange_tensor = self.load_tensor_from_storage(checkpoint_dir, ten_meta)
if not self.cpu_transfer:
exchange_tensor = exchange_tensor.cuda()
else:
# TODO: for non-flattened ranges we could reuse the buffer from the start here
exchange_tensor = torch.empty(
ten_meta.sharded_tensor_no_data.local_shape,
device='cpu' if self.cpu_transfer else 'cuda',
dtype=ten_meta.sharded_tensor_no_data.dtype,
)
logger.debug(
f'exchange {ten_meta.sharded_tensor_no_data.key}, {exchange_tensor.shape}({exchange_tensor.numel()}), broadcast({src_rank} -> {self.dp_group_ranks})'
)
torch.distributed.broadcast(
exchange_tensor, group=self.data_parallel_group, src=src_rank
)
self._distribute_data_to_state_dict(ten_meta, exchange_tensor, sharded_state_dict)
logger.debug(f'exchange {ten_meta.sharded_tensor_no_data.key} done')
# free buffer memory
exchange_tensor = None
@timed(verbose=False)
def _distribute_data_to_state_dict(
self,
ten_meta: _ShardedTensorMetadata,
loaded_ten: torch.Tensor,
sharded_state_dict: ShardedStateDict,
):
tensor_key = sharded_tensor_chunk_id(ten_meta.sharded_tensor_no_data)
def _fill_in_data(t: Union[ShardedTensor, torch.Tensor]):
if not isinstance(t, ShardedTensor) or sharded_tensor_chunk_id(t) != tensor_key:
# already filled-in or key not matching
return t
sharded_tensor: ShardedTensor = t
x = loaded_ten
if sharded_tensor.flattened_range is not None:
x = flatten_range(sharded_tensor, x)
# Reuse existing buffer
sharded_tensor.data.data.copy_(x)
return sharded_tensor.data
dict_list_map_inplace(_fill_in_data, sharded_state_dict)
def load_tensors_metadata(self, checkpoint_dir: Path):
def get_ts_shape_dtype(path):
arr = open_ts_array(path)
return arr.shape, arr.dtype.numpy_dtype
return load_zarr_based_sharded_metadata(checkpoint_dir, get_ts_shape_dtype)
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Strategies using Zarr as an underlying format. """
import logging
import os
import threading
from functools import partial
from logging import getLogger
from pathlib import Path
from typing import Callable, List, Optional, Tuple
import numpy as np
import torch
import zarr
from ..core import CheckpointingException
from ..dict_utils import dict_list_map_inplace, nested_values
from ..mapping import ShardedStateDict, ShardedTensor, is_main_replica
from .base import LoadShardedStrategy, SaveShardedStrategy, StrategyAction, default_strategies
logger = logging.getLogger(__name__)
numpy_to_torch_dtype_dict = {
np.dtype('bool'): torch.bool,
np.dtype('uint8'): torch.uint8,
np.dtype('int8'): torch.int8,
np.dtype('int16'): torch.int16,
np.dtype('int32'): torch.int32,
np.dtype('int64'): torch.int64,
np.dtype('float16'): torch.float16,
np.dtype('float32'): torch.float32,
np.dtype('float64'): torch.float64,
np.dtype('complex64'): torch.complex64,
np.dtype('complex128'): torch.complex128,
}
torch_to_numpy_dtype_dict = {v: k for k, v in numpy_to_torch_dtype_dict.items()}
try:
import tensorstore
HAS_BFLOAT16 = True
numpy_to_torch_dtype_dict[np.dtype('bfloat16')] = torch.bfloat16
torch_to_numpy_dtype_dict[torch.bfloat16] = np.dtype('bfloat16')
except ImportError:
HAS_BFLOAT16 = False
_import_trigger = None
logger = getLogger(__name__)
class ZarrSaveShardedStrategy(SaveShardedStrategy):
def save(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path):
sharded_tensors = list(nested_values(sharded_state_dict))
arrays = _create_or_open_zarr_arrays(sharded_tensors, checkpoint_dir)
for ten, arr in zip(sharded_tensors, arrays):
_save_to_existing_array(ten, arr)
torch.distributed.barrier()
def _create_or_open_zarr_arrays(
sharded_tensors: List[ShardedTensor], checkpoint_dir: Path
) -> List[Optional[zarr.Array]]:
""" Returns list of zarr arrays corresponding to given tensors.
For a sharded tensors that:
a) is main replica and represents the first chunk (all offsets 0), creates the Zarr array
b) is main replica but not the first chunk, opens the arrays created in (a) (possibly by other process)
c) otherwise, sets the corresponding array to None since it won't be used
Args:
sharded_tensors (List[ShardedTensor]): sharded tensors from a given rank that will be saved to checkpoint
checkpoint_dir (Path): checkpoint in which the arrays will be created
"""
arrays = []
for ten in sharded_tensors:
arr = _create_zarr_array(ten, checkpoint_dir) if _should_create_array(ten) else None
arrays.append(arr)
torch.distributed.barrier()
# Open arrays created above by other processes
for arr_idx, ten in enumerate(sharded_tensors):
if arrays[arr_idx] is not None:
# array created by this process
assert _should_create_array(ten), ten
continue
if not is_main_replica(ten.replica_id):
# this array won't be needed for saving and can stay None
continue
open_kwargs = {}
if ten.flattened_range is not None:
open_kwargs['synchronizer'] = zarr.ProcessSynchronizer(
str(checkpoint_dir / f'{ten.key}.sync')
)
arrays[arr_idx] = _open_zarr_array_verbose(checkpoint_dir / ten.key, 'r+', **open_kwargs)
return arrays
def _should_create_array(ten: ShardedTensor):
return (
is_main_replica(ten.replica_id)
and set(ten.global_offset) == {0}
and (ten.flattened_range is None or ten.flattened_range.start == 0)
)
def _save_to_existing_array(sharded_tensor: ShardedTensor, arr: Optional[zarr.Array]):
if not is_main_replica(sharded_tensor.replica_id):
return
assert arr is not None
x = sharded_tensor.data
x = x.detach().cpu()
torch.cuda.synchronize()
if x.dtype == torch.bfloat16:
x = x.float()
x = x.numpy()
x = x.astype('bfloat16')
else:
x = x.numpy()
if sharded_tensor.flattened_range is None:
arr[sharded_tensor.global_slice()] = x
else:
arr.set_coordinate_selection(sharded_tensor.global_coordinates(), x)
def _create_zarr_array(sharded_tensor: ShardedTensor, checkpoint_dir: Path):
np_dtype = torch_to_numpy_dtype_dict[sharded_tensor.dtype]
try:
arr = zarr.create(
sharded_tensor.global_shape,
dtype=np_dtype,
store=checkpoint_dir / sharded_tensor.key,
chunks=sharded_tensor.max_allowed_chunks(),
compressor=None,
fill_value=None,
write_empty_chunks=True,
)
logger.debug(f'Created a new Zarr array at {checkpoint_dir / sharded_tensor.key}')
except zarr.errors.ContainsArrayError as e:
raise CheckpointingException(
f'Array {checkpoint_dir / sharded_tensor.key} already exists'
) from e
if HAS_BFLOAT16 and np_dtype == np.dtype('bfloat16'):
arr._dtype = np_dtype
zarray = arr.store['.zarray']
arr.store['.zarray'] = zarray.replace(b'<V2', b'bfloat16')
return arr
class ZarrLoadShardedStrategy(LoadShardedStrategy):
def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path):
dict_list_map_inplace(
partial(_load_from_array, checkpoint_dir=checkpoint_dir), sharded_state_dict
)
return sharded_state_dict
def load_tensors_metadata(self, checkpoint_dir: Path):
def get_zarr_shape_dtype(path):
arr = zarr.open(path, 'r')
return arr.shape, arr.dtype
return load_zarr_based_sharded_metadata(checkpoint_dir, get_zarr_shape_dtype)
def check_backend_compatibility(self, loaded_version):
pass # TODO
def check_version_compatibility(self, loaded_version):
pass # TODO
def _load_from_array(sharded_tensor: ShardedTensor, checkpoint_dir: Path):
assert isinstance(sharded_tensor, ShardedTensor), type(sharded_tensor)
arr = _open_zarr_array_verbose(checkpoint_dir / sharded_tensor.key, 'r')
if not sharded_tensor.allow_shape_mismatch and sharded_tensor.global_shape != arr.shape:
_msg = (
f'Global shape mismatch for loaded ({arr.shape})'
f' and expected ({sharded_tensor.global_shape}) tensor'
f' for key {sharded_tensor.key}'
)
raise CheckpointingException(_msg)
x = arr[sharded_tensor.global_slice()] # flattened tensors loading is delayed
return postprocess_numpy_array(x, sharded_tensor)
def _open_zarr_array_verbose(path: Path, mode: str, **open_kwargs):
try:
return zarr.open(str(path), mode, **open_kwargs)
except zarr.errors.PathNotFoundError as e:
ckpt_dir = path.parent
err_msg = f'Array {path} not found'
if ckpt_dir.exists():
ckpt_files = [f.name for f in ckpt_dir.iterdir()]
logger.debug(f'{err_msg}. Checkpoint directory {ckpt_dir} content: {ckpt_files}')
else:
err_msg += f'. Checkpoint directory {ckpt_dir} does not exist.'
raise CheckpointingException(err_msg) from e
def postprocess_numpy_array(loaded_array, sharded_tensor, apply_flattened_range=True):
x = loaded_array
if HAS_BFLOAT16 and x.dtype == np.dtype('bfloat16'):
x = x.astype(np.dtype('float32'))
x = torch.from_numpy(x)
x = x.bfloat16()
else:
x = torch.from_numpy(x)
# TODO: consider some other consistency checks
if x.shape != sharded_tensor.local_shape:
if sharded_tensor.allow_shape_mismatch:
x = pad_to_expected_shape(x, sharded_tensor)
else:
_msg = (
f'Local shape mismatch for loaded ({x.shape})'
f' and expected ({sharded_tensor.local_shape}) tensor'
f' for key {sharded_tensor.key}'
)
raise CheckpointingException(_msg)
if apply_flattened_range and sharded_tensor.flattened_range is not None:
x = flatten_range(sharded_tensor, x)
# TODO: consider cuda() tensors support
return x
def flatten_range(sharded_tensor, x):
return x.flatten()[sharded_tensor.flattened_range]
def pad_to_expected_shape(x: torch.Tensor, expected_sharded_ten: ShardedTensor):
pad_args = []
assert len(x.shape) == len(expected_sharded_ten.local_shape)
# Reversed iteration order because F.pad expects so
for x_sh, exp_sh, axis_fragm in reversed(
list(
zip(x.shape, expected_sharded_ten.local_shape, expected_sharded_ten.axis_fragmentations)
)
):
if x_sh == exp_sh:
pad_args.extend((0, 0))
elif x_sh > exp_sh:
assert (
False
), f'Expected shape ({exp_sh}) smaller than actual ({x_sh}) for {repr(expected_sharded_ten)}'
else:
pad_args.extend((0, exp_sh - x_sh))
# TODO: behavior control with envvar is for testing purposes only, remove it
if not int(os.environ.get('DIST_CKPT_PAD_REPLICATE', 0)):
return torch.nn.functional.pad(x, pad_args)
# unsqueeze and squeeze to get shapes supported by cudnn
print(f'Replicating last row for {expected_sharded_ten.key}')
if x.dtype == torch.bfloat16:
return (
torch.nn.functional.pad(x.float().unsqueeze(0), pad_args, mode='replicate')
.squeeze(0)
.bfloat16()
)
return torch.nn.functional.pad(x.unsqueeze(0), pad_args, mode='replicate').squeeze(0)
def load_zarr_based_sharded_metadata(
checkpoint_dir: Path, get_shape_dtype_fn: Callable[[str], Tuple[Tuple[int], np.dtype]]
) -> ShardedStateDict:
"""Load metadata of Zarr arrays.
Args:
checkpoint_dir (str): checkpoint root directory
get_shape_dtype_fn (str -> ((int, ...), np.dtype)): a function returning
an array shape and dtype for a given Zarr array path
"""
sharded_state_dict = {}
for subdir in checkpoint_dir.iterdir():
if not subdir.is_dir() or not (subdir / '.zarray').exists():
continue
key = subdir.name
arr_shape, arr_dtype = get_shape_dtype_fn(str(subdir))
sharded_state_dict[key] = ShardedTensor(
key,
None,
numpy_to_torch_dtype_dict[arr_dtype],
arr_shape,
arr_shape,
tuple(0 for _ in arr_shape),
tuple(1 for _ in arr_shape),
)
return sharded_state_dict
# default_strategies[StrategyAction.LOAD_SHARDED.value][('zarr', 1)] = ZarrLoadShardedStrategy()
default_strategies[StrategyAction.SAVE_SHARDED.value][('zarr', 1)] = ZarrSaveShardedStrategy(
'zarr', 1
)
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Helpers for manipulating sharded tensors and sharded state dicts. """
from typing import Dict, Tuple
from .dict_utils import dict_list_map_inplace, extract_matching_values
from .mapping import (
LocalNonpersitentObject,
ShardedBase,
ShardedObject,
ShardedStateDict,
ShardedTensor,
ShardedTensorFactory,
StateDict,
)
def extract_sharded_tensors(
sharded_state_dict: ShardedStateDict,
) -> Tuple[ShardedStateDict, StateDict]:
""" Extract a dict consisting of only ShardedTensor objects from a given state dict with any objects.
Args:
sharded_state_dict: state dict possibly containing ShardedTensor objects
Returns:
Tuple[ShardedStateDict, StateDict]: tuple of:
- state dict with all ShardedTensor (keeping the original state dict structure)
- state dict with all objects other than ShardedTensor (keeping the original state dict structure)
"""
return extract_matching_values(sharded_state_dict, lambda v: isinstance(v, ShardedTensor))
def extract_sharded_tensors_and_factories(
sharded_state_dict: ShardedStateDict,
) -> Tuple[ShardedStateDict, StateDict]:
""" Extract a dict consisting of only ShardedTensor and ShardedTensorFactory objects from a given state dict with any objects.
Args:
sharded_state_dict: state dict possibly containing ShardedTensor and ShardedTensorFactory objects
Returns:
Tuple[ShardedStateDict, StateDict]: tuple of:
- state dict with all ShardedTensor and ShardedTensorFactory (keeping the original state dict structure)
- state dict with all other objects (keeping the original state dict structure)
"""
return extract_matching_values(
sharded_state_dict, lambda v: isinstance(v, (ShardedTensor, ShardedTensorFactory))
)
def extract_sharded_tensors_or_nonpersistent(
sharded_state_dict: ShardedStateDict,
) -> Tuple[ShardedStateDict, StateDict]:
""" Extract a dict consisting of only ShardedTensor, ShardedTensorFactory and LocalNonpersitentObject
objects from a given state dict with any objects.
Args:
sharded_state_dict: state dict possibly containing ShardedTensor, ShardedTensorFactory and LocalNonpersitentObject objects
Returns:
Tuple[ShardedStateDict, StateDict]: tuple of:
- state dict with all ShardedTensor, ShardedTensorFactory and LocalNonpersitentObject (keeping the original state dict structure)
- state dict with all other objects (keeping the original state dict structure)
"""
return extract_matching_values(
sharded_state_dict,
lambda v: isinstance(v, (ShardedTensor, LocalNonpersitentObject, ShardedTensorFactory)),
)
def extract_sharded_base(
sharded_state_dict: ShardedStateDict,
) -> Tuple[ShardedStateDict, StateDict]:
return extract_matching_values(sharded_state_dict, lambda v: isinstance(v, ShardedBase),)
def extract_nonpersistent(
sharded_state_dict: ShardedStateDict,
) -> Tuple[ShardedStateDict, StateDict]:
return extract_matching_values(
sharded_state_dict, lambda v: isinstance(v, LocalNonpersitentObject),
)
def add_prefix_for_sharding(sharded_state_dict: ShardedStateDict, prefix: str):
""" Prepend a given prefix to all ShardedBase objects in a given state dict *in-place*.
Args:
sharded_state_dict (ShardedStateDict): sharded state dict
prefix (str): prefix to be prepended
Returns:
None: state dict is modified in-place
"""
def add_prefix(t):
if isinstance(t, ShardedBase):
t.key = f'{prefix}{t.key}'
return t
dict_list_map_inplace(add_prefix, sharded_state_dict)
def replace_prefix_for_sharding(
sharded_state_dict: ShardedStateDict, old_prefix: str, new_prefix: str
):
""" Replaces the given prefix in *all* sharded keys in a given state dict.
Errors out if some key does not begin with a given prefix.
Args:
sharded_state_dict (ShardedStateDict): sharded state dict to replace keys in
old_prefix (str): prefix to be replaced in each key
new_prefix (str): new prefix
Returns:
None: state dict is modified in place
"""
def _replace_prefix(x):
if isinstance(x, (ShardedTensor, ShardedTensorFactory, ShardedObject)):
if not x.key.startswith(old_prefix):
raise ValueError(f'Expected {x.key} to begin with prefix {old_prefix}')
x.key = f'{new_prefix}{x.key[len(old_prefix):]}' # str.removeprefix in Python >= 3.9
return x
dict_list_map_inplace(_replace_prefix, sharded_state_dict)
def apply_prefix_mapping(sharded_state_dict: ShardedStateDict, prefix_map: Dict[str, str]):
""" Replaces prefixes *only in keys matching* with one of prefixes in the map.
Args:
sharded_state_dict (ShardedStateDict): sharded state dict to replace keys in
prefix_map (Dict[str, str]): map of old->new prefixes. The first matching prefix for each key is used
Returns:
None: state dict is modified in place
"""
def _replace_prefixes(x):
if not isinstance(x, (ShardedTensor, ShardedTensorFactory, ShardedObject)):
return x
for old_prefix, new_prefix in prefix_map.items():
if x.key.startswith(old_prefix):
x.key = (
f'{new_prefix}{x.key[len(old_prefix):]}' # str.removeprefix in Python >= 3.9
)
break
return x
dict_list_map_inplace(_replace_prefixes, sharded_state_dict)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from .distributed_data_parallel import DistributedDataParallel
from .distributed_data_parallel_config import DistributedDataParallelConfig
from .finalize_model_grads import finalize_model_grads
from .param_and_grad_buffer import ParamAndGradBuffer, shard_buffer
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from contextlib import contextmanager
from logging import getLogger
from typing import Dict, Optional
import torch
from .. import parallel_state
from ..transformer.module import MegatronModule
from ..transformer.transformer_config import TransformerConfig
from .distributed_data_parallel_config import DistributedDataParallelConfig
from .param_and_grad_buffer import ParamAndGradBuffer
logger = getLogger(__name__)
class DistributedDataParallel(MegatronModule):
"""
DDP wrapper which stores grads in contiguous buffers. Also has option of overlapping
communication with backprop computation by breaking up full model's gradients into smaller
buckets and running all-reduce / reduce-scatter on each bucket asynchronously. This class
also provides the option to do the gradient accumulation in a type other than the param type
(e.g., fp32 for a bf16 model).
Args:
config: Transformer config object.
ddp_config: DistributedDataParallel config object.
module: Underlying model.
data_parallel_group: Data-parallel process group.
expert_data_parallel_group: Optional data-parallel process group for experts in a MoE.
disable_bucketing: If true, force assign all parameters to a single bucket. If false,
use standard bucketing policy: assign parameters to smaller buckets and all-reduce
per bucket _if_ overlap_grad_reduce is True and pp_rank is 0.
check_for_nan_in_grad: If true, check if local grad norm is NaN.
"""
def __init__(
self,
config: TransformerConfig,
ddp_config: DistributedDataParallelConfig,
module: torch.nn.Module,
data_parallel_group: torch.distributed.ProcessGroup,
expert_data_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
disable_bucketing: bool = False,
):
super().__init__(config=config)
self.module = module
# If bucket_size is not provided as an input, use sane default.
# If using very large dp_sizes, make buckets larger to ensure that chunks used in NCCL
# ring-reduce implementations are large enough to remain bandwidth-bound rather than
# latency-bound.
if ddp_config.bucket_size is None:
dp_size = parallel_state.get_data_parallel_world_size()
ddp_config.bucket_size = max(40000000, 1000000 * dp_size)
# Set bucket_size to infinity if overlap_grad_reduce is False.
if not ddp_config.overlap_grad_reduce:
ddp_config.bucket_size = None
self.ddp_config = ddp_config
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
logger.info(
f'Setting up DistributedDataParallel with {type(self.ddp_config).__name__}: {self.ddp_config}'
)
# Turn off bucketing if we are on a pipeline stage that is not the first (since
# data-parallel communication on these stages is not on the critical path), or if
# disable_bucketing is True (e.g., we might not want to break up model parameters
# into buckets for model chunks after the first in the interleaved schedule).
self.bucket_size = self.ddp_config.bucket_size
if parallel_state.get_pipeline_model_parallel_rank() > 0:
self.bucket_size = None
if disable_bucketing:
self.bucket_size = None
self.module = module
self.param_to_buffer = {}
# Group parameters by their gradient type.
param_to_name = {}
dense_params = []
expert_parallel_params = []
for name, param in self.module.named_parameters():
if not param.requires_grad:
continue
param.grad_added_to_main_grad = False
param_to_name[param] = name
if getattr(param, 'allreduce', True):
dense_params.append(param)
else:
expert_parallel_params.append(param)
def allocate_buffers_for_parameters(
input_params, data_parallel_group, gradient_scaling_factor=1.0,
):
param_and_grad_dtype_to_params = {}
# Group parameters by their gradient type.
for param in input_params:
if not param.requires_grad:
continue
param_dtype = param.dtype
grad_dtype = torch.float if self.ddp_config.grad_reduce_in_fp32 else param.dtype
params = param_and_grad_dtype_to_params.get((param_dtype, grad_dtype), [])
params.append(param)
param_and_grad_dtype_to_params[(param_dtype, grad_dtype)] = params
# Allocate the grad buffers and map the grads.
buffers = []
for (param_dtype, grad_dtype), params in param_and_grad_dtype_to_params.items():
buffers.append(
ParamAndGradBuffer(
self.ddp_config,
param_dtype,
grad_dtype,
params,
data_parallel_group,
self.bucket_size,
param_to_name,
gradient_scaling_factor,
)
)
for param in params:
self.param_to_buffer[param] = buffers[-1]
return buffers
data_parallel_world_size = torch.distributed.get_world_size(data_parallel_group)
# Allocate the param+grad buffers for dense params' grads.
self.buffers = allocate_buffers_for_parameters(
dense_params,
data_parallel_group,
gradient_scaling_factor=1.0 / data_parallel_world_size,
)
# Allocate separate param+grad buffers for expert parallel params' grads.
self.expert_parallel_buffers = allocate_buffers_for_parameters(
expert_parallel_params,
expert_data_parallel_group,
gradient_scaling_factor=1.0 / data_parallel_world_size,
)
# Delete references to weight_tensor if they exist since we don't want two parameter copies
# if we re-mapped parameters (which happens when we use the distributed optimizer).
# This is a temporary workaround around a TE bug that is fixed with
# https://github.com/NVIDIA/TransformerEngine/pull/719.
if self.ddp_config.use_distributed_optimizer:
@torch.no_grad()
def unmap_weight_tensor(m):
if hasattr(m, 'weight_tensor'):
m.weight_tensor = None
self.module.apply(unmap_weight_tensor)
# Register backward hook.
# Accumulation function for the gradients need to be stored so they
# don't go out of scope.
self.grad_accs = []
for param in self.module.parameters():
if param.requires_grad:
# Expand so we get access to grad_fn.
param_tmp = param.expand_as(param)
# Get the gradient accumulator function.
grad_acc = param_tmp.grad_fn.next_functions[0][0]
grad_acc.register_hook(self._make_param_hook(param, self.param_to_buffer))
self.grad_accs.append(grad_acc)
def forward(self, *inputs, **kwargs):
"""
Calls the wrapped module's forward() method.
"""
return self.module(*inputs, **kwargs)
def _make_param_hook(
self,
param: torch.nn.Parameter,
param_to_buffer: Dict[torch.nn.Parameter, ParamAndGradBuffer],
):
"""
Creates the all-reduce / reduce-scatter hook for backprop.
"""
def param_hook(*unused):
if param.requires_grad:
if self.ddp_config.overlap_grad_reduce:
assert (
param.grad is not None
), 'param.grad being None is not safe when overlap_grad_reduce is True'
if param.grad is not None and (
not param.grad_added_to_main_grad or getattr(param, 'zero_out_wgrad', False)
):
param.main_grad.add_(param.grad.data)
param.grad = None
if self.ddp_config.overlap_grad_reduce:
param_to_buffer[param].register_grad_ready(param)
return param_hook
@contextmanager
def no_sync(self):
"""
Context manager that turns off gradient synchronization.
"""
for buffer in self.buffers + self.expert_parallel_buffers:
buffer.is_last_microbatch = False
try:
yield
finally:
for buffer in self.buffers + self.expert_parallel_buffers:
buffer.is_last_microbatch = True
def start_grad_sync(self, *unused):
"""
Initiates grad sync (all-reduce or reduce-scatter) communication operations
for all model gradients.
When overlap_grad_reduce is set to True, dispatches asynchronous communication
calls. When overlap_grad_reduce is set to False, calls synchronous
communication ops.
"""
for buffer in self.buffers + self.expert_parallel_buffers:
buffer.start_grad_sync()
def finish_grad_sync(self):
"""
Finishes grad sync (all-reduce or reduce-scatter) communication operations
for all model gradients.
When overlap_grad_reduce is set to True, waits for asynchronous communication
calls to complete. When overlap_grad_reduce is set to False, calls synchronous
communication ops.
"""
for buffer in self.buffers + self.expert_parallel_buffers:
buffer.finish_grad_sync()
def zero_grad_buffer(self):
"""
Zeros out all grad buffers. Needs to be called at the beginning of each
training iteration.
"""
for param in self.module.parameters():
if param.requires_grad:
param.grad_added_to_main_grad = False
for buffer in self.buffers + self.expert_parallel_buffers:
buffer.reset()
def broadcast_params(self):
"""
Syncs parameters across all DP ranks.
"""
for param in self.module.parameters():
is_expert_parallel = not getattr(param, 'allreduce', True)
if is_expert_parallel:
torch.distributed.broadcast(
param.data,
src=torch.distributed.get_process_group_ranks(self.expert_data_parallel_group),
group=self.expert_data_parallel_group,
)
else:
torch.distributed.broadcast(
param.data,
src=torch.distributed.get_process_group_ranks(self.data_parallel_group),
group=self.data_parallel_group,
)
def state_dict(self, prefix='', keep_vars=False):
"""
Returns a dictionary containing references to the whole state of the
wrapped module.
Both parameters and persistent buffers (e.g. running averages) are included.
Keys are corresponding parameter and buffer names. Parameters and buffers
set to None are not included.
"""
return self.module.state_dict(prefix=prefix, keep_vars=keep_vars)
def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
"""
Returns wrapped module's state_dict for checkpoint saving.
"""
return self.module.state_dict_for_save_checkpoint(prefix=prefix, keep_vars=keep_vars)
def load_state_dict(self, state_dict, strict=True):
"""
Copies parameters and buffers from state_dict into the wrapped module and its
descendants. If strict is True, then the keys of state_dict must exactly match
the keys returned by this module’s state_dict() function.
"""
self.module.load_state_dict(state_dict, strict=strict)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from dataclasses import dataclass
from typing import Optional
@dataclass
class DistributedDataParallelConfig:
"""Configuration for DistributedDataParallel."""
grad_reduce_in_fp32: bool = False
"""If true, reduce grads in fp32."""
overlap_grad_reduce: bool = False
"""If true, overlap grad all-reduce / reduce-scatter with backward compute."""
use_distributed_optimizer: bool = False
"""If true, issue reduce-scatter collectives to aggregate gradients and clean up originally
allocated model parameters, otherwise issue all-reduce collectives.
"""
check_for_nan_in_grad: bool = False
""" If true, check for NaNs in gradients _before_ communication collective."""
bucket_size: Optional[int] = None
"""Maximum number of parameters in each bucket. If unspecified, MCore uses a default
value of max(40000000, 1000000 * dp_size) parameters (larger DP sizes need larger buckets
to ensure collectives do not become latency-bound)."""
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from typing import List
import torch
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from .. import parallel_state
from ..transformer.transformer_config import TransformerConfig
from ..utils import get_attr_wrapped_model, get_model_config
def _allreduce_word_embedding_grads(model: List[torch.nn.Module], config: TransformerConfig):
"""
All-reduce word embedding grads.
Reduce grads across first and last stages to ensure that word_embeddings parameters stay in
sync. This should only run for models that support pipelined model parallelism (BERT and GPT).
"""
if (
parallel_state.is_rank_in_embedding_group(ignore_virtual=True)
and parallel_state.get_pipeline_model_parallel_world_size() > 1
):
if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
model_module = model[0]
elif parallel_state.is_pipeline_last_stage(ignore_virtual=True):
model_module = model[-1]
else: # We do not support the interleaved schedule for T5 yet.
model_module = model[0]
# Look for module with 'pre_process' attribute to get around the fact that DDP and
# other wrapper classes inherit from non-core MegatronModule that has
# 'share_embeddings_and_output_weights' and 'shared_embedding_or_output_weight'
# attributes already, causing get_attr_wrapped_model() to not unwrap anything here.
# TODO: Clean this up once the wrapper classes inherit from core MegatronModule.
model_module = get_attr_wrapped_model(model_module, 'pre_process', return_model_obj=True)
if model_module.share_embeddings_and_output_weights:
weight = model_module.shared_embedding_or_output_weight()
grad = weight.main_grad
torch.distributed.all_reduce(grad, group=parallel_state.get_embedding_group())
def _allreduce_position_embedding_grads(model: List[torch.nn.Module], config: TransformerConfig):
"""
All-reduce position_embeddings grad across first (encoder) and split (decoder) stages to
ensure that position embeddings parameters stay in sync. This should only run for T5 models
with pipeline parallelism.
"""
if (
parallel_state.is_rank_in_position_embedding_group()
and parallel_state.get_pipeline_model_parallel_world_size() > 1
and config.pipeline_model_parallel_split_rank is not None
):
model_module = model[0]
grad = get_attr_wrapped_model(
model_module, 'language_model.embedding.position_embeddings.weight.main_grad'
)
torch.distributed.all_reduce(grad, group=parallel_state.get_position_embedding_group())
def _allreduce_embedding_grads(model: List[torch.nn.Module], config: TransformerConfig):
"""
All-reduce both word and position embeddings.
"""
_allreduce_word_embedding_grads(model, config)
_allreduce_position_embedding_grads(model, config)
def _allreduce_layernorm_grads(model: List[torch.nn.Module], config: TransformerConfig):
"""
All-reduce layernorm grads (for sequence parallelism).
"""
# All-reduce layernorm parameters across model parallel nodes
# when sequence parallelism is used
if parallel_state.get_tensor_model_parallel_world_size() > 1 and (
config.sequence_parallel or config.qk_layernorm
):
grads = []
for model_chunk in model:
for name, param in get_attr_wrapped_model(model_chunk, 'named_parameters')():
if (
getattr(param, 'sequence_parallel', False)
or 'q_layernorm' in name
or 'k_layernorm' in name
):
grad = param.main_grad
grads.append(grad.data)
if grads:
coalesced = _flatten_dense_tensors(grads)
torch.distributed.all_reduce(
coalesced, group=parallel_state.get_tensor_model_parallel_group()
)
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
buf.copy_(synced)
def finalize_model_grads(model: List[torch.nn.Module]):
"""
All-reduce all model grads across DP replicas, layernorm grads for sequence parallelism,
embedding grads across first and last pipeline stages (if not tied).
"""
config = get_model_config(model[0])
# All-reduce / reduce-scatter across DP replicas.
if config.timers is not None:
config.timers('all-grads-sync', log_level=1).start(barrier=config.barrier_with_L1_time)
for model_chunk in model:
model_chunk.finish_grad_sync()
if config.timers is not None:
config.timers('all-grads-sync').stop()
# All-reduce layer-norm grads (for sequence parallelism).
if config.timers is not None:
config.timers('layernorm-grads-all-reduce', log_level=1).start(
barrier=config.barrier_with_L1_time
)
_allreduce_layernorm_grads(model, config)
if config.timers is not None:
config.timers('layernorm-grads-all-reduce').stop()
# All-reduce embedding grads (for pipeline parallelism).
if config.timers is not None:
config.timers('embedding-grads-all-reduce', log_level=1).start(
barrier=config.barrier_with_L1_time
)
_allreduce_embedding_grads(model, config)
if config.timers is not None:
config.timers('embedding-grads-all-reduce').stop()
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import math
import os
from enum import Enum
from logging import getLogger
from typing import Dict, List, Optional
import torch
from .. import parallel_state
from .distributed_data_parallel_config import DistributedDataParallelConfig
logger = getLogger(__name__)
class BufferType(Enum):
PARAM = 1
GRAD = 2
def shard_buffer(buffer: torch.Tensor, data_parallel_world_size: int):
"""
Shard buffer into data_parallel_world_size chunks of equal size.
"""
assert buffer.numel() % data_parallel_world_size == 0
shard_size = buffer.numel() // data_parallel_world_size
sharded_buffer = [
buffer[(r * shard_size) : ((r + 1) * shard_size)] for r in range(data_parallel_world_size)
]
return sharded_buffer
class Bucket:
"""
Bucket to keep track of a subset of the model's gradients. Provides functionality to register
when params in the bucket have grads ready to be synced; an asynchronous communication call
is automatically launched when _all_ params in the bucket have grads ready.
Args:
ddp_config: DistributedDataParallel config object.
params: List of parameters whose gradients are collated in this bucket.
param_data: View in larger ParamAndGradBuffer.param_data that this bucket is responsible for.
grad_data: View in larger ParamAndGradBuffer.grad_data that this bucket is responsible for.
offset: Offset of this bucket's view in the larger ParamAndGradBuffer.
numel_unpadded: Number of unpadded elements in bucket.
data_parallel_group: Data-parallel process group.
data_parallel_world_size: World size using the data-parallel group group.
gradient_scaling_factor: This factor is utilized to scale gradients prior to their
communication. Its application is twofold: it facilitates the averaging of gradients
and the scaling of gradients in the context of the Mixture of Experts (MoE) model.
"""
def __init__(
self,
ddp_config: DistributedDataParallelConfig,
params: List[torch.nn.Parameter],
param_data: Optional[torch.Tensor],
grad_data: torch.Tensor,
offset: int,
numel_unpadded: int,
data_parallel_group: torch.distributed.ProcessGroup,
data_parallel_world_size: int,
gradient_scaling_factor: float,
):
self.ddp_config = ddp_config
# State for bookkeeping: params is the set of parameters this bucket is
# responsible for, params_with_grad is the set of parameters with grads
# available. When overlap_grad_reduce is True, communication (all-reduce
# or reduce-scatter) is issued when params_with_grad equals params.
self.params_list = params
self.params = set(params)
self.params_with_grad = set()
self.param_data = param_data
self.grad_data = grad_data
# The distributed optimizer needs to keep track of this bucket's offset
# within the full grad_buffer.
self.offset = offset
self.numel_unpadded = numel_unpadded
self.data_parallel_group = data_parallel_group
self.data_parallel_world_size = data_parallel_world_size
self.data_parallel_rank = torch.distributed.get_rank(group=data_parallel_group)
self.gradient_scaling_factor = gradient_scaling_factor
self.reset()
def reset(self):
"""
Reset metadata in bucket in preparation for the next iteration of training.
"""
self.params_with_grad = set()
self.communication_handle = None
self.communication_issued = False
def start_grad_sync(self):
"""
Initiates grad sync (all-reduce or reduce-scatter) communication operation
for this bucket.
When overlap_grad_reduce is set to True, dispatches an asynchronous
communication call. When overlap_grad_reduce is set to False, makes
synchronous call.
"""
assert (
self.communication_handle is None and not self.communication_issued
), 'Should not have multiple communication calls in flight at once'
# Make sure norm of grads in bucket are not NaN
# prior to data-parallel all-reduce / reduce-scatter.
if self.ddp_config.check_for_nan_in_grad:
global_rank = torch.distributed.get_rank()
norm = self.grad_data.norm(p=2)
assert not norm.isnan(), (
f'Rank {global_rank}: found NaN in local grad norm in '
f'backward pass before data-parallel communication collective. '
f'Device: {torch.cuda.current_device()}, node: {os.uname()[1]}'
)
self.grad_data *= self.gradient_scaling_factor
# Use async_op only when overlap_grad_reduce is True.
if self.ddp_config.use_distributed_optimizer:
local_data_view = shard_buffer(self.grad_data, self.data_parallel_world_size)[
self.data_parallel_rank
]
self.communication_handle = torch.distributed._reduce_scatter_base(
local_data_view,
self.grad_data,
group=self.data_parallel_group,
async_op=self.ddp_config.overlap_grad_reduce,
)
else:
self.communication_handle = torch.distributed.all_reduce(
self.grad_data,
group=self.data_parallel_group,
async_op=self.ddp_config.overlap_grad_reduce,
)
self.communication_issued = True
def finish_grad_sync(self):
"""
Finishes grad sync (all-reduce or reduce-scatter) communication operation
for this bucket.
When overlap_grad_reduce is set to True, waits for asynchronous communication
call to complete. When overlap_grad_reduce is set to False, makes synchronous call.
"""
# If overlap_grad_reduce is False, start (and finish) synchronous communication call here.
if not self.ddp_config.overlap_grad_reduce:
self.start_grad_sync()
return
assert self.communication_handle is not None and self.communication_issued, (
f'Communication call has not been issued for this bucket '
f'({len(self.params_with_grad)}/{len(self.params)} params have grad available)'
)
self.communication_handle.wait()
def register_grad_ready(self, param: torch.nn.Parameter):
"""
Registers grads for the passed-in param to be "ready" for grad sync.
When the number of microbatches is greater than 1, we only want to register
grads as ready when processing the last microbatch and overlap_grad_reduce is True.
"""
assert param in self.params, 'Param is not in the bucket'
assert param not in self.params_with_grad, 'Cannot set grad twice'
assert (
self.ddp_config.overlap_grad_reduce
), 'register_grad_ready() should be called only when overlapping grad reduce'
self.params_with_grad.add(param)
# If all params in bucket have grads available, issue communication call.
if len(self.params_with_grad) == len(self.params):
self.start_grad_sync()
class ParamAndGradBuffer:
"""
Groups parameters and gradients into a contiguous buffer, and then breaks the buffer into
buckets with roughly `bucket_size` parameters each.
Args:
ddp_config: DistributedDataParallel config object.
param_dtype: Type of param tensor.
grad_dtype: Type of grad tensor.
params: List of parameters whose parameters and gradients are collated in the underlying
tensor.
data_parallel_group: Data-parallel process group.
bucket_size: The rough size of each bucket in terms of number of parameters.
param_to_name: Mapping from `torch.nn.Parameter` to name (for logging purposes).
gradient_scaling_factor: This factor is utilized to scale gradients prior to their
communication. Its application is twofold: it facilitates the averaging of gradients
and the scaling of gradients in the context of the Mixture of Experts (MoE) model.
"""
def __init__(
self,
ddp_config: DistributedDataParallelConfig,
param_dtype: torch.dtype,
grad_dtype: torch.dtype,
params: List[torch.nn.Parameter],
data_parallel_group: torch.distributed.ProcessGroup,
bucket_size: int,
param_to_name: Dict[torch.nn.Parameter, str],
gradient_scaling_factor: float,
):
self.ddp_config = ddp_config
# Check that params are unique.
unique_params = set()
for param in params:
assert param not in unique_params
unique_params.add(param)
del unique_params
# Store attributes that will be needed later.
self.param_dtype = param_dtype
self.grad_dtype = grad_dtype
self.data_parallel_group = data_parallel_group
self.data_parallel_world_size = torch.distributed.get_world_size(
group=self.data_parallel_group
)
self.gradient_scaling_factor = gradient_scaling_factor
self.is_last_microbatch = True
# Data structures to store underlying buckets and relevant indexing data.
self.buckets = []
self.param_to_bucket = {} # Param -> bucket mapping.
self.param_index_map = {} # Param -> location in buffer mapping (used in dist. optimizer).
def _pad_if_needed(data_index: int) -> int:
"""
Pads data indices if using distributed optimizer (to ensure uniform sharding).
"""
if self.ddp_config.use_distributed_optimizer:
return (
int(math.ceil(data_index / self.data_parallel_world_size))
* self.data_parallel_world_size
)
return data_index
# First, figure out how many elements should be in the underlying buffer storage.
# Note that if we need to split the buffer into smaller buckets, each of these
# might need to be padded as well (if using the distributed optimizer).
data_start_index = 0
bucket_data_start_index = data_start_index
bucket_params = set()
self.bucket_indices = []
per_bucket_numel_unpadded = []
bucket_id = 0
def _create_new_bucket(data_end_index: int) -> int:
"""
Create the bucket_id'th bucket with collected bucket_params, starting at
bucket_data_start_index.
"""
nonlocal bucket_data_start_index, bucket_params, bucket_id
per_bucket_numel_unpadded.append(data_end_index - bucket_data_start_index)
data_end_index = _pad_if_needed(data_end_index)
# Update bucket metadata.
self.bucket_indices.append((bucket_data_start_index, data_end_index))
bucket_data_start_index = data_end_index
# Re-set bucket_params and increment bucket_id for next bucket.
bucket_params = set()
bucket_id += 1
# Return the potentially padded data_end_index.
return data_end_index
for param in params[::-1]:
# Iterate through parameters in reverse order to roughly follow backprop order,
# and skip parameters that don't require gradients.
if not param.requires_grad:
continue
this_numel = param.data.nelement()
data_end_index = data_start_index + this_numel
def _does_param_require_new_bucket(param):
"""
Split shared embedding parameters into separate bucket if using distributed
optimizer that makes use of reduce-scatters instead of all-reduces.
This ensures that the first and last pipeline stage partition optimizer state
for the shared embedding parameters the same way across DP replicas, allowing
the DP reduce-scatter to be before the embedding all-reduce.
"""
return (
getattr(param, "shared_embedding", False)
and self.ddp_config.use_distributed_optimizer
)
# Create bucket with already collected parameters if current param needs its own bucket.
if _does_param_require_new_bucket(param) and len(bucket_params) > 0:
# We are creating a bucket for the already accumulated parameters, whose params
# end at the current data_start_index.
if self.ddp_config.use_distributed_optimizer:
# data_start_index should already be padded.
assert data_start_index % self.data_parallel_world_size == 0
_create_new_bucket(data_start_index)
self.param_index_map[param] = (
data_start_index,
data_end_index,
bucket_id,
)
bucket_params.add(param)
# If we have enough elements already or the current param is part of the shared embedding
# layer and needs a separate bucket, form a new bucket.
if (
bucket_size is not None
and (data_end_index - bucket_data_start_index) >= bucket_size
) or _does_param_require_new_bucket(param):
data_end_index = _create_new_bucket(data_end_index)
data_start_index = data_end_index
# Add remaining params to a new bucket.
if len(bucket_params) > 0:
data_end_index = _create_new_bucket(data_end_index)
# Next, create underlying storage for buffer (with numel elements that includes
# padding as necessary).
self.numel = data_end_index
if self.ddp_config.use_distributed_optimizer:
assert self.numel % self.data_parallel_world_size == 0
self.param_data = None
# Only re-map param tensors if using distributed optimizer.
if self.ddp_config.use_distributed_optimizer:
self.param_data = torch.zeros(
self.numel,
dtype=self.param_dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
self.grad_data = torch.zeros(
self.numel,
dtype=self.grad_dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
# Finally, map param.data and param.main_grad fields to buffers.
bucket_params = set()
bucket_data_start_index = 0
cur_bucket_id = 0
for param in params[::-1]:
if not param.requires_grad:
continue
data_start_index, data_end_index, bucket_id = self.param_index_map[param]
# Assign param.data to appropriate segment of self.param_data.
if self.param_data is not None:
old_param_data = param.data
param.data = self._get(
param.data.shape, data_start_index, buffer_type=BufferType.PARAM
)
assert old_param_data._base is None
# Copy tensor values (from initialization or checkpoint).
param.data.detach().copy_(old_param_data)
del old_param_data
param.main_grad = self._get(
param.data.shape, data_start_index, buffer_type=BufferType.GRAD
)
if bucket_id != cur_bucket_id:
bucket_data_end_index = _pad_if_needed(data_start_index)
self._set_bucket(
bucket_params=bucket_params,
start_index=bucket_data_start_index,
end_index=bucket_data_end_index,
numel_unpadded=per_bucket_numel_unpadded[cur_bucket_id],
bucket_id=cur_bucket_id,
)
bucket_data_start_index = bucket_data_end_index
bucket_params = set()
assert cur_bucket_id + 1 == len(self.buckets)
assert bucket_id == cur_bucket_id + 1
cur_bucket_id = bucket_id
bucket_params.add(param)
# Add remaining params to a new bucket.
if len(bucket_params) > 0:
bucket_data_end_index = _pad_if_needed(data_end_index)
self._set_bucket(
bucket_params=bucket_params,
start_index=bucket_data_start_index,
end_index=bucket_data_end_index,
numel_unpadded=per_bucket_numel_unpadded[cur_bucket_id],
bucket_id=cur_bucket_id,
)
# Log buckets for all PP stages.
if (
parallel_state.get_data_parallel_rank(with_context_parallel=True) == 0
and parallel_state.get_tensor_model_parallel_rank() == 0
):
logger.info(
f'Number of buckets for gradient all-reduce / reduce-scatter: {len(self.buckets)}'
)
for index, bucket in enumerate(self.buckets):
numel = 0
for param in bucket.params:
numel += param.data.nelement()
logger.info(f'Params for bucket {index+1} ({numel} elements):')
for param in bucket.params:
logger.info(f' {param_to_name[param]}')
def _get(self, shape: torch.Size, start_index: int, buffer_type: BufferType) -> torch.Tensor:
"""
Return a tensor with the input `shape` as a view into the 1-D data starting at
`start_index`.
"""
end_index = start_index + shape.numel()
assert end_index <= self.numel, 'Requested tensor is out of buffer range'
if buffer_type == BufferType.PARAM:
assert self.param_data is not None
buffer_tensor = self.param_data[start_index:end_index]
elif buffer_type == BufferType.GRAD:
buffer_tensor = self.grad_data[start_index:end_index]
else:
raise Exception("Illegal buffer type provided to GradBuffer._get() function")
buffer_tensor = buffer_tensor.view(shape)
return buffer_tensor
def _set_bucket(
self,
bucket_params: List[torch.nn.Parameter],
start_index: int,
end_index: int,
numel_unpadded: int,
bucket_id: int,
):
"""
Helper function to create new bucket, add it to list of buckets, and
also update param->bucket mapping.
"""
# Assert that indices are correctly padded (if needed), and that bucket
# position is same as originally computed.
if self.ddp_config.use_distributed_optimizer:
assert start_index % self.data_parallel_world_size == 0
assert end_index % self.data_parallel_world_size == 0
assert (start_index, end_index) == self.bucket_indices[bucket_id]
# Get appropriate view into global ParamAndGradBuffer.
bucketed_param_data = None
if self.param_data is not None:
bucketed_param_data = self._get(
torch.Size([end_index - start_index]), start_index, buffer_type=BufferType.PARAM
)
bucketed_grad_data = self._get(
torch.Size([end_index - start_index]), start_index, buffer_type=BufferType.GRAD
)
bucket = Bucket(
ddp_config=self.ddp_config,
params=bucket_params,
param_data=bucketed_param_data,
grad_data=bucketed_grad_data,
offset=start_index,
numel_unpadded=numel_unpadded,
data_parallel_group=self.data_parallel_group,
data_parallel_world_size=self.data_parallel_world_size,
gradient_scaling_factor=self.gradient_scaling_factor,
)
self.buckets.append(bucket)
for bucket_param in bucket_params:
assert bucket_param not in self.param_to_bucket
self.param_to_bucket[bucket_param] = bucket
def reset(self):
"""
Zero out the underlying grad_buffer and reset all buckets in preparation for the next
iteration of training.
"""
self.grad_data.zero_()
for bucket in self.buckets:
bucket.reset()
self.is_last_microbatch = True
def start_grad_sync(self):
"""
Initiates grad sync (all-reduce or reduce-scatter) communication operations
for all buckets in the grad buffer.
When overlap_grad_reduce is set to True, dispatches asynchronous communication
calls. When overlap_grad_reduce is set to False, calls synchronous
communication ops.
"""
for bucket in self.buckets:
bucket.start_grad_sync()
def finish_grad_sync(self):
"""
Finishes grad sync (all-reduce or reduce-scatter) communication operations
for all buckets in the grad buffer.
When overlap_grad_reduce is set to True, waits for asynchronous communication
calls to complete. When overlap_grad_reduce is set to False, calls synchronous
communication ops.
"""
for bucket in self.buckets:
bucket.finish_grad_sync()
def register_grad_ready(self, param: torch.nn.Parameter):
"""
Registers grads for the passed-in param to be "ready" for grad sync.
When the number of microbatches is greater than 1, we only want to register
grads as ready when processing the last microbatch and overlap_grad_reduce is True.
"""
assert (
self.ddp_config.overlap_grad_reduce
), 'register_grad_ready() should only be called when overlap_grad_reduce is True'
if self.is_last_microbatch:
bucket = self.param_to_bucket[param]
bucket.register_grad_ready(param)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment