Commit 4e867b3c authored by jerrrrry's avatar jerrrrry
Browse files

Initial commit

parents
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import os
from typing import Any, Dict, NamedTuple, Protocol, Tuple
import torch
try:
import boto3
import botocore.exceptions as exceptions
except ModuleNotFoundError:
pass
S3_PREFIX = "s3://"
class S3Config(NamedTuple):
"""Config when the data (.bin) file and the index (.idx) file are in S3
TODO: These parameters are few and can be consolidated with parameters specific to bin reader
classes - @jkamalu
Attributes:
path_to_idx_cache (str): The local directory where we will store the index (.idx) file
bin_chunk_nbytes (int): If the number of bytes is too small, then we send a request to S3 at each call of the `read` method in _S3BinReader, which is slow, because each request has a fixed cost independent of the size of the byte range requested. If the number of bytes is too large, then we only rarely have to send requests to S3, but it takes a lot of time to complete the request when we do, which can block training. We've found that 256 * 1024 * 1024 (i.e., 256 MiB) has worked well (though we have not put that much effort into tuning it), so we default to it.
"""
path_to_idx_cache: str
bin_chunk_nbytes: int = 256 * 1024 * 1024
class S3Client(Protocol):
"""The protocol which all s3 clients should abide by"""
def download_file(self, Bucket: str, Key: str, Filename: str) -> None: ...
def upload_file(self, Filename: str, Bucket: str, Key: str) -> None: ...
def head_object(self, Bucket: str, Key: str) -> Dict[str, Any]: ...
def get_object(self, Bucket: str, Key: str, Range: str) -> Dict[str, Any]: ...
def close(self) -> None: ...
def is_s3_path(path: str) -> bool:
"""Ascertain whether a path is in S3
Args:
path (str): The path
Returns:
bool: True if the path is in S3, False otherwise
"""
return path.startswith(S3_PREFIX)
def parse_s3_path(path: str) -> Tuple[str, str]:
"""Parses the given S3 path returning correspsonding bucket and key.
Args:
path (str): The S3 path
Returns:
Tuple[str, str]: A (bucket, key) tuple
"""
assert is_s3_path(path)
parts = path.replace(S3_PREFIX, "").split("/")
bucket = parts[0]
if len(parts) > 1:
key = "/".join(parts[1:])
assert S3_PREFIX + bucket + "/" + key == path
else:
key = ""
return bucket, key
def object_exists(client: S3Client, path: str) -> bool:
"""Ascertain whether the object at the given S3 path exists in S3
Args:
client (S3Client): The S3 client
path (str): The S3 path
Raises:
botocore.exceptions.ClientError: The error code is 404
Returns:
bool: True if the object exists in S3, False otherwise
"""
parsed_s3_path = parse_s3_path(path)
try:
response = client.head_object(bucket=parsed_s3_path[0], key=parsed_s3_path[1])
except exceptions.ClientError as e:
if e.response["Error"]["Code"] != "404":
raise e
return True
def _download_file(client: S3Client, s3_path: str, local_path: str) -> None:
"""Download the object at the given S3 path to the given local file system path
Args:
client (S3Client): The S3 client
s3_path (str): The S3 source path
local_path (str): The local destination path
"""
dirname = os.path.dirname(local_path)
os.makedirs(dirname, exist_ok=True)
parsed_s3_path = parse_s3_path(s3_path)
client.download_file(parsed_s3_path[0], parsed_s3_path[1], local_path)
def maybe_download_file(s3_path: str, local_path: str) -> None:
"""Download the object at the given S3 path to the given local file system path
In a distributed setting, downloading the S3 object proceeds in stages in order
to try to have the minimum number of processes download the object in order for
all the ranks to have access to the downloaded object.
Args:
s3_path (str): The S3 source path
local_path (str): The local destination path
"""
if torch.distributed.is_initialized():
rank = torch.distributed.get_rank()
local_rank = rank % torch.cuda.device_count()
else:
rank = 0
local_rank = 0
s3_client = boto3.client("s3")
if (not os.path.exists(local_path)) and (rank == 0):
_download_file(s3_client, s3_path, local_path)
if torch.distributed.is_initialized():
torch.distributed.barrier()
# If the `local_path` is in a file system that is not
# shared across all the ranks, then we assume it's in the
# host file system and each host needs to download the file.
if (not os.path.exists(local_path)) and (local_rank == 0):
_download_file(s3_client, s3_path, local_path)
if torch.distributed.is_initialized():
torch.distributed.barrier()
# If the `local_path` still does not exist, then we assume
# each rank is saving to a separate location.
if not os.path.exists(local_path):
_download_file(s3_client, s3_path, local_path)
if torch.distributed.is_initialized():
torch.distributed.barrier()
assert os.path.exists(local_path)
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
from .core import check_is_distributed_checkpoint
from .mapping import LocalNonpersistentObject, ShardedObject, ShardedTensor
from .serialization import (
load,
load_common_state_dict,
load_plain_tensors,
load_tensors_metadata,
remove_sharded_tensors,
save,
)
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Module for managing distributed checkpoints metadata. """
import json
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Optional
CONFIG_FNAME = 'metadata.json'
class CheckpointingException(Exception):
"""Base checkpointing related exception"""
pass
@dataclass
class CheckpointingConfig:
"""Documents backends used in the checkpoint.
Checkpoint config keeps track of formats used for storing the sharded tensors
(sharded_backend) and other objects (common_backend).
Note that versioning is not for the checkpoint content (which is application specific),
but for the checkpoint format itself.
"""
sharded_backend: str
sharded_backend_version: int = 1
common_backend: str = 'torch'
common_backend_version: int = 1
def check_is_distributed_checkpoint(checkpoint_dir):
"""Checks if `metadata.json` exists in the checkpoint and is a valid config.
Args:
checkpoint_dir: checkpoint directory
Returns:
bool: True if `metadata.json` exists in the checkpoint and is a valid config.
"""
return maybe_load_config(checkpoint_dir) is not None
def maybe_load_config(checkpoint_dir: str) -> Optional[CheckpointingConfig]:
"""Returns checkpoint config if `checkpoint_dir` is a distributed checkpoint and None otherwise
Args:
checkpoint_dir: checkpoint directory
Returns:
CheckpointingConfig (optional): None if checkpoint is not a valid distributed checkpoint
"""
config_path = Path(checkpoint_dir, CONFIG_FNAME)
if not config_path.exists():
return None
with config_path.open() as f:
config_dict = json.load(f)
return CheckpointingConfig(**config_dict)
def save_config(config: CheckpointingConfig, checkpoint_dir: str):
"""Save given config to checkpoint directory.
Args:
config: checkpoint config
checkpoint_dir: checkpoint directory
Returns:
None
"""
config_path = Path(checkpoint_dir, CONFIG_FNAME)
with config_path.open('w') as f:
json.dump(asdict(config), f)
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Utilities for operating with dicts and lists.
All functions in this module handle nesting of dicts and lists.
Other objects (e.g. tuples) are treated as atomic leaf types that cannot be traversed.
"""
from collections import defaultdict
from typing import Any, Callable, Dict, Iterable, List, Tuple, TypeVar, Union
import numpy as np
import torch
U, V = TypeVar("U"), TypeVar("V")
def extract_matching_values(
x: Union[dict, list], predicate: Callable[[Any], bool], return_lists_as_dicts: bool = False
) -> Tuple[Union[dict, list], Union[dict, list]]:
"""Return matching and nonmatching values. Keeps hierarchy.
Args:
x (Union[dict, list]) : state dict to process. Top-level argument must be a dict or list
predicate (object -> bool): determines matching values
return_lists_as_dicts (bool): if True, matching lists will be turned
into dicts, with keys indicating the indices of original elements.
Useful for reconstructing the original hierarchy.
"""
def _set_elem(target, k, v):
if return_lists_as_dicts:
target[k] = v
else:
target.append(v)
if isinstance(x, dict):
matching_vals = {}
nonmatching_vals = {}
for k, v in x.items():
if isinstance(v, (list, dict)):
match, nonmatch = extract_matching_values(v, predicate, return_lists_as_dicts)
if match:
matching_vals[k] = match
if nonmatch or not v:
nonmatching_vals[k] = nonmatch
elif predicate(v):
matching_vals[k] = v
else:
nonmatching_vals[k] = v
elif isinstance(x, list): # type: ignore
matching_vals = {} if return_lists_as_dicts else []
nonmatching_vals = {} if return_lists_as_dicts else []
for ind, v in enumerate(x):
if isinstance(v, (list, dict)) and v:
match, nonmatch = extract_matching_values(v, predicate, return_lists_as_dicts)
if match:
_set_elem(matching_vals, ind, match)
if nonmatch or not v:
_set_elem(nonmatching_vals, ind, nonmatch)
else:
target = matching_vals if predicate(v) else nonmatching_vals
_set_elem(target, ind, v)
else:
raise ValueError(f'Unexpected top-level object type: {type(x)}')
return matching_vals, nonmatching_vals
def diff(x1: Any, x2: Any, prefix: Tuple = ()) -> Tuple[list, list, list]:
"""Recursive diff of dicts.
Args:
x1 (object): left dict
x2 (object): right dict
prefix (tuple): tracks recursive calls. Used for reporting differing keys.
Returns:
Tuple[list, list, list]: tuple of:
- only_left: Prefixes present only in left dict
- only_right: Prefixes present only in right dict
- mismatch: values present in both dicts but not equal across dicts.
For tensors equality of all elems is checked.
Each element is a tuple (prefix, type of left value, type of right value).
"""
mismatch = []
if isinstance(x1, dict) and isinstance(x2, dict):
only_left = [prefix + (k,) for k in x1.keys() - x2.keys()]
only_right = [prefix + (k,) for k in x2.keys() - x1.keys()]
for k in x2.keys() & x1.keys():
_left, _right, _mismatch = diff(x1[k], x2[k], prefix + (k,))
only_left.extend(_left)
only_right.extend(_right)
mismatch.extend(_mismatch)
elif isinstance(x1, list) or isinstance(x1, tuple) or isinstance(x1, np.ndarray):
assert type(x1) == type(x2)
only_left = list(range(len(x1) - 1, len(x2) - 1, -1))
only_right = list(range(len(x1) - 1, len(x2) - 1, -1))
for i, (v1, v2) in enumerate(zip(x1, x2)):
_left, _right, _mismatch = diff(v1, v2, prefix + (i,))
only_left.extend(_left)
only_right.extend(_right)
mismatch.extend(_mismatch)
else:
only_left = []
only_right = []
if isinstance(x1, torch.Tensor) and isinstance(x2, torch.Tensor):
if x1.device != x2.device:
_is_mismatch = not torch.all(x1.cpu() == x2.cpu())
else:
_is_mismatch = not torch.all(x1 == x2)
# TODO: change with concrete type that has both replica_id and data attrs
elif hasattr(x1, 'replica_id') and hasattr(x2, 'replica_id'):
assert type(x1) == type(x2)
only_left, only_right, mismatch = diff(
x1.data, x2.data, prefix + (type(x1),)
) # type: ignore
_is_mismatch = False
else:
try:
_is_mismatch = bool(x1 != x2)
except RuntimeError:
_is_mismatch = True
if _is_mismatch:
mismatch.append((prefix, type(x1), type(x2)))
return only_left, only_right, mismatch
def inspect_types(x: Any, prefix: Tuple = (), indent: int = 4):
"""Helper to print types of (nested) dict values."""
print_indent = lambda: print(' ' * indent * len(prefix), end='')
if isinstance(x, dict):
print()
for k, v in x.items():
print_indent()
print(f'> {k}: ', end='')
inspect_types(v, prefix + (k,), indent)
elif isinstance(x, list):
print()
for i, v in enumerate(x):
print_indent()
print(f'- {i}: ', end='')
inspect_types(v, prefix + (i,), indent)
else:
if isinstance(x, torch.Tensor):
print(f'Tensor of shape {x.shape}')
else:
try:
x_str = str(x)
except:
x_str = '<no string repr>'
if len(x_str) > 30:
x_str = x_str[:30] + '... (truncated)'
print(f'[{type(x)}]: {x_str}')
def nested_values(x: Union[dict, list]):
"""Returns iterator over (nested) values of a given dict or list."""
x_iter = x.values() if isinstance(x, dict) else x
for v in x_iter:
if isinstance(v, (dict, list)):
yield from nested_values(v)
else:
yield v
def nested_items_iter(x: Union[dict, list]):
"""Returns iterator over (nested) tuples (container, key, value) of a given dict or list."""
x_iter = x.items() if isinstance(x, dict) else enumerate(x)
for k, v in x_iter:
if isinstance(v, (dict, list)):
yield from nested_items_iter(v)
else:
yield x, k, v
def dict_map(f: Callable, d: dict):
"""`map` equivalent for dicts."""
for sub_d, k, v in nested_items_iter(d):
sub_d[k] = f(v)
def dict_map_with_key(f: Callable, d: dict):
"""`map` equivalent for dicts with a function that accepts tuple (key, value)."""
for sub_d, k, v in nested_items_iter(d):
sub_d[k] = f(k, v)
def dict_list_map_inplace(f: Callable[[U], V], x: Union[Dict, List, U]):
"""Maps dicts and lists *in-place* with a given function."""
if isinstance(x, dict):
for k, v in x.items():
x[k] = dict_list_map_inplace(f, v)
elif isinstance(x, list):
x[:] = (dict_list_map_inplace(f, v) for v in x)
else:
return f(x)
return x
def dict_list_map_outplace(f: Callable[[U], V], x: Union[Dict, List, U]) -> Union[Dict, List, V]:
"""Maps dicts and lists *out-of-place* with a given function."""
if isinstance(x, dict):
return {k: dict_list_map_outplace(f, v) for k, v in x.items()}
elif isinstance(x, list):
return [dict_list_map_outplace(f, v) for v in x]
else:
return f(x)
def merge(x1: Union[dict, list], x2: Union[dict, list], key: Tuple[Union[str, int], ...] = ()):
"""Merges dicts and lists recursively."""
if isinstance(x1, dict) and isinstance(x2, dict):
for k, v2 in x2.items():
if k not in x1:
x1[k] = v2
else:
x1[k] = merge(x1[k], v2, key=key + (k,))
elif isinstance(x1, list) and isinstance(x2, list):
if len(x1) != len(x2):
raise ValueError(
f'Cannot merge two lists with different lengths ({len(x1)} and {len(x2)}, '
f'encountered at level {key})'
)
for i, v2 in enumerate(x2):
x1[i] = merge(x1[i], v2, key=key + (i,))
else:
raise ValueError(
f'Duplicate non-dict and non-list values encountered: `{x1}` and `{x2}` '
f'(at level {key})'
)
return x1
def map_reduce(
xs: Iterable,
key_fn: Callable = lambda x: x,
value_fn: Callable = lambda x: x,
reduce_fn: Callable = lambda x: x,
) -> dict:
"""Simple map-reduce implementation following `more_itertools.map_reduce` interface."""
res = defaultdict(list)
for x in xs:
res[key_fn(x)].append(value_fn(x))
for k in res:
res[k] = reduce_fn(res[k])
return dict(res)
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
"""Utilities for exchanging data between ranks."""
import logging
from collections import defaultdict
from functools import reduce
from itertools import zip_longest
from typing import Any, Dict, List, NamedTuple, Optional, Set, Tuple, TypeVar, cast
import numpy as np
import torch
from .core import CheckpointingException
from .dict_utils import nested_values
from .mapping import ShardedStateDict, ShardedTensor, is_main_replica
from .utils import _sharded_tensor_shard_id, _ShardId, debug_time
# TODO: remove TE references once the TE bug is fixed
# Check if Transformer Engine has Float8Tensor class
HAVE_TE_FLOAT8TENSOR = False
try:
from transformer_engine.pytorch.float8_tensor import Float8Tensor
HAVE_TE_FLOAT8TENSOR = True
except (ImportError, ModuleNotFoundError):
# Float8Tensor not found
pass
def is_float8tensor(tensor: torch.Tensor) -> bool:
"""Check if a tensor is a Transformer Engine Float8Tensor"""
return HAVE_TE_FLOAT8TENSOR and isinstance(tensor, Float8Tensor)
logger = logging.getLogger(__name__)
class ShardDistribution(NamedTuple):
"""Represents a distribution of ShardedTensors.
Given distribution is valid only for a specific parallelization group,
which is implicit here (not referenced by this class).
Args:
main_rank_for_shard (Dict[_ShardId, int]): specifies which rank should hold
the main replica for a given shard
shards_in_this_group (Set[_ShardId]): which shards have a main replica
in this parallelization group
shard_to_metadata (Dict[_ShardId, ShardedTensor]): maps ShardedTensor
identifier to the original ShardedTensor
all_ranks_for_shard (Dict[_ShardId, List[int]]): specifies which ranks
need a given shard in a given parallelization group
"""
main_rank_for_shard: Dict[_ShardId, int]
shards_in_this_group: Set[_ShardId]
shard_to_metadata: Dict[_ShardId, ShardedTensor]
all_ranks_for_shard: Dict[_ShardId, List[int]]
def _shard_size(sh_ten: ShardedTensor):
"""Returns size in bytes of a given sharded tensor."""
if sh_ten.flattened_range is None:
numel = np.product(sh_ten.local_shape)
else:
numel = sh_ten.flattened_range.stop - sh_ten.flattened_range.start
return numel * torch._utils._element_size(sh_ten.dtype)
def _get_empty_tensor_for_exchange(
shard_id: _ShardId,
needed_shards: Dict[_ShardId, ShardedTensor],
unneeded_shards: Dict[_ShardId, ShardedTensor],
loaded_tensors: Dict[_ShardId, torch.Tensor],
) -> Tuple[torch.Tensor, Optional[torch.device]]:
"""Determines the empty tensor to use for exchange.
If shard_id is needed by this rank, it will be in the `unloaded_shards`.
Otherwise, the metadata for this tensor can be found in `shard_to_metadata`
Args:
shard_id (_ShardId): shard_id that will be exchanged
needed_shards (Dict[_ShardId, ShardedTensor]): mapping from shard ids
to metadata for shards needed by this rank
unneeded_shards (Dict[_ShardId, ShardedTensor]): mapping from shard ids
to metadata for shards that can be discarded after exchange
loaded_tensors (Dict[_ShardId, torch.Tensor]): mapping where useful tensors
are placed in
Returns:
Tuple[torch.Tensor, Optional[torch.device]]: empty CUDA tensor to be exchanged,
and the device of the original state dict tensor (if there was any)
"""
local_unloaded_sh_ten = needed_shards.get(shard_id)
if local_unloaded_sh_ten is None:
orig_device = None # this tensor will be discarded anyway
sh_ten = unneeded_shards[shard_id]
if sh_ten.data is None:
sh_ten.init_data('cuda')
tensor = sh_ten.data
sh_ten.data = None # won't be used. free memory
else:
tensor = sh_ten.data
if tensor.device.type == 'cpu':
tensor = torch.empty_like(tensor, device='cuda')
else:
local_unloaded_sh_ten.init_data('cuda')
orig_device = local_unloaded_sh_ten.data.device
tensor = local_unloaded_sh_ten.data
if tensor.device.type == 'cpu':
tensor = torch.empty_like(tensor, device='cuda')
loaded_tensors[shard_id] = tensor
return tensor, orig_device
T = TypeVar('T')
def distribute_shards_to_ranks(
shard_to_ranks: Dict[T, List[int]], shard_to_size: Dict[T, int], num_ranks: int
) -> Dict[T, int]:
"""Computes uniform distribution of workload across ranks, based on sizes.
Currently, the assignment is greedy, based on:
1. Firstly, the coverage of each shard
(how many ranks the shard is available on; lower coverage is assigned first)
2. Secondly, the size of each shard (larger size is assigned first)
3. Finally, shard id for differentiation.
Third step is added because we rely on the fact that
the assignment is deterministic on all ranks.
Args:
shard_to_ranks (Dict[T, List[int]]): mapping of rank access to shards
shard_to_size (Dict[T, int]): sizes of each shard
num_ranks (int): number of ranks in the parallelization group
Returns (Dict[T, int]): assignment of shard to rank (which rank should do the work
to achieve maximal uniformity)
"""
shard_to_ranks = {k: tuple(v) for k, v in shard_to_ranks.items()}
shard_to_saving_rank = {}
rank_sizes = [(0, rank) for rank in range(num_ranks)]
# start from tensors of lowest coverage, then go by tensor size from largest (hence minus size)
for shard_id, shard_ranks in sorted(
shard_to_ranks.items(),
key=lambda sh_id_ranks: (
len(sh_id_ranks[1]),
-shard_to_size[sh_id_ranks[0]],
sh_id_ranks[0],
),
):
# assign greedily to the least occupied rank
size, rank = min((size, rank) for size, rank in rank_sizes if rank in shard_ranks)
shard_to_saving_rank[shard_id] = rank
rank_sizes[rank] = (size + shard_to_size[shard_id], rank)
logger.debug(f'distribute_shards_to_ranks distribution: {rank_sizes}')
return shard_to_saving_rank
def determine_main_replica_uniform_distribution(
sharded_state_dict: ShardedStateDict,
parallelization_group: torch.distributed.ProcessGroup,
ignore_groups: bool = False,
) -> Optional[ShardDistribution]:
"""Computes the save distribution.
Should be used in conjunction with `distribute_main_replicas_with_precomputed_distribution`
which applies the computed save distribution.
We rely on the fact that the assignment algorithm is deterministic on all ranks,
so there is no extra communication needed after metadata exchange.
Args:
sharded_state_dict (ShardedStateDict): state dict to compute the distribution of
parallelization_group (ProcessGroup): distribution will be computed
within this process group
ignore_groups (bool, optional): whether the distribution defines groups.
This option is primarily used during loading, as it ensures that all replicas,
including non-main ones, are loaded by this parallelization group
Defaults to False.
Returns (ShardDistribution, optional): distribution that can be used to apply the
parallelization. Returns None if the process_group is trivial (1 rank)
"""
group_size = torch.distributed.get_world_size(group=parallelization_group)
if group_size <= 1:
return
local_shards = list(
sh_base
for sh_base in nested_values(sharded_state_dict)
if isinstance(sh_base, ShardedTensor)
)
local_shards_no_data = [ten.without_data() for ten in local_shards]
all_shards = [None] * torch.distributed.get_world_size(group=parallelization_group)
torch.distributed.all_gather_object(
all_shards, local_shards_no_data, group=parallelization_group
)
shard_to_ranks = defaultdict(list)
shard_to_size = {}
shard_to_metadata = {}
shards_in_this_parallelization_group: Set[_ShardId] = set()
for rank, rank_shards in enumerate(all_shards):
for sh_ten in rank_shards:
shard_id = _sharded_tensor_shard_id(sh_ten)
shard_to_ranks[shard_id].append(rank)
if shard_id not in shard_to_size:
shard_to_size[shard_id] = _shard_size(sh_ten)
shard_to_metadata[shard_id] = sh_ten
if is_main_replica(sh_ten.replica_id) or ignore_groups:
shards_in_this_parallelization_group.add(shard_id)
shard_to_ranks = {
k: v for k, v in shard_to_ranks.items() if k in shards_in_this_parallelization_group
}
shard_to_saving_rank = distribute_shards_to_ranks(
shard_to_ranks, shard_to_size, len(all_shards)
)
return ShardDistribution(
shard_to_saving_rank,
shards_in_this_parallelization_group,
shard_to_metadata,
shard_to_ranks,
)
@torch.no_grad()
@debug_time(f"exchange_loaded_tensors_gather_rounds", logger)
def exchange_loaded_tensors_gather_rounds(
loaded_tensors: Dict[_ShardId, torch.Tensor],
unloaded_shards: Dict[_ShardId, ShardedTensor],
shard_distribution: ShardDistribution = None,
parallelization_group: Optional[torch.distributed.ProcessGroup] = None,
) -> Dict[_ShardId, torch.Tensor]:
"""Exchange the tensors loaded by different ranks with several all_gather calls.
Groups tensors by dtype, divide tensors that will be exchanged into rounds
and execute all_gather for tensors from each round.
Note: the loading is distributed across ranks based on total loaded size
in bytes, so there is no guarantee that number of rounds needed for each
rank will be similar, which might result in a lot of almost empty
all_gathers. The solution would be to group all tensors into a one
bytes tensor and do a single all_gather (with similarly sized messages).
Args:
loaded_tensors (Dict[_ShardId, torch.Tensor]): mapping from ShardedTensor
shard ids to tensors already loaded by this rank.
unloaded_shards (Dict[_ShardId, torch.Tensor]): mapping from ShardedTensor
shard ids to ShardedTensors that aren't loaded yet.
shard_distribution (ShardDistribution): distribution of all shards
parallelization_group (ProcessGroup, optional): process group used for load
distribution. Tensors will be exchanged within this group
Returns:
Dict[_ShardId, torch.Tensor]: dictionary mapping shard ids to tensors
needed by this rank to load a given state dict. Includes
previously loaded tensors (from `loaded_tensors` input)
"""
main_rank_for_shard, _, shard_to_metadata, all_ranks_for_shard = shard_distribution
local_rank = torch.distributed.get_rank(group=parallelization_group)
all_loaded_tensors = dict(loaded_tensors)
# Group by dtype so that we all_gather tensors of the same dtype
for dtype in sorted(set(map(lambda sh_ten: sh_ten.dtype, shard_to_metadata.values())), key=str):
with debug_time(f"dtype_{dtype}"):
# shards_by_rank maps rank to tensors loaded by this rank
shards_by_rank: List[List[torch.Tensor]] = [
[] for _ in range(torch.distributed.get_world_size(group=parallelization_group))
]
for shard_id, rank in main_rank_for_shard.items():
if len(all_ranks_for_shard[shard_id]) == 1:
assert all_ranks_for_shard[shard_id][0] == main_rank_for_shard[shard_id], (
f'When there is only 1 ranks that needs a given shard,'
f' it should be the loading rank.'
f' Got: needs [{all_ranks_for_shard[shard_id][0]}]'
f' vs loads [{main_rank_for_shard[shard_id]}]'
)
# Skipping the exchange since only the loading rank needs this tensor
# TODO: we can employ some optimizations even for `len(shard_to_ranks) > 1`
# case, e.g. P2P exchange. Currently handling this case saves most of the
# work though.
continue
if shard_to_metadata[shard_id].dtype == dtype:
shards_by_rank[rank].append(shard_id)
# Transpose `shards_by_rank` to form exchange rounds
shards_by_round = zip_longest(*shards_by_rank, fillvalue=None)
for round_idx, round_shard_ids in enumerate(shards_by_round):
round_tensors = []
orig_devices = {}
for rank, shard_id in enumerate(round_shard_ids):
if shard_id is None:
# if no more useful data, the given rank will exchange empty tensor
local_ten = torch.empty(0, dtype=dtype, device='cuda')
orig_device = None
else:
assert isinstance(shard_id, tuple), type(shard_id)
if rank == local_rank:
assert shard_id in all_loaded_tensors, (
shard_id,
all_loaded_tensors.keys(),
)
orig_device = all_loaded_tensors[shard_id]
all_loaded_tensors[shard_id] = all_loaded_tensors[shard_id].cuda()
local_ten = all_loaded_tensors[shard_id]
else:
local_ten, orig_device = _get_empty_tensor_for_exchange(
shard_id, unloaded_shards, shard_to_metadata, all_loaded_tensors
)
# Because of a TE bug, we have to exchange a nominal dtype instead of FP8
# It's ok to keep the nominal dtype after exchange, because TE will handle
# this during state dict load.
# TODO: remove it once the bug is fixed
if is_float8tensor(local_ten):
try:
local_ten = local_ten.from_float8()
except Exception as e:
local_ten = local_ten.dequantize()
all_loaded_tensors[shard_id] = local_ten
round_tensors.append(local_ten)
if orig_device is not None:
orig_devices[shard_id] = orig_device
torch.distributed.all_gather(
list(round_tensors),
round_tensors[local_rank],
group=parallelization_group,
async_op=False,
)
# Move tensors back to CPU if originally was on CPU
for shard_id, orig_device in orig_devices.items():
all_loaded_tensors[shard_id] = all_loaded_tensors[shard_id].to(orig_device)
del round_tensors # remove tensor references
return all_loaded_tensors
def exchange_loaded_tensors_gather_object(
loaded_tensors: Dict[_ShardId, torch.Tensor],
unloaded_shards: Dict[_ShardId, ShardedTensor],
shard_distribution: ShardDistribution,
parallelization_group: Optional[torch.distributed.ProcessGroup] = None,
) -> Dict[_ShardId, torch.Tensor]:
"""Exchange the tensors loaded by different ranks with a simple all_gather_object call.
This version can be used for debugging purposes do to its simplistic
implementation. Shouldn't be used if performance is important.
Args:
loaded_tensors (Dict[_ShardId, torch.Tensor]): mapping from ShardedTensor
shard ids to tensors already loaded by this rank.
unloaded_shards (Dict[_ShardId, torch.Tensor]): mapping from ShardedTensor
shard ids to ShardedTensors that aren't loaded yet.
shard_distribution (ShardDistribution): distribution of all shards
parallelization_group (ProcessGroup, optional): process group used for load
distribution. Tensors will be exchanged within this group
Returns:
Dict[_ShardId, torch.Tensor]: dictionary mapping shard ids to tensors
needed by this rank to load a given state dict. Includes
previously loaded tensors (from `loaded_tensors` input)
"""
all_loaded_tensors_list = [None] * torch.distributed.get_world_size(group=parallelization_group)
torch.distributed.all_gather_object(
all_loaded_tensors_list, loaded_tensors, group=parallelization_group
)
all_loaded_tensors_list = cast(List[Dict[_ShardId, torch.Tensor]], all_loaded_tensors_list)
all_loaded_tensors = reduce(lambda x, y: {**x, **y}, all_loaded_tensors_list)
# Error checks
if len(all_loaded_tensors) != sum(map(len, all_loaded_tensors_list)):
err_msg = 'Duplicate shard ids loaded by different ranks'
if torch.distributed.get_rank() == 0:
logger.error(
f'{err_msg}. Shards ids by rank:'
f' {[lt.keys() for lt in all_loaded_tensors_list]}'
)
raise CheckpointingException(err_msg)
return all_loaded_tensors
def exchange_loaded_objects_gather_object(
loaded_objects: Dict[_ShardId, Any]
) -> Dict[_ShardId, Any]:
"""Exchange the objects loaded by different ranks with a simple all_gather_object call.
Args:
loaded_objects (Dict[_ShardId, Any]): mapping from shard ids to objects
already loaded by this rank.
Returns:
Dict[_ShardId, Any]: dictionary mapping shard ids to objects needed by this rank to
load a given state dict.
"""
all_loaded_objects_list = [None] * torch.distributed.get_world_size(group=None)
torch.distributed.all_gather_object(all_loaded_objects_list, loaded_objects, group=None)
all_loaded_objects_list = cast(List[Dict[_ShardId, Any]], all_loaded_objects_list)
all_loaded_objects = reduce(lambda x, y: {**x, **y}, all_loaded_objects_list)
# Error checks
if len(all_loaded_objects) != sum(map(len, all_loaded_objects_list)):
err_msg = 'Duplicate shard ids loaded by different ranks'
if torch.distributed.get_rank() == 0:
logger.error(
f'{err_msg}. Shards ids by rank:'
f' {[lt.keys() for lt in all_loaded_objects_list]}'
)
raise CheckpointingException(err_msg)
return all_loaded_objects
@torch.no_grad()
@debug_time("exchange_loaded_tensors_broadcast", logger)
def exchange_loaded_tensors_broadcast(
loaded_tensors: Dict[_ShardId, torch.Tensor],
unloaded_shards: Dict[_ShardId, ShardedTensor],
shard_distribution: ShardDistribution,
parallelization_group: Optional[torch.distributed.ProcessGroup] = None,
) -> Dict[_ShardId, torch.Tensor]:
"""Exchange the tensors loaded by different ranks by a series of broadcasts.
For each rank for each loaded tensor do a broadcast to the whole group.
A reasonable tradeoff in terms of performance and simplicity.
Args:
loaded_tensors (Dict[_ShardId, torch.Tensor]): mapping from ShardedTensor
shard ids to tensors already loaded by this rank.
unloaded_shards (Dict[_ShardId, ShardedTensor]): mapping from ShardedTensor
shard ids to ShardedTensors that aren't loaded yet.
shard_distribution (ShardDistribution): distribution of all shards
parallelization_group (ProcessGroup, optional): process group used for load
distribution. Tensors will be exchanged within this group
Returns:
Dict[_ShardId, torch.Tensor]: dictionary mapping shard ids to tensors
needed by this rank to load a given state dict. Includes
previously loaded tensors (from `loaded_tensors` input)
"""
main_rank_for_shard, _, shard_to_metadata, all_ranks_for_shard = shard_distribution
local_rank = torch.distributed.get_rank(group=parallelization_group)
all_loaded_tensors = dict(loaded_tensors)
for idx, (shard_id, rank) in enumerate(main_rank_for_shard.items()):
if len(all_ranks_for_shard[shard_id]) == 1:
assert all_ranks_for_shard[shard_id][0] == main_rank_for_shard[shard_id], (
f'When there is only 1 ranks that needs a given shard,'
f' it should be the loading rank.'
f'Got: needs [{all_ranks_for_shard[shard_id][0]}]'
f' vs loads [{main_rank_for_shard[shard_id]}]'
)
# Skipping the exchange since only the loading rank needs this tensor
# TODO: we can employ some optimizations even for `len(shard_to_ranks) > 1` case,
# e.g. P2P exchange. Currently handling this case saves most of the work though.
continue
if rank == local_rank:
assert shard_id in all_loaded_tensors, (shard_id, all_loaded_tensors.keys())
orig_device = all_loaded_tensors[shard_id].device
local_ten = all_loaded_tensors[shard_id].cuda()
else:
local_ten, orig_device = _get_empty_tensor_for_exchange(
shard_id, unloaded_shards, shard_to_metadata, all_loaded_tensors
)
# Because of a TE bug, we have to exchange a nominal dtype instead of FP8
# It's ok to keep the nominal dtype after exchange, because TE will handle
# this during state dict load.
# TODO: remove it once the bug is fixed
if is_float8tensor(local_ten):
try:
local_ten = local_ten.from_float8()
except Exception as e:
local_ten = local_ten.dequantize()
all_loaded_tensors[shard_id] = local_ten
global_src_rank = (
rank
if parallelization_group == None
else torch.distributed.get_global_rank(parallelization_group, rank)
)
# We can do async_op=True only if there is no CPU-copy follow-up
torch.distributed.broadcast(
local_ten,
src=global_src_rank,
group=parallelization_group,
async_op=orig_device is None,
)
# Move tensor back to CPU if originally was on CPU
if orig_device is not None:
all_loaded_tensors[shard_id] = local_ten.to(orig_device)
del local_ten
return all_loaded_tensors
def exchange_by_distribution(
loaded_tensors: Dict[_ShardId, torch.Tensor],
unloaded_shards: Dict[_ShardId, ShardedTensor],
shard_distribution: ShardDistribution,
parallelization_group: Optional[torch.distributed.ProcessGroup] = None,
exchange_algo='broadcast',
) -> Dict[_ShardId, torch.Tensor]:
"""Exchange tensors loaded by different ranks using the specified exchange_algo.
Args:
loaded_tensors (Dict[_ShardId, torch.Tensor]): mapping from ShardedTensor
shard ids to tensors already loaded by this rank.
unloaded_shards (Dict[_ShardId, ShardedTensor]): mapping from ShardedTensor
shard ids to ShardedTensors that aren't loaded yet.
shard_distribution (ShardDistribution): distribution of all shards
parallelization_group (ProcessGroup, optional): process group used for load
distribution. Tensors will be exchanged within this group
exchange_algo (str): The algorithm used for performing exchanges.
Defaults to 'broadcast'.
Returns:
Dict[_ShardId, torch.Tensor]: dictionary mapping shard ids to tensors
needed by this rank to load a given state dict. Includes
previously loaded tensors (from `loaded_tensors` input)
"""
assert shard_distribution is not None, 'Expecting distribution to perform exchange'
if exchange_algo == 'gather_object':
exchange_fn = exchange_loaded_tensors_gather_object
elif exchange_algo == 'gather_rounds':
exchange_fn = exchange_loaded_tensors_gather_rounds
elif exchange_algo == 'broadcast':
exchange_fn = exchange_loaded_tensors_broadcast
else:
raise NotImplementedError(f'Unrecognized gather algorithm: {exchange_algo}')
return exchange_fn(loaded_tensors, unloaded_shards, shard_distribution, parallelization_group)
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Core library classes for representing sharding of tensors and objects.
The main expected usage is wrapping torch.Tensors in state dicts with
ShardedTensor class (mostly with the ShardedTensor.from_rank_offsets classmethod).
"""
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass, field, replace
from itertools import chain
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from .core import CheckpointingException
from .dict_utils import dict_list_map_inplace
logger = logging.getLogger(__name__)
# These type definitions are just hints to differentiate a plain model state
# dict (StateDict) from a state dict with tensors replaced with ShardedTensors
# (ShardedStateDict).
StateDict = Dict[str, Any]
CommonStateDict = Dict[str, Any]
ShardedStateDict = Dict[str, Any]
ReplicaId = Union[int, Tuple[int, ...]]
class ShardedBase(ABC):
"""Base class for ShardedTensor and ShardedStateDict."""
key: str
data: object
replica_id: ReplicaId
@abstractmethod
def validate_metadata_integrity(self):
"""Codifies the constraints on metadata attributes."""
@abstractmethod
def without_data(self) -> 'ShardedBase':
"""Returns a new ShardedBase instance with data=None."""
raise NotImplementedError
@dataclass
class ShardedTensor(ShardedBase):
"""Represents a mapping between a local tensor and a global tensor.
Global tensor is assumed to consist of many local tensors distributed
between different processes.
Args:
key: unique identifier of a global tensor
data: local tensor data. Can be None only for consistency validation
dtype: tensor dtype
local_shape: local tensor shape
global_shape: global tensor shape
global_offset: offset of a local tensor in a global tensor,
specified in number of tensor elements
axis_fragmentations: global tensor fragmentation of each axis
replica_id: indicates given local tensor's replication wrt.
local tensors in different processes
prepend_axis_num: number of axes prepended to the local tensor to
reflect global tensor shape. The behavior is similar to
unsqueezing the local tensor.
allow_shape_mismatch: if True, during loading, the global shape of
a stored tensor does not have to match the expected global shape.
Useful for representing tensors with flexible shape,
e.g. padded.
flattened_range: specifies a slice that should be applied to a
flattened tensor with `local_shape` in order to get
the tensor stored as `data`
"""
key: str
data: Optional[torch.Tensor] = field(repr=False)
dtype: torch.dtype
local_shape: Tuple[int, ...]
global_shape: Tuple[int, ...]
global_offset: Tuple[int, ...]
axis_fragmentations: Optional[Tuple[int, ...]]
replica_id: ReplicaId = 0
prepend_axis_num: int = 0
allow_shape_mismatch: bool = False
flattened_range: Optional[slice] = None
def __post_init__(self):
self.validate_metadata_integrity()
def validate_metadata_integrity(self) -> None:
"""Codifies the constraints on metadata attributes.
Meeting those constraints is guaranteed when instantiating a ShardedTensor
class with `from_rank_offsets` or `from_rank_offsets_flat` constructors.
Returns:
None
"""
has_flattened_range = self.flattened_range is not None
if self.data is not None:
if self.data.dtype != self.dtype:
raise CheckpointingException(
f'Data dtype should match `dtype` attribute for {self}'
)
if not has_flattened_range and self.data.shape != self.local_shape:
raise CheckpointingException(
f'Data shape should match `local_shape` attribute for {self}'
)
if has_flattened_range:
if self.data.ndim != 1:
raise CheckpointingException(f'Data should be 1D for a flattened {self}')
real_data = self.data
try:
self.data = None
self.init_data(device='meta')
if self.data.shape != real_data.shape:
raise CheckpointingException(
f'Data shape {real_data.shape} doesnt match'
f' expected {self.data.shape} for {self}'
)
finally:
self.data = real_data
if len(self.global_shape) != len(self.global_offset):
raise CheckpointingException(
f'Global offset dimensions should be equal to global shape dimensions for {self}'
)
if len(self.local_shape) + self.prepend_axis_num != len(self.global_shape):
raise CheckpointingException(
f'Local shape together with `prepend_axis_num` dimensions should be '
f'equal to global shape dimensions for {self}'
)
for off, sh in zip(self.global_offset[self.prepend_axis_num :], self.local_shape):
# NOTE: In custom FSDP, we have a case where a new parameter shard is created locally.
# For example, consider parameters [p0, p1, p2] sharded across GPU0 and GPU1.
# GPU0 receives p0 and a portion of p1, while GPU1 receives the
# remaining portion of p1 and p2.
# As a result, there is no parameter shard of p2 on GPU0, and
# the shape of p2 on GPU0 is zero.
if sh != 0 and off % sh != 0:
raise CheckpointingException(
f'Global offset ({off}) must be divisible by local shape ({sh}) for {self}.'
)
if has_flattened_range and self.flattened_range.step is not None:
raise CheckpointingException(
f'`step` argument in the flattened range of a ShardedTensor is not supported.'
)
def global_slice(self) -> Tuple[Union[int, slice], ...]:
"""
Returns a tuple of int and slice objects representing a slice of the
global tensor that this ShardedTensor corresponds to.
"""
assert len(self.global_offset) == len(self.local_shape) + self.prepend_axis_num
return tuple(
chain(
(off for off in self.global_offset[: self.prepend_axis_num]),
(
slice(off, off + sh)
for off, sh in zip(
self.global_offset[self.prepend_axis_num :], self.local_shape
)
),
)
)
def global_coordinates(self) -> Tuple[np.ndarray, ...]:
"""
Returns a tuple of np.ndarrays representing the coordinates of the global tensor
that this ShardedTensor corresponds to.
"""
if self.flattened_range is None:
raise CheckpointingException(
f'`global_coordinates` is undefined for'
f' {self.__class__.__name__} without `flattened_range`'
)
local_coords = self.local_coordinates()
assert len(local_coords) + self.prepend_axis_num == len(self.global_offset), (
len(local_coords),
self,
)
global_coords = tuple(
c + off
for c, off in zip((0,) * self.prepend_axis_num + local_coords, self.global_offset)
)
return global_coords
def local_coordinates(self) -> Tuple[np.ndarray, ...]:
"""
Returns a tuple of np.ndarrays representing the coordinates of the local tensor
that this ShardedTensor corresponds to.
"""
if self.flattened_range is None:
raise CheckpointingException(
f'`local_coordinates` is undefined for'
f' {self.__class__.__name__} without `flattened_range`'
)
# TODO: np.unravel_index?
mask = np.zeros(np.product(self.local_shape), dtype=bool)
mask[self.flattened_range] = True
return np.nonzero(mask.reshape(self.local_shape))
def local_chunk_offset_in_global(self) -> Tuple[int, ...]:
"""Offset of a local chunk in a global array of chunks.
Returns:
Tuple[int, ...]: the offset of the whole local chunk in a global array of chunks.
"""
assert len(self.global_offset) == len(self.local_shape) + self.prepend_axis_num
chunk_offset = list(self.global_offset[: self.prepend_axis_num])
for off, sh in zip(self.global_offset[self.prepend_axis_num :], self.local_shape):
assert off % sh == 0, str(self)
chunk_offset.append(off // sh)
return tuple(chunk_offset)
def max_allowed_chunks(self) -> Tuple[int, ...]:
"""
Returns the maximum allowed chunks for this ShardedTensor.
"""
chunks = []
for axis_sh, axis_fragm in zip(self.global_shape, self.axis_fragmentations):
if not self.allow_shape_mismatch and axis_sh % axis_fragm != 0:
raise CheckpointingException(
f'Axis shape ({axis_sh}) not divisible by axis fragmentation ({axis_fragm}'
)
axis_chunk_size = axis_sh // axis_fragm
chunks.append(axis_chunk_size)
return tuple(chunks)
def without_data(self):
return replace(self, data=None)
@classmethod
def from_rank_offsets(
cls,
key: str,
data: torch.Tensor,
*rank_offsets: Tuple[int, int, int],
replica_id: ReplicaId = 0,
prepend_axis_num: int = 0,
flattened_range: None = None,
**init_kwargs,
):
"""Allows to construct the ShardedTensor given offset specified in process ranks.
Args:
key (str): unique key
data (torch.Tensor): local tensor data
rank_offsets (Tuple[int, int, int]): each tuple
(axis, axis_rank_offset, axis_fragm) says that if
global tensor is divided into `axis_fragm` fragment along `axis`
axis, then local tensor data corresponds to the `axis_rank_offset` chunk.
replica_id (ReplicaId): see ShardedTensor
prepend_axis_num (int): see ShardedTensor
flattened_range (None): must be None when using this constructor
init_kwargs: passed to ShardedTensor.__init__
"""
if flattened_range is not None:
raise ValueError(
'Cannot instantiate a flat ShardedTensor with `from_rank_offsets` method.'
' Use `from_rank_offsets_flat` instead'
)
global_offset = [0] * (data.ndim + prepend_axis_num)
global_shape = ([1] * prepend_axis_num) + list(data.shape)
axis_fragmentations = [1] * (data.ndim + prepend_axis_num)
_seen_axis = set()
for axis, axis_rank_offset, axis_fragm in rank_offsets:
if axis < 0 or axis_rank_offset < 0 or axis_fragm < 1 or axis_rank_offset >= axis_fragm:
raise CheckpointingException(f'Invalid rank offsets: {rank_offsets} for key {key}.')
_seen_axis.add(axis)
local_axis_shape = 1 if axis < prepend_axis_num else data.shape[axis - prepend_axis_num]
global_shape[axis] = axis_fragm * local_axis_shape
global_offset[axis] = axis_rank_offset * local_axis_shape
axis_fragmentations[axis] = axis_fragm
return cls(
key,
data,
data.dtype,
tuple(data.shape),
tuple(global_shape),
tuple(global_offset),
tuple(axis_fragmentations),
replica_id,
prepend_axis_num,
flattened_range=flattened_range,
**init_kwargs,
)
@classmethod
def from_rank_offsets_flat(
cls,
key: str,
data: torch.Tensor,
non_flat_local_shape: Tuple[int, ...],
*args,
flattened_range: Optional[slice] = None,
**kwargs,
):
"""Allows to construct a *flattened* ShardedTensor given offset specified in process ranks.
Args:
key (str):
data (torch.Tensor): this should be a flattened data tensor
non_flat_local_shape (Tuple[int, ...]): expected local shape of a non-flat chunk
*args: passed unchanged to the `from_rank_offsets` constructor
flattened_range (slice): see ShardedTensor. Defaults to None, but must be set to
a non-None slice.
**kwargs:
Returns:
ShardedTensor: constructed ShardedTensor instance
"""
if flattened_range is None:
raise CheckpointingException(
'Cannot instantiate a non-flat ShardedTensor with `from_rank_offsets_flat` method.'
' Use `from_rank_offsets` instead'
)
if data.ndim != 1:
raise CheckpointingException(
f'Flattened ShardedTensor requires 1D data, got shape: {data.shape}'
)
if flattened_range.stop - flattened_range.start != data.numel():
raise CheckpointingException(
f'Flattened ShardedTensor data length ({data.numel()}) must meet the '
f'slice length: {flattened_range.stop - flattened_range.start}'
)
non_flat_data_meta = torch.empty(*non_flat_local_shape, dtype=data.dtype, device='meta')
sh_ten = cls.from_rank_offsets(key, non_flat_data_meta, *args, **kwargs)
instance = replace(sh_ten, data=data, flattened_range=flattened_range)
instance.validate_metadata_integrity()
return instance
def init_data(self, device: Union[str, torch.device], init_fn=torch.empty):
"""
Initialize the tensor data of this ShardedTensor.
Only called if `data` attribute is None.
Args:
device (Union[str, torch.device]): device to place the tensor on
init_fn (Callable, optional): function to use to initialize the tensor.
Defaults to `torch.empty`.
"""
if self.data is not None:
return
self.data = init_fn(self.local_shape, dtype=self.dtype, device=device)
if self.flattened_range is not None:
self.data = self.data.flatten()[self.flattened_range.start : self.flattened_range.stop]
def narrow(self, dim: int, start: int, length: int) -> List['ShardedTensor']:
"""This is an analogue of torch.narrow for ShardedTensors.
Narrowing assumes that we narrow a local tensor on each rank.
This has consequences on local_shape, global_shape, global_offset, etc.
Args:
dim (int): dimension to narrow. Doesn't include prepended axes.
start (int): start element
length (int): length of the slice
Returns:
List[ShardedTensor]: narrowed ShardedTensors. For non-flat tensors,
the list will always have 1 element. For flat ShardedTensors the number of
elements varies depending on `dim` and on overlap, because flat
tensors must be contiguous. In particular the list can be empty.
"""
prepended_dim = dim + self.prepend_axis_num
local_length_along_dim = self.local_shape[dim]
def _update_tuple(x, ind, val):
x = list(x)
x[ind] = val
return tuple(x)
def _safe_div(x, y):
assert x % y == 0, (x, y)
return x // y
# Decrease global shape and global offset by `length / local_length_along_dim`
assert (
self.global_shape[prepended_dim] % local_length_along_dim == 0
), f'Only regular grid of local tensors is supported for narrowing, got: {self}'
assert (
self.global_offset[prepended_dim] % local_length_along_dim == 0
), f'Only regular grid of local tensors is supported for narrowing, got: {self}'
global_shape = _update_tuple(
self.global_shape,
prepended_dim,
_safe_div(self.global_shape[prepended_dim] * length, local_length_along_dim),
)
global_offset = _update_tuple(
self.global_offset,
prepended_dim,
_safe_div(self.global_offset[prepended_dim] * length, local_length_along_dim),
)
if self.flattened_range is None:
new_data = self.data.narrow(dim, start, length)
# always a single result tensor
return [
replace(
self,
data=new_data,
local_shape=new_data.shape,
global_shape=global_shape,
global_offset=global_offset,
)
]
else:
if dim != 0:
raise CheckpointingException(
f'Narrowing along the first axis is supported for now only, got dim={dim}'
)
# If dim=0, we will always get 0 or 1 resulting tensor.
# If dim>1, in general there can be more result tensors (e.g. max 3 for dim=1)
# For on original flat ShardedTensor of local shape [3, 4] and
# flattened_range=slice(5, 10),
# the X signs mark the actual (flat) data in `self.data`
# notice 12 (3*4) total "virtual" elements, out of which 5 is actual data.
# flat original: [.....XXXXX..]
# If we narrow to start=1, length=1 in the original local shape dimensions,
# the overlapping flat slice would be:
# narrow to: [....XXXX....]
# flat overlap: [.....XXX....]
# Now `data` is flattened and sliced, so we must compute local_shape manually
local_shape = _update_tuple(self.local_shape, dim, length)
other_dims_volume = np.prod(
_update_tuple(local_shape, dim, 1)
) # 4 in the example above
volume_before_split = other_dims_volume * start # 4 in the example above
volume_of_split = other_dims_volume * length # 4 in the example above
flat_slice_start_shifted = (
self.flattened_range.start - volume_before_split
) # 5 - 4 = 1 in the example above
flat_slice_stop_shifted = (
self.flattened_range.stop - volume_before_split
) # 10 - 4 = 6 in the example above
# Find an intersection of
# (flat_slice_start_shifted, flat_slice_stop_shifted) vs (0, volume_of_split)
if flat_slice_stop_shifted <= 0 or flat_slice_start_shifted >= volume_of_split:
return [] # no intersection
# new_flattened_range = slice(1, 4) in the example above
new_flattened_range = slice(
max(flat_slice_start_shifted, 0), min(flat_slice_stop_shifted, volume_of_split)
)
# Apply the intersection to the flattened data tensor.
# Compute start and slice appropriate length
intersection_slice_start = (
new_flattened_range.start - flat_slice_start_shifted
) # 0 in the example above
new_data = self.data[
intersection_slice_start : intersection_slice_start
+ new_flattened_range.stop
- new_flattened_range.start
]
return [
replace(
self,
data=new_data,
local_shape=local_shape,
global_shape=global_shape,
global_offset=global_offset,
flattened_range=new_flattened_range,
)
]
def is_main_replica(replica_id: ReplicaId):
"""Checks if given `replica_id` is considered as main.
"Main" replica is:
- integer 0
- or an iterable with all 0 elements
It is the application responsibility to set correct replicas for sharded tensors.
Args:
replica_id (Union[int, Tuple[int, ...]]): replica id
Returns:
(bool): True for a "main" replica
"""
if isinstance(replica_id, int):
return replica_id == 0
return all(r == 0 for r in replica_id)
class LocalNonpersistentObject:
"""Object that should not be stored in a checkpoint, but restored locally.
Wrapping any object inside the state dict with LocalNonpersistentObject
will result in:
- during saving, this object will *not* be stored in the checkpoint
- during loading, a local version of this object will be placed in a state dict
"""
def __init__(self, obj):
self.obj = obj
def unwrap(self):
"""Returns the original object."""
return self.obj
@dataclass
class ShardedObject(ShardedBase):
"""Represents a mapping between a local object and a global object.
Global object is assumed to consist of many local objects distributed
between different processes.
NOTE: Contrary to ShardedTensor, it's impossible to change global object
sharding. Conceptually, ShardedObject is a fully-sharded ShardedTensor
with atomic arbitrary typed elements.
Args:
key: unique identifier of a global tensor
data: local object data. Can be None only for consistency validation
global_shape: global object shape
global_offset: offset of a local object in a global object, specified in number of shards
replica_id: indicates local object replication wrt. local objects in different processes
"""
key: str
data: object
global_shape: Tuple[int, ...]
global_offset: Tuple[int, ...]
replica_id: ReplicaId = 0
def __post_init__(self):
self.validate_metadata_integrity()
def validate_metadata_integrity(self):
if len(self.global_shape) != len(self.global_offset):
raise CheckpointingException(
f'Global offset dimensions should be equal to global shape dimensions for {self}'
)
def without_data(self):
return replace(self, data=None)
@property
def unique_key(self):
"""returns a unique key for this object"""
return (
f'{self.key}/shard_'
f'{".".join(map(str, self.global_offset))}_'
f'{".".join(map(str, self.global_shape))}'
)
def __str__(self):
return f'{self.__class__.__name__}(key=\'{self.key}\')'
@classmethod
def empty_from_unique_key(cls, unique_key, replica_id: ReplicaId = 0) -> 'ShardedObject':
"""Instantiates a ShardedObject from a unique key.
Args:
unique_key: a string of the form
<key>/shard_<global_offset>_<global_shape>
replica_id: indicates local object replication wrt.
local objects in different processes
Returns:
a ShardedObject with data=None
"""
key, shard_key = unique_key.split('/')
shard_str, offset, shape = shard_key.split('_')
assert shard_str == 'shard'
offset = tuple(map(int, offset.split('.')))
shape = tuple(map(int, shape.split('.')))
if len(shape) + 1 == len(offset):
# This is a backward-compatible fix. We don't know the last
# element of global shape so set it to -1.
shape += (-1,)
return cls(key, None, shape, offset, replica_id)
FactoryBuildFn = Callable[[str, torch.Tensor, ReplicaId, Optional[slice]], ShardedStateDict]
FactoryMergeFn = Callable[[StateDict], torch.Tensor]
@dataclass
class ShardedTensorFactory(ShardedBase):
"""Allows to apply transformations to tensors before/after serialization.
The essence of those transformations is that they can be applied to
optimizer states the same way they are applied to the model params.
The ultimate state dict with sharded tensors must depend functionally on
`build_fn` arguments (key, data, replica_id, flattened_range),
which will be provided by the optimizer.
Builder creates a sub-state-dict out of a tensor before saving, and merger
merges the corresponding state dict after loading.
Args:
key (str): unique identifier of the factory
data (torch.Tensor): original model parameter that will be further
transformed by this factory
build_fn (callable): function that transforms the original tensor
to a sharded state dict
merge_fn (callable): function that transforms loaded subtree back
into a single tensor (inverse of `build_fn`)
replica_id (ReplicaId): indicates factory replication wrt.
factories in different processes
flattened_range (slice, optional): indicates additional flattening
applied to the ShardedTensors produced by the factory
"""
key: str
data: torch.Tensor
build_fn: FactoryBuildFn
merge_fn: FactoryMergeFn
replica_id: ReplicaId = 0
flattened_range: Optional[slice] = None
def build(self):
"""Builds a ShardedStateDict from the original tensor"""
return self.build_fn(self.key, self.data, self.replica_id, self.flattened_range)
def validate_metadata_integrity(self):
"""No reasonable checks can be applied"""
pass
def without_data(self):
return replace(self, data=None)
def apply_factories(sharded_state_dict: ShardedStateDict):
"""Turn ShardedTensorFactories into ShardedTensors *in-place*.
Args:
sharded_state_dict (ShardedStateDict): state dict possibly
containing ShardedTensorFactory objects
Returns:
None: state dict is modified in place
"""
def apply(x):
if isinstance(x, ShardedTensorFactory):
x = x.build()
return x
dict_list_map_inplace(apply, sharded_state_dict)
def apply_factory_merges(
x1: StateDict, x2: ShardedStateDict, key: Tuple[str, ...] = ()
) -> StateDict:
"""Apply merges defined by ShardedTensorFactories *in-place*.
Args:
x1 (StateDict): state dict loaded from the checkpoint
x2 (ShardedStateDict): subset of `x1` (in terms of dict keys)
with ShardedTensorFactory
as (possibly nested) values that define how to
merge objects from the `x1` state dict
key (Tuple[str, ...]): current key in a recursive call.
Used only for reporting meaningful errors
Returns:
StateDict: `x1` modified in-place
"""
if isinstance(x2, ShardedTensorFactory):
return x2.merge_fn(x1)
# There rest is almost the same as the `merge` function from `dict_utils`
if isinstance(x1, dict) and isinstance(x2, dict):
for k, v2 in x2.items():
if k not in x1:
raise ValueError(
f'Different dict keys encountered in `apply_factory_merges` '
f'({x1.keys()} vs {x2.keys()})'
)
else:
x1[k] = apply_factory_merges(x1[k], v2, key=key + (k,))
elif isinstance(x1, list) and isinstance(x2, list):
if len(x1) != len(x2):
err_msg = (
f'Cannot merge two lists with different lengths '
f'({len(x1)} and {len(x2)}, encountered at key {key})'
)
logger.error(err_msg + f'\nx1: {x1}\nx2: {x2}')
raise ValueError(err_msg)
for i, v2 in enumerate(x2):
x1[i] = apply_factory_merges(x1[i], v2, key=key + (i,))
elif isinstance(x1, list) and isinstance(x2, dict):
for k, v2 in x2.items():
if not isinstance(k, int):
raise ValueError(
f'Invalid dict key {k} non-integer type encountered '
f'in a list-dict merge at level {key}'
)
if k >= len(x1):
raise ValueError(
f'Dict key {k} out of bound for list of length'
f'{len(x1)} (encountered at level {key})'
)
x1[k] = apply_factory_merges(x1[k], v2, key=key + (k,))
else:
raise ValueError(
f'Duplicate non-dict and non-list values encountered: `{x1}` and `{x2} (at key {key})`'
)
return x1
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Helpers for defining sharding for optimizer states based on existing sharding
for model parameters.
"""
import logging
from copy import deepcopy
from dataclasses import replace
from typing import Dict, Iterable, Tuple, Union
logger = logging.getLogger(__name__)
import torch
from megatron.core.utils import to_local_if_dtensor
from .dict_utils import nested_values
from .mapping import (
LocalNonpersistentObject,
ShardedStateDict,
ShardedTensor,
ShardedTensorFactory,
StateDict,
)
from .utils import extract_sharded_tensors_and_factories
def get_optim_param_to_id_map(optim_params_iter: Iterable[torch.nn.Parameter]) -> Dict[int, int]:
"""Generate mapping from optimizer param to optimizer state id."""
param_mappings = {}
for i, param in enumerate(optim_params_iter):
param = to_local_if_dtensor(param)
if id(param) not in param_mappings:
param_mappings[id(param)] = i
return param_mappings
def get_param_id_to_sharded_param_map(
model_sharded_state_dict: ShardedStateDict, optim_params_iter: Iterable[torch.nn.Parameter]
) -> Dict[int, Union[ShardedTensor, ShardedTensorFactory]]:
"""Generate mapping from optimizer state ids to model sharded parameters.
Args:
model_sharded_state_dict: sharded state dict with all model sharded tensors
(can have any structure)
optim_params_iter: iterable which iterates over model parameters tracked by the optimizer.
The iteration must be in the same order as in the optimizer parameters.
Returns:
Dict[int, Union[ShardedTensor, ShardedTensorFactory]]: mapping from optimizer state ids
to model sharded parameters.
"""
model_sharded_state_dict, _ = extract_sharded_tensors_and_factories(model_sharded_state_dict)
id_to_sharded_param_map = {}
param_to_id_map = get_optim_param_to_id_map(optim_params_iter)
# If using PyTorch FSDP2 the values in model_sharded_state_dict would
# have been converted to local tensors during initialization.
# See the make_(tp)_sharded_tensor_for_checkpoint functions.
for ten in nested_values(model_sharded_state_dict):
if id(ten.data) in param_to_id_map:
id_to_sharded_param_map[param_to_id_map[id(ten.data)]] = ten
else:
logger.debug(f'{ten} is not tracked by the optimizer')
if not id_to_sharded_param_map:
logger.warning(
"Sharded parameters mapping is empty. It means tensors in model state dict"
" do not correspond to tensors in optimizer parameters map."
" Make sure to call state_dict with `keep_vars=True`."
)
return id_to_sharded_param_map
def make_sharded_optimizer_tensor(
model_param: Union[ShardedTensor, ShardedTensorFactory], optim_param: torch.Tensor, prefix: str
) -> Union[ShardedTensor, ShardedTensorFactory]:
"""Build a ShardedTensor or ShardedTensorFactory for optimizer param based on model param
Args:
model_param (Union[ShardedTensor, ShardedTensorFactory]): model param
optim_param (torch.Tensor): corresponding optimizer param
prefix (str): optimizer prefix for the ShardedTensor or ShardedTensorFactory
Returns:
Union[ShardedTensor, ShardedTensorFactory]: wrapped optimizer parameter
"""
optim_param = to_local_if_dtensor(optim_param)
if isinstance(model_param, ShardedTensorFactory):
return replace(model_param, key=f'{prefix}.{model_param.key}', data=optim_param)
assert tuple(optim_param.shape) == model_param.local_shape, (
f'Optimizer shape ({tuple(optim_param.shape)} does not match model shape '
f'({model_param.local_shape})'
)
sh_ten = replace(
model_param, key=f'{prefix}.{model_param.key}', data=optim_param, dtype=optim_param.dtype
)
sh_ten.validate_metadata_integrity()
return sh_ten
def optim_state_to_sharding_state(
optim_state_dict: StateDict,
id_to_sharded_param_map: Dict[int, ShardedTensor],
exclude_keys: Tuple[str] = (),
):
"""Turn optimizer state dict to sharded state dict based on model state dict *in-place*.
Can be used to add sharding information to most common optimizer state dict.
Creates separate ShardedTensors for each key in `optim_state_dict['state']`
(e.g. for torch.optim.Adam there will be separate tensors for `exp_avg` and `exp_avg_sq`)
Args:
optim_state_dict (StateDict): optimizer state dict with
state parameters under `state` key and group hyperparameters under
`param_groups` -> `params` key.
id_to_sharded_param_map (Dict[int, ShardedTensor]): mapping from optimizer param ids
to model sharded tensors. Can be generated with `get_param_id_to_sharded_param_map`
function.
exclude_keys (Tuple[str]): optimizer state keys to exclude from the final state dict.
Returns:
None: state dict is modified in place
"""
sharded_state = {}
for param_id, param_state in optim_state_dict['state'].items():
sharded_state[param_id] = {}
for state_key, param in param_state.items():
if state_key in exclude_keys:
continue
if param_id in id_to_sharded_param_map:
sharded_state[param_id][state_key] = make_sharded_optimizer_tensor(
id_to_sharded_param_map[param_id], param, prefix=f'optimizer.state.{state_key}'
)
else:
raise ValueError(f'Param id {param_id} does not match any model sharded param')
optim_state_dict['param_groups'] = deepcopy(optim_state_dict['param_groups'])
for group in optim_state_dict['param_groups']:
group['params'] = LocalNonpersistentObject(group['params'])
optim_state_dict['state'] = sharded_state
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Entrypoints for saving and loading the distributed checkpoints.
Functions `load` and `save` are equivalents of `torch.load` and `torch.save`
but expect torch.Tensors to be wrapped with classes from the `mapping module`.
Additionally, `load` expects the sharded state dict argument as a guidance for
loading the sharded tensors.
"""
import logging
from pathlib import Path
from typing import Callable, Dict, Optional, Set, Tuple, Union
import torch
from . import ShardedTensor
from .core import CheckpointingConfig, save_config
from .dict_utils import extract_matching_values, merge
from .mapping import (
CheckpointingException,
CommonStateDict,
ShardedObject,
ShardedStateDict,
StateDict,
apply_factory_merges,
)
from .state_dict_utils import load_preprocess, save_preprocess
from .strategies.async_utils import AsyncRequest
from .strategies.base import (
AsyncSaveShardedStrategy,
LoadCommonStrategy,
LoadShardedStrategy,
SaveCommonStrategy,
SaveShardedStrategy,
StrategyAction,
get_default_strategy,
)
from .utils import extract_sharded_base
from .validation import (
StrictHandling,
determine_global_metadata,
parse_strict_flag,
validate_integrity_and_strict_load,
validate_sharded_objects_handling,
verify_checkpoint_and_load_strategy,
)
logger = logging.getLogger(__name__)
# flat state dict with sharded objects without any data
CkptShardedMetadata = Dict[str, Union[ShardedTensor, ShardedObject]]
def load(
sharded_state_dict: ShardedStateDict,
checkpoint_dir: str,
sharded_strategy: Union[LoadShardedStrategy, Tuple[str, int], None] = None,
common_strategy: Union[LoadCommonStrategy, Tuple[str, int], None] = None,
validate_access_integrity: bool = True,
strict: Union[str, StrictHandling] = StrictHandling.ASSUME_OK_UNEXPECTED,
) -> Union[StateDict, Tuple[StateDict, Set[str], Set[str]]]:
"""Loading entrypoint.
In the steps below, the following verbs refer to corresponding objects:
- load = load from checkpoint
- extract = extract from sharded_state_dict
- add = add to the final state dict
Steps:
1. Load common state dict and form the base of the result state dict
2. Apply factories to sharded_state_dict
3. Extract LocalNonPersistentObject and add
4. (optional) Extract ShardedObjects, load and add
5. Extract ShardedBase, load, apply factory merges and add
Args:
sharded_state_dict (ShardedStateDict): state dict of the existing model
populated with ShardedTensors. Used as a mapping to determine which
parts of global tensors stored in the checkpoint should be loaded.
checkpoint_dir (str): directory with the checkpoint
sharded_strategy (LoadShardedStrategy, Tuple[str, int], optional):
configures loading behavior for sharded tensors
common_strategy (LoadCommonStrategy, Tuple[str, int], optional):
configures loading behavior for common data
validate_access_integrity (bool default = True): checks if each tensor shard is accessed
exactly once (as main replica) by some process
strict (StrictHandling, str, optional): determines the behavior in case of a mismatch
between the requested sharded state dict and the checkpoint. See `StrictHandling` docs
for more details. Some values affect the return value of this function
(missing and unexpected keys are returned).
Defaults to `True` (StrictHandling.ASSUME_OK_UNEXPECTED) which doesn't
incur any performance overhead. Other recommended values
are: `False` (StrictHandling.LOG_UNEXPECTED) which logs only unexpected keys
or `StrictHandling.RETURN_ALL` which returns all mismatch keys.
Returns:
StateDict or Tuple[StateDict, Set[str], Set[str]]: in most cases only
the loaded state dict is returned. If `strict` flag was set to
"""
sharded_strategy, common_strategy = verify_checkpoint_and_load_strategy(
checkpoint_dir, sharded_strategy, common_strategy
)
checkpoint_dir = Path(checkpoint_dir)
common_state_dict = common_strategy.load_common(checkpoint_dir)
sharded_state_dict, nonpersistent_state_dict, sh_ten_factories = load_preprocess(
sharded_state_dict
)
merge(common_state_dict, nonpersistent_state_dict)
# At this point we are only dealing with ShardedBase objects
sharded_state_dict, _ = extract_sharded_base(sharded_state_dict)
# Validation
ckpt_sharded_metadata = None
local_metadata, global_metadata = None, None
strict = parse_strict_flag(strict)
if StrictHandling.requires_explicit_ckpt_mismatch_check(strict):
ckpt_sharded_metadata = load_sharded_metadata(
str(checkpoint_dir), sharded_strategy, common_strategy
)
if validate_access_integrity or StrictHandling.requires_global_app_metadata(strict):
local_metadata, global_metadata = determine_global_metadata(sharded_state_dict)
sharded_state_dict, missing_keys, unexpected_keys = validate_integrity_and_strict_load(
sharded_state_dict,
strict,
validate_access_integrity,
local_metadata,
global_metadata,
ckpt_sharded_metadata,
)
# ShardedBase loading
if not sharded_strategy.can_handle_sharded_objects:
validate_sharded_objects_handling(sharded_strategy, common_strategy)
sharded_objects_state_dict, sharded_state_dict = extract_matching_values(
sharded_state_dict, lambda v: isinstance(v, ShardedObject)
)
sharded_objects = common_strategy.load_sharded_objects(
sharded_objects_state_dict, checkpoint_dir
)
merge(common_state_dict, sharded_objects)
loaded_state_dict = sharded_strategy.load(sharded_state_dict, checkpoint_dir)
merge(common_state_dict, loaded_state_dict)
loaded_state_dict = apply_factory_merges(common_state_dict, sh_ten_factories)
if StrictHandling.requires_returning_mismatch_keys(strict):
return common_state_dict, missing_keys, unexpected_keys
else:
return common_state_dict
def load_common_state_dict(checkpoint_dir: Path) -> StateDict:
"""Load common (non-sharded) objects state dict from the checkpoint.
Args:
checkpoint_dir (Path): checkpoint directory
Returns:
StateDict: state dict with non-sharded objects from the checkpoint
"""
sharded_strategy, common_strategy = verify_checkpoint_and_load_strategy(str(checkpoint_dir))
return common_strategy.load_common(checkpoint_dir)
def load_tensors_metadata(
checkpoint_dir: str, sharded_strategy: Union[LoadShardedStrategy, None] = None
) -> CkptShardedMetadata:
"""Load tensors metadata from the checkpoint.
Returns a dictionary similar to a sharded state dict, but note that
the dictionary keys are simply ShardedTensor keys (contrary to the
actual sharded state dicts where keys correspond to state dict keys).
Dict values are ShardedTensors without any sharding (so, the only useful
information is tensors global shape and dtype).
Concrete implementation depends on the loading strategy. If no strategy is
given, a default for a given backend is used.
Args:
checkpoint_dir (str): checkpoint directory to load from
sharded_strategy (LoadShardedStrategy, optional): sharded strategy to load metadata.
Defaults to None - in this case a default load strategy for a given checkpoint type
is used.
Returns:
CkptShardedMetadata: flat state dict without data describing ShardedTensors
in the checkpoint
"""
sharded_strategy, common_strategy = verify_checkpoint_and_load_strategy(
checkpoint_dir, sharded_strategy
)
return sharded_strategy.load_tensors_metadata(Path(checkpoint_dir))
def load_sharded_metadata(
checkpoint_dir: str,
sharded_strategy: Union[LoadShardedStrategy, None] = None,
common_strategy: Union[LoadCommonStrategy, None] = None,
) -> CkptShardedMetadata:
"""Load sharded metadata from the checkpoint.
Similar to `load_tensors_metadata`, but includes also ShardedObjects.
Returns a dictionary similar to a sharded state dict, but note that
the dictionary keys are simply ShardedTensor keys (contrary to the
actual sharded state dicts where keys correspond to state dict keys).
Dict values are ShardedTensors without any sharding (so, the only useful
information is tensors global shape and dtype).
Concrete implementation depends on the loading strategy. If no strategy is
given, a default for a given backend is used.
Args:
checkpoint_dir (str): checkpoint directory to load from
sharded_strategy (LoadShardedStrategy, optional): sharded strategy to load metadata.
Defaults to None - in this case a default load strategy for a given checkpoint type
is used.
common_strategy (LoadCommonStrategy, optional): common strategy to load metadata.
Defaults to None - in this case a default load strategy for a given checkpoint type is
used. This strategy won't be used unless `sharded_strategy` can't handle ShardedObjects
Returns:
CkptShardedMetadata: flat state dict without data describing ShardedTensors
and ShardedObjects in the checkpoint
"""
sharded_strategy, common_strategy = verify_checkpoint_and_load_strategy(
checkpoint_dir, sharded_strategy, common_strategy
)
sharded_metadata = sharded_strategy.load_sharded_metadata(Path(checkpoint_dir))
if not sharded_strategy.can_handle_sharded_objects:
validate_sharded_objects_handling(sharded_strategy, common_strategy)
common_metadata = common_strategy.load_sharded_metadata(Path(checkpoint_dir))
sharded_metadata = merge(sharded_metadata, common_metadata)
return sharded_metadata
def load_plain_tensors(checkpoint_dir: str) -> StateDict:
"""Load checkpoint tensors without any sharding and plain structure.
NOTE: common state dict is NOT included.
Args:
checkpoint_dir (str): checkpoint directory to load the tensors from.
Returns:
StateDict: checkpoint state dict containing only torch.Tensors.
"""
sharded_state_dict = load_tensors_metadata(checkpoint_dir)
# Don't validate integrity because shards will be overlapped
# if world_size > 1 (all processes load whole tensors)
return load(sharded_state_dict, checkpoint_dir, validate_access_integrity=False)
#
# def load_plain_tensors_and_objects(checkpoint_dir: str) -> StateDict:
# """Load checkpoint tensors and objects without any sharding and plain structure.
#
# NOTE: state dict structure might be different than the one used for checkpoint saving.
# NOTE: common state dict is NOT included.
#
# Args:
# checkpoint_dir (str): checkpoint directory to load the state dict from.
#
# Returns:
# StateDict: complete checkpoint state dict without any sharding.
# """
# sharded_state_dict = load_tensors_metadata(checkpoint_dir)
# # Don't validate integrity because shards will be overlapped
# # if world_size > 1 (all processes load whole tensors)
# return load(sharded_state_dict, checkpoint_dir, validate_access_integrity=False)
def remove_sharded_tensors(checkpoint_dir: str, key_prefix: str):
"""determine the appropriate sharding strategy and delegate removal to the sharded strategy"""
sharded_strategy, common_strategy = verify_checkpoint_and_load_strategy(checkpoint_dir)
sharded_strategy.remove_sharded_tensors(checkpoint_dir, key_prefix)
def save(
sharded_state_dict: ShardedStateDict,
checkpoint_dir: str,
sharded_strategy: Union[SaveShardedStrategy, Tuple[str, int], None] = None,
common_strategy: Union[SaveCommonStrategy, Tuple[str, int], None] = None,
validate_access_integrity: bool = True,
async_sharded_save: bool = False,
preprocess_common_before_consistancy_check: Callable[[CommonStateDict], StateDict] = None,
) -> Optional[AsyncRequest]:
"""Saving entrypoint.
Extracts ShardedTensors from the given state dict. Rank 0 saves the
"regular" part of the checkpoint to common torch file.
The ShardedTensors are saved according to a strategy specified by the
config.
Steps:
1. Apply factories
2. Extract and discard LocalNonPersistentObject
3. Extract all ShardedBase object
4. Save all other objects to common.pt
5. (optional) Extract and save ShardedObjects
6. Save all ShardedBase objects
7. Write metadata.json file with backend and version metadata.
Step (6) can be performed asynchronously (see `async_sharded_save`), in this
case the actual save is embodied in the returned async request and can be
scheduled by the external caller. For async request, step (7) is added as
one of the finalization functions, so that metadata.json is written only
if the checkpoint is complete.
Args:
sharded_state_dict (ShardedStateDict): state dict of the populated with
ShardedTensors. Used as a mapping to determine how local tensors
should be saved as global tensors in the checkpoint.
checkpoint_dir (str): directory to save the checkpoint to
sharded_strategy (SaveShardedStrategy, Tuple[str, int], optional):
configures sharded tensors saving behavior and backend
common_strategy (SaveCommonStrategy, Tuple[str, int], optional):
configures common data saving behavior and backend
validate_access_integrity (bool default = True): checks if each tensor shard is accessed
exactly once (as main replica) by some process.
It also makes sure the common state dict is consistant across all ranks
async_sharded_save (bool, optional): if True, for the sharded state dict part
an async save implementation will be called, with the AsyncRequest
being returned to the caller. Note that it is the caller responsibility to
actually schedule the async save. Defaults to False.
preprocess_common_before_consistancy_check (Callable[[CommonStateDict], StateDict], None):
A callable function that will preprocess the common state dict (i.e can be used to
remove keys that we expect to be different in the state dict). The function must not
modify the original state dict
Returns:
AsyncRequest (optional): if `async_sharded_save` is True, returns
async request that should be scheduled by the caller of this function.
None otherwise.
"""
checkpoint_dir = Path(checkpoint_dir)
if torch.distributed.get_rank() == 0:
if not checkpoint_dir.exists():
raise CheckpointingException(
f'Checkpoint destination directory does not exist: {checkpoint_dir}'
)
if next(checkpoint_dir.iterdir(), None) is not None:
# Don't throw exception here since this could cause a cascade of failures
# without human intervention in cases where multiple jobs are queued up.
if torch.distributed.get_rank() == 0:
logger.warning("Overwriting old incomplete / corrupted checkpoint...")
if common_strategy is not None:
raise NotImplementedError('The only supported common strategy is torch')
if sharded_strategy is None:
sharded_strategy = get_default_save_sharded_strategy()
if not isinstance(sharded_strategy, SaveShardedStrategy):
assert isinstance(sharded_strategy, tuple), type(sharded_strategy)
sharded_strategy = get_default_strategy(StrategyAction.SAVE_SHARDED, *sharded_strategy)
if common_strategy is None:
common_strategy = get_default_save_common_strategy()
if not isinstance(common_strategy, SaveCommonStrategy):
assert isinstance(common_strategy, tuple), type(common_strategy)
common_strategy = get_default_strategy(StrategyAction.SAVE_COMMON, *common_strategy)
sharded_state_dict, state_dict = save_preprocess(
sharded_state_dict, validate_access_integrity, preprocess_common_before_consistancy_check
)
common_strategy.save_common(state_dict, checkpoint_dir)
if not sharded_strategy.can_handle_sharded_objects:
validate_sharded_objects_handling(sharded_strategy, common_strategy)
sharded_objects_state_dict, sharded_state_dict = extract_matching_values(
sharded_state_dict, lambda v: isinstance(v, ShardedObject)
)
common_strategy.save_sharded_objects(sharded_objects_state_dict, checkpoint_dir)
def metadata_finalize_fn():
if torch.distributed.get_rank() == 0:
save_config(
CheckpointingConfig(sharded_strategy.backend, sharded_strategy.version),
checkpoint_dir,
)
torch.distributed.barrier()
if not async_sharded_save:
sharded_strategy.save(sharded_state_dict, checkpoint_dir)
metadata_finalize_fn()
return
if not isinstance(sharded_strategy, AsyncSaveShardedStrategy):
raise CheckpointingException(
f'Cannot apply async_save to non-async strategy {sharded_strategy}'
)
async_request = sharded_strategy.async_save(sharded_state_dict, checkpoint_dir)
async_request.finalize_fns.append(metadata_finalize_fn)
return async_request
def get_default_save_sharded_strategy(
backend: str = 'torch_dist', version: int = 1
) -> SaveShardedStrategy:
"""Get default save sharded strategy."""
return get_default_strategy(StrategyAction.SAVE_SHARDED, backend, version)
def get_default_save_common_strategy(
backend: str = 'torch', version: int = 1
) -> SaveCommonStrategy:
"""Get default save common strategy."""
return get_default_strategy(StrategyAction.SAVE_COMMON, backend, version)
def get_default_load_sharded_strategy(checkpoint_dir: str) -> LoadShardedStrategy:
"""Get default load sharded strategy."""
return verify_checkpoint_and_load_strategy(checkpoint_dir)[0]
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Utilities for transforming state_dict."""
from typing import Callable, Union
from .dict_utils import dict_list_map_inplace, extract_matching_values
from .mapping import (
CommonStateDict,
ShardedStateDict,
ShardedTensor,
ShardedTensorFactory,
StateDict,
apply_factories,
)
from .utils import extract_nonpersistent, extract_sharded_base
from .validation import determine_global_metadata, validate_sharding_integrity
def save_preprocess(
sharded_state_dict: ShardedStateDict,
validate_access_integrity: bool = True,
preprocess_common_before_consistancy_check: Callable[[CommonStateDict], StateDict] = None,
):
"""Preprocesses the given state dictionary by applying factories,
discarding non-persistent data and extracting the common state dictionary.
Optionally, it can validate sharding integrity.
Args:
sharded_state_dict (ShardedStateDict): The initial state dictionary to be preprocessed.
validate_access_integrity (bool): If True, triggers validation of sharding integrity.
preprocess_common_before_consistancy_check (callable, None): A callable function
that will preprocess the common state dict (i.e can be used to remove keys
that we expect to be different in the state dict)
Returns:
Tuple[ShardedStateDict, dict]:
The preprocessed sharded state dictionary and the common state dictionary.
"""
apply_factories(sharded_state_dict)
_, sharded_state_dict = extract_nonpersistent(sharded_state_dict)
sharded_part, common_state_dict = extract_sharded_base(sharded_state_dict)
sharded_part = filter_out_empty_flatten_tensor(sharded_part)
if validate_access_integrity:
preprocessed_common_state_dict = common_state_dict
if preprocess_common_before_consistancy_check:
preprocessed_common_state_dict = preprocess_common_before_consistancy_check(
common_state_dict
)
validate_sharding_integrity(
determine_global_metadata(sharded_part)[1],
common_state_dict=preprocessed_common_state_dict,
)
return sharded_part, common_state_dict
def load_preprocess(sharded_state_dict: ShardedStateDict):
"""Preprocesses the given state dictionary by applying factories
and extracting non-persistent data, without modifying the original dictionary.
Args:
sharded_state_dict (ShardedStateDict):
The initial state dictionary to be processed (remains unchanged).
Returns:
Tuple[ShardedStateDict, dict, dict]:
- A preprocessed copy of the sharded state dictionary.
- A dictionary containing non-persistent state data.
- A dictionary of `ShardedTensorFactory` instances.
"""
# Create a copy of sharded_state_dict as the passed in state dict may have
# references that prevent tensors from being deallocated
sharded_state_dict, _ = extract_matching_values(sharded_state_dict, lambda x: True)
sharded_state_dict = filter_out_empty_flatten_tensor(sharded_state_dict)
sh_ten_factories, _ = extract_matching_values(
sharded_state_dict,
lambda x: isinstance(x, ShardedTensorFactory),
return_lists_as_dicts=True,
)
apply_factories(sharded_state_dict)
# Data inside sh_ten_factories no longer needed so delete them to reduce memory usage
dict_list_map_inplace(ShardedTensorFactory.without_data, sh_ten_factories)
# Non-persistent objects
nonpersistent_state_dict, sharded_state_dict = extract_nonpersistent(sharded_state_dict)
dict_list_map_inplace(lambda o: o.unwrap(), nonpersistent_state_dict)
return sharded_state_dict, nonpersistent_state_dict, sh_ten_factories
def filter_out_empty_flatten_tensor(sharded_state_dict: Union[dict, list]):
"""
Filter out ShardedTensors with empty flatten_range.
These tensors can cause the PyTorch check in failure.
Args:
sharded_state_dict: state dict possibly containing ShardedTensor objects
"""
# Filter out ShardedTensors with empty flatten_range.
# These tensors can cause the PyTorch check in
# `TorchShardedTensor._init_from_local_shards_and_global_metadata` to fail.
# This situation may occur in custom Fully Sharded Data Parallel (FSDP) cases.
sharded_state_dict, _ = extract_matching_values(
sharded_state_dict,
lambda v: not (
isinstance(v, ShardedTensor)
and v.flattened_range
and v.flattened_range.start == v.flattened_range.stop
),
)
return sharded_state_dict
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Various loading and saving strategies """
from megatron.core.dist_checkpointing.strategies.common import register_default_common_strategies
# We load "common" strategies by default to be always available
register_default_common_strategies()
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""
This module provides an async utilities which allow to start
a checkpoint save process in the background.
"""
import gc
import logging
from abc import ABC, abstractmethod
from collections import deque
from contextlib import contextmanager
from queue import Empty
from time import sleep, time
from typing import Callable, Dict, List, NamedTuple, Optional, Tuple
import torch
from torch import multiprocessing as mp
from ..utils import debug_time
logger = logging.getLogger(__name__)
@contextmanager
def _disable_gc():
"""Temporarily disables GC."""
gc_enabled = gc.isenabled()
try:
if gc_enabled:
gc.disable()
yield
finally:
if gc_enabled:
gc.enable()
class AsyncRequest(NamedTuple):
"""Represents an async request that needs to be scheduled for execution.
Args:
async_fn (Callable, optional): async function to call. None represents noop.
async_fn_args (Tuple): args to pass to `async_fn`.
finalize_fns (List[Callable]): list of functions to call to finalize the request.
These functions will be called synchronously after `async_fn` is done
*on all ranks*.
async_fn_kwargs (Tuple): kwargs to pass to `async_fn`.
preload_fn (Callable): preload function to stage tensors from GPU to Host.
This should be self-contained with a proper list of arguments with `partial`.
is_frozen (Bool): a flag to indicate this async request can be modified or not.
call_idx (int): index variable used to order async requests for synchronization
in preloading and writing tensors on the async caller
"""
async_fn: Optional[Callable]
async_fn_args: Tuple
finalize_fns: List[Callable]
async_fn_kwargs: Dict = {}
preload_fn: Callable = None
is_frozen: bool = False
call_idx: int = 0
def add_finalize_fn(self, fn: Callable) -> None:
"""Adds a new finalize function to the request.
Args:
fn (Callable): function to add to the async request. This function
will be called *after* existing finalization functions.
Returns:
None
"""
if self.is_frozen:
raise RuntimeError('Cannot add finalization functions to a frozen AsyncRequest')
self.finalize_fns.append(fn)
def execute_sync(self) -> None:
"""Helper to synchronously execute the request.
This logic is equivalent to what should happen in case of the async call.
"""
if self.async_fn is not None:
self.async_fn(*self.async_fn_args)
torch.distributed.barrier()
for finalize_fn in self.finalize_fns:
finalize_fn()
def freeze(self) -> 'AsyncRequest':
"""Freezes the async request, disallowing adding new finalization functions.
Returns:
AsyncRequest: new async request with all same fields except for the
`is_frozen` flag.
"""
return self._replace(is_frozen=True)
class AsyncCaller(ABC):
"""Wrapper around mp.Process that ensures correct semantic of distributed finalization.
Starts process asynchronously and allows checking if all processes on all ranks are done.
"""
@abstractmethod
def schedule_async_call(self, async_req: AsyncRequest) -> None:
"""Schedule `async_req` with some process forking or reusing
persistent worker
This method must be called on all ranks.
Args:
async_req (AsyncRequest): `AsyncRequest` object containing to
start async process
"""
raise NotImplementedError("This should be implemented")
@abstractmethod
def is_current_async_call_done(self, blocking: bool, no_dist: bool) -> bool:
"""Check if async save is finished on all ranks.
For semantic correctness, requires rank synchronization in each check.
This method must be called on all ranks.
Args:
blocking (bool, optional): if True, will wait until the call is done
on all ranks. Otherwise, returns immediately if at least one rank
is still active. Defaults to False.
no_dist (bool, Optional): if True, training ranks simply check its
asynchronous checkpoint writer without synchronization.
Returns:
bool: True if all ranks are done (immediately of after active wait
if `blocking` is True), False if at least one rank is still active.
"""
raise NotImplementedError("This should be implemented")
def sync_all_async_calls(self, is_alive: int) -> bool:
"""Check if all ranks have completed async checkpoint writing
Args:
is_alive (bool): if True, the current async request is not completed
Returns:
bool: True if all ranks are done, False if at least one rank is still active.
"""
ten = torch.tensor([is_alive], dtype=torch.int, device=torch.cuda.current_device())
torch.distributed.all_reduce(ten)
return ten[0] == 0
@abstractmethod
def close(self):
"""Terminate the async caller at exit of an application or some termination conditions"""
logger.info(f"AsyncCaller: {torch.distributed.get_rank()}, Destroying Async Caller")
def __del__(self):
raise NotImplementedError("This should be implemented")
class TemporalAsyncCaller(AsyncCaller):
"""Wrapper around mp.Process that ensures correct semantic of distributed finalization.
Starts process asynchronously and allows checking if all processes on all ranks are done.
"""
def __init__(self):
self.process: Optional[mp.Process] = None
self.start_time: Optional[float] = None
@_disable_gc()
def schedule_async_call(self, async_req: AsyncRequest) -> None:
"""Spawn a process with `async_fn` as the target.
This method must be called on all ranks.
Args:
async_fn (Callable, optional): async function to call. If None,
no process will be started.
async_req (AsyncRequest): `AsyncRequest` object containing to
start async process
"""
if async_req.async_fn is None:
return # nothing to do
async_fn_args = list(async_req.async_fn_args)
if async_req.preload_fn:
# If there's a preload_fn in `async_req`, we call this func
# to do the defined action in `async_req.preload_fn` to
# stage GPU tensors to its defined destination
async_fn_args[1] = async_req.preload_fn()
rank = torch.distributed.get_rank()
start_sync = time()
torch.cuda.synchronize()
end_sync = time()
logger.debug(f"rank: {rank}, takes {end_sync - start_sync} to finish D2H ")
ctx = mp.get_context('fork')
self.start_time = time()
self.process = ctx.Process(
target=async_req.async_fn, args=async_fn_args, kwargs=async_req.async_fn_kwargs
)
self.process.start()
init_time = time()
logger.debug(f"rank: {rank}, takes {init_time - self.start_time} to schedule async ckpt ")
def is_current_async_call_done(self, blocking: bool = False, no_dist: bool = False) -> bool:
"""Check if async save is finished on all ranks.
For semantic correctness, requires rank synchronization in each check.
This method must be called on all ranks.
Args:
blocking (bool, optional): if True, will wait until the call is done
on all ranks. Otherwise, returns immediately if at least one rank
is still active. Defaults to False.
no_dist (bool, Optional): if True, training ranks simply check its
asynchronous checkpoint writer without synchronization.
Returns:
bool: True if all ranks are done (immediately of after active wait
if `blocking` is True), False if at least one rank is still active.
"""
# The following takes the same overhead
# as torch.distributed.barrier (single integer all-reduce)
is_alive = int(self.process.is_alive()) if self.process is not None else 0
is_done = not is_alive if no_dist else self.sync_all_async_calls(is_alive)
if is_done or blocking:
# Process join is called in the following cases
# 1. blocking == True -> regardless of is_done
# 2. blocking == False (non-blocking)
# -> is_done == True: async requests on all ranks are identified to be finished
# `self.close()` makes sure the async callers terminated
self.close()
is_done = True
return is_done
def close(self):
"""For TemporalAsyncCaller, this method is called explictly in `is_current_async_calls_done`
This method make sure the TemporalAsyncCaller terminated
with all its assigned async request completed
"""
if self.process:
logger.debug(f"rank: {torch.distributed.get_rank()}, joining self.process")
self.process.join()
self.process = None
logger.debug(
"TemporalAsyncCaller: Async process join finished "
f"after {time() - self.start_time:.2f}s from forking"
)
self.start_time = None
def __del__(self):
pass
class PersistentAsyncCaller(AsyncCaller):
"""Wrapper around mp.Process that ensures correct semantic of distributed finalization.
Starts process asynchronously and allows checking if all processes on all ranks are done.
"""
def __init__(self):
self.process: mp.Process = None
self.start_time: Optional[float] = None
ctx = mp.get_context('spawn')
# main queue to deliver `AsyncRequest` from host to the ckpt worker
self.queue: mp.JoinableQueue = ctx.JoinableQueue()
# Queue used to synchronize for the completion of preloading tensors to host
# between a trainer and ckpt worker
self.preload_q: mp.JoinableQueue = ctx.JoinableQueue()
# Queue used to inform trainer when the saving is completed
self.comp_q: mp.Queue = ctx.Queue()
self.cur_item: int = None
self.cur_idx: int = -1
def schedule_async_call(self, async_req: AsyncRequest) -> None:
"""Put `AsyncRequest` to the Persistent Async Caller
This method must be called on all ranks.
Args:
async_fn (Callable, optional): async function to call. If None,
no process will be started.
async_req (AsyncRequest): `AsyncRequest` object containing to
schedule a checkpointing request
"""
if async_req.async_fn is None:
return # nothing to do
start_sync = end_sync = None
self.start_time = time()
if self.process is None:
ctx = mp.get_context('spawn')
logger.info(
f"PersistentAsyncCaller: {torch.distributed.get_rank()}, Starting Async Caller"
)
self.process: mp.Process = ctx.Process(
target=PersistentAsyncCaller.async_loop,
args=(
torch.distributed.get_rank(),
self.queue,
self.preload_q,
self.comp_q,
logger.getEffectiveLevel(),
),
)
self.process.start()
logger.info(
f"PersistentAsyncCaller: {torch.distributed.get_rank()}, Started Async Caller"
)
if async_req.preload_fn:
self.preload_q.put(async_req.call_idx)
self.queue.put(async_req)
logger.debug(f"rank: {torch.distributed.get_rank()}, put {async_req.call_idx}")
if async_req.preload_fn:
start_sync = time()
# Synchronize for pre-staging tensors
self.preload_q.join()
end_sync = time()
logger.debug(
f"rank: {torch.distributed.get_rank()}, "
f"takes {end_sync - start_sync} to finish D2H "
)
init_time = time()
logger.debug(
f"rank: {torch.distributed.get_rank()}, takes {init_time - self.start_time} "
"to schedule async ckpt "
)
def is_current_async_call_done(self, blocking: bool = False, no_dist: bool = False) -> bool:
"""Check if async save is finished on all ranks.
For semantic correctness, requires rank synchronization in each check.
This method must be called on all ranks.
Args:
blocking (bool, optional): if True, will wait until the call is done
on all ranks. Otherwise, returns immediately if at least one rank
is still active. Defaults to False.
no_dist (bool, Optional): if True, training ranks simply check its
asynchronous checkpoint writer without synchronization.
Returns:
bool: True if all ranks are done (immediately of after active wait
if `blocking` is True), False if at least one rank is still active.
"""
is_alive: bool = False
if self.process:
while self.cur_item is None:
try:
# Retrieve comp call_idx without waiting
self.cur_item = self.comp_q.get_nowait()
except Empty:
# This method is called after any `AsyncRequest` is pushed to the main loop
# So, the background writing is still active
# before the worker put call_idx to `comp_q`
if not blocking:
is_alive = True
break
sleep(0.1)
if self.cur_item is not None:
logger.debug(
f"rank: {torch.distributed.get_rank()}, item: {self.cur_item}"
f" is completed, {is_alive}"
)
is_done = not is_alive if no_dist else self.sync_all_async_calls(is_alive)
# This is set to False when blocking == False so this routine is called again
# to simply call `sync_all_async_calls` to check if other ranks complete the writing
if is_done:
# The current request is completed globally. Reset the current item for polling.
logger.debug(
f"rank: {torch.distributed.get_rank()}, item: {self.cur_item}"
f" is completed globally, {is_done}"
)
self.cur_item = None
return is_done
def close(self):
"""Wait on the left async requests and terminate the PersistentAsyncCaller
Signals the PersistentAsyncCaller by sending a 'DONE' message to make it terminated
"""
logger.info(
f"PersistentAsyncCaller: {torch.distributed.get_rank()}, Destroying Async Caller"
)
if self.process:
self.queue.put('DONE')
self.queue.join()
self.process.join()
self.process = None
def __del__(self):
self.close()
@staticmethod
@_disable_gc()
def async_loop(
rank: int,
queue: mp.JoinableQueue,
preload_q: mp.JoinableQueue,
comp_q: mp.Queue,
log_level: int = logging.INFO,
):
"""Main function for the persistent checkpoint worker
The persisent worker is created once and terminated at exit or
when application calls `close()` explictily
This routine receives `AsyncRequest` and does `preload_fn` first and
put the integer value in `preload_q` to inform the trainer to proceed.
When the `async_fn` from the request` is completed (background saving is done),
it puts a integer value to `comp_q` to notify the trainer the completion.
Args:
rank (int): the rank of the trainer where the persistent worker is created.
queue (mp.JoinableQueue): the main queue used to receive `AsyncRequest
from the training rank
preload_q (mp.JoinableQueue): a queue to inform trainer that preloading of tensors
from GPU to Host or dedicated location is completed
comp_q (mp.Queue): a queue to inform the training rank the completion of scheduled
async checkpoint request
log_level (int, Optional): an integer to set log-level in this spawned process
to get aligned with the training rank's logging level
"""
logger = logging.getLogger(__name__)
logger.setLevel(log_level)
logger.info(f"PersistentAsyncCaller: persistent ckpt worker for {rank} has started")
while True:
item = queue.get()
if isinstance(item, str) and item == 'DONE':
queue.task_done()
break
elif isinstance(item, AsyncRequest):
async_fn_args = list(item.async_fn_args)
if item.preload_fn:
call_idx = preload_q.get()
# the 2nd arg is state dict
async_fn_args[1] = item.preload_fn()
logger.debug(f"{rank} has completed D2H of {call_idx}")
preload_q.task_done()
item.async_fn(*async_fn_args, **item.async_fn_kwargs)
logger.debug(f"{rank} has completed saving {item.call_idx}")
comp_q.put(item.call_idx)
queue.task_done()
logger.info(f"PersistentAsyncCaller: persistent ckpt worker for {rank} has terminated")
class _ActiveAsyncRequest(NamedTuple):
"""Helper to represent an active async call.
Args:
idx (int): index of the call (starting from 0)
async_caller (DistributedAsyncCaller): async caller instance that represents
the async process handling the async request
async_request (AsyncRequest): async request that is being called
"""
idx: int
async_caller: AsyncCaller
async_request: AsyncRequest
class AsyncCallsQueue:
"""Manages a queue of async calls.
Allows adding a new async call with `schedule_async_request` and finalizing
active calls with `maybe_finalize_async_calls`.
"""
def __init__(self, persistent: bool = False):
self.async_calls: deque[_ActiveAsyncRequest] = deque([])
self.call_idx: int = -1
self.persistent: bool = persistent
self.persistent_caller: AsyncCaller = None
def _get_async_caller(self):
if not self.persistent:
return TemporalAsyncCaller()
if self.persistent_caller is None:
self.persistent_caller = PersistentAsyncCaller()
return self.persistent_caller
def schedule_async_request(self, async_request: AsyncRequest) -> int:
"""Start a new async call and add it to a queue of active async calls.
This method must be called on all ranks.
Args:
async_request (AsyncRequest): async request to start.
Returns:
int: index of the async call that was started.
This can help the user keep track of the async calls.
"""
self.call_idx += 1
async_caller = self._get_async_caller()
# Backward compatibility for local checkpointing built with the old AsyncRequest
if len(async_request._fields) != len(AsyncRequest._fields):
async_request = AsyncRequest(**async_request._asdict())
async_request = async_request.freeze()
async_caller.schedule_async_call(
async_request._replace(call_idx=self.call_idx, finalize_fns=[])
)
self.async_calls.append(_ActiveAsyncRequest(self.call_idx, async_caller, async_request))
return self.call_idx
def maybe_finalize_async_calls(self, blocking=False, no_dist=False) -> List[int]:
"""Finalizes all available calls.
This method must be called on all ranks.
Args:
blocking (bool, optional): if True, will wait until all active requests
are done. Otherwise, finalizes only the async request that already
finished. Defaults to False.
Returns:
List[int]: list of indices (as returned by `schedule_async_request`)
of async calls that have been successfully finalized.
"""
call_idx_finalized = []
while self.async_calls:
next_async_done = self.async_calls[0].async_caller.is_current_async_call_done(
blocking, no_dist
)
if not next_async_done:
break
with debug_time("finalize", logger):
call_idx, _, async_request = self.async_calls.popleft()
for finalize_fn in async_request.finalize_fns:
finalize_fn()
ten = torch.tensor([call_idx], dtype=torch.int, device=torch.cuda.current_device())
torch.distributed.all_reduce(ten, op=torch.distributed.ReduceOp.MAX)
assert ten.item() == call_idx, 'Unmatched async calls. '
'That probably means not all ranks are participating in async finalization'
call_idx_finalized.append(call_idx)
return call_idx_finalized
def get_num_unfinalized_calls(self):
"""Get the number of active async calls."""
return len(self.async_calls)
def close(self):
"""Finalize all calls upon closing."""
self.maybe_finalize_async_calls(blocking=True)
if self.persistent and self.persistent_caller:
self.persistent_caller.close()
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Strategies base interfaces. """
from abc import ABC, abstractmethod
from collections import defaultdict
from enum import Enum
from pathlib import Path
from typing import Any, DefaultDict, Union
from ..mapping import CheckpointingException, ShardedStateDict, StateDict
from .async_utils import AsyncCallsQueue, AsyncRequest
class StrategyAction(Enum):
"""Specifies save vs load and sharded vs common action."""
LOAD_COMMON = 'load_common'
LOAD_SHARDED = 'load_sharded'
SAVE_COMMON = 'save_common'
SAVE_SHARDED = 'save_sharded'
default_strategies: DefaultDict[str, dict[tuple, Any]] = defaultdict(dict)
async_calls = AsyncCallsQueue()
def get_default_strategy(action: StrategyAction, backend: str, version: int):
"""Retrieves a default strategy for a given action, backend and version."""
error_hint: str = None
try:
if backend == 'zarr':
error_hint = ' Please install `zarr` and `tensorstore!=0.1.46` packages'
from .tensorstore import register_default_tensorstore_strategies
register_default_tensorstore_strategies()
from .zarr import register_default_zarr_strategies
register_default_zarr_strategies()
elif backend == 'torch_dist':
error_hint = ' Please use PyTorch version >=2.1'
from .torch import register_default_torch_strategies
register_default_torch_strategies()
except ImportError as e:
raise CheckpointingException(
f'Cannot import a default strategy for: {(action.value, backend, version)}. '
f'Error: {e}. Hint: {error_hint}'
) from e
try:
return default_strategies[action.value][(backend, version)]
except KeyError as e:
raise CheckpointingException(
f'Cannot find a default strategy for: {(action.value, backend, version)}'
) from e
def register_default_strategy(
action: StrategyAction,
backend: str,
version: int,
strategy: Union['SaveStrategyBase', 'LoadStrategyBase'],
):
"""Adds a given strategy to the registry of default strategies.
Args:
action (StrategyAction): specifies save/load and sharded/common
backend (str): backend that the strategy becomes a default for
version (int): version that the strategy becomes a default for
strategy (SaveStrategyBase, LoadStrategyBase): strategy to register
"""
default_strategies[action.value][(backend, version)] = strategy
class LoadStrategyBase(ABC):
"""Base class for a load strategy. Requires implementing checks for compatibility with a
given checkpoint version."""
@abstractmethod
def check_backend_compatibility(self, loaded_backend):
"""Verifies if this strategy is compatible with `loaded_backend`."""
raise NotImplementedError
@abstractmethod
def check_version_compatibility(self, loaded_version):
"""Verifies if this strategy is compatible with `loaded_version`."""
raise NotImplementedError
@property
def can_handle_sharded_objects(self):
"""Returns whether or not this strategy can handle loading ShardedObjects."""
return False
class SaveStrategyBase(ABC):
"""Base class for a save strategy. Requires defining a backend type and
version of the saved format."""
def __init__(self, backend: str, version: int):
self.backend = backend
self.version = version
@property
def can_handle_sharded_objects(self):
"""Returns whether or not this strategy can handle saving ShardedObjects."""
return False
def __str__(self):
return f'{self.__class__.__name__}({self.backend}, {self.version})'
class LoadCommonStrategy(LoadStrategyBase):
"""Load strategy for common (non-sharded) objects"""
@abstractmethod
def load_common(self, checkpoint_dir: Path):
"""Load common part of the checkpoint."""
raise NotImplementedError
@abstractmethod
def load_sharded_objects(
self, sharded_objects_state_dict: ShardedStateDict, checkpoint_dir: Path
):
"""Load sharded objects from the checkpoint."""
raise NotImplementedError
def load_sharded_metadata(self, checkpoint_dir: Path) -> ShardedStateDict:
"""Load just the metadata from the checkpoint."""
if not self.can_handle_sharded_objects:
return {}
raise NotImplementedError
class LoadShardedStrategy(LoadStrategyBase):
"""Load strategy for sharded tensors"""
@abstractmethod
def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path):
"""Load the sharded part of the checkpoint."""
raise NotImplementedError
@abstractmethod
def load_tensors_metadata(self, checkpoint_dir: Path):
"""Load tensors metadata from the checkpoint for ShardedTensors.
Returns a dictionary similar to a sharded state dict, but note that
the dictionary keys are simply ShardedTensor keys (contrary to the
actual sharded state dicts where keys correspond to state dict keys).
Dict values are ShardedTensors without any data and sharding (so, the
only useful information is tensors global shape and dtype).
"""
raise NotImplementedError(
f'Loading only tensors metadata not implemented for {self.__class__.__name__}'
)
def load_sharded_metadata(self, checkpoint_dir: Path):
"""Load sharded metadata from the checkpoint for ShardedTensors and ShardedObjects.
Returns a dictionary similar to a sharded state dict, but note that
the dictionary keys are simply sharded keys (contrary to the
actual sharded state dicts where keys correspond to state dict keys).
Dict values are ShardedTensors or ShardedObjects without any data and sharding.
"""
if not self.can_handle_sharded_objects:
return self.load_tensors_metadata(checkpoint_dir)
raise NotImplementedError(
f'Loading only sharded metadata not implemented for {self.__class__.__name__}'
)
def remove_sharded_tensors(self, checkpoint_dir: str, key_prefix: str):
"""Remove all tensors whose key starts with key_prefix"""
raise NotImplementedError
class SaveCommonStrategy(SaveStrategyBase):
"""Save strategy for common (non-sharded) objects"""
@abstractmethod
def save_common(self, common_state_dict: StateDict, checkpoint_dir: Path):
"""Save common part of the state dict."""
raise NotImplementedError
def save_sharded_objects(
self, sharded_objects_state_dict: ShardedStateDict, checkpoint_dir: Path
):
"""Save sharded objects from the state dict."""
raise NotImplementedError
class SaveShardedStrategy(SaveStrategyBase):
"""Save strategy for sharded tensors"""
@abstractmethod
def save(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path):
"""Save the sharded part of the state dict."""
raise NotImplementedError
class AsyncSaveShardedStrategy(SaveShardedStrategy):
"""Save strategy suitable for async save."""
@abstractmethod
def async_save(
self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path
) -> AsyncRequest:
"""Perform preparation and return an AsyncRequest to the external caller.
Args:
sharded_state_dict (ShardedStateDict): sharded state dict to save
checkpoint_dir (Path): checkpoint target directory
Returns:
AsyncRequest: represents the async save function and finalization function.
It is the caller responsibility to actually schedule the async save.
"""
raise NotImplementedError
def save(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path):
"""Each async strategy can be trivially used as a sync strategy."""
async_request = self.async_save(sharded_state_dict, checkpoint_dir)
# multiprocessing routines may cause issue when called on parent process
# We keep this verbose call for now
global async_calls
async_calls.schedule_async_request(async_request)
async_calls.maybe_finalize_async_calls(blocking=True)
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
""" FS Reader with metadata cached support. """
import os
from typing import Union
from torch.distributed.checkpoint import FileSystemReader, Metadata
class CachedMetadataFileSystemReader(FileSystemReader):
"""
Extends FileSystemReader to cache metadata for improved performance.
Attributes:
_cached_metadata (Metadata or None): Cached metadata from the file system.
"""
def __init__(self, path: Union[str, os.PathLike]) -> None:
"""
Initialize with file system path.
Args:
path (Union[str, os.PathLike]): Path to the checkpoint directory or file.
"""
super().__init__(path=path)
self._cached_metadata = None
def read_metadata(self) -> Metadata:
"""
Read metadata from file system, caching for subsequent calls.
Returns:
Metadata: Checkpoint metadata.
"""
if self._cached_metadata is None:
self._cached_metadata = super().read_metadata()
return self._cached_metadata
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
""" Common strategies. """
import logging
import os
from pathlib import Path
import torch
from megatron.core.dist_checkpointing.mapping import ShardedStateDict, StateDict
from megatron.core.dist_checkpointing.strategies.base import (
SaveCommonStrategy,
StrategyAction,
register_default_strategy,
)
from ..dict_utils import dict_list_map_inplace, nested_values
from ..mapping import CheckpointingException, ShardedObject, is_main_replica
from ..strategies.base import LoadCommonStrategy
COMMON_STATE_FNAME = 'common.pt'
logger = logging.getLogger(__name__)
def register_default_common_strategies():
"""Register default common strategies."""
register_default_strategy(StrategyAction.LOAD_COMMON, 'torch', 1, TorchCommonLoadStrategy())
register_default_strategy(
StrategyAction.SAVE_COMMON, 'torch', 1, TorchCommonSaveStrategy('torch', 1)
)
class TorchCommonSaveStrategy(SaveCommonStrategy):
"""Common save strategy leveraging native torch save/load."""
def save_common(self, common_state_dict: StateDict, checkpoint_dir: Path):
"""Save common part of the state dict."""
if torch.distributed.get_rank() == 0:
torch.save(common_state_dict, checkpoint_dir / COMMON_STATE_FNAME)
def save_sharded_objects(
self, sharded_objects_state_dict: ShardedStateDict, checkpoint_dir: Path
):
"""Save sharded objects from the state dict."""
for sh_obj in nested_values(sharded_objects_state_dict):
if is_main_replica(sh_obj.replica_id):
save_path = checkpoint_dir / f'{sh_obj.unique_key}.pt'
os.makedirs(save_path.parent, exist_ok=True)
torch.save(sh_obj.data, save_path)
def can_handle_sharded_objects(self):
"""This strategy can handle ShardedObjects."""
return True
class TorchCommonLoadStrategy(LoadCommonStrategy):
"""Common load strategy leveraging native torch save/load."""
def load_common(self, checkpoint_dir: Path):
"""Load common (non-sharded) objects state dict from the checkpoint.
Args:
checkpoint_dir (Path): checkpoint directory
Returns:
StateDict: state dict with non-sharded objects from the checkpoint
"""
load_path = Path(checkpoint_dir) / COMMON_STATE_FNAME
try:
return torch.load(load_path, map_location='cpu', weights_only=False)
except FileNotFoundError as e:
err_msg = f'Common file {load_path} does not exist'
ckpt_files = [f.name for f in checkpoint_dir.iterdir()]
logger.debug(f'{err_msg}. Checkpoint directory content: {ckpt_files}')
raise CheckpointingException(err_msg) from e
def load_sharded_objects(
self, sharded_objects_state_dict: ShardedStateDict, checkpoint_dir: Path
):
"""Replaces all ShardedObject from a given state dict with values loaded from the
checkpoint.
Args:
sharded_objects_state_dict (ShardedStateDict):
sharded state dict defining what objects should be loaded.
checkpoint_dir (Path): checkpoint directory
Returns:
None: sharded state dict is modified in place
"""
def load_sharded_object(sh_obj: ShardedObject):
sh_obj.data = None
load_path = checkpoint_dir / f'{sh_obj.unique_key}.pt'
try:
loaded_obj = torch.load(load_path, weights_only=False)
except FileNotFoundError as e:
# Backward compatible logic: previously the save format was incorrect
old_load_path = (checkpoint_dir / sh_obj.unique_key).with_suffix('.pt')
try:
loaded_obj = torch.load(old_load_path, weights_only=False)
except FileNotFoundError:
err_msg = f'Object shard {load_path} not found'
obj_subdir = checkpoint_dir / sh_obj.key
if obj_subdir.exists():
obj_files = [f.name for f in obj_subdir.iterdir()]
logger.debug(
f'{err_msg}. Object {sh_obj.key} directory content: {obj_files}'
)
else:
ckpt_files = [f.name for f in checkpoint_dir.iterdir()]
logger.debug(
f'{err_msg}. Object {sh_obj.key} directory does not exist. Checkpoint'
f' directory content: {ckpt_files}'
)
raise CheckpointingException(err_msg) from e
return loaded_obj
return dict_list_map_inplace(load_sharded_object, sharded_objects_state_dict)
def load_sharded_metadata(self, checkpoint_dir: Path) -> ShardedStateDict:
sharded_metadata = {}
for subdir in checkpoint_dir.iterdir():
if not subdir.is_dir():
continue
shard_files = list(subdir.glob('shard_*.pt'))
if not shard_files:
continue
sh_objs = []
for shard_file in shard_files:
full_key = f'{subdir.name}/{shard_file.stem}'
sh_objs.append(ShardedObject.empty_from_unique_key(full_key))
# This is a backward-compatibility fix, where the last global shape is missing in the
# name
if sh_objs[0].global_shape[-1] < 0:
max_last_offset = max(map(lambda sh_obj: sh_obj.global_offset[-1], sh_objs))
for sh_obj in sh_objs:
sh_obj.global_shape = (*sh_obj.global_shape[:-1], max_last_offset + 1)
# Update the sharded state dict
for sh_obj in sh_objs:
sharded_metadata[sh_obj.unique_key] = sh_obj
return sharded_metadata
@property
def can_handle_sharded_objects(self):
"""This strategy can handle ShardedObjects."""
return True
def check_backend_compatibility(self, loaded_version):
pass
def check_version_compatibility(self, loaded_version):
pass
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
""" Storage writer for PyT Distributed format allowing asynchronous save. """
import dataclasses
import logging
import os
import queue
from functools import partial
from heapq import heappop, heappush
from itertools import chain
from operator import itemgetter
from pathlib import Path
from time import time
from typing import Callable, Dict, List, Optional, Tuple, Union
import psutil
import torch
from torch import multiprocessing as mp
from torch.distributed.checkpoint import FileSystemWriter
from torch.distributed.checkpoint.filesystem import DEFAULT_SUFFIX, _StoragePrefix, _write_item
from torch.distributed.checkpoint.planner import SavePlan, SavePlanner, WriteItem, WriteItemType
from torch.distributed.checkpoint.storage import WriteResult
from torch.futures import Future
from .async_utils import _disable_gc
logger = logging.getLogger(__name__)
WriteBucket = Tuple[Path, str, Tuple[list, list]] # represents writes to a single file
_results_queue = None
def _get_write_results_queue():
global _results_queue
if _results_queue is None:
ctx = mp.get_context('fork')
with _disable_gc():
_results_queue = ctx.Manager().Queue()
return _results_queue
class FileSystemWriterAsync(FileSystemWriter):
"""
Async-enabled implementation of FileSystemWriter using file IO.
This class doesn't spawn the async process itself, relies on the external async mechanism.
Flow:
1. Call `write_data`
2. Externally start async process with `get_save_function_and_args` function and args
3. The async function to call is `writer_proxy_func` which calls
`write_preloaded_data` in multiple processes
After saving is finalized on all ranks:
4. Call `super().finish` with the results gathered in `self.writer_result`
Note that step (3) above can also be called synchronously.
Currently, it's assumed that a separate writer is created for each ckpt save
(intermediate state is stored as writer attributes).
"""
def __init__(self, *args, separation_hint: Optional[str] = None, **kwargs):
super().__init__(*args, **kwargs)
if not self.single_file_per_rank:
raise NotImplementedError(
'single_file_per_rank flag not supported for FileSystemWriterAsync'
)
self.can_run_decentralized_global_plan: bool = True
# Intermediate state between preparation and finalization
self.write_buckets: Optional[List[WriteBucket]] = None
self.results_queue: Optional[mp.Queue] = None
self.separation_hint = separation_hint
def prepare_write_data(self, plan: SavePlan, planner: SavePlanner) -> None:
"""
First stage of async saving. Copy data to CPU and plan the local saving.
Args:
plan (SavePlan): save plan generated by the PyT Distributed compatible planner
planner (SavePlanner): save planner used to resolve the bytes and tensor data
Returns: None, but stores the save plan in `self.write_buckets`
"""
storage_plan: _StoragePrefix = plan.storage_data
start = time()
logger.debug(f"thread_count: {self.thread_count}, time: {start}")
if self.separation_hint:
assert (
self.thread_count > 1
), "thread_count must be at least 2 if separation_hint is provided"
bins = self.thread_count // 2 if self.separation_hint is not None else self.thread_count
item_buckets = _split_by_size_and_type(bins, plan.items)
logger.debug(f"bucket_prep, time: {time() - start}")
start = time()
# move tensors from GPU to CPU before starting async writing
# We do D2H synchronously for now
file_count = 0
def gen_file(prefix=""):
nonlocal file_count
file_name = f"{prefix}{storage_plan.prefix}{file_count}{DEFAULT_SUFFIX}"
file_count += 1
return file_name
def _clone_if_needed(ten: torch.Tensor):
"""Clone if we detect incontiguous storage for CPU tensors
Makes sure we perform a `clone` only if we detect incontiguous storage,
so that we don't blow up host memory unnecessarily.
TODO: For persistent worker, this work should be changed to move the cpu tensor
to shared_memory.
"""
ten = ten.detach()
if ten.device.type != "cpu":
# We do D2H later when the async_request is scheduled for both sync / async
# checkpointing
return ten
is_view = ten.untyped_storage().size() != ten.numel() * ten.itemsize
return ten.clone() if is_view else ten
# Prepare bytes / tensor data in each bucket, which will be assigned to each writer process
self.write_buckets = []
for group_name, group_buckets in _split_by_separation_hint(
item_buckets, self.separation_hint
).items():
for bucket in group_buckets:
bytes_data = [
(item, planner.resolve_data(item))
for item in bucket
if item.type == WriteItemType.BYTE_IO
]
tensor_data = [
(item, _clone_if_needed(planner.resolve_data(item)))
for item in bucket
if item.type != WriteItemType.BYTE_IO
]
if len(bytes_data) > 0 or len(tensor_data) > 0:
file_name = gen_file(prefix=group_name)
self.write_buckets.append(
(self.path / file_name, file_name, (bytes_data, tensor_data))
)
# Check if there is anything to write on this rank
if len(self.write_buckets) > 0:
assert len(self.write_buckets) <= self.thread_count, (
len(self.write_buckets),
self.thread_count,
)
self.results_queue = _get_write_results_queue()
else:
self.results_queue = None
end = time()
logger.debug(f"D2H and push, time: {end - start}")
def get_save_function_and_args(self) -> Tuple[Optional[Callable], Optional[Callable], List]:
"""
Get function that saves the data to storage along with its arguments.
Allows the external caller to apply the save function synchronously or asynchronously.
Returns: None (if there is nothing to write on this rank) or a tuple of:
1) the function that saves the data.
2) the function that stages the GPU tensors to a destination for async checkpointing.
This function should be self-contained.
3) arguments to that function in 1).
"""
if not self.write_buckets:
return None, None, ()
transform_list = [self.transforms] if hasattr(self, 'transforms') else []
return (
partial(self.write_preloaded_data_multiproc, transform_list),
partial(self.preload_tensors, self.write_buckets, True),
[torch.distributed.get_rank(), self.write_buckets, self.results_queue],
)
@staticmethod
def preload_tensors(write_buckets: List[WriteBucket], non_blocking=True) -> List[WriteBucket]:
"""Preload tensors in state_dict to host memory through CPU memory
Args:
write_buckets(List): List of `WriteBucket`,
which includes what to be saved in a checkpoint
non_blocking (bool, optional): knob to enable pinned D2H memcpy. Default is True.
"""
result = []
for bucket in write_buckets:
file_name, storage_key, (bytes_data, tensor_data) = bucket
tensor_data = [
(item, tensor.to("cpu", non_blocking=non_blocking)) for item, tensor in tensor_data
]
result.append((file_name, storage_key, (bytes_data, tensor_data)))
if non_blocking:
torch.cuda.synchronize()
return result
@staticmethod
@_disable_gc()
def write_preloaded_data_multiproc(
transform_list, rank, write_buckets: List[WriteBucket], global_results_queue: mp.Queue
) -> None:
"""
Performs saving data to storage with multiple processes.
Starts predefined number of processes and uses 2 queues to make sure the results
are complete:
- local_results_queue - to send the actual results
- count_queue - small queue to mark worker as completed
Using just one queue disallowed proper exception handling.
This method is meant to be run in a forked subprocess.
Triggering GC during execution leads to CUDA errors
(cleaning up tensors owned by the parent process).
To prevent this, we disable the GC explicitly for this function with _disable_gc.
Args:
write_buckets (List[WriteBucket]): write plan
global_results_queue (mp.Queue): mp.Queue to collect Dict[List[WriteResults]]
(or an Exception) from parallel write processes to the main training process
Returns: None
"""
logger = logging.getLogger(__name__)
w_start = time()
write_results_or_exc: Union[dict, Exception] = dict()
ctx = mp.get_context('fork')
local_results_queue = ctx.Queue()
count_queue = ctx.JoinableQueue()
p_list = []
for i, write_bucket in enumerate(write_buckets):
try:
count_queue.put(i)
p_list.append(
ctx.Process(
target=partial(FileSystemWriterAsync.write_preloaded_data, transform_list),
args=(i, write_bucket, local_results_queue, count_queue, True),
)
)
except Exception as e:
err_msg = f'An error is caught while a proc {i} is created, error: {e}'
logger.error(err_msg)
write_results_or_exc = RuntimeError(err_msg)
if not isinstance(write_results_or_exc, Exception):
for p in p_list:
p.start()
logger.debug('FileSystemWriterAsync: collecting worker results...')
# To make sure all nodes are completed
count_queue.join()
# At this point, all workers completed, so the queue should have exactly
# `len(write_buckets)` items
for proc_idx in range(len(write_buckets)):
try:
local_proc_idx, local_results_or_exc = local_results_queue.get()
except queue.Empty:
write_results_or_exc = RuntimeError(
f'Unexpected empty `local_results_queue`'
f' (got only {proc_idx}/{len(write_buckets)} items)'
)
break
else:
if isinstance(local_results_or_exc, Exception):
err_msg = (
f"Local process {local_proc_idx} encountered"
f" an error: {local_results_or_exc}"
)
logger.error(err_msg)
write_results_or_exc = local_results_or_exc
break
assert isinstance(local_results_or_exc, list), type(local_results_or_exc)
write_results_or_exc[local_proc_idx] = local_results_or_exc
p_list[local_proc_idx].join()
logger.debug('FileSystemWriterAsync: collected worker results successfully')
global_results_queue.put(write_results_or_exc)
w_end = time()
logger.debug(f"{w_end}, rank: {rank}," f" write(sync,parallel): {w_end - w_start}")
@staticmethod
@_disable_gc()
def write_preloaded_data(
transform_list,
local_proc_idx: int,
write_bucket: WriteBucket,
results_queue: mp.SimpleQueue,
count_queue: mp.JoinableQueue,
use_fsync: bool,
) -> None:
"""
Performs actual data saving to storage.
Args:
local_proc_idx (int): index of a local process that performs writing
write_bucket (WriteBucket): data to write to storage
results_queue (mp.Queue): queue to return the write results
to the proxy checkpoint process.
count_queue (mp.JoinableQueue): queue to marks worker task as completed
use_fsync (bool): if True, calls os.fsync at the end of saving
Returns: None, the write result are put into the `queue`
"""
logger = logging.getLogger(__name__)
logger.debug(f'{local_proc_idx} started')
mem_before = _process_memory()
local_results = []
try:
file_name, storage_key, (bytes_data, tensor_data) = write_bucket
with open(file_name, "wb") as stream:
for write_item, data in bytes_data:
local_results.append(
_write_item(*transform_list, stream, data, write_item, storage_key)
)
for write_item, tensor in tensor_data:
assert tensor.is_cpu
local_results.append(
_write_item(*transform_list, stream, tensor, write_item, storage_key)
)
if use_fsync:
os.fsync(stream.fileno())
local_output = (local_proc_idx, local_results)
except Exception as e:
logger.debug(f'{local_proc_idx} failed')
local_output = (local_proc_idx, e)
results_queue.put(local_output)
# Signal this process is done.
count_queue.get()
count_queue.task_done()
mem_after = _process_memory()
logger.debug(
f"{local_proc_idx} consumed: {mem_after - mem_before},"
f" before: {mem_before}, after: {mem_after}"
)
def write_data(self, plan: SavePlan, planner: SavePlanner) -> Future[List[WriteResult]]:
"""Write all items from ``plan``."""
raise NotImplementedError('write_data not implemented for FileSystemWriterAsync')
def retrieve_write_results(self) -> List[WriteResult]:
"""
Turn the latest dict including write results from `self.results_queue`
into a single results lists. Includes error check.
Returns (List[WriteResult]): the list of write results
from all local processes performing the save.
"""
assert self.write_buckets is not None
if self.results_queue is None:
write_results_or_exc = {}
else:
try:
write_results_or_exc = self.results_queue.get_nowait()
except queue.Empty:
raise RuntimeError(f'results_queue should not be empty')
if isinstance(write_results_or_exc, Exception):
raise RuntimeError(f'Worker failure: {write_results_or_exc}') from write_results_or_exc
write_results: dict = write_results_or_exc
if len(write_results) != len(self.write_buckets):
raise RuntimeError(
f'Incomplete worker results (expected {len(self.write_buckets)},'
f' got {len(write_results)}. This probably indicates a worker failure.'
)
return list(chain.from_iterable(write_results.values()))
def prepare_decentralized_global_plan(self, local_plan: SavePlan) -> SavePlan:
"""Instead of assigning indices by plan order, uses PyT rank (same outcome).
Args:
local_plan (SavePlan): local plan to turn to a global plan
(without interactions with other ranks)
Returns:
SavePlan - locally transformed plan equivalent to the plan that would be
created by the coordinator
"""
return dataclasses.replace(
local_plan, storage_data=_StoragePrefix(f"__{torch.distributed.get_rank()}_")
)
def _split_by_size_and_type(bins: int, items: List[WriteItem]) -> List[List[WriteItem]]:
"""
Splits write items according to item size into close to uniform bins.
Same as torch.distributed.checkpoint.filesystem._split_by_size_and_type,
but with a fixed _item_size function.
Args:
bins (int): numbers of bins to split to
items (List[WriteItem]): list of write items
Returns (List[List[WriteItem]]): write items split to bins
"""
if bins == 1:
return [items]
bytes_items: List[WriteItem] = []
tensor_items: List[WriteItem] = []
for wi in items:
container = bytes_items if wi.type == WriteItemType.BYTE_IO else tensor_items
container.append(wi)
buckets: List[List[WriteItem]] = [[] for _ in range(bins)]
bucket_sizes = [0 for _ in range(bins)]
# Assign bytes with a simple round-robin
for i, item in enumerate(bytes_items):
buckets[i % bins].append(item)
# Sort tensor items by size in decreasing order once and store the size with item
sized_tensors = [(item, _item_size(item)) for item in tensor_items]
sized_tensors.sort(key=itemgetter(1), reverse=True)
# Use a min heap for bin assignment
# Store (total_size_of_bin, bin_index) tuples
heap: List[Tuple[int, int]] = [(0, i) for i in range(bins)]
# Assign tensors using heap
for item, size in sized_tensors:
total_bin_size, bin_idx = heappop(heap)
buckets[bin_idx].append(item)
heappush(heap, (total_bin_size + size, bin_idx))
return buckets
def _split_by_separation_hint(
buckets: List[List[WriteItem]], separation_hint: Optional[str] = None
) -> Dict[str, List[List[WriteItem]]]:
"""
Splits buckets into those whose keys begin with the separation_hint and those whose keys do not
Args:
buckets (List[List[WriteItem]]): buckets to split
separation_hint (Optional[str]): optional prefix to split on
Returns (Dict[str, List[List[WriteItem]]]): a dictionary
mapping the prefix to the relevant buckets
"""
bins = len(buckets)
buckets_with_separation_hint = {}
if separation_hint is not None:
buckets_default = [[] for _ in range(bins)]
buckets_hint = [[] for _ in range(bins)]
for i in range(bins):
for item in buckets[i]:
if item.index.fqn.startswith(separation_hint):
buckets_hint[i].append(item)
else:
buckets_default[i].append(item)
buckets_with_separation_hint[""] = buckets_default
buckets_with_separation_hint[separation_hint] = buckets_hint
else:
buckets_with_separation_hint[""] = buckets
return buckets_with_separation_hint
def _item_size(item: WriteItem) -> int:
"""
Calculates size (in bytes) of a single write item.
Same as torch.distributed.checkpoint.filesystem._item_size,
but fixes computing chunk size (with item.tensor_data.chunk.sizes)
Args:
item (WriteItem): write item to compute the size of
Returns (int): size of an item in bytes
"""
size = 1
assert item.tensor_data is not None
# can't use math.prod as PT needs to support older python
for s in item.tensor_data.chunk.sizes:
size *= s
dtype = item.tensor_data.properties.dtype
return size * torch._utils._element_size(dtype)
def _process_memory() -> int:
"""
Get memory used by current process.
Returns (int): memory used by current process
"""
process = psutil.Process(os.getpid())
mem_info = process.memory_info()
return mem_info.rss
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import logging
from pathlib import Path
from time import time
from typing import Any, Callable, Dict, Optional, Tuple, TypeVar
import torch
import torch.distributed as dist
from torch.distributed.checkpoint import Metadata
from megatron.core.dist_checkpointing import ShardedObject, ShardedTensor
from megatron.core.dist_checkpointing.core import CheckpointingException
from megatron.core.dist_checkpointing.dict_utils import (
dict_list_map_inplace,
extract_matching_values,
merge,
nested_values,
)
from megatron.core.dist_checkpointing.exchange_utils import (
ShardDistribution,
determine_main_replica_uniform_distribution,
exchange_by_distribution,
exchange_loaded_objects_gather_object,
)
from megatron.core.dist_checkpointing.mapping import ShardedStateDict, StateDict, is_main_replica
from megatron.core.dist_checkpointing.strategies.base import (
AsyncSaveShardedStrategy,
LoadShardedStrategy,
SaveShardedStrategy,
)
from megatron.core.dist_checkpointing.utils import (
_sharded_object_id,
_sharded_tensor_shard_id,
_ShardId,
debug_time,
)
from megatron.core.dist_checkpointing.validation import (
determine_global_metadata,
validate_sharding_integrity,
)
logger = logging.getLogger(__name__)
T = TypeVar('T', ShardedObject, ShardedTensor)
class FullyParallelSaveStrategyWrapper(AsyncSaveShardedStrategy):
"""Wraps arbitrary strategy and distributes the save during `save`.
The save distribution happens without any *data* communication.
Only the *metadata* is exchanged and based on data replication on different
ranks, we try to distribute the save as uniformly as possible.
This wrapper assumes, that setting `replica_id` to 0 will make the
underlying strategy do the saving on current rank. All the other `replica_id`s
are set to 1.
Currently, the save distribution is realized with a greedy algorithm
described in `distribute_shards_to_ranks`.
Args:
strategy (SaveShardedStrategy): base strategy to wrap
parallelization_group (ProcessGroup, optional): process group to use for save
distribution. Note that this doesn't have to match exactly the
data distribution, but should cover the replication pattern
to maximize performance. Defaults to the whole world.
do_cache_distribution (bool, optional): whether to cache the save distribution
from previous calls. Should be set to True only if the state dict
structure between the calls is always the same. Defaults to True.
"""
def __init__(
self,
strategy: SaveShardedStrategy,
parallelization_group: Optional[torch.distributed.ProcessGroup] = None,
do_cache_distribution: bool = False,
):
super().__init__(strategy.backend, strategy.version)
self.base_strategy = strategy
self.parallelization_group = parallelization_group
self.do_cache_distribution = do_cache_distribution
self.cached_distribution: Optional[ShardDistribution] = None
def async_save(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path):
if not isinstance(self.base_strategy, AsyncSaveShardedStrategy):
raise CheckpointingException(
f'Cannot apply async_save to non-async base strategy {self.base_strategy}'
)
self.apply_saving_parallelization(sharded_state_dict)
return self.base_strategy.async_save(sharded_state_dict, checkpoint_dir)
def save(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path):
self.apply_saving_parallelization(sharded_state_dict)
return self.base_strategy.save(sharded_state_dict, checkpoint_dir)
def apply_saving_parallelization(self, sharded_state_dict: ShardedStateDict) -> None:
"""Distributes the save across ranks by exchanging metadata.
Exchanges metadata from the state dict and computes the uniform
(as close as possible) distribution of saves among the ranks.
If `self.do_cache_distribution` is True, caches the distribution between
the calls and subsequent distributions happen without any inter-rank
communication.
Args:
sharded_state_dict (ShardedStateDict): state dict to distribute the saving
Returns: None
"""
start = time()
if self.do_cache_distribution and self.cached_distribution is not None:
logger.debug(f'Apply *cached* save parallelization')
precomputed_distribution = self.cached_distribution
else:
logger.debug(f'Apply save parallelization')
precomputed_distribution = determine_main_replica_uniform_distribution(
sharded_state_dict, self.parallelization_group
)
distribute_main_replicas_with_precomputed_distribution(
sharded_state_dict, self.parallelization_group, precomputed_distribution
)
if self.cached_distribution is None:
# First time applying the parallelization
validate_sharding_integrity(determine_global_metadata(sharded_state_dict)[1])
if self.do_cache_distribution:
self.cached_distribution = precomputed_distribution
end = time()
logger.debug(f"parallel save sharding, time: {end - start}")
@property
def can_handle_sharded_objects(self):
return self.base_strategy.can_handle_sharded_objects
class FullyParallelLoadStrategyWrapper(LoadShardedStrategy):
"""Wraps arbitrary load strategy and distributes the load during `load`.
See `load` method docs for details.
Args:
strategy (LoadShardedStrategy): base strategy to wrap
parallelization_group (ProcessGroup, optional): process group to use for load
distribution. Note that this doesn't have to match exactly the
data distribution, but should cover the replication pattern
to maximize performance. Defaults to the whole world.
In most cases, it's recommended to set it to the DP group.
do_cache_distribution (bool, optional): whether to cache the load distribution
from previous calls. Should be set to True only if the state dict
structure between the calls is always the same. Defaults to False,
since the loading in general happens only once during training.
Note that the load distribution *cannot* be reused as a save distribution,
because save/load is not fully symmetrical.
exchange_algo (str): algorithm to use for exchanging the data.
Options:
- broadcast - each rank broadcasts individual tensors to others
- gather_object (default) - ranks all_gather_object the whole loaded state dicts
- gather_rounds (default) - ranks all gather individual tensors in rounds
See method docs for more details.
"""
def __init__(
self,
strategy: LoadShardedStrategy,
parallelization_group: Optional[torch.distributed.ProcessGroup] = None,
do_cache_distribution: bool = False,
exchange_algo: str = 'broadcast',
):
super().__init__()
self.base_strategy = strategy
if parallelization_group is None:
parallelization_group = (
dist.GroupMember.WORLD
) # explicit group needed for torch.distributed.get_global_rank call
self.parallelization_group = parallelization_group
self.do_cache_distribution = do_cache_distribution
self.exchange_algo = exchange_algo
self.cached_distribution: Optional[ShardDistribution] = None
self.cached_global_metadata: Optional[Metadata] = None
@debug_time("FullyParallelLoadStrategyWrapper.load", logger)
def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path) -> StateDict:
"""Distributes the load and calls underlying strategy only for parts of the state dict.
Steps:
1. Load metadata is exchanged between the ranks in the parallelization group.
2. Each rank deterministically plans the load for the whole workload
so that the loads are as uniform as possible.
3. Each ranks loads its planned shard of the checkpoint.
4. All ranks exchange the loaded shards.
Internode communication is involved in steps (1) (with metadata)
and (4) (with actual data). Storage interaction is involved in step (3).
Currently, the load distribution (step 2) is realized with a greedy algorithm
described in `distribute_shards_to_ranks` (same as for saving distribution).
Currently, the shards are all gathered between all ranks in the parallelization
group. This might not be optimal (some ranks do not need all tensors),
but it's a reasonable approximation for an optimal exchange in most scenarios.
Args:
sharded_state_dict (ShardedStateDict): sharded state dict to load
checkpoint_dir (Path): checkpoint directory to load from
Returns:
StateDict: loaded state dict. The state dict should be equivalent to
a state dict that would be loaded with the underlying strategy
without this wrapper.
"""
loaded_state_dict = {}
if torch.distributed.get_world_size(self.parallelization_group) <= 1:
return self.base_strategy.load(sharded_state_dict, checkpoint_dir)
# Step 1 and 2: exchange load metadata and distribute the load
with debug_time("self.apply_loading_parallelization", logger):
precomputed_distribution: ShardDistribution | None = self.apply_loading_parallelization(
sharded_state_dict
)
assert (
precomputed_distribution is not None
), 'Expecting non-trivial distribution for non-trivial parallelization group'
# Step 3: load part of the checkpoint.
# Load only sharded objects first. ShardedTensors will be loaded separately
# so that we can keep track of sharded tensors loaded by this rank
(sharded_tensors, sharded_state_dict, to_load_shards, unloaded_shards) = (
self._defer_loading_sharded_tensors(sharded_state_dict)
)
(sharded_objects, sharded_state_dict, to_load_objects, unloaded_objects) = (
self._defer_loading_sharded_objects(sharded_state_dict)
)
assert (
len(sharded_state_dict) == 0
), "sharded_state_dict is not empty after deferring tensors and objects"
with debug_time("base_load_ShardedObjects", logger):
# Load sharded objects first
loaded_objects = self.base_strategy.load(to_load_objects, checkpoint_dir)
with debug_time("base_load_ShardedTensors", logger):
# Load sharded tensors separately
loaded_tensors = self.base_strategy.load(to_load_shards, checkpoint_dir)
with debug_time("self.exchange_loaded_tensors", logger):
# Step 4: exchange data between ranks
logger.debug(f'Applying parallel load with algo {self.exchange_algo}')
all_loaded_tensors = exchange_by_distribution(
loaded_tensors,
unloaded_shards,
precomputed_distribution,
self.parallelization_group,
self.exchange_algo,
)
if not set(unloaded_shards.keys()).issubset(all_loaded_tensors.keys()):
missing_shards = set(unloaded_shards.keys()) - all_loaded_tensors.keys()
raise CheckpointingException(
f'Missing shards after fully parallel loading: {missing_shards}'
)
with debug_time("torch.cuda.synchronize", logger):
torch.cuda.synchronize()
all_loaded_objects = exchange_loaded_objects_gather_object(loaded_objects)
if not set(unloaded_objects.keys()).issubset(all_loaded_objects.keys()):
missing_object_shards = set(unloaded_objects.keys()) - all_loaded_objects.keys()
raise CheckpointingException(
f'Missing object shards after fully parallel loading: {missing_object_shards}'
)
torch.cuda.synchronize()
self.fill_in_deferred_sharded_tensors(sharded_tensors, all_loaded_tensors)
self.fill_in_deferred_sharded_objects(sharded_objects, all_loaded_objects)
merge(loaded_state_dict, sharded_objects)
merge(loaded_state_dict, sharded_tensors)
if hasattr(self.base_strategy, "cached_global_metadata"):
self.cached_global_metadata = self.base_strategy.cached_global_metadata
return loaded_state_dict
@staticmethod
def _defer_loading_sharded_objects(
sharded_state_dict: ShardedStateDict,
) -> Tuple[
ShardedStateDict,
ShardedStateDict,
Dict[_ShardId, ShardedObject],
Dict[_ShardId, ShardedObject],
]:
return _defer_loading_sharded_items(sharded_state_dict, ShardedObject, _sharded_object_id)
@staticmethod
def _defer_loading_sharded_tensors(
sharded_state_dict: ShardedStateDict,
) -> Tuple[
ShardedStateDict,
ShardedStateDict,
Dict[_ShardId, ShardedTensor],
Dict[_ShardId, ShardedTensor],
]:
return _defer_loading_sharded_items(
sharded_state_dict, ShardedTensor, _sharded_tensor_shard_id
)
@staticmethod
def fill_in_deferred_sharded_objects(
sharded_state_dict: ShardedStateDict, loaded_objects: Dict[_ShardId, Any]
) -> None:
"""Fill in objects not loaded by current rank with objects from `loaded_objects` map.
Args:
sharded_state_dict (ShardedStateDict): sharded state dict to fill in.
ShardedObjects are completely replaced with corresponding objects.
loaded_objects (Dict[_ShardId, Any]): dict allowing to map
ShardedObject from the sharded_state_dict to loaded objects.
Returns:
None
"""
_fill_in_deferred_sharded_items(
sharded_state_dict, loaded_objects, ShardedObject, _sharded_object_id
)
@staticmethod
def fill_in_deferred_sharded_tensors(
sharded_state_dict: ShardedStateDict, loaded_tensors: Dict[_ShardId, torch.Tensor]
) -> None:
"""Fill in tensors not loaded by current rank with tensors from `loaded_tensors` map.
Args:
sharded_state_dict (ShardedStateDict): sharded state dict to fill in.
ShardedTensors are completely replaced with corresponding torch.Tensors.
loaded_tensors (Dict[_ShardId, torch.Tensor]): dict allowing to map
ShardedTensor from the sharded_state_dict to loaded tensors.
Returns:
None
"""
_fill_in_deferred_sharded_items(
sharded_state_dict, loaded_tensors, ShardedTensor, _sharded_tensor_shard_id
)
def apply_loading_parallelization(
self, sharded_state_dict: ShardedStateDict
) -> Optional[ShardDistribution]:
"""Distributes the load across ranks by exchanging metadata.
Exchanges metadata from the state dict and computes the uniform
(as close as possible) distribution of loads among the ranks.
Marks ShardedTensors to be loaded by the current rank with replica_id 0
(and others with non 0 values).
If `self.do_cache_distribution` is True, caches the distribution between
the calls and subsequent distributions happen without any inter-rank
communication.
Args:
sharded_state_dict (ShardedStateDict): state dict to distribute the loading
Returns:
ShardDistribution (optional): the computed loading distribution
"""
if self.do_cache_distribution and self.cached_distribution is not None:
logger.debug(f'Apply *cached* load parallelization')
precomputed_distribution = self.cached_distribution
else:
logger.debug(f'Apply load parallelization')
precomputed_distribution = determine_main_replica_uniform_distribution(
sharded_state_dict, self.parallelization_group, True
)
distribute_main_replicas_with_precomputed_distribution(
sharded_state_dict, self.parallelization_group, precomputed_distribution
)
if self.do_cache_distribution:
self.cached_distribution = precomputed_distribution
return precomputed_distribution
@property
def can_handle_sharded_objects(self):
return self.base_strategy.can_handle_sharded_objects
def load_tensors_metadata(self, checkpoint_dir: Path):
return self.base_strategy.load_tensors_metadata(checkpoint_dir)
def load_sharded_metadata(self, checkpoint_dir: Path):
return self.base_strategy.load_sharded_metadata(checkpoint_dir)
def check_backend_compatibility(self, loaded_version):
return self.base_strategy.check_backend_compatibility(loaded_version)
def check_version_compatibility(self, loaded_version):
return self.base_strategy.check_version_compatibility(loaded_version)
def distribute_main_replicas_with_precomputed_distribution(
sharded_state_dict: ShardedStateDict,
parallelization_group: torch.distributed.ProcessGroup,
precomputed_distribution: Optional[ShardDistribution],
):
"""Applies the save distribution computed with `determine_main_replica_uniform_distribution`.
Based on rank assignment, sets replica ids of the shards saved by current rank to 0
and all the other replica ids to 1.
Args:
sharded_state_dict (ShardedStateDict): state dict to apply the save distribution to
parallelization_group (ProcessGroup): distribution will be applied within this
process group. Must match with the process group passed to
`determine_main_replica_uniform_distribution`.
precomputed_distribution (ShardDistribution): distribution computed with
`determine_main_replica_uniform_distribution`
Returns: None
Example replica ids of tensors A, B, C before distribution:
rank0: A: (0, 0, 0), B: (0, 0, 0), C: (0, 0, 0)
rank1: A: (0, 0, 1), B: (0, 0, 1), C: (0, 0, 1)
rank2: A: (0, 0, 2), B: (0, 0, 2), C: (0, 0, 2)
Replicas after distribution for the example above:
rank0: A: 0, B: 1, C: 1
rank1: A: 1, B: 0, C: 1
rank2: A: 1, B: 1, C: 0
"""
if torch.distributed.get_world_size(group=parallelization_group) <= 1:
return
if precomputed_distribution is None:
raise ValueError(
'precomputed_distribution must be not None for non-trivial parallelization group'
)
local_shards = list(
sh_base
for sh_base in nested_values(sharded_state_dict)
if isinstance(sh_base, ShardedTensor)
)
rank_within_dp_group = torch.distributed.get_rank(parallelization_group)
for sh_ten in local_shards:
shard_id = _sharded_tensor_shard_id(sh_ten)
if (
shard_id in precomputed_distribution.shards_in_this_group
and rank_within_dp_group == precomputed_distribution.main_rank_for_shard[shard_id]
):
sh_ten.replica_id = 0
else:
sh_ten.replica_id = 1
def _defer_loading_sharded_items(
sharded_state_dict: ShardedStateDict, item_type: type, shard_id_func: Callable[[T], _ShardId]
) -> Tuple[ShardedStateDict, ShardedStateDict, Dict[_ShardId, T], Dict[_ShardId, T]]:
"""Divides state dict into parts loaded by this vs other ranks.
Args:
sharded_state_dict (ShardedStateDict): state dict with sharded items
that will be divided.
item_type: The type of sharded item (ShardedObject or ShardedTensor)
shard_id_func: Function to get the shard ID for the item type
Returns: a tuple of:
- ShardedStateDict: sub-state dict only with sharded items
- ShardedStateDict: sub-state dict with non-sharded items
- Dict[_ShardId, T]: mapping from shard id to items loaded by *this* rank
- Dict[_ShardId, T]: mapping from shard id to items loaded by *other* ranks
"""
to_load_shards = {}
unloaded_shards = {}
sharded_items, remaining_state_dict = extract_matching_values(
sharded_state_dict, lambda v: isinstance(v, item_type)
)
def wrap_non_main_replicas(x: Any) -> Any:
if isinstance(x, item_type):
shard_id = shard_id_func(x)
if is_main_replica(x.replica_id):
to_load_shards[shard_id] = x
else:
unloaded_shards[shard_id] = x
return x
dict_list_map_inplace(wrap_non_main_replicas, sharded_items)
return sharded_items, remaining_state_dict, to_load_shards, unloaded_shards
def _fill_in_deferred_sharded_items(
sharded_state_dict: ShardedStateDict,
loaded_items: Dict[_ShardId, Any],
item_type: type,
shard_id_func: Callable[[T], _ShardId],
) -> None:
"""Helper function to fill in items not loaded by current rank."""
def fill_in_sharded_item(x: Any) -> Any:
if isinstance(x, item_type):
try:
x = loaded_items[shard_id_func(x)]
except KeyError as e:
raise CheckpointingException(
f'Missing loaded item shard: {shard_id_func(x)}'
) from e
return x
dict_list_map_inplace(fill_in_sharded_item, sharded_state_dict)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
""" Performant resharding of flattened tensors.
Tensors that are first sharded (e.g. across TP) and then flattened cause
very irregular access patterns during loading. The idea for performant save/load
is to store tensors with global shape [X, Y, Z] and local shape [x, y, z]
as tensors with global shape [X // x, Y // y, Z // z, x * y * z] and
local shape [1, 1, 1, x * y * z]. This allows parallel save of tensors along the
last (flattened) dimension. During loading, some additional resharding is needed.
"""
import logging
import math
from dataclasses import dataclass
from itertools import product
from typing import Any, Dict, Tuple, Union
import numpy as np
import torch
from torch.distributed.checkpoint import ChunkStorageMetadata
from torch.distributed.checkpoint.resharding import _shards_get_overlap_region_wrt_saved_tensor
from megatron.core.dist_checkpointing import ShardedTensor
from megatron.core.dist_checkpointing.core import CheckpointingException
from megatron.core.dist_checkpointing.dict_utils import (
dict_list_map_inplace,
extract_matching_values,
)
from megatron.core.dist_checkpointing.mapping import (
ShardedStateDict,
ShardedTensorFactory,
StateDict,
apply_factories,
apply_factory_merges,
)
logger = logging.getLogger(__name__)
@dataclass
class TensorReformulationMetadata:
"""Metadata needed to restore the original tensor shape.
Args:
ckpt_orig_global_shape (Tuple[int, ...]): original global shape of the tensor
saved in the checkpoint. This is the global shape of the application,
further reformulated into `ckpt_reform_global_shape` while saving.
ckpt_reform_global_shape (Tuple[int, ...]): reformulated global shape of the tensor
saved in the checkpoint. This is the actual saved shape.
"""
ckpt_orig_global_shape: Tuple[int, ...]
ckpt_reform_global_shape: Tuple[int, ...]
def __post_init__(self):
assert self.ckpt_orig_global_shape
def nd_flattened_tensor_reformulated_global_shape(sh_ten: ShardedTensor) -> Tuple[int, ...]:
"""Reformulated global shape of the flattened N-D ShardedTensor.
N-D tensor global shape [X, Y, Z] and local shape [x, y, z]
is reformulated into global shape [X // x, Y // y, Z // z, x * y * z] and
local shape [1, 1, 1, x * y * z], to allow parallel save of tensors along the
last (flattened) dimension.
Args:
sh_ten (ShardedTensor): flattened N-D ShardedTensor (N > 1)
Returns:
Tuple[int, ...]: reformulated tensor shape
"""
assert is_nd_flattened_tensor(sh_ten), sh_ten
return sh_ten.axis_fragmentations + (int(np.prod(sh_ten.local_shape)),)
def is_nd_flattened_tensor(sh_ten: Any) -> bool:
"""Checks if ShardedTensor is flattened and more than 1-dimensional
Args:
sh_ten (Any): any object
Returns:
bool: whether the given object is a flattened ShardedTensor and is N-dimensional (N > 1)
"""
return isinstance(sh_ten, ShardedTensor) and sh_ten.flattened_range is not None
# information needed to restore. With current implementation, this is a nested state dict
# with ShardedTensorFactories which is basically a ShardedStateDict type
ReformulationRestoreMetadata = ShardedStateDict
def apply_nd_flattened_tensors_reformulation(
sharded_state_dict: ShardedStateDict,
reformulation_metadata: Dict[str, TensorReformulationMetadata],
) -> Tuple[ShardedStateDict, ReformulationRestoreMetadata]:
"""Applies N-D reformulation to a given sharded state dict.
After applying the method and loading the reformulated state dict,
the `restore_nd_flattened_tensors_formulation` needs to be applied.
Current implementation uses ShardedTensorFactories for convenience of
restoring the original structure, but it's just an implementation detail.
Turns N-D ShardedTensors into factories and immediately applies them,
keeping the data needed to restore the original structure.
Args:
sharded_state_dict (ShardedStateDict): sharded state dict potentially
with tensors to reformulate.
reformulation_metadata (Dict[str, TensorReformulationMetadata]): dict
containing all metadata needed for reformulating tensors in `sharded_state_dict`.
for each N-D flattened tensor `sh_ten` in `sharded_state_dict` there must be an
entry with `sh_ten.key`.
Returns:
tuple:
ShardedStateDict - reformulated sharded state dict
ReformulationRestoreMetadata - data needed to restore the original formulation
with `restore_nd_flattened_tensors_formulation`
"""
def maybe_reformulate_nd_flattened_tensor(sh_ten: Any):
if not isinstance(sh_ten, ShardedTensor) or not is_nd_flattened_tensor(sh_ten):
return sh_ten
# N-D flattened ShardedTensor
try:
sh_ten_reformulation_metadata = reformulation_metadata[sh_ten.key]
except KeyError as e:
# Handle legacy checkpointing where 1-D flatten tensor metadata was not saved
if len(sh_ten.global_shape) == 1:
return sh_ten
raise CheckpointingException(
f'Missing reformulation metadata for tensor {sh_ten}. '
f'Existing keys: {reformulation_metadata.keys()}'
) from e
ckpt_actual_saved_shape = sh_ten_reformulation_metadata.ckpt_reform_global_shape
app_actual_load_shape = nd_flattened_tensor_reformulated_global_shape(sh_ten)
if ckpt_actual_saved_shape == app_actual_load_shape:
# Same shape - no need to reshard
return sh_ten
return reformulate_single_nd_flattened_tensor(sh_ten, sh_ten_reformulation_metadata)
# Turn N-D tensors into factories and immediately apply them
dict_list_map_inplace(maybe_reformulate_nd_flattened_tensor, sharded_state_dict)
sh_ten_factories, _ = extract_matching_values(
sharded_state_dict,
lambda x: isinstance(x, ShardedTensorFactory),
return_lists_as_dicts=True,
)
apply_factories(sharded_state_dict)
# Unlink `data` pointers to free memory
def unlink_data(x):
x.data = None
return x
dict_list_map_inplace(unlink_data, sh_ten_factories)
return sharded_state_dict, sh_ten_factories
def restore_nd_flattened_tensors_formulation(
state_dict: StateDict, formulation_restore_metadata: ReformulationRestoreMetadata
) -> StateDict:
"""Restores the original state dict from a reformulated form.
Inverse of `apply_nd_flattened_tensors_reformulation`.
Args:
state_dict (StateDict): state dict obtained by loading a reformulated
sharded state dict.
formulation_restore_metadata (ReformulationRestoreMetadata): metadata returned by
`apply_nd_flattened_tensors_reformulation` function
Returns:
StateDict: state dict with the original tensors formulation restored
"""
return apply_factory_merges(state_dict, formulation_restore_metadata)
def reformulate_single_nd_flattened_tensor(
sh_ten: ShardedTensor, reformulation_metadata: TensorReformulationMetadata
) -> Union[Any, ShardedTensorFactory]:
"""Reformulates shapes of a single N-D flattened ShardedTensor.
We need to define a pair of transformations:
- turn N-D ShardedTensor with original formulation into multiple reformulated ShardedTensors
- merge multiple reformulated loaded torch.Tensors into a single original tensor
Current implementation uses ShardedTensorFactories as a convenient mechanism
for specifying and keeping track of those transformations.
Args:
sh_ten (ShardedTensor): sharded tensor to reformulate.
reformulation_metadata (TensorReformulationMetadata): metadata needed to
perform the reformulation
Returns:
ShardedTensorFactory: factory that keeps information how to reformulate
(build) the ShardedTensor and then restore original formulation (merge)
after loading.
"""
rmd = reformulation_metadata
# Data won't be needed - remove unnecessary tensor references
sh_ten = sh_ten.without_data()
# Based on reformulation_metadata, determine other tensor shapes and metadata
ckpt_axis_fragmentation = rmd.ckpt_reform_global_shape[:-1]
for sh, fragm in zip(rmd.ckpt_orig_global_shape, ckpt_axis_fragmentation):
assert sh % fragm == 0, (sh_ten, rmd.ckpt_reform_global_shape)
ckpt_local_shape_with_prepended_axis = tuple(
sh // fragm for sh, fragm in zip(rmd.ckpt_orig_global_shape, ckpt_axis_fragmentation)
)
assert (
ckpt_local_shape_with_prepended_axis[: sh_ten.prepend_axis_num]
== (1,) * sh_ten.prepend_axis_num
), (ckpt_local_shape_with_prepended_axis, sh_ten)
ckpt_local_shape = ckpt_local_shape_with_prepended_axis[sh_ten.prepend_axis_num :]
# Iterate over reformulated shapes needed by the application and from checkpoint,
# and generate new ShardedTensors that match the checkpoint sharding.
overlap_dim_offsets = []
assert len(ckpt_axis_fragmentation) == len(sh_ten.axis_fragmentations), (
ckpt_axis_fragmentation,
sh_ten,
)
for dim, (app_chunk_dim_offset, ckpt_fragm, app_fragm) in enumerate(
zip(
sh_ten.local_chunk_offset_in_global(),
ckpt_axis_fragmentation,
sh_ten.axis_fragmentations,
)
):
# without `int`, it's an exact offset of the app shard expressed in ckpt_local_shape units
first_overlap_dim_offset = int(ckpt_fragm / app_fragm * app_chunk_dim_offset)
# `math.ceil` argument is an exact offset of the app next shard expressed
# in ckpt_local_shape units
next_overlap_dim_offset = math.ceil(ckpt_fragm / app_fragm * (app_chunk_dim_offset + 1))
overlap_dim_offsets.append(range(first_overlap_dim_offset, next_overlap_dim_offset))
logger.debug(
f'Generated the following number of overlap shards for each dimension: '
f'{list(map(len, overlap_dim_offsets))} for fragmentation ckpt '
f'{ckpt_axis_fragmentation} vs app {sh_ten.axis_fragmentations} '
f'and chunk offset {sh_ten.local_chunk_offset_in_global()}'
)
reformulated_sh_tens = {}
for chunk_offset in product(*overlap_dim_offsets):
global_offset = tuple(
chunk_off * chunk_shape
for chunk_off, chunk_shape in zip(chunk_offset, ckpt_local_shape_with_prepended_axis)
)
reformulated_sh_tens[(global_offset, ckpt_local_shape)] = ShardedTensor(
sh_ten.key,
None,
sh_ten.dtype,
ckpt_local_shape,
rmd.ckpt_orig_global_shape,
global_offset,
ckpt_axis_fragmentation,
sh_ten.replica_id,
sh_ten.prepend_axis_num,
sh_ten.allow_shape_mismatch,
flattened_range=slice(0, rmd.ckpt_reform_global_shape[-1]), # whole ckpt shard
)
# Now, we have to define the transformations from application sharding
# to checkpoint sharding.
@torch.no_grad()
def sh_ten_build_fn(*args, **kwargs):
# Here we simply return the precomputed tensors.
return reformulated_sh_tens
@torch.no_grad()
def sh_ten_merge_fn(sub_state_dict):
# This is the non-flattened local tensor with original formulation
# that we are going to fill with shards loaded from the checkpoint.
app_non_flat_ten = torch.empty(
sh_ten.local_shape,
dtype=sh_ten.dtype,
device=sh_ten.data.device if sh_ten.data is not None else None,
)
assert len(sub_state_dict) > 0
for (ckpt_global_offset, ckpt_local_shape), ckpt_ten in sub_state_dict.items():
# For each ckpt shard, we fill the appropriate application shard part
dest_ten = app_non_flat_ten
src_ten = ckpt_ten.view(ckpt_local_shape)
# We don't need narrowing over `prepend_axis_num` axes so we take
# the [sh_ten.prepend_axis_num:] offsets slice
for (
dim,
offset_for_saved_tensor,
offset_for_current_tensor,
length,
) in _shards_get_overlap_region_wrt_saved_tensor(
saved_shard=ChunkStorageMetadata(
ckpt_global_offset[sh_ten.prepend_axis_num :], ckpt_local_shape
),
current_shard=ChunkStorageMetadata(
sh_ten.global_offset[sh_ten.prepend_axis_num :], sh_ten.local_shape
),
):
src_ten = src_ten.narrow(dim, offset_for_saved_tensor, length)
dest_ten = dest_ten.narrow(dim, offset_for_current_tensor, length)
dest_ten.copy_(src_ten)
return app_non_flat_ten.flatten()[sh_ten.flattened_range]
return ShardedTensorFactory(
sh_ten.key,
sh_ten.data,
sh_ten_build_fn,
sh_ten_merge_fn,
sh_ten.replica_id,
sh_ten.flattened_range,
)
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
""" State dict saver for PyT Distributed format allowing asynchronous save. """
from logging import getLogger
from time import time
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import torch
import torch.distributed as dist
from torch.distributed.checkpoint import CheckpointException
from torch.distributed.checkpoint.default_planner import DefaultSavePlanner
from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE, Metadata
from torch.distributed.checkpoint.planner import SavePlan, SavePlanner
from torch.distributed.checkpoint.utils import _DistWrapper, _get_failure_dict
if TYPE_CHECKING:
from .filesystem_async import FileSystemWriterAsync
from .torch import MCoreSavePlanner
logger = getLogger(__name__)
from dataclasses import fields
def _compare_dataclasses(obj1, obj2):
if type(obj1) != type(obj2):
return f"Objects are of different types: {type(obj1)} and {type(obj2)}"
differences = []
for field in fields(obj1):
value1 = getattr(obj1, field.name)
value2 = getattr(obj2, field.name)
if value1 != value2:
differences.append(f"{field.name}: {value1} != {value2}")
return differences if differences else "All fields are equal"
def save_state_dict_async_plan(
state_dict: STATE_DICT_TYPE,
storage_writer: 'FileSystemWriterAsync',
process_group: Optional[dist.ProcessGroup] = None,
coordinator_rank: int = 0,
planner: Optional[Union[SavePlanner, 'MCoreSavePlanner']] = None,
cached_ckpt_structure: Optional[Tuple[SavePlan, SavePlan, bool]] = None,
loaded_all_plans: Optional[List[SavePlan]] = None,
) -> Tuple[Tuple['FileSystemWriterAsync', Union[Metadata, None], _DistWrapper], SavePlan, bool]:
"""
First stage of saving a state dict to storage.
This is an async adjustment of torch.distributed.checkpoint.state_dict_saver.
In order to support async save, saving should be split into three parts:
1. Planning
2. Actual saving
3. Finalization
Out of these, step (2) *must* happen asynchronously.
The first step is realized with this function.
The planning part consists of several steps, described here:
https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.SavePlanner
Args:
state_dict (STATE_DICT_TYPE): state dict to save
storage_writer (FileSystemWriterAsync): in current version only an instance of
FileSystemWriterAsync
process_group (dist.ProcessGroup, optional): process group used for save planning
coordinator_rank (int, optional): coordinator rank for planning. Defaults to 0.
planner (SavePlanner, optional): save planner for torch.distributed.checkpoint format
cached_ckpt_structure (Tuple[SavePlan, SavePlan, bool], Optional):
Each object of this tuple will be used in the order as following
cached_central_plan (SavePlan): a globally coordinated save plan
cached in the previous iteration
cached_local_plan (SavePlan): a local plan
cached in the previous iteration
validated_cache_reuse (bool): boolean value to tell global_metadata and planning dict
is consistent over iterations
Returns: Tuple of:
- storage writer (the one passed as input)
- metadata from planning (or None if we reuse cached global metadata)
- distributed wrapper used for planning
The return value of this function should be passed as an input to
`save_state_dict_async_finalize` and cached_plan to skip `reduce_scatter` at planning.
"""
cached_central_plan, cached_local_plan, validated_cache_reuse = (None, None, False)
if cached_ckpt_structure:
cached_central_plan, cached_local_plan, validated_cache_reuse = cached_ckpt_structure
rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
dist_wrapper = _DistWrapper(process_group, True, coordinator_rank)
if planner is None:
planner = DefaultSavePlanner()
assert planner is not None
global_metadata = None
logger.debug(f"rank: {rank}, starting state dict save")
local_plan = cached_local_plan
global_md_verify_reuse = False
def local_step():
nonlocal local_plan
assert planner is not None
# PyTorch 2.4 introduced additional `metadata` argument,
# we have to reference `is_coordinator` args by name
planner.set_up_planner(state_dict, is_coordinator=dist_wrapper.is_coordinator)
storage_writer.set_up_storage_writer(dist_wrapper.is_coordinator)
if not validated_cache_reuse and local_plan is None:
local_plan = planner.create_local_plan()
local_plan = storage_writer.prepare_local_plan(local_plan)
return local_plan
def global_step(all_local_plans):
nonlocal global_metadata
assert planner is not None
all_local_plans, global_metadata = planner.create_global_plan(all_local_plans)
all_local_plans = storage_writer.prepare_global_plan(all_local_plans)
return all_local_plans
# Execute local and global planning
# Ideally we want to use the cached plan. Otherwise if the planner and storage_writer
# allow it (`can_run_decentralized_global_plan`) we gather the plans to create
# the metadata but prepare the plans independently on each rank.
# In the worst case we have to reduce_scatter all the plans.
start_plan = time()
if validated_cache_reuse and cached_central_plan:
logger.debug(f"rank: {rank}, Passed cache reusable")
local_step()
central_plan = cached_central_plan
elif getattr(planner, 'can_run_decentralized_global_plan', False) and getattr(
storage_writer, 'can_run_decentralized_global_plan', False
):
local_plan = local_step()
global_md_verify_reuse = verify_global_md_reuse(
loaded_all_plans, local_plan, rank, dist_wrapper
)
if not loaded_all_plans or not global_md_verify_reuse:
all_local_plans = dist_wrapper.gather_object(local_plan)
if dist_wrapper.is_coordinator:
_, global_metadata = planner.create_global_plan(all_local_plans)
global_metadata.all_local_plans = all_local_plans
else:
logger.debug(f"rank: {rank}, Passed cached global metadata")
global_metadata = None
local_plan = planner.create_decentralized_global_plan(local_plan)
local_plan = storage_writer.prepare_decentralized_global_plan(local_plan)
central_plan = local_plan
else:
central_plan = dist_wrapper.reduce_scatter("plan", local_step, global_step)
central_plan = planner.finish_plan(central_plan)
end_plan = time()
logger.debug(f"rank: {rank}, plan time: {end_plan - start_plan}")
# Prepare async writing of tensors.
# The `storage_writer` will store the information about tensors it needs to save
start = time()
storage_writer.prepare_write_data(central_plan, planner)
end = time()
logger.debug(f"{time()} rank: {rank}, write(async) time: {end - start}")
return (
(storage_writer, global_metadata, dist_wrapper),
central_plan,
local_plan,
cached_central_plan == central_plan,
global_md_verify_reuse,
)
def verify_global_md_reuse(
loaded_all_plans: List[SavePlan], local_plan: SavePlan, rank: int, dist_wrapper: _DistWrapper
) -> bool:
"""
Verifies that global metadata reuse is possible by checking the loaded plans from the
checkpoint are consistent, which means we have the same settings when resuming training.
Args:
loaded_all_plans: List[SavePlan], The loaded plans from the checkpoint
(stored in checkpoint metadata).
local_plan: SavePlan, The local save plan.
rank: Current process rank.
dist_wrapper (_DistWrapper): distributed wrapper created during planning
Returns: True iff the global metadata reuse is possible.
"""
logger.debug(f"verifying reuse of global metadata")
if not loaded_all_plans:
global_md_verify_reuse = False
logger.debug("loaded global metadata reuse verification: no loaded plans passed")
elif len(loaded_all_plans) == dist_wrapper.get_world_size():
local_verify_reuse = all(
getattr(local_plan, f.name) == getattr(loaded_all_plans[rank], f.name)
for f in fields(local_plan)
if f.name != 'storage_data'
)
if not local_verify_reuse:
logger.debug(
f"local_verify_reuse is False: diffs -"
f" {_compare_dataclasses(local_plan, loaded_all_plans[rank])}"
)
all_results = torch.tensor([local_verify_reuse], dtype=torch.int, device='cuda')
torch.distributed.all_reduce(all_results, op=torch.distributed.ReduceOp.MIN)
# Check if all reduced results are True
global_md_verify_reuse = all_results.item() == 1
else:
global_md_verify_reuse = False
return global_md_verify_reuse
def save_state_dict_async_finalize(
storage_writer: 'FileSystemWriterAsync', global_metadata: Metadata, dist_wrapper: _DistWrapper
) -> None:
"""
Finalization of save_state_dict_async_plan.
The input arguments are the same as the save_state_dict_async_plan output,
the `write_results` are retrieved from the storage_writer.
Args:
storage_writer (FileSystemWriterAsync): storage writer used for planning
global_metadata (Metadata): metadata created during planning
dist_wrapper (_DistWrapper): distributed wrapper created during planning
Returns: None
"""
write_results = storage_writer.retrieve_write_results()
# Gather the write results that will be saved to the metadata file.
gather_start = time()
all_results = dist_wrapper.gather_object(write_results)
gather_end = time()
logger.debug(f"{gather_end}, {torch.distributed.get_rank()}, gather: {gather_end-gather_start}")
# Store the metadata on coordinator rank
if dist_wrapper.is_coordinator:
node_failures = _get_failure_dict(all_results)
if len(node_failures) == 0:
assert global_metadata is not None
write_start = time()
storage_writer.finish(global_metadata, all_results)
write_end = time()
logger.debug(f"{write_end}, metadata_write: {write_end - write_start}")
else:
raise CheckpointException("write", node_failures)
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Strategies using TensorStore to load and save Zarr arrays. """
from functools import partial
from itertools import starmap
from logging import getLogger
from pathlib import Path
import tensorstore as ts
import torch
from ..core import CheckpointingException
from ..dict_utils import dict_list_map_inplace
from ..mapping import ShardedStateDict, ShardedTensor
from .base import LoadShardedStrategy, StrategyAction, register_default_strategy
from .zarr import load_zarr_based_sharded_metadata, postprocess_numpy_array
logger = getLogger(__name__)
def register_default_tensorstore_strategies():
"""Register default strategies leveraging tensorstore."""
register_default_strategy(
StrategyAction.LOAD_SHARDED, 'zarr', 1, TensorStoreLoadShardedStrategy()
)
class TensorStoreLoadShardedStrategy(LoadShardedStrategy):
"""Load strategy for Zarr backend using `tensorstore` for loading."""
def __init__(self, load_directly_on_device: bool = False):
super().__init__()
self.load_directly_on_device = load_directly_on_device
def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path):
if torch.distributed.get_rank() == 0:
print(f'Loading distributed checkpoint with {self.__class__.__name__}')
if self.load_directly_on_device:
print(f'Loading distributed checkpoint directly on the GPU')
load_fn = partial(
_load_from_array,
checkpoint_dir=checkpoint_dir,
load_directly_on_device=self.load_directly_on_device,
)
dict_list_map_inplace(load_fn, sharded_state_dict)
return sharded_state_dict
def load_tensors_metadata(self, checkpoint_dir: Path):
def get_ts_shape_dtype(path):
arr = open_ts_array(path)
return arr.shape, arr.dtype.numpy_dtype
return load_zarr_based_sharded_metadata(checkpoint_dir, get_ts_shape_dtype)
def check_backend_compatibility(self, loaded_version):
pass # TODO
def check_version_compatibility(self, loaded_version):
pass # TODO
def merge_global_slice_with_shape(global_slice, actual_shape, key):
"""Intersects the global slice with the actual shape (prevent overflow)."""
def _merge_slice(dim_slice, dim_size):
if isinstance(dim_slice, slice):
assert (
dim_slice.start < dim_size
), f'Got empty slice for ShardedTensor {key} ({dim_slice}, {dim_size})'
if dim_slice.stop > dim_size:
dim_slice = slice(dim_slice.start, dim_size, dim_slice.step)
return dim_slice
assert len(global_slice) == len(actual_shape), (global_slice, actual_shape, key)
return tuple(starmap(_merge_slice, zip(global_slice, actual_shape)))
def _load_from_array(
sharded_tensor: ShardedTensor,
checkpoint_dir: Path,
load_directly_on_device: bool = False,
apply_flattened_range: bool = True,
):
x = _load_regular_chunk(sharded_tensor, checkpoint_dir)
ten = postprocess_numpy_array(x, sharded_tensor, apply_flattened_range)
if load_directly_on_device:
sharded_tensor.data.data.copy_(ten)
return sharded_tensor.data
else:
return ten
def _load_regular_chunk(sharded_tensor: ShardedTensor, checkpoint_dir: Path):
assert isinstance(sharded_tensor, ShardedTensor), type(sharded_tensor)
arr = open_ts_array(checkpoint_dir / sharded_tensor.key)
if sharded_tensor.global_shape == arr.shape:
x = (
arr[sharded_tensor.global_slice()].read().result()
) # flattened tensors loading is delayed
elif sharded_tensor.allow_shape_mismatch:
global_slice = merge_global_slice_with_shape(
sharded_tensor.global_slice(), arr.shape, sharded_tensor.key
)
x = arr[global_slice].read().result() # flattened tensors loading is delayed
else:
_msg = (
f'Global shape mismatch for loaded ({arr.shape})'
f' and expected ({sharded_tensor.global_shape}) tensor'
f' for key {sharded_tensor.key}'
)
raise CheckpointingException(_msg)
return x
def open_ts_array(arr_path: Path):
"""Opens a Zarr file array with Tensorstore with basic setting.
Args:
arr_path (Path): path to a Zarr (Tensorstore) array
"""
spec = {'driver': 'zarr', 'metadata_key': '.zarray', 'kvstore': {}}
spec['kvstore'] = {'driver': 'file', 'path': str(arr_path)}
try:
arr = ts.open(ts.Spec(spec), open=True).result()
except Exception as e:
raise CheckpointingException(f'Array {arr_path} could not be loaded. Error: {e}') from e
return arr
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Strategies using PyTorch distributed.checkpoint as an underlying format. """
import io
import os
import pickle
import warnings
from collections import ChainMap, defaultdict
from contextlib import contextmanager
from dataclasses import dataclass
from itertools import product
from logging import getLogger
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, cast
import torch
from packaging.version import Version as PkgVersion
from torch.distributed import checkpoint
from torch.distributed._shard.metadata import ShardMetadata
from torch.distributed._shard.sharded_tensor import Shard
from torch.distributed._shard.sharded_tensor import ShardedTensor as TorchShardedTensor
from torch.distributed._shard.sharded_tensor import ShardedTensorMetadata, TensorProperties
from torch.distributed.checkpoint import (
BytesStorageMetadata,
DefaultLoadPlanner,
DefaultSavePlanner,
FileSystemReader,
FileSystemWriter,
LoadPlan,
Metadata,
ReadItem,
SavePlan,
TensorStorageMetadata,
WriteItem,
)
from torch.distributed.checkpoint._nested_dict import FLATTEN_MAPPING, unflatten_state_dict
from torch.distributed.checkpoint._traverse import OBJ_PATH, traverse_state_dict
from torch.distributed.checkpoint.metadata import Metadata
from torch.distributed.checkpoint.planner_helpers import _create_write_items
from ...utils import get_torch_version, is_torch_min_version
from ..core import CheckpointingException
from ..dict_utils import nested_values
from ..mapping import (
ShardedBase,
ShardedObject,
ShardedStateDict,
ShardedTensor,
StateDict,
is_main_replica,
)
from .async_utils import AsyncRequest
from .base import (
AsyncSaveShardedStrategy,
LoadShardedStrategy,
StrategyAction,
register_default_strategy,
)
from .cached_metadata_filesystem_reader import CachedMetadataFileSystemReader
from .filesystem_async import FileSystemWriterAsync
from .resharding import (
TensorReformulationMetadata,
apply_nd_flattened_tensors_reformulation,
is_nd_flattened_tensor,
nd_flattened_tensor_reformulated_global_shape,
restore_nd_flattened_tensors_formulation,
)
from .state_dict_saver import save_state_dict_async_finalize, save_state_dict_async_plan
try:
if not torch.cuda.is_available():
raise ImportError
from transformer_engine.pytorch.float8_tensor import Float8Tensor
HAVE_TE = True
except ImportError:
HAVE_TE = False
try:
from torch.distributed._tensor import DTensor
HAVE_DTENSOR = True
except ImportError:
HAVE_DTENSOR = False
_metadata_fn: str = ".metadata"
def register_default_torch_strategies():
"""Register default strategies related to PyT Distributed backend."""
register_default_strategy(
StrategyAction.LOAD_SHARDED, 'torch_dist', 1, TorchDistLoadShardedStrategy()
)
register_default_strategy(
StrategyAction.SAVE_SHARDED, 'torch_dist', 1, TorchDistSaveShardedStrategy('torch_dist', 1)
)
logger = getLogger(__name__)
def flatten_state_dict(
state_dict: ShardedStateDict,
) -> Tuple[ShardedStateDict, Dict[str, OBJ_PATH]]:
"""Flattens state dict into a single level dict.
It's a copy of torch.distributed.checkpoint._nested_dict.flatten_state_dict
which also accepts ShardedBase tensors as terminal objects
Args:
state_dict (ShardedStateDict): state dict to be flattened
Returns (tuple): flattened state dict and a mapping allowing to recreate the original one
"""
flattened = {}
mappings = {}
def flat_copy(path: OBJ_PATH, value: Any) -> None:
new_fqn = ".".join(map(str, path))
if new_fqn in flattened:
raise ValueError(f"duplicated flatten key {new_fqn}")
flattened[new_fqn] = value
mappings[new_fqn] = path
traverse_state_dict(state_dict, flat_copy, lambda x: isinstance(x, (torch.Tensor, ShardedBase)))
return flattened, mappings
def sharded_tensor_to_torch_sharded_tensor(
sh_tens: List[ShardedTensor],
rank: Optional[int] = None,
load_legacy_1d_flatten_tensors: bool = False,
) -> TorchShardedTensor:
"""Convert MCore ShardedTensor to PyT ShardedTensor. PyT requires information about all chunks.
On high-level, this function follows the logic of
torch.distributed.fsdp._shard_utils._create_chunk_sharded_tensor.
Additionally, it saves `prepend_axis_num` and `has_flattened_range` (specific to MCore)
as attributes for further restoration in `_unwrap_pyt_sharded_tensor`.
NOTE: this function assumes regular (grid) sharding of the MCore ShardedTensor.
The only local irregularities could be introduced with a `flattened_range` attribute.
This function handles 2 different type of ShardedTensors:
1. Non-flat regular ShardedTensors (`not has_flattened_range`)
2. N-D flattened ShardedTensors (`has_flattened_range`)
(1) type are saved according to their original shape.
Type (2) however requires global shape adjustment for efficiency:
we treat [X, Y, Z] global shape tensor with local shape [x, y, z]
as a [X // x, Y // y, Z // z, x * y * z] tensor with last axis
partitioned according to `flattened_range` slices.
This will need special handling while resharding.
Args:
sh_tens (List[ShardedTensor]): list of sharded tensors to convert
rank (int, optional): current process rank passed to PyT ShardedTensor.
If None, assumes rank in the default pg.
load_legacy_1d_flatten_tensors (bool, optional): flag indicating if 1-D flattened tensors
should be loaded in a legacy way. Defaults to False.
Returns (TorchShardedTensor): PyT ShardedTensor containing all passed shards.
"""
if rank is None:
rank = torch.distributed.get_rank()
some_sh_ten = sh_tens[0]
has_flattened_range = some_sh_ten.flattened_range is not None
for sh_ten in sh_tens:
assert (sh_ten.flattened_range is not None) == has_flattened_range, sh_tens
if not sh_ten.data.is_contiguous():
sh_ten.data = sh_ten.data.contiguous()
if load_legacy_1d_flatten_tensors and len(some_sh_ten.global_shape) == 1:
# Legacy 1-D flattened tensors are loaded as non-flat regular ShardedTensors
has_flattened_range = False
local_global_offsets = {}
prepend_axis_num = sh_tens[0].prepend_axis_num
# Determine local shards according to tensor type (see docs)
if has_flattened_range:
# Type (3) case: N-D flattened ShardedTensors
for sh_ten in sh_tens:
local_global_offsets.setdefault(sh_ten.local_chunk_offset_in_global(), []).append(
sh_ten
)
assert sh_ten.data.ndim == 1, sh_ten
sh_ten.data = sh_ten.data.view((1,) * len(sh_ten.global_shape) + (-1,))
# Global shape reformulation:
global_shape = nd_flattened_tensor_reformulated_global_shape(some_sh_ten)
offsets_shape = (1,) * len(
some_sh_ten.global_shape
) # reformulated global shape has shape equal ti number of local chunks
local_shards = [
Shard.from_tensor_and_offsets(
sh_ten.data,
list(
sh_ten.local_chunk_offset_in_global() + (sh_ten.flattened_range.start,)
), # additional flattened offset
rank,
)
for sh_ten in sh_tens
]
else:
# Type (1) case: non-flat regular ShardedTensors
for sh_ten in sh_tens:
local_global_offsets.setdefault(sh_ten.global_offset, []).append(sh_ten)
sh_ten.data = sh_ten.data.view(
(1,) * prepend_axis_num + sh_ten.local_shape
) # adjust to prepended_axis_num
global_shape = some_sh_ten.global_shape
offsets_shape = some_sh_ten.data.shape # includes prepended axes
local_shards = [
Shard.from_tensor_and_offsets(
sh_ten.data, list(sh_ten.global_offset), rank # simple case
)
for sh_ten in sh_tens
]
# Create a ShardedTensor without invoking communication. Determine global shards
world_size = torch.distributed.get_world_size()
shard_metadata = []
# NOTE: here we assume a regular grid of shards
for fragment_offsets in product(*map(range, some_sh_ten.axis_fragmentations)):
offset = tuple(map(lambda x: x[0] * x[1], zip(fragment_offsets, offsets_shape)))
if offset in local_global_offsets:
# local shard
placement = f"rank:{rank}/cuda"
for sh_ten in local_global_offsets[offset]:
if has_flattened_range:
assert offset == sh_ten.local_chunk_offset_in_global()
# This is not an actual offset, but an offset of the whole shard
# This is needed for a PyT Dist internal integrity check
offset = sh_ten.local_chunk_offset_in_global() + (0,)
size = (1,) * len(offsets_shape) + global_shape[-1:]
else:
size = sh_ten.data.shape
shard_metadata.append(ShardMetadata(offset, size, placement))
else:
# pylint: disable=line-too-long
# for shards from other ranks we provide simplistic data - this information will be discarded
# during TorchShardedTensor._init_from_local_shards_and_global_metadata call.
# Due to a bug in PyT 24.05 container we must specify some concrete rank within a world size.
# The exact rank doesn't matter as long as it's different than my rank - hence (rank + 1) % WS.
placement = f"rank:{(rank + 1) % world_size}/cuda"
if has_flattened_range:
offset = offset + (0,)
size = (1,) * len(offsets_shape) + global_shape[-1:]
else:
size = offsets_shape
shard_metadata.append(ShardMetadata(offset, size, placement))
tensor = some_sh_ten.data
sharded_tensor_metadata = ShardedTensorMetadata(
shards_metadata=shard_metadata,
size=torch.Size(global_shape),
tensor_properties=TensorProperties(
dtype=tensor.dtype,
layout=tensor.layout,
requires_grad=tensor.requires_grad,
memory_format=torch.contiguous_format,
pin_memory=tensor.is_pinned(),
),
)
pyt_sh_ten = TorchShardedTensor._init_from_local_shards_and_global_metadata(
local_shards, sharded_tensor_metadata=sharded_tensor_metadata, process_group=None
)
# Store MCore related data as PyTShardedTensor attribute.
# This won't be stored in the checkpoint, only for runtime purposes
pyt_sh_ten.mcore_sh_ten = sh_ten.without_data()
pyt_sh_ten.mcore_metadata = {}
if has_flattened_range:
pyt_sh_ten.mcore_metadata['nd_reformulated_orig_global_shape'] = sh_ten.global_shape
return pyt_sh_ten
def mcore_to_pyt_state_dict(
state_dict: Dict[str, List[ShardedBase]],
is_loading: bool = False,
init_device: torch.device = torch.device("cpu"),
load_legacy_1d_flatten_tensors: bool = False,
) -> Dict[str, Union[TorchShardedTensor, io.BytesIO]]:
"""Convert state dict with ShardedTensors and ShardedObjects
to state dict compatible with PyT Dist format.
Operates in-place and returns the original state dict.
Args:
state_dict (Dict[str, List[ShardedBase]]): flattened state dict, where values
are lists of either ShardedTensor or ShardedObjects.
is_loading (bool, optional): flag indicating if loading or saving. Defaults to False.
init_device (torch.device, optional): device to initialize potentially missing tensors
during loading. Defaults to 'cpu'.
Returns (Dict[str, Union[TorchShardedTensor, io.BytesIO]]): original dictionary with values
converted either into PyT ShardedTensors or io.BytesIO.
"""
rank = torch.distributed.get_rank()
pyt_state_dict = {}
def _mcore_to_torch_sharded_tensor(sh_tens: List[ShardedTensor]) -> TorchShardedTensor:
"""Build a PyT ShardedTensor from given shards.
During loading:
- if data is None, initialize it with an empty tensor (will be used to copy the data into)
- if `allow_shape_mismatch` is True, the data is initialized with zeros
prior to loading (not all parts of the tensor will be read from the checkpoint)
"""
assert all(isinstance(sh_ten, ShardedTensor) for sh_ten in sh_tens), sh_tens
for sh_ten in sh_tens:
if sh_ten.data is None:
if is_loading:
sh_ten.init_data(
init_device,
init_fn=torch.zeros if sh_ten.allow_shape_mismatch else torch.empty,
)
else:
raise CheckpointingException(f'`data` attr is None for {sh_ten}')
else:
sh_ten.data = sh_ten.data.detach()
if sh_ten.allow_shape_mismatch and is_loading:
sh_ten.data.zero_()
torch_sh_ten = sharded_tensor_to_torch_sharded_tensor(
sh_tens, rank, load_legacy_1d_flatten_tensors
)
torch_sh_ten.key = sh_tens[0].key
return torch_sh_ten
def _mcore_to_torch_sharded_object(sh_objs: List[ShardedObject]) -> io.BytesIO:
"""Build io.BytesIO from given sharded objects data."""
assert all(isinstance(sh_obj, ShardedObject) for sh_obj in sh_objs), sh_objs
serialized_data = io.BytesIO()
torch.save([sh_obj.data for sh_obj in sh_objs], serialized_data)
return serialized_data
for k, v in state_dict.items():
if isinstance(v[0], ShardedTensor):
v = cast(List[ShardedTensor], v)
pyt_state_dict[k] = _mcore_to_torch_sharded_tensor(v)
else:
v = cast(List[ShardedObject], v)
pyt_state_dict[k] = _mcore_to_torch_sharded_object(v)
return pyt_state_dict
def _unwrap_pyt_sharded_tensor(sh_ten: TorchShardedTensor) -> List[torch.Tensor]:
"""Unwrap tensor from PyT ShardedTensor instance.
If `prepend_axis_num` was non-zero (which is specific to MCore ShardedTensor)
then the tensor has additional singleton dimensions which should be squeezed.
"""
mcore_sh_ten = sh_ten.mcore_sh_ten
ret_tensors = []
for sh in sh_ten.local_shards():
ten = sh.tensor
if mcore_sh_ten.flattened_range is not None:
assert ten.shape[:-1] == (1,) * (len(ten.shape) - 1), ten.shape
ten = ten.view(-1)
else:
for _ in range(mcore_sh_ten.prepend_axis_num):
ten = ten.squeeze(0)
ret_tensors.append(ten)
return ret_tensors
def _replace_state_dict_keys_with_sharded_keys(
sharded_state_dict: ShardedStateDict, keep_only_main_replica: bool = False
) -> Tuple[Dict[str, List[ShardedBase]], FLATTEN_MAPPING, Dict[str, List[str]]]:
"""Group ShardedBase objects by keys and
return mappings required for recreating the original dict."""
flat_sd, flat_mapping = flatten_state_dict(sharded_state_dict)
rename_mapping = defaultdict(list)
new_flat_sd = defaultdict(list)
for k, sh_base in flat_sd.items():
assert isinstance(sh_base, ShardedBase), type(sh_base)
key = sh_base.unique_key if isinstance(sh_base, ShardedObject) else sh_base.key
if is_main_replica(sh_base.replica_id) or not keep_only_main_replica:
rename_mapping[key].append(k)
new_flat_sd[key].append(sh_base)
return new_flat_sd, flat_mapping, rename_mapping
def _replace_sharded_keys_with_state_dict_keys(
state_dict: Dict[str, List[Union[torch.Tensor, io.BytesIO]]],
flat_mapping: FLATTEN_MAPPING,
rename_mapping: Dict[str, List[str]],
):
"""Inverse of _replace_state_dict_keys_with_sharded_keys."""
recovered_sd = {}
for k, tensors in state_dict.items():
assert len(tensors) == len(rename_mapping[k])
for ten, recovered_k in zip(tensors, rename_mapping[k]):
recovered_sd[recovered_k] = ten
return unflatten_state_dict(recovered_sd, flat_mapping)
def _restore_dict_types(x: Union[dict, list, Any], keys_template: Union[dict, list, Any]):
"""Recursively update `x` keys, based on `keys_template`."""
if isinstance(keys_template, dict):
assert isinstance(x, dict), type(x)
for k, v in keys_template.items():
if not isinstance(k, str):
assert str(k) in x, (k, x.keys)
x[k] = x.pop(str(k))
_restore_dict_types(x[k], v)
elif isinstance(keys_template, list):
assert isinstance(x, list), type(x)
for x_val, templ_val in zip(x, keys_template):
_restore_dict_types(x_val, templ_val)
@dataclass(frozen=True)
class MCoreSavePlan(SavePlan):
"""SavePlan with MCore specific data."""
mcore_data: Dict[str, Dict[str, Any]] = None # Mcore related data about each tensor
class MCoreSavePlanner(DefaultSavePlanner):
"""Differs with the default planner by saving BytesIO objects on all ranks.
In the integration of MCore with PyT Distributed format, BytesIO objects
come from ShardedObjects, which should be treated as separate objects on each rank
(not common on all ranks).
Also, the objects are already packed in io.BytesIO, so no need to redo it
in transform_object.
"""
def __init__(
self,
*args,
dedup_replicated_tensors: Optional[bool] = None,
nd_flattened_global_shapes: Optional[Dict[str, Tuple[int, ...]]] = None,
can_run_decentralized_global_plan: bool = True,
**kwargs,
) -> None:
# `dedup_replicated_tensors` was deprecated in 2.3; this check avoids warnings
# during saving.
if get_torch_version() <= PkgVersion("2.2"):
kwargs['dedup_replicated_tensors'] = dedup_replicated_tensors
super().__init__(*args, **kwargs)
self.nd_flattened_global_shapes = nd_flattened_global_shapes or {}
self.can_run_decentralized_global_plan = can_run_decentralized_global_plan
if can_run_decentralized_global_plan:
assert (
not dedup_replicated_tensors
), 'Cannot run decentralized plan with dedup_replicated_tensors=True'
assert (
not self.flatten_state_dict
), 'Cannot run decentralized plan with flatten_state_dict=True'
def create_local_plan(self) -> SavePlan:
"""Adds IOBytes write request on non-coordinator ranks."""
# NOTE: for PyT 2.4.0a0 we can't rely on `create_default_local_save_plan` because
# some alpha versions (specifically 2.4.0a0+f70bd71a48 in 24.06 NGC PyTorch container)
# add iobytes request only on coordinator ranks and some alpha versions
# (specifically 2.4.0a0+3bcc3cddb5 in 24.07 NGC PyTorch container)
# add those requests on all ranks. We inline a simplified version of this method below.
write_items = []
for fqn, obj in self.state_dict.items():
assert not HAVE_DTENSOR or not isinstance(
obj, DTensor
) # translation from MCore ShardedTensors shouldn't result in DTensors
# Create write requests for tensor and bytes values.
# For MCore, these should be already non-duplicates.
write_items += _create_write_items(fqn, obj)
self.plan = MCoreSavePlan(
items=write_items,
planner_data=self.mappings,
mcore_data={
k: sh_ten.mcore_metadata
for k, sh_ten in self.state_dict.items()
if isinstance(sh_ten, TorchShardedTensor)
},
)
return self.plan
def create_global_plan(self, all_plans: List[MCoreSavePlan]) -> Tuple[List[SavePlan], Metadata]:
"""Merges MCore data for all plans."""
global_plan, metadata = super().create_global_plan(all_plans)
metadata.mcore_data = dict(ChainMap(*(plan.mcore_data for plan in all_plans)))
return global_plan, metadata
def create_decentralized_global_plan(self, local_plan: SavePlan) -> SavePlan:
"""Nothing to do, just some checks.
Args:
local_plan (SavePlan): local plan to turn to a global plan
(without interactions with other ranks)
Returns:
SavePlan - locally transformed plan equivalent to the plan that would be
created by the coordinator
"""
assert (
not self.flatten_state_dict
), 'Cannot run decentralized plan with flatten_state_dict=True'
assert not local_plan.planner_data, 'Planner data should be empty with decentralized plan'
return local_plan
def transform_object(self, write_item: WriteItem, object: Any):
"""Make no transformations - bytes objects are already serialized."""
return object
class MCoreLoadPlanner(DefaultLoadPlanner):
"""Adds global shape validation to the default planner.
If global shape validation can be ignored (shouldn't!), the default
load planner can be used.
"""
def __init__(
self,
*args,
shapes_validation_sharded_tensors: Iterable[ShardedTensor] = (),
allow_shape_mismatch_sharded_tensors: Dict[str, ShardedTensor] = None,
**kwargs,
) -> None:
super().__init__(*args, **kwargs)
self.shapes_validation_sharded_tensors = shapes_validation_sharded_tensors
self.allow_shape_mismatch_sharded_tensors = allow_shape_mismatch_sharded_tensors
self._intermediate_read_item_and_target: Optional[Tuple[ReadItem, torch.Tensor]] = None
@staticmethod
def _expected_shape(sh_ten):
return (
nd_flattened_tensor_reformulated_global_shape(sh_ten)
if is_nd_flattened_tensor(sh_ten)
else sh_ten.global_shape
)
def _validate_global_shapes(self, metadata, sharded_tensors):
for sh_ten in sharded_tensors:
if sh_ten.key not in metadata.state_dict_metadata:
raise KeyError(
f"{sh_ten.key} from model not in state dict:"
f" {sorted(metadata.state_dict_metadata.keys())}"
)
loaded_shape = metadata.state_dict_metadata[sh_ten.key].size
expected_shape = self._expected_shape(sh_ten)
if loaded_shape != expected_shape:
if is_nd_flattened_tensor(sh_ten) and len(sh_ten.global_shape) == 1:
# Handle legacy 1-D flattened tensors checkpoint format
# where the global shape is not stored in the metadata
expected_shape = sh_ten.global_shape
if loaded_shape == expected_shape:
continue
_msg = (
f'Global shape mismatch for loaded ({loaded_shape})'
f' and expected ({expected_shape}) tensor'
f' for key {sh_ten.key}'
)
raise CheckpointingException(_msg)
@contextmanager
def _temporarily_bypass_shape_validation(self):
"""
Temporarily set the size of tensors to their expected shapes to bypass DCP shape validation.
This is used when validating the shapes during local plan creation.
"""
if not self.allow_shape_mismatch_sharded_tensors:
yield
return
tensor_metadata = self.metadata.state_dict_metadata
metadata_with_sizes = [
(tensor_metadata[key], tensor_metadata[key].size, sharded_tensor)
for key, sharded_tensor in self.allow_shape_mismatch_sharded_tensors.items()
]
try:
# Temporarily set sizes to expected shapes
for md, _, sharded_tensor in metadata_with_sizes:
md.size = self._expected_shape(sharded_tensor)
yield
finally:
# Restore original sizes after yield
for md, size, _ in metadata_with_sizes:
md.size = size
def create_local_plan(self) -> LoadPlan:
"""Runs additional shapes validation."""
self._validate_global_shapes(self.metadata, self.shapes_validation_sharded_tensors)
with self._temporarily_bypass_shape_validation():
local_plan = super().create_local_plan()
return local_plan
def resolve_tensor(self, read_item: ReadItem):
"""Override to add FP8 support.
Narrowing the Float8Tensor can create incontiguous tensors and there are
no `copy` kernels for such cases. This method creates a contiguous FP8
tensors so that the subsequent `copy_` in FileSystemReader succeeds.
Note that this requires tracking the original tensor
(as `self._intermediate_read_item_and_target` attribute)
and restoring it in `commit_tensor` method.
"""
target_tensor = super().resolve_tensor(read_item)
if (
not target_tensor.is_contiguous()
and HAVE_TE
and isinstance(target_tensor, Float8Tensor)
):
self._intermediate_read_item_and_target = (read_item, target_tensor)
target_tensor = Float8Tensor.make_like(
target_tensor, data=target_tensor._data.contiguous()
)
return target_tensor
def commit_tensor(self, read_item: ReadItem, tensor: torch.Tensor) -> None:
"""Restores the original FP8 tensor saved in `resolve_tensor`."""
if self._intermediate_read_item_and_target is not None:
interm_read_item, target_tensor = self._intermediate_read_item_and_target
assert (
interm_read_item is read_item
), '`commit_tensor` method should be called right after `resolve_tensor`'
target_tensor.copy_(tensor)
tensor = target_tensor
self._intermediate_read_item_and_target = None
return super().commit_tensor(read_item, tensor)
class TorchDistSaveShardedStrategy(AsyncSaveShardedStrategy):
"""Async save strategy for the PyT Distributed format.
The idea is to translate MCore ShardedTensors into PyT ShardedTensors
and use the async-adjusted torch.distributed.checkpoint saving mechanism
provided by the FileSystemWriterAsync writer.
"""
def __init__(
self,
backend: str,
version: int,
keep_only_main_replica: bool = True,
thread_count: int = 2,
cached_metadata: bool = False,
separation_hint: str = None,
):
"""Adds parameters specific to PyT Distributed format
Args:
backend (str): format backend string
version (int): format version
keep_only_main_replica (bool, optional): PyT Distributed has a mechanism
for deduplication, but replica_id aware deduplication is more coherent.
Default is True (recommended to keep it).
thread_count (int, optional): threads to use during saving.
Affects the number of files in the checkpoint (saving ranks * num_threads).
cached_metadata (bool, optional): Enables using cached global metadata to avoid
gathering local metadata every checkpointing invocation
separation_hint(str, optional): If provided, all tensors whose keys have this
prefix will be saved to a separate file.
"""
super().__init__(backend, version)
self.keep_only_main_replica = keep_only_main_replica
self.thread_count = thread_count
# Cached SavePlans to skip plan in `save_state_dict_async_plan`
# cached outcome of `SavePlan.prepare_global_plan`,
# which aggregates local plans from all ranks
self.cached_central_plan: SavePlan = None
# cached outcome of `SavePlan.prepare_local_plan` describes how local state_dict is written
self.cached_local_plan: SavePlan = None
# Cached global metadata, only `coordinator` for dist-ckpt holds
# if central plans are consistent over iters
self.cached_global_metadata: Metadata = None
# This variable records if the ckpt structures are consistent
# so the following checkpoint savings reuse `cached_global_metadata`
self.validated_cache_reuse: bool = False
# The knob to enable cached metadata communication in saving
self.use_cached_ckpt_structure: bool = cached_metadata
self.separation_hint = separation_hint
self.validated_loaded_metadata_reuse = False
def async_save(
self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path
) -> AsyncRequest:
"""Translates MCore ShardedTensors to PyT ShardedTensors & saves in PyT Distributed format.
Args:
sharded_state_dict (ShardedStateDict): sharded state dict to save
checkpoint_dir (Path): checkpoint directory
Returns: None
"""
# Translate the state dict
(sharded_state_dict, flat_mapping, rename_mapping) = (
_replace_state_dict_keys_with_sharded_keys(
sharded_state_dict, self.keep_only_main_replica
)
)
pyt_state_dict = mcore_to_pyt_state_dict(sharded_state_dict, False)
# Use PyT saving mechanism
writer = FileSystemWriterAsync(
checkpoint_dir, separation_hint=self.separation_hint, thread_count=self.thread_count
)
# This should be set differently if we run in a smaller process group than the default
coordinator = 0
# Try twice to validate the generated `central_plan` is the same across iterations
# If so, reuse `cached_central_plan` and `cached_global_metadata`
# From the 3rd iteration, `save_state_dict_async_plan` will not generate `global_metadata`
# (return None) so `self.cached_global_metadata` is reused
args_cached_plans = None
loaded_all_plans = None
if self.use_cached_ckpt_structure:
loaded_all_plans = getattr(self.cached_global_metadata, "all_local_plans", None)
if loaded_all_plans is None:
logger.debug(
"no all_local_plans in metadata - can't verify global metadata reuse..."
)
args_cached_plans = (
self.cached_central_plan,
self.cached_local_plan,
self.validated_cache_reuse,
)
(
save_state_dict_ret,
self.cached_central_plan,
self.cached_local_plan,
self.validated_cache_reuse,
self.validated_loaded_metadata_reuse,
) = save_state_dict_async_plan(
pyt_state_dict,
writer,
None,
coordinator,
planner=MCoreSavePlanner(
dedup_replicated_tensors=not self.keep_only_main_replica, flatten_state_dict=False
),
cached_ckpt_structure=args_cached_plans,
loaded_all_plans=loaded_all_plans,
)
rank = torch.distributed.get_rank()
if self.use_cached_ckpt_structure:
if (
loaded_all_plans
and self.cached_global_metadata
and self.validated_loaded_metadata_reuse
):
if coordinator == rank:
logger.debug(
f"rank: {rank}, reuse global metadata from loaded"
f" .metadata, {save_state_dict_ret[1]}"
)
save_state_dict_ret = list(save_state_dict_ret)
save_state_dict_ret[1] = self.cached_global_metadata
elif self.validated_cache_reuse:
logger.debug(f"rank: {rank}, cache validated")
if save_state_dict_ret[1]: # when global_metadata is not cached
self.cached_global_metadata = save_state_dict_ret[1] # Cache Metadata
# Only Coordinator rank holds cached global_metadata
# (None is returned for global_metadata)
elif coordinator == rank:
logger.debug(
f"rank: {rank}, reuse global metadata cached from previous"
f" save iteration, {save_state_dict_ret[1]}"
)
save_state_dict_ret = list(save_state_dict_ret)
save_state_dict_ret[1] = self.cached_global_metadata
return self._get_save_and_finalize_callbacks(writer, save_state_dict_ret)
def _get_save_and_finalize_callbacks(self, writer, save_state_dict_ret) -> AsyncRequest:
save_fn_args = writer.get_save_function_and_args()
save_fn, preload_fn, save_args = save_fn_args
def finalize_fn():
save_state_dict_async_finalize(*save_state_dict_ret)
torch.distributed.barrier()
return AsyncRequest(save_fn, save_args, [finalize_fn], preload_fn=preload_fn)
def can_handle_sharded_objects(self):
return True
def get_reformulation_metadata(
sharded_state_dict: ShardedStateDict, checkpoint_dir: Path
) -> Dict[str, TensorReformulationMetadata]:
"""Reads MCore data for N-D flattened tensors from checkpoint metadata during ckpt load.
Args:
sharded_state_dict (ShardedStateDict): sharded state dict to load
checkpoint_dir (Path): checkpoint directory
Returns:
Dict[str, TensorReformulationMetadata] - dictionary that maps keys of every
N-D flattened tensor from the sharded_state_dict to its original global shape
as stored in `mcore_data` in the checkpoint.
"""
ckpt_metadata = FileSystemReader(checkpoint_dir).read_metadata()
reformulation_metadata = {}
for sh_ten in nested_values(sharded_state_dict):
if not is_nd_flattened_tensor(sh_ten):
continue
try:
ckpt_global_shape = ckpt_metadata.mcore_data[sh_ten.key][
'nd_reformulated_orig_global_shape'
]
except KeyError as e:
if len(sh_ten.global_shape) == 1:
warnings.warn(
f'Legacy checkpoint format detected for 1-D flattened tensor {sh_ten}. '
'Skip metadata reformulation.'
)
continue
raise CheckpointingException(
f'Cannot find global shape metadata for N-D flattened tensor {sh_ten} '
f'in checkpoint metadata: {ckpt_metadata.mcore_data}'
) from e
reformulation_metadata[sh_ten.key] = TensorReformulationMetadata(
ckpt_global_shape, ckpt_metadata.state_dict_metadata[sh_ten.key].size
)
return reformulation_metadata
class TorchDistLoadShardedStrategy(LoadShardedStrategy):
"""Basic load strategy for the PyT Distributed format."""
def __init__(self):
self.cached_global_metadata: Optional[Metadata] = None
super().__init__()
def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path) -> StateDict:
"""Translates MCore ShardedTensors to PyT ShardedTensors & loads from PyT Distributed fmt.
Args:
sharded_state_dict (ShardedStateDict): sharded state dict with mapping
information to instruct loading
checkpoint_dir (Path): checkpoint directory
Returns: loaded state dict
"""
# Apply N-D tensors resharding
reformulation_metadata = get_reformulation_metadata(sharded_state_dict, checkpoint_dir)
sharded_state_dict, formulation_restore_data = apply_nd_flattened_tensors_reformulation(
sharded_state_dict, reformulation_metadata
)
# Check if there are legacy 1-D flattened tensors in the checkpoint
has_legacy_1d_flattened_tensors = False
for sh_ten in nested_values(sharded_state_dict):
if is_nd_flattened_tensor(sh_ten) and sh_ten.key not in reformulation_metadata:
has_legacy_1d_flattened_tensors = True
break
flexible_shape_sharded_tensors = [
sh_ten
for sh_ten in nested_values(sharded_state_dict)
if isinstance(sh_ten, ShardedTensor) and not sh_ten.allow_shape_mismatch
]
allow_shape_mismatch_sharded_tensors = {
sh_ten.key: sh_ten
for sh_ten in nested_values(sharded_state_dict)
if isinstance(sh_ten, ShardedTensor) and sh_ten.allow_shape_mismatch
}
orig_sharded_state_dict = sharded_state_dict
# MCore state dict to PyT Distributed compatible
(sharded_state_dict, flat_mapping, rename_mapping) = (
_replace_state_dict_keys_with_sharded_keys(sharded_state_dict)
)
pyt_state_dict = mcore_to_pyt_state_dict(
sharded_state_dict, True, load_legacy_1d_flatten_tensors=has_legacy_1d_flattened_tensors
)
# Load PyT Distributed format
fsr = CachedMetadataFileSystemReader(checkpoint_dir)
checkpoint.load_state_dict(
pyt_state_dict,
fsr,
planner=MCoreLoadPlanner(
shapes_validation_sharded_tensors=flexible_shape_sharded_tensors,
allow_shape_mismatch_sharded_tensors=allow_shape_mismatch_sharded_tensors,
),
)
self.cached_global_metadata = (
fsr.read_metadata()
) # no storage interaction thanks to caching
pyt_state_dict = cast(
Dict[str, Union[TorchShardedTensor, List[io.BytesIO]]], pyt_state_dict
)
# Unwrap ShardedTensors and return to original state dict
mcore_state_dict = {
k: v if not isinstance(v, TorchShardedTensor) else _unwrap_pyt_sharded_tensor(v)
for k, v in pyt_state_dict.items()
}
mcore_state_dict = _replace_sharded_keys_with_state_dict_keys(
mcore_state_dict, flat_mapping, rename_mapping
)
_restore_dict_types(mcore_state_dict, orig_sharded_state_dict)
# Apply N-D tensors resharding postprocessing
mcore_state_dict = restore_nd_flattened_tensors_formulation(
mcore_state_dict, formulation_restore_data
)
return mcore_state_dict
def load_tensors_metadata(self, checkpoint_dir: Path, metadata: Metadata = None):
"""Uses tensors metadata stored in the metadata file."""
if metadata is None:
fs_reader = FileSystemReader(checkpoint_dir)
metadata = fs_reader.read_metadata()
mcore_data = getattr(metadata, 'mcore_data', {})
sharded_metadata = {}
for k, tp in metadata.state_dict_metadata.items():
if not isinstance(tp, TensorStorageMetadata):
continue # load only tensors
nd_orig_global_shape = mcore_data.get(k, {}).get('nd_reformulated_orig_global_shape')
if nd_orig_global_shape is None:
# Regular tensor
sharded_metadata[k] = ShardedTensor.from_rank_offsets(
k, torch.empty(tp.size, **tp.properties.__dict__, device='meta')
).without_data()
else:
# N-D flattened tensor
unflat_ten = torch.empty(
nd_orig_global_shape, **tp.properties.__dict__, device='meta'
)
flat_ten = unflat_ten.flatten()
sharded_metadata[k] = ShardedTensor.from_rank_offsets_flat(
k,
flat_ten,
unflat_ten.shape,
flattened_range=slice(0, unflat_ten.numel()), # whole slice
).without_data()
return sharded_metadata
def load_sharded_metadata(self, checkpoint_dir: Path) -> ShardedStateDict:
"""Uses tensors and objects metadata stored in the metadata file."""
fs_reader = FileSystemReader(checkpoint_dir)
metadata = fs_reader.read_metadata()
sharded_metadata = {}
for metadata_key, storage_metadata in metadata.state_dict_metadata.items():
if not isinstance(storage_metadata, BytesStorageMetadata):
continue
sh_obj = ShardedObject.empty_from_unique_key(metadata_key)
sharded_metadata[sh_obj.unique_key] = sh_obj
sharded_metadata.update(self.load_tensors_metadata(checkpoint_dir, metadata))
return sharded_metadata
def remove_sharded_tensors(self, checkpoint_dir: str, key_prefix: str):
"""Removes checkpoint files whose keys have the given prefix.
Performs the following steps:
1. checks whether there are files that start with the key_prefix
2. loads metadata
3. removes all entries from the metadata that start with the key_prefix
4. resaves the new metadata and removes the old metadata
5. removes the relevant files
"""
assert is_torch_min_version(
"2.3.0"
), f'torch >= 2.3.0 is required for remove_sharded_tensors'
distckpt_files = [f for f in os.listdir(checkpoint_dir) if f.endswith("distcp")]
files_to_remove = [f for f in distckpt_files if f.startswith(key_prefix)]
if not files_to_remove:
warnings.warn(
f'There are no files in {checkpoint_dir} that begin with "{key_prefix}".'
f' Skipping removal.'
)
return
fs_reader = FileSystemReader(checkpoint_dir)
original_metadata = fs_reader.read_metadata()
new_state_dict_metadata = {}
new_planner_data = {}
new_storage_data = {}
for k in original_metadata.state_dict_metadata.keys():
if k.startswith(key_prefix):
continue
new_state_dict_metadata[k] = original_metadata.state_dict_metadata[k]
for k in original_metadata.planner_data.keys():
if k.startswith(key_prefix):
continue
new_planner_data[k] = original_metadata.planner_data[k]
for k in original_metadata.storage_data.keys():
if k.fqn.startswith(key_prefix):
continue
new_storage_data[k] = original_metadata.storage_data[k]
metadata = Metadata(
state_dict_metadata=new_state_dict_metadata,
planner_data=new_planner_data,
storage_data=new_storage_data,
)
fs_writer = FileSystemWriter(checkpoint_dir)
metadata_filename = cast(Path, fs_writer.fs.concat_path(fs_writer.path, _metadata_fn))
tmp_path = cast(
metadata_filename, fs_writer.fs.concat_path(fs_writer.path, f"{_metadata_fn}.tmp")
)
old_path = cast(
metadata_filename, fs_writer.fs.concat_path(fs_writer.path, f"{_metadata_fn}.bck")
)
## save the new metadata
with fs_writer.fs.create_stream(tmp_path, "wb") as metadata_file:
pickle.dump(metadata, metadata_file)
try:
os.fsync(metadata_file.fileno())
except AttributeError:
os.sync()
## move the old metadata
fs_writer.fs.rename(fs_writer.metadata_path, old_path)
try:
## rename the new metadata
fs_writer.fs.rename(tmp_path, fs_writer.metadata_path)
## finally, remove the files we want to drop
for f in files_to_remove:
fs_writer.fs.rm_file(checkpoint_dir / f)
except Exception as e:
fs_writer.fs.rename(old_path, fs_writer.metadata_path)
raise e
else:
fs_writer.fs.rm_file(old_path)
def can_handle_sharded_objects(self):
return True
def check_backend_compatibility(self, loaded_version):
pass # TODO
def check_version_compatibility(self, loaded_version):
pass # TODO
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment