Commit 7c19b3a8 authored by wangsen's avatar wangsen
Browse files

Initial commit

parents
Pipeline #1721 failed with stages
in 0 seconds
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Strategies using TensorStore to load and save Zarr arrays. """
from functools import partial
from itertools import starmap
from logging import getLogger
from pathlib import Path
import tensorstore as ts
import torch
from ..core import CheckpointingException
from ..dict_utils import dict_list_map_inplace
from ..mapping import ShardedStateDict, ShardedTensor
from .base import LoadShardedStrategy, StrategyAction, default_strategies
from .zarr import (
load_zarr_based_sharded_metadata,
numpy_to_torch_dtype_dict,
postprocess_numpy_array,
)
_import_trigger = None
logger = getLogger(__name__)
class TensorStoreLoadShardedStrategy(LoadShardedStrategy):
def __init__(self, load_directly_on_device: bool = False):
super().__init__()
self.load_directly_on_device = load_directly_on_device
def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path):
if torch.distributed.get_rank() == 0:
print(f'Loading distributed checkpoint with {self.__class__.__name__}')
if self.load_directly_on_device:
print(f'Loading distributed checkpoint directly on the GPU')
load_fn = partial(
_load_from_array,
checkpoint_dir=checkpoint_dir,
load_directly_on_device=self.load_directly_on_device,
)
dict_list_map_inplace(load_fn, sharded_state_dict)
return sharded_state_dict
def load_tensors_metadata(self, checkpoint_dir: Path):
def get_ts_shape_dtype(path):
arr = open_ts_array(path)
return arr.shape, arr.dtype.numpy_dtype
return load_zarr_based_sharded_metadata(checkpoint_dir, get_ts_shape_dtype)
def check_backend_compatibility(self, loaded_version):
pass # TODO
def check_version_compatibility(self, loaded_version):
pass # TODO
def merge_global_slice_with_shape(global_slice, actual_shape, key):
def _merge_slice(dim_slice, dim_size):
if isinstance(dim_slice, slice):
assert (
dim_slice.start < dim_size
), f'Got empty slice for ShardedTensor {key} ({dim_slice}, {dim_size})'
if dim_slice.stop > dim_size:
dim_slice = slice(dim_slice.start, dim_size, dim_slice.step)
return dim_slice
assert len(global_slice) == len(actual_shape), (global_slice, actual_shape, key)
return tuple(starmap(_merge_slice, zip(global_slice, actual_shape)))
def _load_from_array(
sharded_tensor: ShardedTensor,
checkpoint_dir: Path,
load_directly_on_device: bool = False,
apply_flattened_range: bool = True,
):
x = _load_regular_chunk(sharded_tensor, checkpoint_dir)
ten = postprocess_numpy_array(x, sharded_tensor, apply_flattened_range)
if load_directly_on_device:
sharded_tensor.data.data.copy_(ten)
return sharded_tensor.data
else:
return ten
def _load_regular_chunk(sharded_tensor: ShardedTensor, checkpoint_dir: Path):
assert isinstance(sharded_tensor, ShardedTensor), type(sharded_tensor)
arr = open_ts_array(checkpoint_dir / sharded_tensor.key)
if sharded_tensor.global_shape == arr.shape:
x = (
arr[sharded_tensor.global_slice()].read().result()
) # flattened tensors loading is delayed
elif sharded_tensor.allow_shape_mismatch:
global_slice = merge_global_slice_with_shape(
sharded_tensor.global_slice(), arr.shape, sharded_tensor.key
)
x = arr[global_slice].read().result() # flattened tensors loading is delayed
else:
_msg = (
f'Global shape mismatch for loaded ({arr.shape})'
f' and expected ({sharded_tensor.global_shape}) tensor'
f' for key {sharded_tensor.key}'
)
raise CheckpointingException(_msg)
return x
def open_ts_array(arr_path: Path):
"""Opens a Zarr file array with Tensorstore with basic setting.
Args:
arr_path (Path): path to a Zarr (Tensorstore) array
"""
spec = {'driver': 'zarr', 'metadata_key': '.zarray', 'kvstore': {}}
spec['kvstore'] = {
'driver': 'file',
'path': str(arr_path),
}
try:
arr = ts.open(ts.Spec(spec), open=True).result()
except Exception as e:
raise CheckpointingException(f'Array {arr_path} could not be loaded. Error: {e}') from e
return arr
default_strategies[StrategyAction.LOAD_SHARDED.value][
('zarr', 1)
] = TensorStoreLoadShardedStrategy()
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Strategies using PyTorch distributed.checkpoint as an underlying format. """
import dataclasses
import io
import itertools
import math
from collections import ChainMap, defaultdict
from dataclasses import dataclass
from itertools import product
from logging import getLogger
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, cast
import numpy as np
import torch
from torch.distributed import checkpoint
from torch.distributed._shard.metadata import ShardMetadata
from torch.distributed._shard.sharded_tensor import Shard, ShardedTensorMetadata, TensorProperties
from torch.distributed._sharded_tensor import ShardedTensor as TorchShardedTensor
from torch.distributed.checkpoint import (
DefaultLoadPlanner,
DefaultSavePlanner,
FileSystemReader,
LoadPlan,
Metadata,
SavePlan,
TensorStorageMetadata,
WriteItem,
)
from torch.distributed.checkpoint._nested_dict import FLATTEN_MAPPING, unflatten_state_dict
from torch.distributed.checkpoint._traverse import OBJ_PATH, traverse_state_dict
from torch.distributed.checkpoint.default_planner import create_default_local_save_plan
from torch.distributed.checkpoint.planner_helpers import _create_write_items
from ..core import CheckpointingException
from ..dict_utils import nested_values
from ..mapping import (
ShardedBase,
ShardedObject,
ShardedStateDict,
ShardedTensor,
StateDict,
is_main_replica,
)
from .async_utils import AsyncRequest
from .base import (
AsyncSaveShardedStrategy,
LoadShardedStrategy,
SaveShardedStrategy,
StrategyAction,
default_strategies,
)
from .filesystem_async import FileSystemWriterAsync
from .state_dict_saver import save_state_dict_async_finalize, save_state_dict_async_plan
_import_trigger = None
logger = getLogger(__name__)
def flatten_state_dict(
state_dict: ShardedStateDict,
) -> Tuple[ShardedStateDict, Dict[str, OBJ_PATH]]:
""" Flattens state dict into a single level dict.
It's a copy of torch.distributed.checkpoint._nested_dict.flatten_state_dict
which also accepts ShardedBase tensors as terminal objects
Args:
state_dict (ShardedStateDict): state dict to be flattened
Returns (tuple): flattened state dict and a mapping allowing to recreate the original one
"""
flattened = {}
mappings = {}
def flat_copy(path: OBJ_PATH, value: Any) -> None:
new_fqn = ".".join(map(str, path))
if new_fqn in flattened:
raise ValueError(f"duplicated flatten key {new_fqn}")
flattened[new_fqn] = value
mappings[new_fqn] = path
traverse_state_dict(state_dict, flat_copy, lambda x: isinstance(x, (torch.Tensor, ShardedBase)))
return flattened, mappings
def sharded_tensor_to_torch_sharded_tensor(
sh_tens: List[ShardedTensor], rank: Optional[int] = None
) -> TorchShardedTensor:
"""Convert MCore ShardedTensor to PyT ShardedTensor. PyT requires information about all chunks.
On high-level, this function follows the logic of torch.distributed.fsdp._shard_utils._create_chunk_sharded_tensor.
Additionally, it saves `prepend_axis_num` and `has_flattened_range` (specific to MCore) as attributes
for further restoration in `_unwrap_pyt_sharded_tensor`.
NOTE: this function assumes regular (grid) sharding of the MCore ShardedTensor.
The only local irregularities could be introduced with a `flattened_range` attribute.
This function handles 3 different type of ShardedTensors:
1. Non-flat regular ShardedTensors (`not has_flattened_range`)
2. 1D flattened ShardedTensors (`is_flattened_range_1d`)
3. N-D flattened ShardedTensors (`has_flattened_range`)
(1) and (2) type are saved according to their original shape.
Type (3) however requires global shape adjustment for efficiency:
we treat [X, Y, Z] global shape tensor with local shape [x, y, z]
as a [X // x, Y // y, Z // z, x * y * z] tensor with last axis
partitioned according to `flattened_range` slices.
This will need special handling while resharding.
Args:
sh_tens (List[ShardedTensor]): list of sharded tensors to convert
rank (int, optional): current process rank passed to PyT ShardedTensor.
If None, assumes rank in the default pg.
Returns (TorchShardedTensor): PyT ShardedTensor containing all passed shards.
"""
if rank is None:
rank = torch.distributed.get_rank()
some_sh_ten = sh_tens[0]
has_flattened_range = some_sh_ten.flattened_range is not None
is_flattened_range_1d = has_flattened_range and len(some_sh_ten.global_shape) == 1
for sh_ten in sh_tens:
assert (sh_ten.flattened_range is not None) == has_flattened_range, sh_tens
if not sh_ten.data.is_contiguous():
sh_ten.data = sh_ten.data.contiguous()
local_global_offsets = {}
prepend_axis_num = sh_tens[0].prepend_axis_num
# Determine local shards according to tensor type (see docs)
if is_flattened_range_1d:
# Type (2) case: 1D flattened ShardedTensors
for sh_ten in sh_tens:
assert len(sh_ten.global_offset) == 1, sh_ten
assert sh_ten.prepend_axis_num == 0, sh_ten
local_global_offsets.setdefault(sh_ten.global_offset, []).append(sh_ten)
global_shape = some_sh_ten.global_shape
offsets_shape = (
some_sh_ten.local_shape
) # local shape is not flattened, we need it for chunk offsets
local_shards = [
Shard.from_tensor_and_offsets(
sh_ten.data,
[
sh_ten.global_offset[0] + sh_ten.flattened_range.start
], # additional flattened offset
rank,
)
for sh_ten in sh_tens
]
elif has_flattened_range:
# Type (3) case: N-D flattened ShardedTensors
for sh_ten in sh_tens:
local_global_offsets.setdefault(sh_ten.local_chunk_offset_in_global(), []).append(
sh_ten
)
assert sh_ten.data.ndim == 1, sh_ten
sh_ten.data = sh_ten.data.view((1,) * len(sh_ten.global_shape) + (-1,))
# Global shape reformulation:
global_shape = some_sh_ten.axis_fragmentations + (int(np.prod(some_sh_ten.local_shape)),)
offsets_shape = (1,) * len(
some_sh_ten.global_shape
) # reformulated global shape has shape equal ti number of local chunks
local_shards = [
Shard.from_tensor_and_offsets(
sh_ten.data,
list(
sh_ten.local_chunk_offset_in_global() + (sh_ten.flattened_range.start,)
), # additional flattened offset
rank,
)
for sh_ten in sh_tens
]
else:
# Type (1) case: non-flat regular ShardedTensors
for sh_ten in sh_tens:
local_global_offsets.setdefault(sh_ten.global_offset, []).append(sh_ten)
sh_ten.data = sh_ten.data.view(
(1,) * prepend_axis_num + sh_ten.local_shape
) # adjust to prepended_axis_num
global_shape = some_sh_ten.global_shape
offsets_shape = some_sh_ten.data.shape # includes prepended axes
local_shards = [
Shard.from_tensor_and_offsets(
sh_ten.data, list(sh_ten.global_offset), rank # simple case
)
for sh_ten in sh_tens
]
# Create a ShardedTensor without invoking communication. Determine global shards
shard_metadata = []
# NOTE: here we assume a regular grid of shards
for fragment_offsets in itertools.product(*map(range, some_sh_ten.axis_fragmentations)):
offset = tuple(map(lambda x: x[0] * x[1], zip(fragment_offsets, offsets_shape)))
if offset in local_global_offsets:
# local shard
placement = f"rank:{rank}/cuda"
for sh_ten in local_global_offsets[offset]:
if is_flattened_range_1d:
offset = (sh_ten.global_offset[0] + sh_ten.flattened_range.start,)
size = sh_ten.data.shape
elif has_flattened_range:
assert offset == sh_ten.local_chunk_offset_in_global()
# This is not an actual offset, but an offset of the whole shard
# This is needed for a PyT Dist internal integrity check
offset = sh_ten.local_chunk_offset_in_global() + (0,)
size = (1,) * len(offsets_shape) + global_shape[-1:]
else:
size = sh_ten.data.shape
shard_metadata.append(ShardMetadata(offset, size, placement))
else:
# for shards from other ranks we provide simplistic data - this information will be discarded
# during TorchShardedTensor._init_from_local_shards_and_global_metadata call
if has_flattened_range and not is_flattened_range_1d:
offset = offset + (0,)
size = (1,) * len(offsets_shape) + global_shape[-1:]
else:
size = offsets_shape
shard_metadata.append(ShardMetadata(offset, size, "cuda"))
tensor = some_sh_ten.data
sharded_tensor_metadata = ShardedTensorMetadata(
shards_metadata=shard_metadata,
size=torch.Size(global_shape),
tensor_properties=TensorProperties(
dtype=tensor.dtype,
layout=tensor.layout,
requires_grad=tensor.requires_grad,
memory_format=torch.contiguous_format,
pin_memory=tensor.is_pinned(),
),
)
pyt_sh_ten = TorchShardedTensor._init_from_local_shards_and_global_metadata(
local_shards, sharded_tensor_metadata=sharded_tensor_metadata, process_group=None
)
# Store MCore related data as PyTShardedTensor attribute. This won't be stored in the checkpoint, only for runtime purposes
pyt_sh_ten.mcore_sh_ten = sh_ten.without_data()
pyt_sh_ten.mcore_metadata = {}
if has_flattened_range and not is_flattened_range_1d:
pyt_sh_ten.mcore_metadata['nd_reformulated_orig_global_shape'] = sh_ten.global_shape
return pyt_sh_ten
def mcore_to_pyt_state_dict(
state_dict: Dict[str, List[ShardedBase]],
is_loading: bool = False,
init_device: torch.device = torch.device("cpu"),
) -> Dict[str, Union[TorchShardedTensor, io.BytesIO]]:
"""Turn state dict with ShardedTensors and ShardedObjects to state dict compatible with PyT Dist format.
Operates in-place and returns the original state dict.
Args:
state_dict (Dict[str, List[ShardedBase]]): flattened state dict, where values
are lists of either ShardedTensor or ShardedObjects.
is_loading (bool, optional): flag indicating if loading or saving. Defaults to False.
init_device (torch.device, optional): device to initialize potentially missing tensors
during loading. Defaults to 'cpu'.
Returns (Dict[str, Union[TorchShardedTensor, io.BytesIO]]): original dictionary with values
converted either into PyT ShardedTensors or io.BytesIO.
"""
rank = torch.distributed.get_rank()
pyt_state_dict = {}
def _mcore_to_torch_sharded_tensor(sh_tens: List[ShardedTensor]) -> TorchShardedTensor:
"""Build a PyT ShardedTensor from given shards.
During loading:
- if data is None, initialize it with an empty tensor (will be used to copy the data into)
- if `allow_shape_mismatch` is True, the data is initialized with zeros
prior to loading (not all parts of the tensor will be read from the checkpoint)
"""
assert all(isinstance(sh_ten, ShardedTensor) for sh_ten in sh_tens), sh_tens
for sh_ten in sh_tens:
if sh_ten.data is None:
if is_loading:
sh_ten.init_data(
init_device,
init_fn=torch.zeros if sh_ten.allow_shape_mismatch else torch.empty,
)
else:
raise CheckpointingException(f'`data` attr is None for {sh_ten}')
else:
sh_ten.data = sh_ten.data.detach()
if sh_ten.allow_shape_mismatch and is_loading:
sh_ten.data.zero_()
torch_sh_ten = sharded_tensor_to_torch_sharded_tensor(sh_tens, rank)
torch_sh_ten.key = sh_tens[0].key
return torch_sh_ten
def _mcore_to_torch_sharded_object(sh_objs: List[ShardedObject]) -> io.BytesIO:
"""Build io.BytesIO from given sharded objects data."""
assert all(isinstance(sh_obj, ShardedObject) for sh_obj in sh_objs), sh_objs
serialized_data = io.BytesIO()
torch.save([sh_obj.data for sh_obj in sh_objs], serialized_data)
return serialized_data
for k, v in state_dict.items():
if isinstance(v[0], ShardedTensor):
v = cast(List[ShardedTensor], v)
pyt_state_dict[k] = _mcore_to_torch_sharded_tensor(v)
else:
v = cast(List[ShardedObject], v)
pyt_state_dict[k] = _mcore_to_torch_sharded_object(v)
return pyt_state_dict
def _unwrap_pyt_sharded_tensor(sh_ten: TorchShardedTensor) -> List[torch.Tensor]:
""" Unwrap tensor from PyT ShardedTensor instance.
If `prepend_axis_num` was non-zero (which is specific to MCore ShardedTensor)
then the tensor has additional singleton dimensions which should be squeezed.
"""
mcore_sh_ten = sh_ten.mcore_sh_ten
ret_tensors = []
for sh in sh_ten.local_shards():
ten = sh.tensor
if mcore_sh_ten.flattened_range is not None:
assert ten.shape[:-1] == (1,) * (len(ten.shape) - 1), ten.shape
ten = ten.view(-1)
else:
for _ in range(mcore_sh_ten.prepend_axis_num):
ten = ten.squeeze(0)
ret_tensors.append(ten)
return ret_tensors
def _replace_state_dict_keys_with_sharded_keys(
sharded_state_dict: ShardedStateDict, keep_only_main_replica: bool = False
) -> Tuple[Dict[str, List[ShardedBase]], FLATTEN_MAPPING, Dict[str, List[str]]]:
"""Group ShardedBase objects by keys and return mappings required for recreating the original dict. """
flat_sd, flat_mapping = flatten_state_dict(sharded_state_dict)
rename_mapping = defaultdict(list)
new_flat_sd = defaultdict(list)
for k, sh_base in flat_sd.items():
assert isinstance(sh_base, ShardedBase), type(sh_base)
key = sh_base.unique_key if isinstance(sh_base, ShardedObject) else sh_base.key
if is_main_replica(sh_base.replica_id) or not keep_only_main_replica:
rename_mapping[key].append(k)
new_flat_sd[key].append(sh_base)
return new_flat_sd, flat_mapping, rename_mapping
def _replace_sharded_keys_with_state_dict_keys(
state_dict: Dict[str, List[Union[torch.Tensor, io.BytesIO]]],
flat_mapping: FLATTEN_MAPPING,
rename_mapping: Dict[str, List[str]],
):
""" Inverse of _replace_state_dict_keys_with_sharded_keys. """
recovered_sd = {}
for k, tensors in state_dict.items():
assert len(tensors) == len(rename_mapping[k])
for ten, recovered_k in zip(tensors, rename_mapping[k]):
recovered_sd[recovered_k] = ten
return unflatten_state_dict(recovered_sd, flat_mapping)
def _restore_dict_types(x: Union[dict, list, Any], keys_template: Union[dict, list, Any]):
""" Recursively update `x` keys, based on `keys_template`. """
if isinstance(keys_template, dict):
assert isinstance(x, dict), type(x)
for k, v in keys_template.items():
if not isinstance(k, str):
assert str(k) in x, (k, x.keys)
x[k] = x.pop(str(k))
_restore_dict_types(x[k], v)
elif isinstance(keys_template, list):
assert isinstance(x, list), type(x)
for x_val, templ_val in zip(x, keys_template):
_restore_dict_types(x_val, templ_val)
@dataclass(frozen=True)
class MCoreSavePlan(SavePlan):
mcore_data: Dict[str, Dict[str, Any]] = None # Mcore related data about each tensor
class MCoreSavePlanner(DefaultSavePlanner):
"""Differs with the default planner by saving BytesIO objects on all ranks.
In the integration of MCore with PyT Distributed format, BytesIO objects
come from ShardedObjects, which should be treated as separate objects on each rank
(not common on all ranks).
Also, the objects are already packed in io.BytesIO, so no need to redo it
in transform_object.
"""
def __init__(
self,
*args,
nd_flattened_global_shapes: Optional[Dict[str, Tuple[int, ...]]] = None,
**kwargs,
) -> None:
super().__init__(*args, **kwargs)
self.nd_flattened_global_shapes = nd_flattened_global_shapes or {}
def create_local_plan(self) -> SavePlan:
plan = create_default_local_save_plan(self.state_dict, self.is_coordinator)
self._add_non_coordinator_iobytes_request(plan)
if self.flatten_state_dict:
plan = dataclasses.replace(plan, planner_data=self.mappings)
plan = MCoreSavePlan(
items=plan.items,
storage_data=plan.storage_data,
planner_data=plan.planner_data,
mcore_data={
k: sh_ten.mcore_metadata
for k, sh_ten in self.state_dict.items()
if isinstance(sh_ten, TorchShardedTensor)
},
)
self.plan = plan
return self.plan
def create_global_plan(self, all_plans: List[MCoreSavePlan]) -> Tuple[List[SavePlan], Metadata]:
global_plan, metadata = super().create_global_plan(all_plans)
metadata.mcore_data = dict(ChainMap(*(plan.mcore_data for plan in all_plans)))
return global_plan, metadata
def _add_non_coordinator_iobytes_request(self, plan):
if self.is_coordinator:
return
for fqn, obj in self.state_dict.items():
if isinstance(obj, io.BytesIO):
plan.items.extend(_create_write_items(fqn, obj))
def transform_object(self, write_item: WriteItem, object: Any):
return object
class MCoreLoadPlanner(DefaultLoadPlanner):
"""Adds global shape validation to the default planner.
If global shape validation can be ignored (shouldn't!), the default
load planner can be used.
"""
def __init__(
self, *args, shapes_validation_sharded_tensors: Iterable[ShardedTensor] = (), **kwargs
) -> None:
super().__init__(*args, **kwargs)
self.shapes_validation_sharded_tensors = shapes_validation_sharded_tensors
def _validate_global_shapes(self, metadata, sharded_tensors):
for sh_ten in sharded_tensors:
loaded_shape = metadata.state_dict_metadata[sh_ten.key].size
if sh_ten.flattened_range is None or len(sh_ten.global_shape) == 1:
expected_shape = sh_ten.global_shape
else:
expected_shape = sh_ten.axis_fragmentations + (int(np.prod(sh_ten.local_shape)),)
if loaded_shape != expected_shape:
_msg = (
f'Global shape mismatch for loaded ({loaded_shape})'
f' and expected ({expected_shape}) tensor'
f' for key {sh_ten.key}'
)
raise CheckpointingException(_msg)
def create_local_plan(self) -> LoadPlan:
self._validate_global_shapes(self.metadata, self.shapes_validation_sharded_tensors)
return super().create_local_plan()
class TorchDistSaveShardedStrategy(AsyncSaveShardedStrategy):
"""Async save strategy for the PyT Distributed format.
The idea is to translate MCore ShardedTensors into PyT ShardedTensors
and use the async-adjusted torch.distributed.checkpoint saving mechanism
provided by the FileSystemWriterAsync writer.
"""
def __init__(
self, backend: str, version: int, keep_only_main_replica: bool = True, thread_count: int = 2
):
"""Adds parameters specific to PyT Distributed format
Args:
backend (str): format backend string
version (int): format version
keep_only_main_replica (bool, optional): PyT Distributed has a mechanism
for deduplication, but replica_id aware deduplication is more coherent.
Default is True (recommended to keep it).
thread_count (int, optional): threads to use during saving.
Affects the number of files in the checkpoint (saving ranks * num_threads).
"""
super().__init__(backend, version)
self.keep_only_main_replica = keep_only_main_replica
self.thread_count = thread_count
def async_save(
self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path
) -> AsyncRequest:
""" Translates MCore ShardedTensors to PyT ShardedTensors and saves in PyT Distributed format.
Args:
sharded_state_dict (ShardedStateDict): sharded state dict to save
checkpoint_dir (Path): checkpoint directory
Returns: None
"""
# Translate the state dict
(
sharded_state_dict,
flat_mapping,
rename_mapping,
) = _replace_state_dict_keys_with_sharded_keys(
sharded_state_dict, self.keep_only_main_replica
)
pyt_state_dict = mcore_to_pyt_state_dict(sharded_state_dict, False)
# Use PyT saving mechanism
writer = FileSystemWriterAsync(checkpoint_dir, thread_count=self.thread_count)
save_state_dict_ret = save_state_dict_async_plan(
pyt_state_dict,
writer,
None,
planner=MCoreSavePlanner(dedup_replicated_tensors=not self.keep_only_main_replica),
)
return self._get_save_and_finalize_callbacks(writer, save_state_dict_ret)
def _get_save_and_finalize_callbacks(self, writer, save_state_dict_ret) -> AsyncRequest:
save_fn_args = writer.get_save_function_and_args()
save_fn, save_args = save_fn_args
def finalize_fn():
save_state_dict_async_finalize(*save_state_dict_ret)
torch.distributed.barrier()
return AsyncRequest(save_fn, save_args, [finalize_fn])
def can_handle_sharded_objects(self):
return True
class TorchDistLoadShardedStrategy(LoadShardedStrategy):
"""Basic load strategy for the PyT Distributed format. """
def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path) -> StateDict:
"""Translates MCore ShardedTensors to PyT ShardedTensors and loads from PyT Distributed format.
Args:
sharded_state_dict (ShardedStateDict): sharded state dict with mapping
information to instruct loading
checkpoint_dir (Path): checkpoint directory
Returns: loaded state dict
"""
flexible_shape_sharded_tensors = [
sh_ten
for sh_ten in nested_values(sharded_state_dict)
if isinstance(sh_ten, ShardedTensor) and not sh_ten.allow_shape_mismatch
]
orig_sharded_state_dict = sharded_state_dict
# MCore state dict to PyT Distributed compatible
(
sharded_state_dict,
flat_mapping,
rename_mapping,
) = _replace_state_dict_keys_with_sharded_keys(sharded_state_dict)
pyt_state_dict = mcore_to_pyt_state_dict(sharded_state_dict, True)
# Load PyT Distributed format
checkpoint.load_state_dict(
pyt_state_dict,
FileSystemReader(checkpoint_dir),
planner=MCoreLoadPlanner(
shapes_validation_sharded_tensors=flexible_shape_sharded_tensors
),
)
pyt_state_dict = cast(
Dict[str, Union[TorchShardedTensor, List[io.BytesIO]]], pyt_state_dict
)
# Unwrap ShardedTensors and return to original state dict
mcore_state_dict = {
k: v if not isinstance(v, TorchShardedTensor) else _unwrap_pyt_sharded_tensor(v)
for k, v in pyt_state_dict.items()
}
mcore_state_dict = _replace_sharded_keys_with_state_dict_keys(
mcore_state_dict, flat_mapping, rename_mapping
)
_restore_dict_types(mcore_state_dict, orig_sharded_state_dict)
return mcore_state_dict
def load_tensors_metadata(self, checkpoint_dir: Path):
"""Uses tensors metadata stored in the metadata file."""
fs_reader = FileSystemReader(checkpoint_dir)
metadata = fs_reader.read_metadata()
mcore_data = getattr(metadata, 'mcore_data', {})
sharded_metadata = {}
for k, tp in metadata.state_dict_metadata.items():
if not isinstance(tp, TensorStorageMetadata):
continue # load only tensors
nd_orig_global_shape = mcore_data.get(k, {}).get('nd_reformulated_orig_global_shape')
if nd_orig_global_shape is None:
# Regular tensor
sharded_metadata[k] = ShardedTensor.from_rank_offsets(
k, torch.empty(tp.size, **tp.properties.__dict__, device='meta'),
).without_data()
else:
# N-D flattened tensor
unflat_ten = torch.empty(
nd_orig_global_shape, **tp.properties.__dict__, device='meta'
)
flat_ten = unflat_ten.flatten()
sharded_metadata[k] = ShardedTensor.from_rank_offsets_flat(
k,
flat_ten,
unflat_ten.shape,
flattened_range=slice(0, unflat_ten.numel()), # whole slice
).without_data()
return sharded_metadata
def can_handle_sharded_objects(self):
return True
def check_backend_compatibility(self, loaded_version):
pass # TODO
def check_version_compatibility(self, loaded_version):
pass # TODO
default_strategies[StrategyAction.LOAD_SHARDED.value][
('torch_dist', 1)
] = TorchDistLoadShardedStrategy()
default_strategies[StrategyAction.SAVE_SHARDED.value][
('torch_dist', 1)
] = TorchDistSaveShardedStrategy('torch_dist', 1)
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" 2-stage checkpoint loading. """
import os
import time
from collections import defaultdict
from dataclasses import dataclass
from functools import partial, wraps
from itertools import chain
from logging import DEBUG, INFO, StreamHandler, getLogger
from operator import attrgetter, itemgetter
from pathlib import Path
from typing import Iterable, List, NamedTuple, Optional, Tuple, Union
import torch
from ..dict_utils import dict_list_map_inplace, map_reduce, nested_values
from ..mapping import ShardedStateDict, ShardedTensor, StateDict
from .base import LoadShardedStrategy
from .tensorstore import TensorStoreLoadShardedStrategy, _load_from_array, open_ts_array
from .zarr import flatten_range, load_zarr_based_sharded_metadata
_import_trigger = None
timers = defaultdict(list)
logger = getLogger(__name__)
def timed(verbose=True):
def timed_dec(fn):
name = fn.__name__
@wraps(fn)
def wrapped(*args, **kwargs):
if verbose:
logger.debug(f'{name} init')
start = time.time()
ret = fn(*args, **kwargs)
took = time.time() - start
if verbose:
logger.debug(f'{name} took {took}s')
timers[name].append(took)
return ret
return wrapped
return timed_dec
@dataclass
class _ShardedTensorMetadata:
global_rank: int
sharded_tensor_no_data: ShardedTensor
dist_group_rank: Tuple[int] # id of distributed group
dist_group_ranks: Tuple[int] # id of distributed group
data_size: Optional[int] = None # bytes
def sharded_tensor_chunk_id(sharded_tensor: ShardedTensor):
return (
sharded_tensor.key,
sharded_tensor.global_offset,
)
class TwoStageDataParallelLoadShardedStrategy(LoadShardedStrategy):
"""Loads one checkpoint replica from storage and broadcasts to other nodes.
This strategy loads checkpoint from storage on minimal set of nodes
and distributes the checkpoint to other nodes with torch.distributed.
Loading is performed with tensorstore.
Steps:
0. (optional) create Gloo distributed groups
1. Exchange ShardedTensors metadata between all nodes
2. Align needed tensors within DP groups
3. For each globally unique tensor:
3.a) on one of the ranks load it from storage to CPU and move to CUDA
3.b) allocate CUDA tensor on other ranks
3.c) broadcast within DP group
3.d) copy tensor content to the model param location
3.e) free tensor buffers from a) and b)
Notes:
1. Loading and broadcasting is done sequentially to avoid both host and device OOMs
2. There is a lot of overlap potential between all three steps done for each tensor:
2.a) loading from storage to numpy
2.b) moving CPU tensors to CUDA
2.c) broadcast
"""
def __init__(self, data_parallel_group, cpu_transfer=True):
super().__init__()
self.cpu_transfer = cpu_transfer
self.data_parallel_group_orig = data_parallel_group
self.data_parallel_group = None if cpu_transfer else data_parallel_group
self.dp_group_ranks = tuple(
sorted(torch.distributed.get_process_group_ranks(data_parallel_group))
)
self.dp_group_rank = torch.distributed.get_rank(self.data_parallel_group_orig)
self.global_rank = torch.distributed.get_rank()
def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path):
self.maybe_init_gloo_group()
all_tensors_sorted = self._build_load_plan(sharded_state_dict)
self._exchange_loaded_tensors(all_tensors_sorted, sharded_state_dict, checkpoint_dir)
# TODO: fix hang in summarize_load_times
# self.summarize_load_times()
return sharded_state_dict
def summarize_load_times(self):
torch.distributed.barrier()
logger.info('Checkpoint loading finished. Summary:')
# TODO: `timers` keys are not guaranteed to be the same across ranks which causes hangs
for key, times in sorted(timers.items()):
times_sum = sum(times)
max_times = torch.tensor([times_sum], device='cuda')
avg_times = torch.tensor([times_sum], device='cuda')
torch.distributed.all_reduce(max_times, op=torch.distributed.ReduceOp.MAX)
torch.distributed.all_reduce(avg_times, op=torch.distributed.ReduceOp.SUM)
avg_times /= torch.distributed.get_world_size()
if torch.distributed.get_rank() == 0:
logger.info(f'{key}: max {max_times[0]}, avg {avg_times[0]}')
@timed(verbose=False)
def load_tensor_from_storage(self, checkpoint_dir, ten_meta: _ShardedTensorMetadata):
logger.debug(f'_load_from_array({ten_meta.sharded_tensor_no_data.key}) init')
ret = _load_from_array(
ten_meta.sharded_tensor_no_data,
checkpoint_dir,
load_directly_on_device=False,
apply_flattened_range=False,
)
logger.debug(f'_load_from_array({ten_meta.sharded_tensor_no_data.key}) DONE')
return ret
@timed()
def maybe_init_gloo_group(self):
if not self.cpu_transfer:
return
all_groups = [None] * torch.distributed.get_world_size()
torch.distributed.all_gather_object(all_groups, self.dp_group_ranks)
all_groups = set(tuple(sorted(gr)) for gr in all_groups)
for group_ranks in sorted(all_groups):
gloo_pg = torch.distributed.new_group(ranks=group_ranks, backend='gloo')
if self.global_rank in group_ranks:
self.data_parallel_group = gloo_pg
assert self.dp_group_rank == torch.distributed.get_rank(self.data_parallel_group)
def check_backend_compatibility(self, loaded_version):
pass # TODO
def check_version_compatibility(self, loaded_version):
pass # TODO
@timed()
def _build_load_plan(
self, sharded_state_dict: ShardedStateDict
) -> List[_ShardedTensorMetadata]:
local_meta = [
_ShardedTensorMetadata(
self.global_rank,
sharded_ten.without_data(),
self.dp_group_rank,
self.dp_group_ranks,
)
for sharded_ten in nested_values(sharded_state_dict)
]
all_meta = [None] * torch.distributed.get_world_size(group=self.data_parallel_group)
torch.distributed.all_gather_object(all_meta, local_meta, group=self.data_parallel_group)
all_meta = list(chain.from_iterable(all_meta))
all_tensors_sorted = self.deduplicate_chunks(all_meta)
return all_tensors_sorted
@timed()
def deduplicate_chunks(self, ten_metas: List[_ShardedTensorMetadata]):
""" Group tensors by chunk and then pick the tensor with the lowest rank.
NOTE: with proper loading overlap, loading from randomized ranks
(instead of the smallest one) could be beneficial here.
"""
ten_metas = map_reduce(
ten_metas,
key_fn=lambda meta: sharded_tensor_chunk_id(meta.sharded_tensor_no_data),
reduce_fn=partial(min, key=attrgetter('dist_group_rank')),
)
all_metas_sorted = list(map(itemgetter(1), sorted(ten_metas.items())))
return all_metas_sorted
@timed()
def _exchange_loaded_tensors(
self, ten_metas: List[_ShardedTensorMetadata], sharded_state_dict, checkpoint_dir
):
logger.debug(f'_exchange_loaded_tensors, num ten_metas: {len(ten_metas)}')
for ten_meta in ten_metas:
src_rank = torch.distributed.get_global_rank(
self.data_parallel_group, ten_meta.dist_group_rank
)
if self.dp_group_rank == ten_meta.dist_group_rank:
exchange_tensor = self.load_tensor_from_storage(checkpoint_dir, ten_meta)
if not self.cpu_transfer:
exchange_tensor = exchange_tensor.cuda()
else:
# TODO: for non-flattened ranges we could reuse the buffer from the start here
exchange_tensor = torch.empty(
ten_meta.sharded_tensor_no_data.local_shape,
device='cpu' if self.cpu_transfer else 'cuda',
dtype=ten_meta.sharded_tensor_no_data.dtype,
)
logger.debug(
f'exchange {ten_meta.sharded_tensor_no_data.key}, {exchange_tensor.shape}({exchange_tensor.numel()}), broadcast({src_rank} -> {self.dp_group_ranks})'
)
torch.distributed.broadcast(
exchange_tensor, group=self.data_parallel_group, src=src_rank
)
self._distribute_data_to_state_dict(ten_meta, exchange_tensor, sharded_state_dict)
logger.debug(f'exchange {ten_meta.sharded_tensor_no_data.key} done')
# free buffer memory
exchange_tensor = None
@timed(verbose=False)
def _distribute_data_to_state_dict(
self,
ten_meta: _ShardedTensorMetadata,
loaded_ten: torch.Tensor,
sharded_state_dict: ShardedStateDict,
):
tensor_key = sharded_tensor_chunk_id(ten_meta.sharded_tensor_no_data)
def _fill_in_data(t: Union[ShardedTensor, torch.Tensor]):
if not isinstance(t, ShardedTensor) or sharded_tensor_chunk_id(t) != tensor_key:
# already filled-in or key not matching
return t
sharded_tensor: ShardedTensor = t
x = loaded_ten
if sharded_tensor.flattened_range is not None:
x = flatten_range(sharded_tensor, x)
# Reuse existing buffer
sharded_tensor.data.data.copy_(x)
return sharded_tensor.data
dict_list_map_inplace(_fill_in_data, sharded_state_dict)
def load_tensors_metadata(self, checkpoint_dir: Path):
def get_ts_shape_dtype(path):
arr = open_ts_array(path)
return arr.shape, arr.dtype.numpy_dtype
return load_zarr_based_sharded_metadata(checkpoint_dir, get_ts_shape_dtype)
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Strategies using Zarr as an underlying format. """
import logging
import os
import threading
from functools import partial
from logging import getLogger
from pathlib import Path
from typing import Callable, List, Optional, Tuple
import numpy as np
import torch
import zarr
from ..core import CheckpointingException
from ..dict_utils import dict_list_map_inplace, nested_values
from ..mapping import ShardedStateDict, ShardedTensor, is_main_replica
from .base import LoadShardedStrategy, SaveShardedStrategy, StrategyAction, default_strategies
logger = logging.getLogger(__name__)
numpy_to_torch_dtype_dict = {
np.dtype('bool'): torch.bool,
np.dtype('uint8'): torch.uint8,
np.dtype('int8'): torch.int8,
np.dtype('int16'): torch.int16,
np.dtype('int32'): torch.int32,
np.dtype('int64'): torch.int64,
np.dtype('float16'): torch.float16,
np.dtype('float32'): torch.float32,
np.dtype('float64'): torch.float64,
np.dtype('complex64'): torch.complex64,
np.dtype('complex128'): torch.complex128,
}
torch_to_numpy_dtype_dict = {v: k for k, v in numpy_to_torch_dtype_dict.items()}
try:
import tensorstore
HAS_BFLOAT16 = True
numpy_to_torch_dtype_dict[np.dtype('bfloat16')] = torch.bfloat16
torch_to_numpy_dtype_dict[torch.bfloat16] = np.dtype('bfloat16')
except ImportError:
HAS_BFLOAT16 = False
_import_trigger = None
logger = getLogger(__name__)
class ZarrSaveShardedStrategy(SaveShardedStrategy):
def save(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path):
sharded_tensors = list(nested_values(sharded_state_dict))
arrays = _create_or_open_zarr_arrays(sharded_tensors, checkpoint_dir)
for ten, arr in zip(sharded_tensors, arrays):
_save_to_existing_array(ten, arr)
torch.distributed.barrier()
def _create_or_open_zarr_arrays(
sharded_tensors: List[ShardedTensor], checkpoint_dir: Path
) -> List[Optional[zarr.Array]]:
""" Returns list of zarr arrays corresponding to given tensors.
For a sharded tensors that:
a) is main replica and represents the first chunk (all offsets 0), creates the Zarr array
b) is main replica but not the first chunk, opens the arrays created in (a) (possibly by other process)
c) otherwise, sets the corresponding array to None since it won't be used
Args:
sharded_tensors (List[ShardedTensor]): sharded tensors from a given rank that will be saved to checkpoint
checkpoint_dir (Path): checkpoint in which the arrays will be created
"""
arrays = []
for ten in sharded_tensors:
arr = _create_zarr_array(ten, checkpoint_dir) if _should_create_array(ten) else None
arrays.append(arr)
torch.distributed.barrier()
# Open arrays created above by other processes
for arr_idx, ten in enumerate(sharded_tensors):
if arrays[arr_idx] is not None:
# array created by this process
assert _should_create_array(ten), ten
continue
if not is_main_replica(ten.replica_id):
# this array won't be needed for saving and can stay None
continue
open_kwargs = {}
if ten.flattened_range is not None:
open_kwargs['synchronizer'] = zarr.ProcessSynchronizer(
str(checkpoint_dir / f'{ten.key}.sync')
)
arrays[arr_idx] = _open_zarr_array_verbose(checkpoint_dir / ten.key, 'r+', **open_kwargs)
return arrays
def _should_create_array(ten: ShardedTensor):
return (
is_main_replica(ten.replica_id)
and set(ten.global_offset) == {0}
and (ten.flattened_range is None or ten.flattened_range.start == 0)
)
def _save_to_existing_array(sharded_tensor: ShardedTensor, arr: Optional[zarr.Array]):
if not is_main_replica(sharded_tensor.replica_id):
return
assert arr is not None
x = sharded_tensor.data
x = x.detach().cpu()
torch.cuda.synchronize()
if x.dtype == torch.bfloat16:
x = x.float()
x = x.numpy()
x = x.astype('bfloat16')
else:
x = x.numpy()
if sharded_tensor.flattened_range is None:
arr[sharded_tensor.global_slice()] = x
else:
arr.set_coordinate_selection(sharded_tensor.global_coordinates(), x)
def _create_zarr_array(sharded_tensor: ShardedTensor, checkpoint_dir: Path):
np_dtype = torch_to_numpy_dtype_dict[sharded_tensor.dtype]
try:
arr = zarr.create(
sharded_tensor.global_shape,
dtype=np_dtype,
store=checkpoint_dir / sharded_tensor.key,
chunks=sharded_tensor.max_allowed_chunks(),
compressor=None,
fill_value=None,
write_empty_chunks=True,
)
logger.debug(f'Created a new Zarr array at {checkpoint_dir / sharded_tensor.key}')
except zarr.errors.ContainsArrayError as e:
raise CheckpointingException(
f'Array {checkpoint_dir / sharded_tensor.key} already exists'
) from e
if HAS_BFLOAT16 and np_dtype == np.dtype('bfloat16'):
arr._dtype = np_dtype
zarray = arr.store['.zarray']
arr.store['.zarray'] = zarray.replace(b'<V2', b'bfloat16')
return arr
class ZarrLoadShardedStrategy(LoadShardedStrategy):
def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path):
dict_list_map_inplace(
partial(_load_from_array, checkpoint_dir=checkpoint_dir), sharded_state_dict
)
return sharded_state_dict
def load_tensors_metadata(self, checkpoint_dir: Path):
def get_zarr_shape_dtype(path):
arr = zarr.open(path, 'r')
return arr.shape, arr.dtype
return load_zarr_based_sharded_metadata(checkpoint_dir, get_zarr_shape_dtype)
def check_backend_compatibility(self, loaded_version):
pass # TODO
def check_version_compatibility(self, loaded_version):
pass # TODO
def _load_from_array(sharded_tensor: ShardedTensor, checkpoint_dir: Path):
assert isinstance(sharded_tensor, ShardedTensor), type(sharded_tensor)
arr = _open_zarr_array_verbose(checkpoint_dir / sharded_tensor.key, 'r')
if not sharded_tensor.allow_shape_mismatch and sharded_tensor.global_shape != arr.shape:
_msg = (
f'Global shape mismatch for loaded ({arr.shape})'
f' and expected ({sharded_tensor.global_shape}) tensor'
f' for key {sharded_tensor.key}'
)
raise CheckpointingException(_msg)
x = arr[sharded_tensor.global_slice()] # flattened tensors loading is delayed
return postprocess_numpy_array(x, sharded_tensor)
def _open_zarr_array_verbose(path: Path, mode: str, **open_kwargs):
try:
return zarr.open(str(path), mode, **open_kwargs)
except zarr.errors.PathNotFoundError as e:
ckpt_dir = path.parent
err_msg = f'Array {path} not found'
if ckpt_dir.exists():
ckpt_files = [f.name for f in ckpt_dir.iterdir()]
logger.debug(f'{err_msg}. Checkpoint directory {ckpt_dir} content: {ckpt_files}')
else:
err_msg += f'. Checkpoint directory {ckpt_dir} does not exist.'
raise CheckpointingException(err_msg) from e
def postprocess_numpy_array(loaded_array, sharded_tensor, apply_flattened_range=True):
x = loaded_array
if HAS_BFLOAT16 and x.dtype == np.dtype('bfloat16'):
x = x.astype(np.dtype('float32'))
x = torch.from_numpy(x)
x = x.bfloat16()
else:
x = torch.from_numpy(x)
# TODO: consider some other consistency checks
if x.shape != sharded_tensor.local_shape:
if sharded_tensor.allow_shape_mismatch:
x = pad_to_expected_shape(x, sharded_tensor)
else:
_msg = (
f'Local shape mismatch for loaded ({x.shape})'
f' and expected ({sharded_tensor.local_shape}) tensor'
f' for key {sharded_tensor.key}'
)
raise CheckpointingException(_msg)
if apply_flattened_range and sharded_tensor.flattened_range is not None:
x = flatten_range(sharded_tensor, x)
# TODO: consider cuda() tensors support
return x
def flatten_range(sharded_tensor, x):
return x.flatten()[sharded_tensor.flattened_range]
def pad_to_expected_shape(x: torch.Tensor, expected_sharded_ten: ShardedTensor):
pad_args = []
assert len(x.shape) == len(expected_sharded_ten.local_shape)
# Reversed iteration order because F.pad expects so
for x_sh, exp_sh, axis_fragm in reversed(
list(
zip(x.shape, expected_sharded_ten.local_shape, expected_sharded_ten.axis_fragmentations)
)
):
if x_sh == exp_sh:
pad_args.extend((0, 0))
elif x_sh > exp_sh:
assert (
False
), f'Expected shape ({exp_sh}) smaller than actual ({x_sh}) for {repr(expected_sharded_ten)}'
else:
pad_args.extend((0, exp_sh - x_sh))
# TODO: behavior control with envvar is for testing purposes only, remove it
if not int(os.environ.get('DIST_CKPT_PAD_REPLICATE', 0)):
return torch.nn.functional.pad(x, pad_args)
# unsqueeze and squeeze to get shapes supported by cudnn
print(f'Replicating last row for {expected_sharded_ten.key}')
if x.dtype == torch.bfloat16:
return (
torch.nn.functional.pad(x.float().unsqueeze(0), pad_args, mode='replicate')
.squeeze(0)
.bfloat16()
)
return torch.nn.functional.pad(x.unsqueeze(0), pad_args, mode='replicate').squeeze(0)
def load_zarr_based_sharded_metadata(
checkpoint_dir: Path, get_shape_dtype_fn: Callable[[str], Tuple[Tuple[int], np.dtype]]
) -> ShardedStateDict:
"""Load metadata of Zarr arrays.
Args:
checkpoint_dir (str): checkpoint root directory
get_shape_dtype_fn (str -> ((int, ...), np.dtype)): a function returning
an array shape and dtype for a given Zarr array path
"""
sharded_state_dict = {}
for subdir in checkpoint_dir.iterdir():
if not subdir.is_dir() or not (subdir / '.zarray').exists():
continue
key = subdir.name
arr_shape, arr_dtype = get_shape_dtype_fn(str(subdir))
sharded_state_dict[key] = ShardedTensor(
key,
None,
numpy_to_torch_dtype_dict[arr_dtype],
arr_shape,
arr_shape,
tuple(0 for _ in arr_shape),
tuple(1 for _ in arr_shape),
)
return sharded_state_dict
# default_strategies[StrategyAction.LOAD_SHARDED.value][('zarr', 1)] = ZarrLoadShardedStrategy()
default_strategies[StrategyAction.SAVE_SHARDED.value][('zarr', 1)] = ZarrSaveShardedStrategy(
'zarr', 1
)
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Helpers for manipulating sharded tensors and sharded state dicts. """
from typing import Dict, Tuple
from .dict_utils import dict_list_map_inplace, extract_matching_values
from .mapping import (
LocalNonpersitentObject,
ShardedBase,
ShardedObject,
ShardedStateDict,
ShardedTensor,
ShardedTensorFactory,
StateDict,
)
def extract_sharded_tensors(
sharded_state_dict: ShardedStateDict,
) -> Tuple[ShardedStateDict, StateDict]:
""" Extract a dict consisting of only ShardedTensor objects from a given state dict with any objects.
Args:
sharded_state_dict: state dict possibly containing ShardedTensor objects
Returns:
Tuple[ShardedStateDict, StateDict]: tuple of:
- state dict with all ShardedTensor (keeping the original state dict structure)
- state dict with all objects other than ShardedTensor (keeping the original state dict structure)
"""
return extract_matching_values(sharded_state_dict, lambda v: isinstance(v, ShardedTensor))
def extract_sharded_tensors_and_factories(
sharded_state_dict: ShardedStateDict,
) -> Tuple[ShardedStateDict, StateDict]:
""" Extract a dict consisting of only ShardedTensor and ShardedTensorFactory objects from a given state dict with any objects.
Args:
sharded_state_dict: state dict possibly containing ShardedTensor and ShardedTensorFactory objects
Returns:
Tuple[ShardedStateDict, StateDict]: tuple of:
- state dict with all ShardedTensor and ShardedTensorFactory (keeping the original state dict structure)
- state dict with all other objects (keeping the original state dict structure)
"""
return extract_matching_values(
sharded_state_dict, lambda v: isinstance(v, (ShardedTensor, ShardedTensorFactory))
)
def extract_sharded_tensors_or_nonpersistent(
sharded_state_dict: ShardedStateDict,
) -> Tuple[ShardedStateDict, StateDict]:
""" Extract a dict consisting of only ShardedTensor, ShardedTensorFactory and LocalNonpersitentObject
objects from a given state dict with any objects.
Args:
sharded_state_dict: state dict possibly containing ShardedTensor, ShardedTensorFactory and LocalNonpersitentObject objects
Returns:
Tuple[ShardedStateDict, StateDict]: tuple of:
- state dict with all ShardedTensor, ShardedTensorFactory and LocalNonpersitentObject (keeping the original state dict structure)
- state dict with all other objects (keeping the original state dict structure)
"""
return extract_matching_values(
sharded_state_dict,
lambda v: isinstance(v, (ShardedTensor, LocalNonpersitentObject, ShardedTensorFactory)),
)
def extract_sharded_base(
sharded_state_dict: ShardedStateDict,
) -> Tuple[ShardedStateDict, StateDict]:
return extract_matching_values(sharded_state_dict, lambda v: isinstance(v, ShardedBase),)
def extract_nonpersistent(
sharded_state_dict: ShardedStateDict,
) -> Tuple[ShardedStateDict, StateDict]:
return extract_matching_values(
sharded_state_dict, lambda v: isinstance(v, LocalNonpersitentObject),
)
def add_prefix_for_sharding(sharded_state_dict: ShardedStateDict, prefix: str):
""" Prepend a given prefix to all ShardedBase objects in a given state dict *in-place*.
Args:
sharded_state_dict (ShardedStateDict): sharded state dict
prefix (str): prefix to be prepended
Returns:
None: state dict is modified in-place
"""
def add_prefix(t):
if isinstance(t, ShardedBase):
t.key = f'{prefix}{t.key}'
return t
dict_list_map_inplace(add_prefix, sharded_state_dict)
def replace_prefix_for_sharding(
sharded_state_dict: ShardedStateDict, old_prefix: str, new_prefix: str
):
""" Replaces the given prefix in *all* sharded keys in a given state dict.
Errors out if some key does not begin with a given prefix.
Args:
sharded_state_dict (ShardedStateDict): sharded state dict to replace keys in
old_prefix (str): prefix to be replaced in each key
new_prefix (str): new prefix
Returns:
None: state dict is modified in place
"""
def _replace_prefix(x):
if isinstance(x, (ShardedTensor, ShardedTensorFactory, ShardedObject)):
if not x.key.startswith(old_prefix):
raise ValueError(f'Expected {x.key} to begin with prefix {old_prefix}')
x.key = f'{new_prefix}{x.key[len(old_prefix):]}' # str.removeprefix in Python >= 3.9
return x
dict_list_map_inplace(_replace_prefix, sharded_state_dict)
def apply_prefix_mapping(sharded_state_dict: ShardedStateDict, prefix_map: Dict[str, str]):
""" Replaces prefixes *only in keys matching* with one of prefixes in the map.
Args:
sharded_state_dict (ShardedStateDict): sharded state dict to replace keys in
prefix_map (Dict[str, str]): map of old->new prefixes. The first matching prefix for each key is used
Returns:
None: state dict is modified in place
"""
def _replace_prefixes(x):
if not isinstance(x, (ShardedTensor, ShardedTensorFactory, ShardedObject)):
return x
for old_prefix, new_prefix in prefix_map.items():
if x.key.startswith(old_prefix):
x.key = (
f'{new_prefix}{x.key[len(old_prefix):]}' # str.removeprefix in Python >= 3.9
)
break
return x
dict_list_map_inplace(_replace_prefixes, sharded_state_dict)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from .distributed_data_parallel import DistributedDataParallel
from .distributed_data_parallel_config import DistributedDataParallelConfig
from .finalize_model_grads import finalize_model_grads
from .param_and_grad_buffer import ParamAndGradBuffer, shard_buffer
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import logging
from contextlib import contextmanager
from typing import Dict, Optional
import torch
from .. import parallel_state
from ..transformer.module import MegatronModule
from ..transformer.transformer_config import TransformerConfig
from ..utils import log_single_rank
from .distributed_data_parallel_config import DistributedDataParallelConfig
from .param_and_grad_buffer import ParamAndGradBuffer
logger = logging.getLogger(__name__)
class DistributedDataParallel(MegatronModule):
"""
DDP wrapper which stores grads in contiguous buffers. Also has option of overlapping
communication with backprop computation by breaking up full model's gradients into smaller
buckets and running all-reduce / reduce-scatter on each bucket asynchronously. This class
also provides the option to do the gradient accumulation in a type other than the param type
(e.g., fp32 for a bf16 model).
Args:
config: Transformer config object.
ddp_config: DistributedDataParallel config object.
module: Underlying model.
disable_bucketing: If true, force assign all parameters to a single bucket. If false,
use standard bucketing policy: assign parameters to smaller buckets and all-reduce
per bucket _if_ overlap_grad_reduce is True and pp_rank is 0.
"""
def __init__(
self,
config: TransformerConfig,
ddp_config: DistributedDataParallelConfig,
module: torch.nn.Module,
disable_bucketing: bool = False,
):
super().__init__(config=config)
self.module = module
# If bucket_size is not provided as an input, use sane default.
# If using very large dp_sizes, make buckets larger to ensure that chunks used in NCCL
# ring-reduce implementations are large enough to remain bandwidth-bound rather than
# latency-bound.
if ddp_config.bucket_size is None:
ddp_config.bucket_size = max(
40000000, 1000000 * parallel_state.get_data_parallel_world_size()
)
# Set bucket_size to infinity if overlap_grad_reduce is False.
if not ddp_config.overlap_grad_reduce:
ddp_config.bucket_size = None
self.ddp_config = ddp_config
log_single_rank(
logger,
logging.INFO,
f'Setting up DistributedDataParallel with config {self.ddp_config}',
)
# Turn off bucketing if we are on a pipeline stage that is not the first (since
# data-parallel communication on these stages is not on the critical path), or if
# disable_bucketing is True (e.g., we might not want to break up model parameters
# into buckets for model chunks after the first in the interleaved schedule).
self.bucket_size = self.ddp_config.bucket_size
if parallel_state.get_pipeline_model_parallel_rank() > 0:
self.bucket_size = None
if disable_bucketing:
self.bucket_size = None
self.module = module
self.param_to_buffer = {}
# Group parameters by their gradient type.
param_to_name = {}
dense_params = []
expert_parallel_params = []
for name, param in self.module.named_parameters():
if not param.requires_grad:
continue
param.grad_added_to_main_grad = False
param_to_name[param] = name
if getattr(param, 'allreduce', True):
dense_params.append(param)
else:
expert_parallel_params.append(param)
def allocate_buffers_for_parameters(
input_params, data_parallel_group, gradient_scaling_factor,
):
param_and_grad_dtype_to_params = {}
# Group parameters by their gradient type.
for param in input_params:
if not param.requires_grad:
continue
param_dtype = param.dtype
grad_dtype = torch.float if self.ddp_config.grad_reduce_in_fp32 else param.dtype
params = param_and_grad_dtype_to_params.get((param_dtype, grad_dtype), [])
params.append(param)
param_and_grad_dtype_to_params[(param_dtype, grad_dtype)] = params
if not config.calculate_per_token_loss:
target_gradient_scaling_factor = 1.0 / parallel_state.get_data_parallel_world_size()
if self.ddp_config.average_in_collective:
# Collective is averaging gradients in collective with data_parallel_group.
assert (
gradient_scaling_factor
/ torch.distributed.get_world_size(group=data_parallel_group)
== target_gradient_scaling_factor
)
else:
assert gradient_scaling_factor == target_gradient_scaling_factor
# Allocate the grad buffers and map the grads.
buffers = []
for (param_dtype, grad_dtype), params in param_and_grad_dtype_to_params.items():
buffers.append(
ParamAndGradBuffer(
self.ddp_config,
param_dtype,
grad_dtype,
params,
data_parallel_group,
self.bucket_size,
param_to_name,
gradient_scaling_factor,
)
)
for param in params:
self.param_to_buffer[param] = buffers[-1]
return buffers
if config.calculate_per_token_loss:
gradient_scaling_factor = 1.0
expert_gradient_scaling_factor = 1.0
else:
if self.ddp_config.average_in_collective:
gradient_scaling_factor = 1.0
expert_gradient_scaling_factor = (
1.0 / parallel_state.get_expert_model_parallel_world_size()
)
else:
data_parallel_world_size = parallel_state.get_data_parallel_world_size()
gradient_scaling_factor = 1.0 / data_parallel_world_size
expert_gradient_scaling_factor = 1.0 / data_parallel_world_size
# Allocate the param+grad buffers for dense params' grads.
self.buffers = allocate_buffers_for_parameters(
dense_params,
parallel_state.get_data_parallel_group(with_context_parallel=True),
gradient_scaling_factor=gradient_scaling_factor,
)
# Allocate separate param+grad buffers for expert parallel params' grads.
self.expert_parallel_buffers = allocate_buffers_for_parameters(
expert_parallel_params,
parallel_state.get_data_modulo_expert_parallel_group(),
gradient_scaling_factor=expert_gradient_scaling_factor,
)
# Delete references to weight_tensor if they exist since we don't want two parameter copies
# if we re-mapped parameters (which happens when we use the distributed optimizer).
# This is a temporary workaround around a TE bug that is fixed with
# https://github.com/NVIDIA/TransformerEngine/pull/719.
if self.ddp_config.use_distributed_optimizer:
@torch.no_grad()
def unmap_weight_tensor(m):
if hasattr(m, 'weight_tensor'):
m.weight_tensor = None
self.module.apply(unmap_weight_tensor)
# Register backward hook.
# Accumulation function for the gradients need to be stored so they
# don't go out of scope.
self.grad_accs = []
for param in self.module.parameters():
if param.requires_grad:
# Expand so we get access to grad_fn.
param_tmp = param.expand_as(param)
# Get the gradient accumulator function.
grad_acc = param_tmp.grad_fn.next_functions[0][0]
grad_acc.register_hook(self._make_param_hook(param, self.param_to_buffer))
self.grad_accs.append(grad_acc)
def forward(self, *inputs, **kwargs):
"""
Calls the wrapped module's forward() method.
"""
return self.module(*inputs, **kwargs)
def _make_param_hook(
self,
param: torch.nn.Parameter,
param_to_buffer: Dict[torch.nn.Parameter, ParamAndGradBuffer],
):
"""
Creates the all-reduce / reduce-scatter hook for backprop.
"""
def param_hook(*unused):
if param.requires_grad:
if self.ddp_config.overlap_grad_reduce:
assert (
param.grad is not None
), 'param.grad being None is not safe when overlap_grad_reduce is True'
if param.grad is not None and (
not param.grad_added_to_main_grad or getattr(param, 'zero_out_wgrad', False)
):
param.main_grad.add_(param.grad.data)
param.grad = None
if self.ddp_config.overlap_grad_reduce:
param_to_buffer[param].register_grad_ready(param)
return param_hook
@contextmanager
def no_sync(self):
"""
Context manager that turns off gradient synchronization.
"""
for buffer in self.buffers + self.expert_parallel_buffers:
buffer.is_last_microbatch = False
try:
yield
finally:
for buffer in self.buffers + self.expert_parallel_buffers:
buffer.is_last_microbatch = True
def start_grad_sync(self, *unused):
"""
Initiates grad sync (all-reduce or reduce-scatter) communication operations
for all model gradients.
When overlap_grad_reduce is set to True, dispatches asynchronous communication
calls. When overlap_grad_reduce is set to False, calls synchronous
communication ops.
"""
for buffer in self.buffers + self.expert_parallel_buffers:
buffer.start_grad_sync()
def scale_gradients(self, scaling_factor: float) -> None:
"""Scale all gradients inside the buffers by `scaling_factor`."""
for buffer in self.buffers + self.expert_parallel_buffers:
buffer.scale_gradients(scaling_factor)
def finish_grad_sync(self):
"""
Finishes grad sync (all-reduce or reduce-scatter) communication operations
for all model gradients.
When overlap_grad_reduce is set to True, waits for asynchronous communication
calls to complete. When overlap_grad_reduce is set to False, calls synchronous
communication ops.
"""
for buffer in self.buffers + self.expert_parallel_buffers:
buffer.finish_grad_sync()
def zero_grad_buffer(self):
"""
Zeros out all grad buffers. Needs to be called at the beginning of each
training iteration.
"""
for param in self.module.parameters():
if param.requires_grad:
param.grad_added_to_main_grad = False
for buffer in self.buffers + self.expert_parallel_buffers:
buffer.reset()
def broadcast_params(self):
"""
Syncs parameters across all DP ranks.
"""
for param in self.module.parameters():
is_expert_parallel = not getattr(param, 'allreduce', True)
if is_expert_parallel:
data_parallel_group = parallel_state.get_data_modulo_expert_parallel_group()
else:
data_parallel_group = parallel_state.get_data_parallel_group(
with_context_parallel=True
)
torch.distributed.broadcast(
param.data,
src=torch.distributed.get_global_rank(data_parallel_group, 0),
group=data_parallel_group,
)
def state_dict(self, prefix='', keep_vars=False):
"""
Returns a dictionary containing references to the whole state of the
wrapped module.
Both parameters and persistent buffers (e.g. running averages) are included.
Keys are corresponding parameter and buffer names. Parameters and buffers
set to None are not included.
"""
return self.module.state_dict(prefix=prefix, keep_vars=keep_vars)
def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
"""
Returns wrapped module's state_dict for checkpoint saving.
"""
return self.module.state_dict_for_save_checkpoint(prefix=prefix, keep_vars=keep_vars)
def load_state_dict(self, state_dict, strict=True):
"""
Copies parameters and buffers from state_dict into the wrapped module and its
descendants. If strict is True, then the keys of state_dict must exactly match
the keys returned by this module’s state_dict() function.
"""
self.module.load_state_dict(state_dict, strict=strict)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from dataclasses import dataclass
from typing import Optional
@dataclass
class DistributedDataParallelConfig:
"""Configuration for DistributedDataParallel."""
grad_reduce_in_fp32: bool = False
"""If true, reduce grads in fp32."""
overlap_grad_reduce: bool = False
"""If true, overlap grad all-reduce / reduce-scatter with backward compute."""
use_distributed_optimizer: bool = False
"""If true, issue reduce-scatter collectives to aggregate gradients and clean up
originally allocated model parameters, otherwise issue all-reduce collectives.
"""
check_for_nan_in_grad: bool = False
""" If true, check for NaNs in gradients _before_ communication collective."""
bucket_size: Optional[int] = None
"""Maximum number of parameters in each bucket. If unspecified, MCore uses a default
value of max(40000000, 1000000 * dp_size) parameters (larger DP sizes need larger
buckets to ensure collectives do not become latency-bound)."""
average_in_collective: bool = False
"""If true, compute average in collective directly, as opposed to dividing by the
dp_size first and then computing sum in the collective."""
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from typing import List, Optional
import torch
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from .. import parallel_state
from ..transformer.transformer_config import TransformerConfig
from ..utils import get_attr_wrapped_model, get_model_config
def _allreduce_word_embedding_grads(model: List[torch.nn.Module], config: TransformerConfig):
"""
All-reduce word embedding grads.
Reduce grads across first and last stages to ensure that word_embeddings parameters stay in
sync. This should only run for models that support pipelined model parallelism (BERT and GPT).
"""
if (
parallel_state.is_rank_in_embedding_group(ignore_virtual=True)
and parallel_state.get_pipeline_model_parallel_world_size() > 1
):
if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
model_module = model[0]
elif parallel_state.is_pipeline_last_stage(ignore_virtual=True):
model_module = model[-1]
else: # We do not support the interleaved schedule for T5 yet.
model_module = model[0]
# Look for module with 'pre_process' attribute to get around the fact that DDP and
# other wrapper classes inherit from non-core MegatronModule that has
# 'share_embeddings_and_output_weights' and 'shared_embedding_or_output_weight'
# attributes already, causing get_attr_wrapped_model() to not unwrap anything here.
# TODO: Clean this up once the wrapper classes inherit from core MegatronModule.
model_module = get_attr_wrapped_model(model_module, 'pre_process', return_model_obj=True)
if model_module.share_embeddings_and_output_weights:
weight = model_module.shared_embedding_or_output_weight()
grad = weight.main_grad
torch.distributed.all_reduce(grad, group=parallel_state.get_embedding_group())
def _allreduce_position_embedding_grads(model: List[torch.nn.Module], config: TransformerConfig):
"""
All-reduce position_embeddings grad across first (encoder) and split (decoder) stages to
ensure that position embeddings parameters stay in sync. This should only run for T5 models
with pipeline parallelism.
"""
if (
parallel_state.is_rank_in_position_embedding_group()
and parallel_state.get_pipeline_model_parallel_world_size() > 1
and config.pipeline_model_parallel_split_rank is not None
):
model_module = model[0]
grad = get_attr_wrapped_model(
model_module, 'language_model.embedding.position_embeddings.weight.main_grad'
)
torch.distributed.all_reduce(grad, group=parallel_state.get_position_embedding_group())
def _allreduce_embedding_grads(model: List[torch.nn.Module], config: TransformerConfig):
"""
All-reduce both word and position embeddings.
"""
_allreduce_word_embedding_grads(model, config)
_allreduce_position_embedding_grads(model, config)
def _allreduce_layernorm_grads(model: List[torch.nn.Module], config: TransformerConfig):
"""
All-reduce layernorm grads (for sequence parallelism).
"""
# All-reduce layernorm parameters across model parallel nodes
# when sequence parallelism is used
if parallel_state.get_tensor_model_parallel_world_size() > 1 and (
config.sequence_parallel or config.qk_layernorm
):
grads = []
for model_chunk in model:
for name, param in get_attr_wrapped_model(model_chunk, 'named_parameters')():
if (
param.requires_grad
and getattr(param, 'sequence_parallel', False)
or 'q_layernorm' in name
or 'k_layernorm' in name
):
grad = param.main_grad
grads.append(grad.data)
if grads:
coalesced = _flatten_dense_tensors(grads)
torch.distributed.all_reduce(
coalesced, group=parallel_state.get_tensor_model_parallel_group()
)
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
buf.copy_(synced)
def finalize_model_grads(model: List[torch.nn.Module], num_tokens: Optional[torch.Tensor] = None):
"""
All-reduce all model grads across DP replicas, layernorm grads for sequence parallelism,
embedding grads across first and last pipeline stages (if not tied),
scale gradients by `num_tokens`.
"""
config = get_model_config(model[0])
# All-reduce / reduce-scatter across DP replicas.
if config.timers is not None:
config.timers('all-grads-sync', log_level=1).start(barrier=config.barrier_with_L1_time)
for model_chunk in model:
model_chunk.finish_grad_sync()
if config.timers is not None:
config.timers('all-grads-sync').stop()
# All-reduce layer-norm grads (for sequence parallelism).
if config.timers is not None:
config.timers('layernorm-grads-all-reduce', log_level=1).start(
barrier=config.barrier_with_L1_time
)
_allreduce_layernorm_grads(model, config)
if config.timers is not None:
config.timers('layernorm-grads-all-reduce').stop()
# All-reduce embedding grads (for pipeline parallelism).
if config.timers is not None:
config.timers('embedding-grads-all-reduce', log_level=1).start(
barrier=config.barrier_with_L1_time
)
_allreduce_embedding_grads(model, config)
if config.timers is not None:
config.timers('embedding-grads-all-reduce').stop()
# normalize gradients for per-token loss normalization.
# if we are using by the number of tokens, then we use that as a divisor. this number
# will be the total number of non-padded tokens in the global batch.
if num_tokens is not None:
# the number of tokens is only present on the last stage, so broadcast it
# to the other ranks in the pipeline parallel group.
torch.distributed.broadcast(
num_tokens,
src=parallel_state.get_pipeline_model_parallel_last_rank(),
group=parallel_state.get_pipeline_model_parallel_group(),
)
# all-reduce across DP ranks.
torch.distributed.all_reduce(num_tokens, group=parallel_state.get_data_parallel_group())
for model_chunk in model:
if num_tokens > 0:
scaling = 1.0 / num_tokens
model_chunk.scale_gradients(scaling)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import logging
import math
import os
from enum import Enum
from typing import Dict, List, Optional
import torch
from ..utils import log_on_each_pipeline_stage
from .distributed_data_parallel_config import DistributedDataParallelConfig
logger = logging.getLogger(__name__)
class BufferType(Enum):
PARAM = 1
GRAD = 2
def shard_buffer(buffer: torch.Tensor, data_parallel_world_size: int):
"""
Shard buffer into data_parallel_world_size chunks of equal size.
"""
assert buffer.numel() % data_parallel_world_size == 0
shard_size = buffer.numel() // data_parallel_world_size
sharded_buffer = [
buffer[(r * shard_size) : ((r + 1) * shard_size)] for r in range(data_parallel_world_size)
]
return sharded_buffer
class Bucket:
"""
Bucket to keep track of a subset of the model's gradients. Provides functionality to register
when params in the bucket have grads ready to be synced; an asynchronous communication call
is automatically launched when _all_ params in the bucket have grads ready.
Args:
ddp_config: DistributedDataParallel config object.
params: List of parameters whose gradients are collated in this bucket.
param_data: View in larger ParamAndGradBuffer.param_data that this bucket is responsible for.
grad_data: View in larger ParamAndGradBuffer.grad_data that this bucket is responsible for.
offset: Offset of this bucket's view in the larger ParamAndGradBuffer.
numel_unpadded: Number of unpadded elements in bucket.
data_parallel_group: Data-parallel process group.
data_parallel_world_size: World size using the data-parallel group group.
gradient_scaling_factor: This factor is utilized to scale gradients prior to their
communication. Its application is twofold: it facilitates the averaging of gradients
and the scaling of gradients in the context of the Mixture of Experts (MoE) model.
"""
def __init__(
self,
ddp_config: DistributedDataParallelConfig,
params: List[torch.nn.Parameter],
param_data: Optional[torch.Tensor],
grad_data: torch.Tensor,
offset: int,
numel_unpadded: int,
data_parallel_group: torch.distributed.ProcessGroup,
data_parallel_world_size: int,
gradient_scaling_factor: float,
):
self.ddp_config = ddp_config
# State for bookkeeping: params is the set of parameters this bucket is
# responsible for, params_with_grad is the set of parameters with grads
# available. When overlap_grad_reduce is True, communication (all-reduce
# or reduce-scatter) is issued when params_with_grad equals params.
self.params_list = params
self.params = set(params)
self.params_with_grad = set()
self.param_data = param_data
self.grad_data = grad_data
# The distributed optimizer needs to keep track of this bucket's offset
# within the full grad_buffer.
self.offset = offset
self.numel_unpadded = numel_unpadded
self.data_parallel_group = data_parallel_group
self.data_parallel_world_size = data_parallel_world_size
self.data_parallel_rank = torch.distributed.get_rank(group=data_parallel_group)
self.gradient_scaling_factor = gradient_scaling_factor
self.reset()
def reset(self):
"""
Reset metadata in bucket in preparation for the next iteration of training.
"""
self.params_with_grad = set()
self.communication_handle = None
self.is_communication_outstanding = False
def start_grad_sync(self):
"""
Initiates grad sync (all-reduce or reduce-scatter) communication operation
for this bucket.
When overlap_grad_reduce is set to True, dispatches an asynchronous
communication call. When overlap_grad_reduce is set to False, makes
synchronous call.
"""
assert (
self.communication_handle is None and not self.is_communication_outstanding
), 'Should not have multiple communication calls outstanding at once'
# Make sure norm of grads in bucket are not NaN
# prior to data-parallel all-reduce / reduce-scatter.
if self.ddp_config.check_for_nan_in_grad:
global_rank = torch.distributed.get_rank()
norm = self.grad_data.norm(p=2)
assert not norm.isnan(), (
f'Rank {global_rank}: found NaN in local grad norm in '
f'backward pass before data-parallel communication collective. '
f'Device: {torch.cuda.current_device()}, node: {os.uname()[1]}'
)
# gradient_scaling_factor already takes into account whether we are computing
# an average or sum in the data-parallel collective.
if self.gradient_scaling_factor != 1.0:
self.grad_data *= self.gradient_scaling_factor
# Decide reduce_op.
reduce_op = torch.distributed.ReduceOp.SUM
if self.ddp_config.average_in_collective:
reduce_op = torch.distributed.ReduceOp.AVG
# Use async_op only when overlap_grad_reduce is True.
if self.ddp_config.use_distributed_optimizer:
local_data_view = shard_buffer(self.grad_data, self.data_parallel_world_size)[
self.data_parallel_rank
]
self.communication_handle = torch.distributed._reduce_scatter_base(
local_data_view,
self.grad_data,
op=reduce_op,
group=self.data_parallel_group,
async_op=self.ddp_config.overlap_grad_reduce,
)
else:
self.communication_handle = torch.distributed.all_reduce(
self.grad_data,
op=reduce_op,
group=self.data_parallel_group,
async_op=self.ddp_config.overlap_grad_reduce,
)
if self.ddp_config.overlap_grad_reduce:
self.is_communication_outstanding = True
else:
self.is_communication_outstanding = False
def finish_grad_sync(self):
"""
Finishes grad sync (all-reduce or reduce-scatter) communication operation
for this bucket.
When overlap_grad_reduce is set to True, waits for asynchronous communication
call to complete. When overlap_grad_reduce is set to False, makes synchronous call.
"""
# If overlap_grad_reduce is False, start (and finish) synchronous communication call here.
if not self.ddp_config.overlap_grad_reduce:
self.start_grad_sync()
return
assert self.communication_handle is not None and self.is_communication_outstanding, (
f'Communication call has not been issued for this bucket '
f'({len(self.params_with_grad)}/{len(self.params)} params have grad available)'
)
self.communication_handle.wait()
def register_grad_ready(self, param: torch.nn.Parameter):
"""
Registers grads for the passed-in param to be "ready" for grad sync.
When the number of microbatches is greater than 1, we only want to register
grads as ready when processing the last microbatch and overlap_grad_reduce is True.
"""
assert param in self.params, 'Param is not in the bucket'
assert param not in self.params_with_grad, 'Cannot set grad twice'
assert (
self.ddp_config.overlap_grad_reduce
), 'register_grad_ready() should be called only when overlapping grad reduce'
self.params_with_grad.add(param)
# If all params in bucket have grads available, issue communication call.
if len(self.params_with_grad) == len(self.params):
self.start_grad_sync()
class ParamAndGradBuffer:
"""
Groups parameters and gradients into a contiguous buffer, and then breaks the buffer into
buckets with roughly `bucket_size` parameters each.
Args:
ddp_config: DistributedDataParallel config object.
param_dtype: Type of param tensor.
grad_dtype: Type of grad tensor.
params: List of parameters whose parameters and gradients are collated in the underlying
tensor.
data_parallel_group: Data-parallel process group.
bucket_size: The rough size of each bucket in terms of number of parameters.
param_to_name: Mapping from `torch.nn.Parameter` to name (for logging purposes).
gradient_scaling_factor: This factor is utilized to scale gradients prior to their
communication. Its application is twofold: it facilitates the averaging of gradients
and the scaling of gradients in the context of the Mixture of Experts (MoE) model.
"""
def __init__(
self,
ddp_config: DistributedDataParallelConfig,
param_dtype: torch.dtype,
grad_dtype: torch.dtype,
params: List[torch.nn.Parameter],
data_parallel_group: torch.distributed.ProcessGroup,
bucket_size: int,
param_to_name: Dict[torch.nn.Parameter, str],
gradient_scaling_factor: float,
):
self.ddp_config = ddp_config
# Check that params are unique.
unique_params = set()
for param in params:
assert param not in unique_params
unique_params.add(param)
del unique_params
# Store attributes that will be needed later.
self.param_dtype = param_dtype
self.grad_dtype = grad_dtype
self.data_parallel_group = data_parallel_group
self.data_parallel_world_size = torch.distributed.get_world_size(
group=self.data_parallel_group
)
self.gradient_scaling_factor = gradient_scaling_factor
self.is_last_microbatch = True
# Data structures to store underlying buckets and relevant indexing data.
self.buckets = []
self.param_to_bucket = {} # Param -> bucket mapping.
self.param_index_map = {} # Param -> location in buffer mapping (used in dist. optimizer).
def _pad(number_to_be_padded: int, divisor: int) -> int:
return int(math.ceil(number_to_be_padded / divisor) * divisor)
def _pad_if_needed(data_index: int) -> int:
"""
Pads data indices if using distributed optimizer (to ensure uniform sharding).
"""
if self.ddp_config.use_distributed_optimizer:
# Workaround for TE bug causing cuBLAS to pick an incompatible algorithm.
# This also helps cuBLAS pick more efficient algorithms for GEMMs.
# We now ensure that all buckets start at a memory address that is 256-byte
# aligned (128 values since params and grads use >= 16-bit precision).
return _pad(data_index, math.lcm(self.data_parallel_world_size, 128))
return data_index
# First, figure out how many elements should be in the underlying buffer storage.
# Note that if we need to split the buffer into smaller buckets, each of these
# might need to be padded as well (if using the distributed optimizer).
data_start_index = 0
bucket_data_start_index = data_start_index
bucket_params = set()
self.bucket_indices = []
per_bucket_numel_unpadded = []
bucket_id = 0
def _create_new_bucket(data_end_index: int) -> int:
"""
Create the bucket_id'th bucket with collected bucket_params, starting at
bucket_data_start_index.
"""
nonlocal bucket_data_start_index, bucket_params, bucket_id
per_bucket_numel_unpadded.append(data_end_index - bucket_data_start_index)
data_end_index = _pad_if_needed(data_end_index)
# Update bucket metadata.
self.bucket_indices.append((bucket_data_start_index, data_end_index))
bucket_data_start_index = data_end_index
# Re-set bucket_params and increment bucket_id for next bucket.
bucket_params = set()
bucket_id += 1
# Return the potentially padded data_end_index.
return data_end_index
for param in params[::-1]:
# Iterate through parameters in reverse order to roughly follow backprop order,
# and skip parameters that don't require gradients.
if not param.requires_grad:
continue
this_numel = param.data.nelement()
data_end_index = data_start_index + this_numel
def _does_param_require_new_bucket(param):
"""
Split shared embedding parameters into separate bucket if using distributed
optimizer that makes use of reduce-scatters instead of all-reduces.
This ensures that the first and last pipeline stage partition optimizer state
for the shared embedding parameters the same way across DP replicas, allowing
the DP reduce-scatter to be before the embedding all-reduce.
"""
return (
getattr(param, "shared_embedding", False)
and self.ddp_config.use_distributed_optimizer
)
# Create bucket with already collected parameters if current param needs its own bucket.
if _does_param_require_new_bucket(param) and len(bucket_params) > 0:
# We are creating a bucket for the already accumulated parameters, whose params
# end at the current data_start_index.
if self.ddp_config.use_distributed_optimizer:
# data_start_index should already be padded.
assert data_start_index % self.data_parallel_world_size == 0
_create_new_bucket(data_start_index)
self.param_index_map[param] = (
data_start_index,
data_end_index,
bucket_id,
)
bucket_params.add(param)
# If we have enough elements already or the current param is part of the shared embedding
# layer and needs a separate bucket, form a new bucket.
if (
bucket_size is not None
and (data_end_index - bucket_data_start_index) >= bucket_size
) or _does_param_require_new_bucket(param):
data_end_index = _create_new_bucket(data_end_index)
data_start_index = data_end_index
# Add remaining params to a new bucket.
if len(bucket_params) > 0:
data_end_index = _create_new_bucket(data_end_index)
# Next, create underlying storage for buffer (with numel elements that includes
# padding as necessary).
self.numel = data_end_index
self.numel_unpadded = sum(per_bucket_numel_unpadded)
assert self.numel_unpadded <= self.numel
if self.ddp_config.use_distributed_optimizer:
assert self.numel % self.data_parallel_world_size == 0
else:
assert self.numel == self.numel_unpadded
self.param_data = None
# Only re-map param tensors if using distributed optimizer.
if self.ddp_config.use_distributed_optimizer:
self.param_data = torch.zeros(
self.numel,
dtype=self.param_dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
self.grad_data = torch.zeros(
self.numel,
dtype=self.grad_dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
# Finally, map param.data and param.main_grad fields to buffers.
bucket_params = set()
bucket_data_start_index = 0
cur_bucket_id = 0
for param in params[::-1]:
if not param.requires_grad:
continue
data_start_index, data_end_index, bucket_id = self.param_index_map[param]
# Assign param.data to appropriate segment of self.param_data.
if self.param_data is not None:
old_param_data = param.data
param.data = self._get(
param.data.shape, data_start_index, buffer_type=BufferType.PARAM
)
assert old_param_data._base is None
# Copy tensor values (from initialization or checkpoint).
param.data.detach().copy_(old_param_data)
del old_param_data
param.main_grad = self._get(
param.data.shape, data_start_index, buffer_type=BufferType.GRAD
)
if bucket_id != cur_bucket_id:
bucket_data_end_index = _pad_if_needed(data_start_index)
self._set_bucket(
bucket_params=bucket_params,
start_index=bucket_data_start_index,
end_index=bucket_data_end_index,
numel_unpadded=per_bucket_numel_unpadded[cur_bucket_id],
bucket_id=cur_bucket_id,
)
bucket_data_start_index = bucket_data_end_index
bucket_params = set()
assert cur_bucket_id + 1 == len(self.buckets)
assert bucket_id == cur_bucket_id + 1
cur_bucket_id = bucket_id
bucket_params.add(param)
# Add remaining params to a new bucket.
if len(bucket_params) > 0:
bucket_data_end_index = _pad_if_needed(data_end_index)
self._set_bucket(
bucket_params=bucket_params,
start_index=bucket_data_start_index,
end_index=bucket_data_end_index,
numel_unpadded=per_bucket_numel_unpadded[cur_bucket_id],
bucket_id=cur_bucket_id,
)
# Log buckets for all PP stages.
log_strs = []
log_strs.append(
f'Number of buckets for gradient all-reduce / reduce-scatter: {len(self.buckets)}'
)
for index, bucket in enumerate(self.buckets):
numel = 0
for param in bucket.params:
numel += param.data.nelement()
log_strs.append(f'Params for bucket {index+1} ({numel} elements):')
for param in bucket.params:
log_strs.append(f'\t{param_to_name[param]}')
log_on_each_pipeline_stage(logger, logging.INFO, '\n'.join(log_strs))
def scale_gradients(self, scaling_factor: float) -> None:
"""Scale the gradient data by `scaling_factor`."""
self.grad_data *= scaling_factor
def _get(self, shape: torch.Size, start_index: int, buffer_type: BufferType) -> torch.Tensor:
"""
Return a tensor with the input `shape` as a view into the 1-D data starting at
`start_index`.
"""
end_index = start_index + shape.numel()
assert end_index <= self.numel, 'Requested tensor is out of buffer range'
if buffer_type == BufferType.PARAM:
assert self.param_data is not None
buffer_tensor = self.param_data[start_index:end_index]
elif buffer_type == BufferType.GRAD:
buffer_tensor = self.grad_data[start_index:end_index]
else:
raise Exception("Illegal buffer type provided to GradBuffer._get() function")
buffer_tensor = buffer_tensor.view(shape)
return buffer_tensor
def _set_bucket(
self,
bucket_params: List[torch.nn.Parameter],
start_index: int,
end_index: int,
numel_unpadded: int,
bucket_id: int,
):
"""
Helper function to create new bucket, add it to list of buckets, and
also update param->bucket mapping.
"""
# Assert that indices are correctly padded (if needed), and that bucket
# position is same as originally computed.
if self.ddp_config.use_distributed_optimizer:
assert start_index % self.data_parallel_world_size == 0
assert end_index % self.data_parallel_world_size == 0
assert (start_index, end_index) == self.bucket_indices[bucket_id]
# Get appropriate view into global ParamAndGradBuffer.
bucketed_param_data = None
if self.param_data is not None:
bucketed_param_data = self._get(
torch.Size([end_index - start_index]), start_index, buffer_type=BufferType.PARAM
)
bucketed_grad_data = self._get(
torch.Size([end_index - start_index]), start_index, buffer_type=BufferType.GRAD
)
bucket = Bucket(
ddp_config=self.ddp_config,
params=bucket_params,
param_data=bucketed_param_data,
grad_data=bucketed_grad_data,
offset=start_index,
numel_unpadded=numel_unpadded,
data_parallel_group=self.data_parallel_group,
data_parallel_world_size=self.data_parallel_world_size,
gradient_scaling_factor=self.gradient_scaling_factor,
)
self.buckets.append(bucket)
for bucket_param in bucket_params:
assert bucket_param not in self.param_to_bucket
self.param_to_bucket[bucket_param] = bucket
def reset(self):
"""
Zero out the underlying grad_buffer and reset all buckets in preparation for the next
iteration of training.
"""
self.grad_data.zero_()
for bucket in self.buckets:
bucket.reset()
self.is_last_microbatch = True
def start_grad_sync(self):
"""
Initiates grad sync (all-reduce or reduce-scatter) communication operations
for all buckets in the grad buffer.
When overlap_grad_reduce is set to True, dispatches asynchronous communication
calls. When overlap_grad_reduce is set to False, calls synchronous
communication ops.
"""
for bucket in self.buckets:
bucket.start_grad_sync()
def finish_grad_sync(self):
"""
Finishes grad sync (all-reduce or reduce-scatter) communication operations
for all buckets in the grad buffer.
When overlap_grad_reduce is set to True, waits for asynchronous communication
calls to complete. When overlap_grad_reduce is set to False, calls synchronous
communication ops.
"""
for bucket in self.buckets:
bucket.finish_grad_sync()
def register_grad_ready(self, param: torch.nn.Parameter):
"""
Registers grads for the passed-in param to be "ready" for grad sync.
When the number of microbatches is greater than 1, we only want to register
grads as ready when processing the last microbatch and overlap_grad_reduce is True.
"""
assert (
self.ddp_config.overlap_grad_reduce
), 'register_grad_ready() should only be called when overlap_grad_reduce is True'
if self.is_last_microbatch:
bucket = self.param_to_bucket[param]
bucket.register_grad_ready(param)
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import enum
class ModelType(enum.Enum):
encoder_or_decoder = 1
encoder_and_decoder = 2
retro_encoder = 3
retro_decoder = 4
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment