Commit d444a97a authored by yangzhong's avatar yangzhong
Browse files

首次上传

parents
Pipeline #3020 canceled with stages
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import os
from collections import deque
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Union
import numpy
import torch
from packaging.version import Version as PkgVersion
from megatron.core.datasets.indexed_dataset import IndexedDataset
from megatron.core.datasets.masked_dataset import (
MaskedWordPieceDataset,
MaskedWordPieceDatasetConfig,
)
from megatron.core.datasets.utils import Split
from megatron.core.utils import get_te_version
@dataclass
class T5MaskedWordPieceDatasetConfig(MaskedWordPieceDatasetConfig):
"""Configuration object for Megatron Core T5 WordPiece datasets
NB: As a temporary holdover from Megatron-LM. The T5 tokenizer has an attribute which defines
a number of special sentinel tokens used during sampling. The assert in __post_init__ serves to
preserve compatibility with Megatron-LM until the T5 tokenizer is in Megatron Core.
"""
sequence_length_encoder: Optional[int] = field(init=False, default=None)
"""A sequence_length alias and the sequence length for the encoder"""
sequence_length_decoder: int = None
"""The sequence length for the decoder"""
def __post_init__(self) -> None:
"""Do asserts and set fields post init"""
super().__post_init__()
self.sequence_length_encoder = self.sequence_length
assert self.sequence_length_encoder is not None
assert self.sequence_length_decoder is not None
assert len(self.tokenizer.additional_special_tokens_ids) > 0
class T5MaskedWordPieceDataset(MaskedWordPieceDataset):
"""The T5 dataset that assumes WordPiece tokenization
Args:
indexed_dataset (IndexedDataset): The IndexedDataset around
which to build the MegatronDataset
dataset_path (str): The real path on disk to the dataset, for bookkeeping
indexed_indices (numpy.ndarray): The set of the documents indices to expose
num_samples (Optional[int]): The number of samples to draw from the indexed
dataset. When None, build as many samples as correspond to one epoch.
index_split (Split): The indexed_indices Split
config (T5MaskedWordPieceDatasetConfig): The config
"""
def __init__(
self,
indexed_dataset: IndexedDataset,
dataset_path: str,
indexed_indices: numpy.ndarray,
num_samples: Optional[int],
index_split: Split,
config: T5MaskedWordPieceDatasetConfig,
) -> None:
super().__init__(
indexed_dataset, dataset_path, indexed_indices, num_samples, index_split, config
)
self.token_lookup = list(self.config.tokenizer.inv_vocab.keys())
# Account for the single <bos> and single <eos> token ids
self.sample_index = self._build_sample_index(self.config.sequence_length - 2, 1)
@staticmethod
def _key_config_attributes() -> List[str]:
"""Inherited method implementation
Returns:
List[str]: The key config attributes
"""
return super(
T5MaskedWordPieceDataset, T5MaskedWordPieceDataset
)._key_config_attributes() + ["sequence_length_decoder"]
@staticmethod
def _build_b1ss_attention_mask(
source_block: torch.tensor, target_block: torch.tensor, make_history_mask: bool = False
) -> torch.tensor:
"""Build an attention-mask having shape (bs, 1, q_len, kv_len)
from source_block and target_block
Args:
source_block (torch.tensor): A 2-D array of tokens (bs, q_len)
target_block (torch.tensor): A 2-D array of tokens (bs, kv_len)
make_history_mask (bool): Whether to turn mask into causal mask
Returns:
torch.tensor: The 4-D attention mask (bs, 1, q_len, kv_len)
"""
batch_size = source_block.shape[0]
attention_mask = []
for i in range(batch_size):
source_sample = source_block[i]
target_sample = target_block[i]
mask = (target_sample[None, :] >= 1) * (source_sample[:, None] >= 1)
if make_history_mask:
arange = numpy.arange(source_sample.shape[0])
history_mask = arange[None,] <= arange[:, None]
history_mask = torch.tensor(history_mask).to(mask.device)
mask = mask * history_mask
mask = ~(mask) # flip True to False
attention_mask.append(mask)
attention_mask = torch.stack(attention_mask)
attention_mask = attention_mask.unsqueeze(1)
return attention_mask
@staticmethod
def config_attention_mask(
encoder_tokens: torch.tensor,
decoder_tokens: torch.tensor,
encoder_mask: torch.tensor,
decoder_mask: torch.tensor,
use_local: bool = False,
test_te_version: str = None,
) -> torch.tensor:
"""Config attention-mask for encoder_mask, decoder_mask, encoder_decoder_mask
conditioned on transformer-implementation (e.g. TE vs local), TE versions,
and TE backends
Args:
encoder_tokens (torch.tensor): A 2-D array of tokens (bs, kv_len)
decoder_tokens (torch.tensor): A 2-D array of tokens (bs, q_len)
encoder_mask (torch.tensor): A 2-D array of tokens (bs, kv_len)
decoder_mask (torch.tensor): A 2-D array of tokens (bs, q_len)
use_local (bool): Whether the current T5 model uses local (vs TE)
transformer implmentation
Returns:
Configured encoder_mask, decoder_mask, encoder_decoder_mask
torch.tensor: configured encoder attention mask
torch.tensor: configured decoder attention mask
torch.tensor: configured encoder-decoder attention mask
"""
# If using local transformer implementation (not transformer_engine):
# re-organize all attention masks, because local and transformer_engine
# backbones use different masks shapes. E.g.:
# (local: b1ss - transformer_engine: b11s)
if use_local:
encoder_mask = T5MaskedWordPieceDataset._build_b1ss_attention_mask(
encoder_tokens, encoder_tokens
)
decoder_mask = T5MaskedWordPieceDataset._build_b1ss_attention_mask(
decoder_tokens, decoder_tokens, make_history_mask=True
)
encoder_decoder_mask = T5MaskedWordPieceDataset._build_b1ss_attention_mask(
decoder_tokens, encoder_tokens
)
else:
# If using transformer_engine transformer implementation:
# 1. For TE version >= 1.10, across all 3 backends,
# The padding mask is configued as
# [bs, 1, 1, seq_len] for self-attention and
# ([bs, 1, 1, q_len], [bs, 1, 1, kv_len]) for cross-attention
# 2. For TE version >=1.7 and <1.10, when using Non-fused backend,
# The padding mask is configued as
# [bs, 1, q_len, kv_len] for both self-attention and for cross-attention
# 3. For TE version <1.7, only support Non-fused backend
# The padding mask is configued as
# [bs, 1, q_len, kv_len] for both self-attention and for cross-attention
# Process for Flash/Fused
encoder_mask = encoder_mask.unsqueeze(1).unsqueeze(1)
decoder_mask = decoder_mask.unsqueeze(1).unsqueeze(1)
encoder_decoder_mask = (decoder_mask, encoder_mask)
# set decoder_mask to None because decoder uses AttnMaskType.causal
decoder_mask = None
# get TE version, using test TE version if not None
if test_te_version is not None:
te_version = PkgVersion(test_te_version)
else:
te_version = get_te_version()
# Check for older TE version than 1.10, adjust attention mask accordingly
flash_attention_enabled = os.getenv('NVTE_FLASH_ATTN') == '1'
fused_attention_enabled = os.getenv('NVTE_FUSED_ATTN') == '1'
if (te_version < PkgVersion("1.10.0")) and (te_version >= PkgVersion("1.7.0")):
if not (flash_attention_enabled) and not (fused_attention_enabled):
encoder_mask = T5MaskedWordPieceDataset._build_b1ss_attention_mask(
encoder_tokens, encoder_tokens
)
encoder_decoder_mask = T5MaskedWordPieceDataset._build_b1ss_attention_mask(
decoder_tokens, encoder_tokens
)
else:
pass
elif te_version < PkgVersion("1.7.0"):
if not (flash_attention_enabled) and not (fused_attention_enabled):
encoder_mask = T5MaskedWordPieceDataset._build_b1ss_attention_mask(
encoder_tokens, encoder_tokens
)
encoder_decoder_mask = T5MaskedWordPieceDataset._build_b1ss_attention_mask(
decoder_tokens, encoder_tokens
)
else:
assert not flash_attention_enabled and not fused_attention_enabled, (
"Flash and fused attention is not supported with transformer "
"engine version < 1.7. Set NVTE_FLASH_ATTN=0 and NVTE_FUSED_ATTN=0"
"or upgrade transformer engine >= 1.7"
)
return encoder_mask, decoder_mask, encoder_decoder_mask
def __getitem__(self, idx: int) -> Dict[str, Union[int, numpy.ndarray]]:
"""Abstract method implementation
Args:
idx (int): The index into the dataset
Returns:
Dict[str, Union[int, numpy.ndarray]]: The
"""
idx_beg, idx_end, target_sequence_length = self.sample_index[idx]
sample = [self.dataset[i] for i in range(idx_beg, idx_end)]
numpy_random_state = numpy.random.RandomState(seed=(self.config.random_seed + idx) % 2**32)
assert target_sequence_length <= self.config.sequence_length
# Flatten the sample into a list of tokens
tokens = [token for sentence in sample for token in sentence]
# Truncate the list of tokens to a desired length
truncated = len(tokens) > target_sequence_length
tokens = tokens[:target_sequence_length]
# Masking
(tokens, _, _, _, masked_spans) = self._create_masked_lm_predictions(
tokens, target_sequence_length, numpy_random_state
)
# Prepare the encoder input and decoder input and output
sentinels = deque(self.config.tokenizer.additional_special_tokens_ids)
encoder_input = []
decoder_input = [self.config.tokenizer.bos]
decoder_output = []
idx_beg = 0
for indices, labels in masked_spans:
sentinel = sentinels.popleft()
# set the end index
idx_end = indices[0]
encoder_input.extend(tokens[idx_beg:idx_end])
encoder_input.append(sentinel)
decoder_input.append(sentinel)
decoder_input.extend(labels)
decoder_output.append(sentinel)
decoder_output.extend(labels)
# set the start index
idx_beg = indices[-1] + 1
encoder_input.extend(tokens[idx_beg:])
decoder_output.append(self.config.tokenizer.eos)
# Pad the sequences and convert to NumPy
length_toks_encoder = len(encoder_input)
length_toks_decoder = len(decoder_input)
length_pads_encoder = self.config.sequence_length_encoder - length_toks_encoder
length_pads_decoder = self.config.sequence_length_decoder - length_toks_decoder
assert length_pads_encoder >= 0
assert length_pads_decoder >= 0
encoder_input = numpy.array(encoder_input, dtype=numpy.int64)
encoder_input = numpy.pad(
encoder_input, (0, length_pads_encoder), constant_values=self.config.tokenizer.pad
)
decoder_input = numpy.array(decoder_input, dtype=numpy.int64)
decoder_input = numpy.pad(
decoder_input, (0, length_pads_decoder), constant_values=self.config.tokenizer.pad
)
# Create attention and history masks
mask_encoder = numpy.array([1] * length_toks_encoder + [0] * length_pads_encoder)
mask_decoder = numpy.array([1] * length_toks_decoder + [0] * length_pads_decoder)
mask_encoder_decoder = None
# Mask the labels
decoder_output = numpy.array(decoder_output, dtype=numpy.int64)
decoder_output = numpy.pad(decoder_output, (0, length_pads_decoder), constant_values=-1)
# Get the loss mask
loss_mask = numpy.zeros(self.config.sequence_length_decoder, dtype=numpy.int64)
loss_mask[:length_toks_decoder] = 1
return {
"text_enc": encoder_input,
"text_dec": decoder_input,
"labels": decoder_output,
"loss_mask": loss_mask,
"truncated": int(truncated),
"enc_mask": mask_encoder,
"dec_mask": mask_decoder,
}
def _get_token_mask(self, numpy_random_state: numpy.random.RandomState) -> int:
"""Abstract method implementation
100% of the time, replace the token id with mask token id.
Args:
numpy_random_state (RandomState): The NumPy random state
Returns:
int: The mask token id
"""
return self.config.tokenizer.mask
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import logging
from enum import Enum
from typing import List, Optional, Tuple
import numpy
import torch
from ..utils import log_single_rank
logger = logging.getLogger(__name__)
class Split(Enum):
train = 0
valid = 1
test = 2
def compile_helpers():
"""Compile C++ helper functions at runtime. Make sure this is invoked on a single process."""
import os
import subprocess
command = ["make", "-C", os.path.abspath(os.path.dirname(__file__))]
if subprocess.run(command).returncode != 0:
import sys
log_single_rank(logger, logging.ERROR, "Failed to compile the C++ dataset helper functions")
sys.exit(1)
def normalize(weights: List[float]) -> List[float]:
"""Do non-exponentiated normalization
Args:
weights (List[float]): The weights
Returns:
List[float]: The normalized weights
"""
w = numpy.array(weights, dtype=numpy.float64)
w_sum = numpy.sum(w)
w = (w / w_sum).tolist()
return w
def get_blend_from_list(
blend: Optional[List[str]],
) -> Optional[Tuple[List[str], Optional[List[float]]]]:
"""Get the megatron.core.datasets.blended_megatron_dataset_config.BlendedMegatronDatasetConfig blend from the blend list
Args:
blend (Optional[List[str]]): The blend list, which can be either (1) a list of prefixes, e.g. ["path/to/dataset_1_prefix", "path/to/dataset_2_prefix"], or (2) a flattened, zipped list of weights and prefixes, e.g. ["30", "path/to/dataset_1_prefix", "70", "path/to/dataset_2_prefix"]
Returns:
Optional[Tuple[List[str], Optional[List[float]]]]: The blend, consisting of a list of dataset prefixes and optionally a list of dataset weights, e.g. [["path/to/dataset_1_prefix", "path/to/dataset_2_prefix"], [30.0, 70.0]].
"""
if blend is None:
return None
if len(blend) % 2 == 1:
weight_per_dataset = None
raw_prefix_per_dataset = blend
else:
raw_weight_per_dataset, raw_prefix_per_dataset = zip(
*[(blend[i], blend[i + 1]) for i in range(0, len(blend), 2)]
)
weight_per_dataset = []
for rwpd in raw_weight_per_dataset:
try:
weight = float(rwpd)
except ValueError:
weight = None
weight_per_dataset.append(weight)
is_none = map(lambda _: _ is None, weight_per_dataset)
if any(is_none):
assert all(is_none)
weight_per_dataset = None
raw_prefix_per_dataset = blend
prefix_per_dataset = [rppd.strip() for rppd in raw_prefix_per_dataset]
return prefix_per_dataset, weight_per_dataset
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import os
from typing import Any, Dict, NamedTuple, Protocol, Tuple
import torch
try:
import boto3
import botocore.exceptions as exceptions
except ModuleNotFoundError:
pass
S3_PREFIX = "s3://"
class S3Config(NamedTuple):
"""Config when the data (.bin) file and the index (.idx) file are in S3
TODO: These parameters are few and can be consolidated with parameters specific to bin reader
classes - @jkamalu
Attributes:
path_to_idx_cache (str): The local directory where we will store the index (.idx) file
bin_chunk_nbytes (int): If the number of bytes is too small, then we send a request to S3 at each call of the `read` method in _S3BinReader, which is slow, because each request has a fixed cost independent of the size of the byte range requested. If the number of bytes is too large, then we only rarely have to send requests to S3, but it takes a lot of time to complete the request when we do, which can block training. We've found that 256 * 1024 * 1024 (i.e., 256 MiB) has worked well (though we have not put that much effort into tuning it), so we default to it.
"""
path_to_idx_cache: str
bin_chunk_nbytes: int = 256 * 1024 * 1024
class S3Client(Protocol):
"""The protocol which all s3 clients should abide by"""
def download_file(self, Bucket: str, Key: str, Filename: str) -> None: ...
def upload_file(self, Filename: str, Bucket: str, Key: str) -> None: ...
def head_object(self, Bucket: str, Key: str) -> Dict[str, Any]: ...
def get_object(self, Bucket: str, Key: str, Range: str) -> Dict[str, Any]: ...
def close(self) -> None: ...
def is_s3_path(path: str) -> bool:
"""Ascertain whether a path is in S3
Args:
path (str): The path
Returns:
bool: True if the path is in S3, False otherwise
"""
return path.startswith(S3_PREFIX)
def parse_s3_path(path: str) -> Tuple[str, str]:
"""Parses the given S3 path returning correspsonding bucket and key.
Args:
path (str): The S3 path
Returns:
Tuple[str, str]: A (bucket, key) tuple
"""
assert is_s3_path(path)
parts = path.replace(S3_PREFIX, "").split("/")
bucket = parts[0]
if len(parts) > 1:
key = "/".join(parts[1:])
assert S3_PREFIX + bucket + "/" + key == path
else:
key = ""
return bucket, key
def object_exists(client: S3Client, path: str) -> bool:
"""Ascertain whether the object at the given S3 path exists in S3
Args:
client (S3Client): The S3 client
path (str): The S3 path
Raises:
botocore.exceptions.ClientError: The error code is 404
Returns:
bool: True if the object exists in S3, False otherwise
"""
parsed_s3_path = parse_s3_path(path)
try:
response = client.head_object(bucket=parsed_s3_path[0], key=parsed_s3_path[1])
except exceptions.ClientError as e:
if e.response["Error"]["Code"] != "404":
raise e
return True
def _download_file(client: S3Client, s3_path: str, local_path: str) -> None:
"""Download the object at the given S3 path to the given local file system path
Args:
client (S3Client): The S3 client
s3_path (str): The S3 source path
local_path (str): The local destination path
"""
dirname = os.path.dirname(local_path)
os.makedirs(dirname, exist_ok=True)
parsed_s3_path = parse_s3_path(s3_path)
client.download_file(parsed_s3_path[0], parsed_s3_path[1], local_path)
def maybe_download_file(s3_path: str, local_path: str) -> None:
"""Download the object at the given S3 path to the given local file system path
In a distributed setting, downloading the S3 object proceeds in stages in order
to try to have the minimum number of processes download the object in order for
all the ranks to have access to the downloaded object.
Args:
s3_path (str): The S3 source path
local_path (str): The local destination path
"""
if torch.distributed.is_initialized():
rank = torch.distributed.get_rank()
local_rank = rank % torch.cuda.device_count()
else:
rank = 0
local_rank = 0
s3_client = boto3.client("s3")
if (not os.path.exists(local_path)) and (rank == 0):
_download_file(s3_client, s3_path, local_path)
if torch.distributed.is_initialized():
torch.distributed.barrier()
# If the `local_path` is in a file system that is not
# shared across all the ranks, then we assume it's in the
# host file system and each host needs to download the file.
if (not os.path.exists(local_path)) and (local_rank == 0):
_download_file(s3_client, s3_path, local_path)
if torch.distributed.is_initialized():
torch.distributed.barrier()
# If the `local_path` still does not exist, then we assume
# each rank is saving to a separate location.
if not os.path.exists(local_path):
_download_file(s3_client, s3_path, local_path)
if torch.distributed.is_initialized():
torch.distributed.barrier()
assert os.path.exists(local_path)
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
from .core import check_is_distributed_checkpoint
from .mapping import LocalNonpersistentObject, LocalNonpersitentObject, ShardedTensor
from .serialization import (
load,
load_common_state_dict,
load_plain_tensors,
load_tensors_metadata,
remove_sharded_tensors,
save,
)
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Module for managing distributed checkpoints metadata. """
import json
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Optional
CONFIG_FNAME = 'metadata.json'
class CheckpointingException(Exception):
"""Base checkpointing related exception"""
pass
@dataclass
class CheckpointingConfig:
"""Documents backends used in the checkpoint.
Checkpoint config keeps track of formats used for storing the sharded tensors
(sharded_backend) and other objects (common_backend).
Note that versioning is not for the checkpoint content (which is application specific),
but for the checkpoint format itself.
"""
sharded_backend: str
sharded_backend_version: int = 1
common_backend: str = 'torch'
common_backend_version: int = 1
def check_is_distributed_checkpoint(checkpoint_dir):
"""Checks if `metadata.json` exists in the checkpoint and is a valid config.
Args:
checkpoint_dir: checkpoint directory
Returns:
bool: True if `metadata.json` exists in the checkpoint and is a valid config.
"""
return maybe_load_config(checkpoint_dir) is not None
def maybe_load_config(checkpoint_dir: str) -> Optional[CheckpointingConfig]:
"""Returns checkpoint config if `checkpoint_dir` is a distributed checkpoint and None otherwise
Args:
checkpoint_dir: checkpoint directory
Returns:
CheckpointingConfig (optional): None if checkpoint is not a valid distributed checkpoint
"""
config_path = Path(checkpoint_dir, CONFIG_FNAME)
if not config_path.exists():
return None
with config_path.open() as f:
config_dict = json.load(f)
return CheckpointingConfig(**config_dict)
def save_config(config: CheckpointingConfig, checkpoint_dir: str):
"""Save given config to checkpoint directory.
Args:
config: checkpoint config
checkpoint_dir: checkpoint directory
Returns:
None
"""
config_path = Path(checkpoint_dir, CONFIG_FNAME)
with config_path.open('w') as f:
json.dump(asdict(config), f)
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Utilities for operating with dicts and lists.
All functions in this module handle nesting of dicts and lists.
Other objects (e.g. tuples) are treated as atomic leaf types that cannot be traversed.
"""
from collections import defaultdict
from typing import Any, Callable, Dict, Iterable, List, Tuple, TypeVar, Union
import numpy as np
import torch
U, V = TypeVar("U"), TypeVar("V")
def extract_matching_values(
x: Union[dict, list], predicate: Callable[[Any], bool], return_lists_as_dicts: bool = False
) -> Tuple[Union[dict, list], Union[dict, list]]:
"""Return matching and nonmatching values. Keeps hierarchy.
Args:
x (Union[dict, list]) : state dict to process. Top-level argument must be a dict or list
predicate (object -> bool): determines matching values
return_lists_as_dicts (bool): if True, matching lists will be turned
into dicts, with keys indicating the indices of original elements.
Useful for reconstructing the original hierarchy.
"""
def _set_elem(target, k, v):
if return_lists_as_dicts:
target[k] = v
else:
target.append(v)
if isinstance(x, dict):
matching_vals = {}
nonmatching_vals = {}
for k, v in x.items():
if isinstance(v, (list, dict)):
match, nonmatch = extract_matching_values(v, predicate, return_lists_as_dicts)
if match:
matching_vals[k] = match
if nonmatch or not v:
nonmatching_vals[k] = nonmatch
elif predicate(v):
matching_vals[k] = v
else:
nonmatching_vals[k] = v
elif isinstance(x, list): # type: ignore
matching_vals = {} if return_lists_as_dicts else []
nonmatching_vals = {} if return_lists_as_dicts else []
for ind, v in enumerate(x):
if isinstance(v, (list, dict)) and v:
match, nonmatch = extract_matching_values(v, predicate, return_lists_as_dicts)
if match:
_set_elem(matching_vals, ind, match)
if nonmatch or not v:
_set_elem(nonmatching_vals, ind, nonmatch)
else:
target = matching_vals if predicate(v) else nonmatching_vals
_set_elem(target, ind, v)
else:
raise ValueError(f'Unexpected top-level object type: {type(x)}')
return matching_vals, nonmatching_vals
def diff(x1: Any, x2: Any, prefix: Tuple = ()) -> Tuple[list, list, list]:
"""Recursive diff of dicts.
Args:
x1 (object): left dict
x2 (object): right dict
prefix (tuple): tracks recursive calls. Used for reporting differing keys.
Returns:
Tuple[list, list, list]: tuple of:
- only_left: Prefixes present only in left dict
- only_right: Prefixes present only in right dict
- mismatch: values present in both dicts but not equal across dicts.
For tensors equality of all elems is checked.
Each element is a tuple (prefix, type of left value, type of right value).
"""
mismatch = []
if isinstance(x1, dict) and isinstance(x2, dict):
only_left = [prefix + (k,) for k in x1.keys() - x2.keys()]
only_right = [prefix + (k,) for k in x2.keys() - x1.keys()]
for k in x2.keys() & x1.keys():
_left, _right, _mismatch = diff(x1[k], x2[k], prefix + (k,))
only_left.extend(_left)
only_right.extend(_right)
mismatch.extend(_mismatch)
elif isinstance(x1, list) or isinstance(x1, tuple) or isinstance(x1, np.ndarray):
assert type(x1) == type(x2)
only_left = list(range(len(x1) - 1, len(x2) - 1, -1))
only_right = list(range(len(x1) - 1, len(x2) - 1, -1))
for i, (v1, v2) in enumerate(zip(x1, x2)):
_left, _right, _mismatch = diff(v1, v2, prefix + (i,))
only_left.extend(_left)
only_right.extend(_right)
mismatch.extend(_mismatch)
else:
only_left = []
only_right = []
if isinstance(x1, torch.Tensor) and isinstance(x2, torch.Tensor):
if x1.device != x2.device:
_is_mismatch = not torch.all(x1.cpu() == x2.cpu())
else:
_is_mismatch = not torch.all(x1 == x2)
# TODO: change with concrete type that has both replica_id and data attrs
elif hasattr(x1, 'replica_id') and hasattr(x2, 'replica_id'):
assert type(x1) == type(x2)
only_left, only_right, mismatch = diff(
x1.data, x2.data, prefix + (type(x1),)
) # type: ignore
_is_mismatch = False
else:
try:
_is_mismatch = bool(x1 != x2)
except RuntimeError:
_is_mismatch = True
if _is_mismatch:
mismatch.append((prefix, type(x1), type(x2)))
return only_left, only_right, mismatch
def inspect_types(x: Any, prefix: Tuple = (), indent: int = 4):
"""Helper to print types of (nested) dict values."""
print_indent = lambda: print(' ' * indent * len(prefix), end='')
if isinstance(x, dict):
print()
for k, v in x.items():
print_indent()
print(f'> {k}: ', end='')
inspect_types(v, prefix + (k,), indent)
elif isinstance(x, list):
print()
for i, v in enumerate(x):
print_indent()
print(f'- {i}: ', end='')
inspect_types(v, prefix + (i,), indent)
else:
if isinstance(x, torch.Tensor):
print(f'Tensor of shape {x.shape}')
else:
try:
x_str = str(x)
except:
x_str = '<no string repr>'
if len(x_str) > 30:
x_str = x_str[:30] + '... (truncated)'
print(f'[{type(x)}]: {x_str}')
def nested_values(x: Union[dict, list]):
"""Returns iterator over (nested) values of a given dict or list."""
x_iter = x.values() if isinstance(x, dict) else x
for v in x_iter:
if isinstance(v, (dict, list)):
yield from nested_values(v)
else:
yield v
def nested_items_iter(x: Union[dict, list]):
"""Returns iterator over (nested) tuples (container, key, value) of a given dict or list."""
x_iter = x.items() if isinstance(x, dict) else enumerate(x)
for k, v in x_iter:
if isinstance(v, (dict, list)):
yield from nested_items_iter(v)
else:
yield x, k, v
def dict_map(f: Callable, d: dict):
"""`map` equivalent for dicts."""
for sub_d, k, v in nested_items_iter(d):
sub_d[k] = f(v)
def dict_map_with_key(f: Callable, d: dict):
"""`map` equivalent for dicts with a function that accepts tuple (key, value)."""
for sub_d, k, v in nested_items_iter(d):
sub_d[k] = f(k, v)
def dict_list_map_inplace(f: Callable[[U], V], x: Union[Dict, List, U]):
"""Maps dicts and lists *in-place* with a given function."""
if isinstance(x, dict):
for k, v in x.items():
x[k] = dict_list_map_inplace(f, v)
elif isinstance(x, list):
x[:] = (dict_list_map_inplace(f, v) for v in x)
else:
return f(x)
return x
def dict_list_map_outplace(f: Callable[[U], V], x: Union[Dict, List, U]) -> Union[Dict, List, V]:
"""Maps dicts and lists *out-of-place* with a given function."""
if isinstance(x, dict):
return {k: dict_list_map_outplace(f, v) for k, v in x.items()}
elif isinstance(x, list):
return [dict_list_map_outplace(f, v) for v in x]
else:
return f(x)
def merge(x1: Union[dict, list], x2: Union[dict, list], key: Tuple[Union[str, int], ...] = ()):
"""Merges dicts and lists recursively."""
if isinstance(x1, dict) and isinstance(x2, dict):
for k, v2 in x2.items():
if k not in x1:
x1[k] = v2
else:
x1[k] = merge(x1[k], v2, key=key + (k,))
elif isinstance(x1, list) and isinstance(x2, list):
if len(x1) != len(x2):
raise ValueError(
f'Cannot merge two lists with different lengths ({len(x1)} and {len(x2)}, '
f'encountered at level {key})'
)
for i, v2 in enumerate(x2):
x1[i] = merge(x1[i], v2, key=key + (i,))
else:
raise ValueError(
f'Duplicate non-dict and non-list values encountered: `{x1}` and `{x2}` '
f'(at level {key})'
)
return x1
def map_reduce(
xs: Iterable,
key_fn: Callable = lambda x: x,
value_fn: Callable = lambda x: x,
reduce_fn: Callable = lambda x: x,
) -> dict:
"""Simple map-reduce implementation following `more_itertools.map_reduce` interface."""
res = defaultdict(list)
for x in xs:
res[key_fn(x)].append(value_fn(x))
for k in res:
res[k] = reduce_fn(res[k])
return dict(res)
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
"""Utilities for exchanging data between ranks."""
import logging
from collections import defaultdict
from functools import reduce
from itertools import zip_longest
from time import time
from typing import Dict, List, NamedTuple, Optional, Set, Tuple, TypeVar, cast
import numpy as np
import torch
from .core import CheckpointingException
from .dict_utils import nested_values
from .mapping import ShardedStateDict, ShardedTensor, is_main_replica
from .utils import _sharded_tensor_shard_id, _ShardId
# TODO: remove TE references once the TE bug is fixed
# Check if Transformer Engine has Float8Tensor class
HAVE_TE_FLOAT8TENSOR = False
try:
from transformer_engine.pytorch.float8_tensor import Float8Tensor
HAVE_TE_FLOAT8TENSOR = True
except (ImportError, ModuleNotFoundError):
# Float8Tensor not found
pass
def is_float8tensor(tensor: torch.Tensor) -> bool:
"""Check if a tensor is a Transformer Engine Float8Tensor"""
return HAVE_TE_FLOAT8TENSOR and isinstance(tensor, Float8Tensor)
logger = logging.getLogger(__name__)
class ShardDistribution(NamedTuple):
"""Represents a distribution of ShardedTensors.
Given distribution is valid only for a specific parallelization group,
which is implicit here (not referenced by this class).
Args:
main_rank_for_shard (Dict[_ShardId, int]): specifies which rank should hold
the main replica for a given shard
shards_in_this_group (Set[_ShardId]): which shards have a main replica
in this parallelization group
shard_to_metadata (Dict[_ShardId, ShardedTensor]): maps ShardedTensor
identifier to the original ShardedTensor
all_ranks_for_shard (Dict[_ShardId, List[int]]): specifies which ranks
need a given shard in a given parallelization group
"""
main_rank_for_shard: Dict[_ShardId, int]
shards_in_this_group: Set[_ShardId]
shard_to_metadata: Dict[_ShardId, ShardedTensor]
all_ranks_for_shard: Dict[_ShardId, List[int]]
def _shard_size(sh_ten: ShardedTensor):
"""Returns size in bytes of a given sharded tensor."""
if sh_ten.flattened_range is None:
numel = np.product(sh_ten.local_shape)
else:
numel = sh_ten.flattened_range.stop - sh_ten.flattened_range.start
return numel * torch._utils._element_size(sh_ten.dtype)
def _get_empty_tensor_for_exchange(
shard_id: _ShardId,
needed_shards: Dict[_ShardId, ShardedTensor],
unneeded_shards: Dict[_ShardId, ShardedTensor],
loaded_tensors: Dict[_ShardId, torch.Tensor],
) -> Tuple[torch.Tensor, Optional[torch.device]]:
"""Determines the empty tensor to use for exchange.
If shard_id is needed by this rank, it will be in the `unloaded_shards`.
Otherwise, the metadata for this tensor can be found in `shard_to_metadata`
Args:
shard_id (_ShardId): shard_id that will be exchanged
needed_shards (Dict[_ShardId, ShardedTensor]): mapping from shard ids
to metadata for shards needed by this rank
unneeded_shards (Dict[_ShardId, ShardedTensor]): mapping from shard ids
to metadata for shards that can be discarded after exchange
loaded_tensors (Dict[_ShardId, torch.Tensor]): mapping where useful tensors
are placed in
Returns:
Tuple[torch.Tensor, Optional[torch.device]]: empty CUDA tensor to be exchanged,
and the device of the original state dict tensor (if there was any)
"""
local_unloaded_sh_ten = needed_shards.get(shard_id)
if local_unloaded_sh_ten is None:
orig_device = None # this tensor will be discarded anyway
sh_ten = unneeded_shards[shard_id]
if sh_ten.data is None:
sh_ten.init_data('cuda')
tensor = sh_ten.data
sh_ten.data = None # won't be used. free memory
else:
tensor = sh_ten.data
if tensor.device.type == 'cpu':
tensor = torch.empty_like(tensor, device='cuda')
else:
local_unloaded_sh_ten.init_data('cuda')
orig_device = local_unloaded_sh_ten.data.device
tensor = local_unloaded_sh_ten.data
if tensor.device.type == 'cpu':
tensor = torch.empty_like(tensor, device='cuda')
loaded_tensors[shard_id] = tensor
return tensor, orig_device
T = TypeVar('T')
def distribute_shards_to_ranks(
shard_to_ranks: Dict[T, List[int]], shard_to_size: Dict[T, int], num_ranks: int
) -> Dict[T, int]:
"""Computes uniform distribution of workload across ranks, based on sizes.
Currently, the assignment is greedy, based on:
1. Firstly, the coverage of each shard
(how many ranks the shard is available on; lower coverage is assigned first)
2. Secondly, the size of each shard (larger size is assigned first)
3. Finally, shard id for differentiation.
Third step is added because we rely on the fact that
the assignment is deterministic on all ranks.
Args:
shard_to_ranks (Dict[T, List[int]]): mapping of rank access to shards
shard_to_size (Dict[T, int]): sizes of each shard
num_ranks (int): number of ranks in the parallelization group
Returns (Dict[T, int]): assignment of shard to rank (which rank should do the work
to achieve maximal uniformity)
"""
shard_to_ranks = {k: tuple(v) for k, v in shard_to_ranks.items()}
shard_to_saving_rank = {}
rank_sizes = [(0, rank) for rank in range(num_ranks)]
# start from tensors of lowest coverage, then go by tensor size from largest (hence minus size)
for shard_id, shard_ranks in sorted(
shard_to_ranks.items(),
key=lambda sh_id_ranks: (
len(sh_id_ranks[1]),
-shard_to_size[sh_id_ranks[0]],
sh_id_ranks[0],
),
):
# assign greedily to the least occupied rank
size, rank = min((size, rank) for size, rank in rank_sizes if rank in shard_ranks)
shard_to_saving_rank[shard_id] = rank
rank_sizes[rank] = (size + shard_to_size[shard_id], rank)
logger.debug(f'distribute_shards_to_ranks distribution: {rank_sizes}')
return shard_to_saving_rank
def determine_main_replica_uniform_distribution(
sharded_state_dict: ShardedStateDict,
parallelization_group: torch.distributed.ProcessGroup,
ignore_groups: bool = False,
) -> Optional[ShardDistribution]:
"""Computes the save distribution.
Should be used in conjunction with `distribute_main_replicas_with_precomputed_distribution`
which applies the computed save distribution.
We rely on the fact that the assignment algorithm is deterministic on all ranks,
so there is no extra communication needed after metadata exchange.
Args:
sharded_state_dict (ShardedStateDict): state dict to compute the distribution of
parallelization_group (ProcessGroup): distribution will be computed
within this process group
ignore_groups (bool, optional): whether the distribution defines groups.
This option is primarily used during loading, as it ensures that all replicas,
including non-main ones, are loaded by this parallelization group
Defaults to False.
Returns (ShardDistribution, optional): distribution that can be used to apply the
parallelization. Returns None if the process_group is trivial (1 rank)
"""
group_size = torch.distributed.get_world_size(group=parallelization_group)
if group_size <= 1:
return
local_shards = list(
sh_base
for sh_base in nested_values(sharded_state_dict)
if isinstance(sh_base, ShardedTensor)
)
local_shards_no_data = [ten.without_data() for ten in local_shards]
all_shards = [None] * torch.distributed.get_world_size(group=parallelization_group)
torch.distributed.all_gather_object(
all_shards, local_shards_no_data, group=parallelization_group
)
shard_to_ranks = defaultdict(list)
shard_to_size = {}
shard_to_metadata = {}
shards_in_this_parallelization_group: Set[_ShardId] = set()
for rank, rank_shards in enumerate(all_shards):
for sh_ten in rank_shards:
shard_id = _sharded_tensor_shard_id(sh_ten)
shard_to_ranks[shard_id].append(rank)
if shard_id not in shard_to_size:
shard_to_size[shard_id] = _shard_size(sh_ten)
shard_to_metadata[shard_id] = sh_ten
if is_main_replica(sh_ten.replica_id) or ignore_groups:
shards_in_this_parallelization_group.add(shard_id)
shard_to_ranks = {
k: v for k, v in shard_to_ranks.items() if k in shards_in_this_parallelization_group
}
shard_to_saving_rank = distribute_shards_to_ranks(
shard_to_ranks, shard_to_size, len(all_shards)
)
return ShardDistribution(
shard_to_saving_rank,
shards_in_this_parallelization_group,
shard_to_metadata,
shard_to_ranks,
)
@torch.no_grad()
def exchange_loaded_tensors_gather_rounds(
loaded_tensors: Dict[_ShardId, torch.Tensor],
unloaded_shards: Dict[_ShardId, ShardedTensor],
shard_distribution: ShardDistribution = None,
parallelization_group: Optional[torch.distributed.ProcessGroup] = None,
) -> Dict[_ShardId, torch.Tensor]:
"""Exchange the tensors loaded by different ranks with several all_gather calls.
Groups tensors by dtype, divide tensors that will be exchanged into rounds
and execute all_gather for tensors from each round.
Note: the loading is distributed across ranks based on total loaded size
in bytes, so there is no guarantee that number of rounds needed for each
rank will be similar, which might result in a lot of almost empty
all_gathers. The solution would be to group all tensors into a one
bytes tensor and do a single all_gather (with similarly sized messages).
Args:
loaded_tensors (Dict[_ShardId, torch.Tensor]): mapping from ShardedTensor
shard ids to tensors already loaded by this rank.
unloaded_shards (Dict[_ShardId, torch.Tensor]): mapping from ShardedTensor
shard ids to ShardedTensors that aren't loaded yet.
shard_distribution (ShardDistribution): distribution of all shards
parallelization_group (ProcessGroup, optional): process group used for load
distribution. Tensors will be exchanged within this group
Returns:
Dict[_ShardId, torch.Tensor]: dictionary mapping shard ids to tensors
needed by this rank to load a given state dict. Includes
previously loaded tensors (from `loaded_tensors` input)
"""
main_rank_for_shard, _, shard_to_metadata, all_ranks_for_shard = shard_distribution
local_rank = torch.distributed.get_rank(group=parallelization_group)
all_loaded_tensors = dict(loaded_tensors)
# Group by dtype so that we all_gather tensors of the same dtype
for dtype in sorted(set(map(lambda sh_ten: sh_ten.dtype, shard_to_metadata.values())), key=str):
start = time()
# shards_by_rank maps rank to tensors loaded by this rank
shards_by_rank: List[List[torch.Tensor]] = [
[] for _ in range(torch.distributed.get_world_size(group=parallelization_group))
]
for shard_id, rank in main_rank_for_shard.items():
if len(all_ranks_for_shard[shard_id]) == 1:
assert all_ranks_for_shard[shard_id][0] == main_rank_for_shard[shard_id], (
f'When there is only 1 ranks that needs a given shard,'
f' it should be the loading rank.'
f' Got: needs [{all_ranks_for_shard[shard_id][0]}]'
f' vs loads [{main_rank_for_shard[shard_id]}]'
)
# Skipping the exchange since only the loading rank needs this tensor
# TODO: we can employ some optimizations even for `len(shard_to_ranks) > 1`
# case, e.g. P2P exchange. Currently handling this case saves most of the
# work though.
continue
if shard_to_metadata[shard_id].dtype == dtype:
shards_by_rank[rank].append(shard_id)
# Transpose `shards_by_rank` to form exchange rounds
shards_by_round = zip_longest(*shards_by_rank, fillvalue=None)
for round_idx, round_shard_ids in enumerate(shards_by_round):
round_tensors = []
orig_devices = {}
for rank, shard_id in enumerate(round_shard_ids):
if shard_id is None:
# if no more useful data, the given rank will exchange empty tensor
local_ten = torch.empty(0, dtype=dtype, device='cuda')
orig_device = None
else:
assert isinstance(shard_id, tuple), type(shard_id)
if rank == local_rank:
assert shard_id in all_loaded_tensors, (shard_id, all_loaded_tensors.keys())
orig_device = all_loaded_tensors[shard_id]
all_loaded_tensors[shard_id] = all_loaded_tensors[shard_id].cuda()
local_ten = all_loaded_tensors[shard_id]
else:
local_ten, orig_device = _get_empty_tensor_for_exchange(
shard_id, unloaded_shards, shard_to_metadata, all_loaded_tensors
)
# Because of a TE bug, we have to exchange a nominal dtype instead of FP8
# It's ok to keep the nominal dtype after exchange, because TE will handle
# this during state dict load.
# TODO: remove it once the bug is fixed
if is_float8tensor(local_ten):
local_ten = local_ten.from_float8()
all_loaded_tensors[shard_id] = local_ten
round_tensors.append(local_ten)
if orig_device is not None:
orig_devices[shard_id] = orig_device
torch.distributed.all_gather(
list(round_tensors),
round_tensors[local_rank],
group=parallelization_group,
async_op=False,
)
# Move tensors back to CPU if originally was on CPU
for shard_id, orig_device in orig_devices.items():
all_loaded_tensors[shard_id] = all_loaded_tensors[shard_id].to(orig_device)
del round_tensors # remove tensor references
end = time()
if torch.distributed.get_rank() == 0:
logger.debug(f'{dtype} exchange rounds all_gather schedule took {end - start}s')
return all_loaded_tensors
def exchange_loaded_tensors_gather_object(
loaded_tensors: Dict[_ShardId, torch.Tensor],
unloaded_shards: Dict[_ShardId, ShardedTensor],
shard_distribution: ShardDistribution,
parallelization_group: Optional[torch.distributed.ProcessGroup] = None,
) -> Dict[_ShardId, torch.Tensor]:
"""Exchange the tensors loaded by different ranks with a simple all_gather_object call.
This version can be used for debugging purposes do to its simplistic
implementation. Shouldn't be used if performance is important.
Args:
loaded_tensors (Dict[_ShardId, torch.Tensor]): mapping from ShardedTensor
shard ids to tensors already loaded by this rank.
unloaded_shards (Dict[_ShardId, torch.Tensor]): mapping from ShardedTensor
shard ids to ShardedTensors that aren't loaded yet.
shard_distribution (ShardDistribution): distribution of all shards
parallelization_group (ProcessGroup, optional): process group used for load
distribution. Tensors will be exchanged within this group
Returns:
Dict[_ShardId, torch.Tensor]: dictionary mapping shard ids to tensors
needed by this rank to load a given state dict. Includes
previously loaded tensors (from `loaded_tensors` input)
"""
all_loaded_tensors_list = [None] * torch.distributed.get_world_size(group=parallelization_group)
torch.distributed.all_gather_object(
all_loaded_tensors_list, loaded_tensors, group=parallelization_group
)
all_loaded_tensors_list = cast(List[Dict[_ShardId, torch.Tensor]], all_loaded_tensors_list)
all_loaded_tensors = reduce(lambda x, y: {**x, **y}, all_loaded_tensors_list)
# Error checks
if len(all_loaded_tensors) != sum(map(len, all_loaded_tensors_list)):
err_msg = 'Duplicate shard ids loaded by different ranks'
if torch.distributed.get_rank() == 0:
logger.error(
f'{err_msg}. Shards ids by rank:'
f' {[lt.keys() for lt in all_loaded_tensors_list]}'
)
raise CheckpointingException(err_msg)
return all_loaded_tensors
@torch.no_grad()
def exchange_loaded_tensors_broadcast(
loaded_tensors: Dict[_ShardId, torch.Tensor],
unloaded_shards: Dict[_ShardId, ShardedTensor],
shard_distribution: ShardDistribution,
parallelization_group: Optional[torch.distributed.ProcessGroup] = None,
) -> Dict[_ShardId, torch.Tensor]:
"""Exchange the tensors loaded by different ranks by a series of broadcasts.
For each rank for each loaded tensor do a broadcast to the whole group.
A reasonable tradeoff in terms of performance and simplicity.
Args:
loaded_tensors (Dict[_ShardId, torch.Tensor]): mapping from ShardedTensor
shard ids to tensors already loaded by this rank.
unloaded_shards (Dict[_ShardId, ShardedTensor]): mapping from ShardedTensor
shard ids to ShardedTensors that aren't loaded yet.
shard_distribution (ShardDistribution): distribution of all shards
parallelization_group (ProcessGroup, optional): process group used for load
distribution. Tensors will be exchanged within this group
Returns:
Dict[_ShardId, torch.Tensor]: dictionary mapping shard ids to tensors
needed by this rank to load a given state dict. Includes
previously loaded tensors (from `loaded_tensors` input)
"""
main_rank_for_shard, _, shard_to_metadata, all_ranks_for_shard = shard_distribution
local_rank = torch.distributed.get_rank(group=parallelization_group)
all_loaded_tensors = dict(loaded_tensors)
start = time()
for idx, (shard_id, rank) in enumerate(main_rank_for_shard.items()):
if len(all_ranks_for_shard[shard_id]) == 1:
assert all_ranks_for_shard[shard_id][0] == main_rank_for_shard[shard_id], (
f'When there is only 1 ranks that needs a given shard,'
f' it should be the loading rank.'
f'Got: needs [{all_ranks_for_shard[shard_id][0]}]'
f' vs loads [{main_rank_for_shard[shard_id]}]'
)
# Skipping the exchange since only the loading rank needs this tensor
# TODO: we can employ some optimizations even for `len(shard_to_ranks) > 1` case,
# e.g. P2P exchange. Currently handling this case saves most of the work though.
continue
if rank == local_rank:
assert shard_id in all_loaded_tensors, (shard_id, all_loaded_tensors.keys())
orig_device = all_loaded_tensors[shard_id].device
local_ten = all_loaded_tensors[shard_id].cuda()
else:
local_ten, orig_device = _get_empty_tensor_for_exchange(
shard_id, unloaded_shards, shard_to_metadata, all_loaded_tensors
)
# Because of a TE bug, we have to exchange a nominal dtype instead of FP8
# It's ok to keep the nominal dtype after exchange, because TE will handle
# this during state dict load.
# TODO: remove it once the bug is fixed
if is_float8tensor(local_ten):
local_ten = local_ten.from_float8()
all_loaded_tensors[shard_id] = local_ten
global_src_rank = (
rank
if parallelization_group == None
else torch.distributed.get_global_rank(parallelization_group, rank)
)
# We can do async_op=True only if there is no CPU-copy follow-up
torch.distributed.broadcast(
local_ten,
src=global_src_rank,
group=parallelization_group,
async_op=orig_device is None,
)
# Move tensor back to CPU if originally was on CPU
if orig_device is not None:
all_loaded_tensors[shard_id] = local_ten.to(orig_device)
del local_ten
end = time()
if torch.distributed.get_rank() == 0:
logger.debug(f'exchange broadcast schedule took {end - start}s')
return all_loaded_tensors
def exchange_by_distribution(
loaded_tensors: Dict[_ShardId, torch.Tensor],
unloaded_shards: Dict[_ShardId, ShardedTensor],
shard_distribution: ShardDistribution = None,
parallelization_group: Optional[torch.distributed.ProcessGroup] = None,
exchange_algo='broadcast',
) -> Dict[_ShardId, torch.Tensor]:
"""Exchange tensors loaded by different ranks using the specified exchange_algo.
Args:
loaded_tensors (Dict[_ShardId, torch.Tensor]): mapping from ShardedTensor
shard ids to tensors already loaded by this rank.
unloaded_shards (Dict[_ShardId, ShardedTensor]): mapping from ShardedTensor
shard ids to ShardedTensors that aren't loaded yet.
shard_distribution (ShardDistribution): distribution of all shards
parallelization_group (ProcessGroup, optional): process group used for load
distribution. Tensors will be exchanged within this group
exchange_algo (str): The algorithm used for performing exchanges.
Defaults to 'broadcast'.
Returns:
Dict[_ShardId, torch.Tensor]: dictionary mapping shard ids to tensors
needed by this rank to load a given state dict. Includes
previously loaded tensors (from `loaded_tensors` input)
"""
if exchange_algo == 'gather_object':
exchange_fn = exchange_loaded_tensors_gather_object
elif exchange_algo == 'gather_rounds':
exchange_fn = exchange_loaded_tensors_gather_rounds
elif exchange_algo == 'broadcast':
exchange_fn = exchange_loaded_tensors_broadcast
else:
raise NotImplementedError(f'Unrecognized gather algorithm: {exchange_algo}')
return exchange_fn(loaded_tensors, unloaded_shards, shard_distribution, parallelization_group)
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Core library classes for representing sharding of tensors and objects.
The main expected usage is wrapping torch.Tensors in state dicts with
ShardedTensor class (mostly with the ShardedTensor.from_rank_offsets classmethod).
"""
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass, field, replace
from itertools import chain
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from .core import CheckpointingException
from .dict_utils import dict_list_map_inplace
logger = logging.getLogger(__name__)
# These type definitions are just hints to differentiate a plain model state
# dict (StateDict) from a state dict with tensors replaced with ShardedTensors
# (ShardedStateDict).
StateDict = Dict[str, Any]
CommonStateDict = Dict[str, Any]
ShardedStateDict = Dict[str, Any]
ReplicaId = Union[int, Tuple[int, ...]]
class ShardedBase(ABC):
"""Base class for ShardedTensor and ShardedStateDict."""
key: str
data: object
replica_id: ReplicaId
@abstractmethod
def validate_metadata_integrity(self):
"""Codifies the constraints on metadata attributes."""
@abstractmethod
def without_data(self) -> 'ShardedBase':
"""Returns a new ShardedBase instance with data=None."""
raise NotImplementedError
@dataclass
class ShardedTensor(ShardedBase):
"""Represents a mapping between a local tensor and a global tensor.
Global tensor is assumed to consist of many local tensors distributed
between different processes.
Args:
key: unique identifier of a global tensor
data: local tensor data. Can be None only for consistency validation
dtype: tensor dtype
local_shape: local tensor shape
global_shape: global tensor shape
global_offset: offset of a local tensor in a global tensor,
specified in number of tensor elements
axis_fragmentations: global tensor fragmentation of each axis
replica_id: indicates given local tensor's replication wrt.
local tensors in different processes
prepend_axis_num: number of axes prepended to the local tensor to
reflect global tensor shape. The behavior is similar to
unsqueezing the local tensor.
allow_shape_mismatch: if True, during loading, the global shape of
a stored tensor does not have to match the expected global shape.
Useful for representing tensors with flexible shape,
e.g. padded.
flattened_range: specifies a slice that should be applied to a
flattened tensor with `local_shape` in order to get
the tensor stored as `data`
"""
key: str
data: Optional[torch.Tensor] = field(repr=False)
dtype: torch.dtype
local_shape: Tuple[int, ...]
global_shape: Tuple[int, ...]
global_offset: Tuple[int, ...]
axis_fragmentations: Optional[Tuple[int, ...]]
replica_id: ReplicaId = 0
prepend_axis_num: int = 0
allow_shape_mismatch: bool = False
flattened_range: Optional[slice] = None
def __post_init__(self):
self.validate_metadata_integrity()
def validate_metadata_integrity(self) -> None:
"""Codifies the constraints on metadata attributes.
Meeting those constraints is guaranteed when instantiating a ShardedTensor
class with `from_rank_offsets` or `from_rank_offsets_flat` constructors.
Returns:
None
"""
has_flattened_range = self.flattened_range is not None
if self.data is not None:
if self.data.dtype != self.dtype:
raise CheckpointingException(
f'Data dtype should match `dtype` attribute for {self}'
)
if not has_flattened_range and self.data.shape != self.local_shape:
raise CheckpointingException(
f'Data shape should match `local_shape` attribute for {self}'
)
if has_flattened_range:
if self.data.ndim != 1:
raise CheckpointingException(f'Data should be 1D for a flattened {self}')
real_data = self.data
try:
self.data = None
self.init_data(device='meta')
if self.data.shape != real_data.shape:
raise CheckpointingException(
f'Data shape {real_data.shape} doesnt match'
f' expected {self.data.shape} for {self}'
)
finally:
self.data = real_data
if len(self.global_shape) != len(self.global_offset):
raise CheckpointingException(
f'Global offset dimensions should be equal to global shape dimensions for {self}'
)
if len(self.local_shape) + self.prepend_axis_num != len(self.global_shape):
raise CheckpointingException(
f'Local shape together with `prepend_axis_num` dimensions should be '
f'equal to global shape dimensions for {self}'
)
for off, sh in zip(self.global_offset[self.prepend_axis_num :], self.local_shape):
if off % sh != 0:
raise CheckpointingException(
f'Global offset ({off}) must be divisible by local shape ({sh}) for {self}.'
)
if has_flattened_range and self.flattened_range.step is not None:
raise CheckpointingException(
f'`step` argument in the flattened range of a ShardedTensor is not supported.'
)
def global_slice(self) -> Tuple[Union[int, slice], ...]:
"""
Returns a tuple of int and slice objects representing a slice of the
global tensor that this ShardedTensor corresponds to.
"""
assert len(self.global_offset) == len(self.local_shape) + self.prepend_axis_num
return tuple(
chain(
(off for off in self.global_offset[: self.prepend_axis_num]),
(
slice(off, off + sh)
for off, sh in zip(
self.global_offset[self.prepend_axis_num :], self.local_shape
)
),
)
)
def global_coordinates(self) -> Tuple[np.ndarray, ...]:
"""
Returns a tuple of np.ndarrays representing the coordinates of the global tensor
that this ShardedTensor corresponds to.
"""
if self.flattened_range is None:
raise CheckpointingException(
f'`global_coordinates` is undefined for'
f' {self.__class__.__name__} without `flattened_range`'
)
local_coords = self.local_coordinates()
assert len(local_coords) + self.prepend_axis_num == len(self.global_offset), (
len(local_coords),
self,
)
global_coords = tuple(
c + off
for c, off in zip((0,) * self.prepend_axis_num + local_coords, self.global_offset)
)
return global_coords
def local_coordinates(self) -> Tuple[np.ndarray, ...]:
"""
Returns a tuple of np.ndarrays representing the coordinates of the local tensor
that this ShardedTensor corresponds to.
"""
if self.flattened_range is None:
raise CheckpointingException(
f'`local_coordinates` is undefined for'
f' {self.__class__.__name__} without `flattened_range`'
)
# TODO: np.unravel_index?
mask = np.zeros(np.product(self.local_shape), dtype=bool)
mask[self.flattened_range] = True
return np.nonzero(mask.reshape(self.local_shape))
def local_chunk_offset_in_global(self) -> Tuple[int, ...]:
"""Offset of a local chunk in a global array of chunks.
Returns:
Tuple[int, ...]: the offset of the whole local chunk in a global array of chunks.
"""
assert len(self.global_offset) == len(self.local_shape) + self.prepend_axis_num
chunk_offset = list(self.global_offset[: self.prepend_axis_num])
for off, sh in zip(self.global_offset[self.prepend_axis_num :], self.local_shape):
assert off % sh == 0, str(self)
chunk_offset.append(off // sh)
return tuple(chunk_offset)
def max_allowed_chunks(self) -> Tuple[int, ...]:
"""
Returns the maximum allowed chunks for this ShardedTensor.
"""
chunks = []
for axis_sh, axis_fragm in zip(self.global_shape, self.axis_fragmentations):
if not self.allow_shape_mismatch and axis_sh % axis_fragm != 0:
raise CheckpointingException(
f'Axis shape ({axis_sh}) not divisible by axis fragmentation ({axis_fragm}'
)
axis_chunk_size = axis_sh // axis_fragm
chunks.append(axis_chunk_size)
return tuple(chunks)
def without_data(self):
return replace(self, data=None)
@classmethod
def from_rank_offsets(
cls,
key: str,
data: torch.Tensor,
*rank_offsets: Tuple[int, int, int],
replica_id: ReplicaId = 0,
prepend_axis_num: int = 0,
flattened_range: None = None,
**init_kwargs,
):
"""Allows to construct the ShardedTensor given offset specified in process ranks.
Args:
key (str): unique key
data (torch.Tensor): local tensor data
rank_offsets (Tuple[int, int, int]): each tuple
(axis, axis_rank_offset, axis_fragm) says that if
global tensor is divided into `axis_fragm` fragment along `axis`
axis, then local tensor data corresponds to the `axis_rank_offset` chunk.
replica_id (ReplicaId): see ShardedTensor
prepend_axis_num (int): see ShardedTensor
flattened_range (None): must be None when using this constructor
init_kwargs: passed to ShardedTensor.__init__
"""
if flattened_range is not None:
raise ValueError(
'Cannot instantiate a flat ShardedTensor with `from_rank_offsets` method.'
' Use `from_rank_offsets_flat` instead'
)
global_offset = [0] * (data.ndim + prepend_axis_num)
global_shape = ([1] * prepend_axis_num) + list(data.shape)
axis_fragmentations = [1] * (data.ndim + prepend_axis_num)
_seen_axis = set()
for axis, axis_rank_offset, axis_fragm in rank_offsets:
if axis < 0 or axis_rank_offset < 0 or axis_fragm < 1 or axis_rank_offset >= axis_fragm:
raise CheckpointingException(f'Invalid rank offsets: {rank_offsets} for key {key}.')
_seen_axis.add(axis)
local_axis_shape = 1 if axis < prepend_axis_num else data.shape[axis - prepend_axis_num]
global_shape[axis] = axis_fragm * local_axis_shape
global_offset[axis] = axis_rank_offset * local_axis_shape
axis_fragmentations[axis] = axis_fragm
return cls(
key,
data,
data.dtype,
tuple(data.shape),
tuple(global_shape),
tuple(global_offset),
tuple(axis_fragmentations),
replica_id,
prepend_axis_num,
flattened_range=flattened_range,
**init_kwargs,
)
@classmethod
def from_rank_offsets_flat(
cls,
key: str,
data: torch.Tensor,
non_flat_local_shape: Tuple[int, ...],
*args,
flattened_range: Optional[slice] = None,
**kwargs,
):
"""Allows to construct a *flattened* ShardedTensor given offset specified in process ranks.
Args:
key (str):
data (torch.Tensor): this should be a flattened data tensor
non_flat_local_shape (Tuple[int, ...]): expected local shape of a non-flat chunk
*args: passed unchanged to the `from_rank_offsets` constructor
flattened_range (slice): see ShardedTensor. Defaults to None, but must be set to
a non-None slice.
**kwargs:
Returns:
ShardedTensor: constructed ShardedTensor instance
"""
if flattened_range is None:
raise CheckpointingException(
'Cannot instantiate a non-flat ShardedTensor with `from_rank_offsets_flat` method.'
' Use `from_rank_offsets` instead'
)
if data.ndim != 1:
raise CheckpointingException(
f'Flattened ShardedTensor requires 1D data, got shape: {data.shape}'
)
if flattened_range.stop - flattened_range.start != data.numel():
raise CheckpointingException(
f'Flattened ShardedTensor data length ({data.numel()}) must meet the '
f'slice length: {flattened_range.stop - flattened_range.start}'
)
non_flat_data_meta = torch.empty(*non_flat_local_shape, dtype=data.dtype, device='meta')
sh_ten = cls.from_rank_offsets(key, non_flat_data_meta, *args, **kwargs)
instance = replace(sh_ten, data=data, flattened_range=flattened_range)
instance.validate_metadata_integrity()
return instance
def init_data(self, device: Union[str, torch.device], init_fn=torch.empty):
"""
Initialize the tensor data of this ShardedTensor.
Only called if `data` attribute is None.
Args:
device (Union[str, torch.device]): device to place the tensor on
init_fn (Callable, optional): function to use to initialize the tensor.
Defaults to `torch.empty`.
"""
if self.data is not None:
return
self.data = init_fn(self.local_shape, dtype=self.dtype, device=device)
if self.flattened_range is not None:
self.data = self.data.flatten()[self.flattened_range.start : self.flattened_range.stop]
def narrow(self, dim: int, start: int, length: int) -> List['ShardedTensor']:
"""This is an analogue of torch.narrow for ShardedTensors.
Narrowing assumes that we narrow a local tensor on each rank.
This has consequences on local_shape, global_shape, global_offset, etc.
Args:
dim (int): dimension to narrow. Doesn't include prepended axes.
start (int): start element
length (int): length of the slice
Returns:
List[ShardedTensor]: narrowed ShardedTensors. For non-flat tensors,
the list will always have 1 element. For flat ShardedTensors the number of
elements varies depending on `dim` and on overlap, because flat
tensors must be contiguous. In particular the list can be empty.
"""
prepended_dim = dim + self.prepend_axis_num
local_length_along_dim = self.local_shape[dim]
def _update_tuple(x, ind, val):
x = list(x)
x[ind] = val
return tuple(x)
def _safe_div(x, y):
assert x % y == 0, (x, y)
return x // y
# Decrease global shape and global offset by `length / local_length_along_dim`
assert (
self.global_shape[prepended_dim] % local_length_along_dim == 0
), f'Only regular grid of local tensors is supported for narrowing, got: {self}'
assert (
self.global_offset[prepended_dim] % local_length_along_dim == 0
), f'Only regular grid of local tensors is supported for narrowing, got: {self}'
global_shape = _update_tuple(
self.global_shape,
prepended_dim,
_safe_div(self.global_shape[prepended_dim] * length, local_length_along_dim),
)
global_offset = _update_tuple(
self.global_offset,
prepended_dim,
_safe_div(self.global_offset[prepended_dim] * length, local_length_along_dim),
)
if self.flattened_range is None:
new_data = self.data.narrow(dim, start, length)
# always a single result tensor
return [
replace(
self,
data=new_data,
local_shape=new_data.shape,
global_shape=global_shape,
global_offset=global_offset,
)
]
else:
if dim != 0:
raise CheckpointingException(
f'Narrowing along the first axis is supported for now only, got dim={dim}'
)
# If dim=0, we will always get 0 or 1 resulting tensor.
# If dim>1, in general there can be more result tensors (e.g. max 3 for dim=1)
# For on original flat ShardedTensor of local shape [3, 4] and
# flattened_range=slice(5, 10),
# the X signs mark the actual (flat) data in `self.data`
# notice 12 (3*4) total "virtual" elements, out of which 5 is actual data.
# flat original: [.....XXXXX..]
# If we narrow to start=1, length=1 in the original local shape dimensions,
# the overlapping flat slice would be:
# narrow to: [....XXXX....]
# flat overlap: [.....XXX....]
# Now `data` is flattened and sliced, so we must compute local_shape manually
local_shape = _update_tuple(self.local_shape, dim, length)
other_dims_volume = np.prod(
_update_tuple(local_shape, dim, 1)
) # 4 in the example above
volume_before_split = other_dims_volume * start # 4 in the example above
volume_of_split = other_dims_volume * length # 4 in the example above
flat_slice_start_shifted = (
self.flattened_range.start - volume_before_split
) # 5 - 4 = 1 in the example above
flat_slice_stop_shifted = (
self.flattened_range.stop - volume_before_split
) # 10 - 4 = 6 in the example above
# Find an intersection of
# (flat_slice_start_shifted, flat_slice_stop_shifted) vs (0, volume_of_split)
if flat_slice_stop_shifted <= 0 or flat_slice_start_shifted >= volume_of_split:
return [] # no intersection
# new_flattened_range = slice(1, 4) in the example above
new_flattened_range = slice(
max(flat_slice_start_shifted, 0), min(flat_slice_stop_shifted, volume_of_split)
)
# Apply the intersection to the flattened data tensor.
# Compute start and slice appropriate length
intersection_slice_start = (
new_flattened_range.start - flat_slice_start_shifted
) # 0 in the example above
new_data = self.data[
intersection_slice_start : intersection_slice_start
+ new_flattened_range.stop
- new_flattened_range.start
]
return [
replace(
self,
data=new_data,
local_shape=local_shape,
global_shape=global_shape,
global_offset=global_offset,
flattened_range=new_flattened_range,
)
]
def is_main_replica(replica_id: ReplicaId):
"""Checks if given `replica_id` is considered as main.
"Main" replica is:
- integer 0
- or an iterable with all 0 elements
It is the application responsibility to set correct replicas for sharded tensors.
Args:
replica_id (Union[int, Tuple[int, ...]]): replica id
Returns:
(bool): True for a "main" replica
"""
if isinstance(replica_id, int):
return replica_id == 0
return all(r == 0 for r in replica_id)
class LocalNonpersistentObject:
"""Object that should not be stored in a checkpoint, but restored locally.
Wrapping any object inside the state dict with LocalNonpersistentObject
will result in:
- during saving, this object will *not* be stored in the checkpoint
- during loading, a local version of this object will be placed in a state dict
"""
def __init__(self, obj):
self.obj = obj
def unwrap(self):
"""Returns the original object."""
return self.obj
# TODO: Delete once NeMo fixes typo.
LocalNonpersitentObject = LocalNonpersistentObject
@dataclass
class ShardedObject(ShardedBase):
"""Represents a mapping between a local object and a global object.
Global object is assumed to consist of many local objects distributed
between different processes.
NOTE: Contrary to ShardedTensor, it's impossible to change global object
sharding. Conceptually, ShardedObject is a fully-sharded ShardedTensor
with atomic arbitrary typed elements.
Args:
key: unique identifier of a global tensor
data: local object data. Can be None only for consistency validation
global_shape: global object shape
global_offset: offset of a local object in a global object, specified in number of shards
replica_id: indicates local object replication wrt. local objects in different processes
"""
key: str
data: object
global_shape: Tuple[int, ...]
global_offset: Tuple[int, ...]
replica_id: ReplicaId = 0
def __post_init__(self):
self.validate_metadata_integrity()
def validate_metadata_integrity(self):
if len(self.global_shape) != len(self.global_offset):
raise CheckpointingException(
f'Global offset dimensions should be equal to global shape dimensions for {self}'
)
def without_data(self):
return replace(self, data=None)
@property
def unique_key(self):
"""returns a unique key for this object"""
return (
f'{self.key}/shard_'
f'{".".join(map(str, self.global_offset))}_'
f'{".".join(map(str, self.global_shape))}'
)
def __str__(self):
return f'{self.__class__.__name__}(key=\'{self.key}\')'
@classmethod
def empty_from_unique_key(cls, unique_key, replica_id: ReplicaId = 0) -> 'ShardedObject':
"""Instantiates a ShardedObject from a unique key.
Args:
unique_key: a string of the form
<key>/shard_<global_offset>_<global_shape>
replica_id: indicates local object replication wrt.
local objects in different processes
Returns:
a ShardedObject with data=None
"""
key, shard_key = unique_key.split('/')
shard_str, offset, shape = shard_key.split('_')
assert shard_str == 'shard'
offset = tuple(map(int, offset.split('.')))
shape = tuple(map(int, shape.split('.')))
if len(shape) + 1 == len(offset):
# This is a backward-compatible fix. We don't know the last
# element of global shape so set it to -1.
shape += (-1,)
return cls(key, None, shape, offset, replica_id)
FactoryBuildFn = Callable[[str, torch.Tensor, ReplicaId, Optional[slice]], ShardedStateDict]
FactoryMergeFn = Callable[[StateDict], torch.Tensor]
@dataclass
class ShardedTensorFactory(ShardedBase):
"""Allows to apply transformations to tensors before/after serialization.
The essence of those transformations is that they can be applied to
optimizer states the same way they are applied to the model params.
The ultimate state dict with sharded tensors must depend functionally on
`build_fn` arguments (key, data, replica_id, flattened_range),
which will be provided by the optimizer.
Builder creates a sub-state-dict out of a tensor before saving, and merger
merges the corresponding state dict after loading.
Args:
key (str): unique identifier of the factory
data (torch.Tensor): original model parameter that will be further
transformed by this factory
build_fn (callable): function that transforms the original tensor
to a sharded state dict
merge_fn (callable): function that transforms loaded subtree back
into a single tensor (inverse of `build_fn`)
replica_id (ReplicaId): indicates factory replication wrt.
factories in different processes
flattened_range (slice, optional): indicates additional flattening
applied to the ShardedTensors produced by the factory
"""
key: str
data: torch.Tensor
build_fn: FactoryBuildFn
merge_fn: FactoryMergeFn
replica_id: ReplicaId = 0
flattened_range: Optional[slice] = None
def build(self):
"""Builds a ShardedStateDict from the original tensor"""
return self.build_fn(self.key, self.data, self.replica_id, self.flattened_range)
def validate_metadata_integrity(self):
"""No reasonable checks can be applied"""
pass
def without_data(self):
return replace(self, data=None)
def apply_factories(sharded_state_dict: ShardedStateDict):
"""Turn ShardedTensorFactories into ShardedTensors *in-place*.
Args:
sharded_state_dict (ShardedStateDict): state dict possibly
containing ShardedTensorFactory objects
Returns:
None: state dict is modified in place
"""
def apply(x):
if isinstance(x, ShardedTensorFactory):
x = x.build()
return x
dict_list_map_inplace(apply, sharded_state_dict)
def apply_factory_merges(
x1: StateDict, x2: ShardedStateDict, key: Tuple[str, ...] = ()
) -> StateDict:
"""Apply merges defined by ShardedTensorFactories *in-place*.
Args:
x1 (StateDict): state dict loaded from the checkpoint
x2 (ShardedStateDict): subset of `x1` (in terms of dict keys)
with ShardedTensorFactory
as (possibly nested) values that define how to
merge objects from the `x1` state dict
key (Tuple[str, ...]): current key in a recursive call.
Used only for reporting meaningful errors
Returns:
StateDict: `x1` modified in-place
"""
if isinstance(x2, ShardedTensorFactory):
return x2.merge_fn(x1)
# There rest is almost the same as the `merge` function from `dict_utils`
if isinstance(x1, dict) and isinstance(x2, dict):
for k, v2 in x2.items():
if k not in x1:
raise ValueError(
f'Different dict keys encountered in `apply_factory_merges` '
f'({x1.keys()} vs {x2.keys()})'
)
else:
x1[k] = apply_factory_merges(x1[k], v2, key=key + (k,))
elif isinstance(x1, list) and isinstance(x2, list):
if len(x1) != len(x2):
err_msg = (
f'Cannot merge two lists with different lengths '
f'({len(x1)} and {len(x2)}, encountered at key {key})'
)
logger.error(err_msg + f'\nx1: {x1}\nx2: {x2}')
raise ValueError(err_msg)
for i, v2 in enumerate(x2):
x1[i] = apply_factory_merges(x1[i], v2, key=key + (i,))
elif isinstance(x1, list) and isinstance(x2, dict):
for k, v2 in x2.items():
if not isinstance(k, int):
raise ValueError(
f'Invalid dict key {k} non-integer type encountered '
f'in a list-dict merge at level {key}'
)
if k >= len(x1):
raise ValueError(
f'Dict key {k} out of bound for list of length'
f'{len(x1)} (encountered at level {key})'
)
x1[k] = apply_factory_merges(x1[k], v2, key=key + (k,))
else:
raise ValueError(
f'Duplicate non-dict and non-list values encountered: `{x1}` and `{x2} (at key {key})`'
)
return x1
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Helpers for defining sharding for optimizer states based on existing sharding
for model parameters.
"""
import logging
from copy import deepcopy
from dataclasses import replace
from typing import Dict, Iterable, Tuple, Union
logger = logging.getLogger(__name__)
import torch
from megatron.core.utils import to_local_if_dtensor
from .dict_utils import nested_values
from .mapping import (
LocalNonpersistentObject,
ShardedStateDict,
ShardedTensor,
ShardedTensorFactory,
StateDict,
)
from .utils import extract_sharded_tensors_and_factories
def get_optim_param_to_id_map(optim_params_iter: Iterable[torch.nn.Parameter]) -> Dict[int, int]:
"""Generate mapping from optimizer param to optimizer state id."""
param_mappings = {}
for i, param in enumerate(optim_params_iter):
param = to_local_if_dtensor(param)
if id(param) not in param_mappings:
param_mappings[id(param)] = i
return param_mappings
def get_param_id_to_sharded_param_map(
model_sharded_state_dict: ShardedStateDict, optim_params_iter: Iterable[torch.nn.Parameter]
) -> Dict[int, Union[ShardedTensor, ShardedTensorFactory]]:
"""Generate mapping from optimizer state ids to model sharded parameters.
Args:
model_sharded_state_dict: sharded state dict with all model sharded tensors
(can have any structure)
optim_params_iter: iterable which iterates over model parameters tracked by the optimizer.
The iteration must be in the same order as in the optimizer parameters.
Returns:
Dict[int, Union[ShardedTensor, ShardedTensorFactory]]: mapping from optimizer state ids
to model sharded parameters.
"""
model_sharded_state_dict, _ = extract_sharded_tensors_and_factories(model_sharded_state_dict)
id_to_sharded_param_map = {}
param_to_id_map = get_optim_param_to_id_map(optim_params_iter)
# If using PyTorch FSDP2 the values in model_sharded_state_dict would
# have been converted to local tensors during initialization.
# See the make_(tp)_sharded_tensor_for_checkpoint functions.
for ten in nested_values(model_sharded_state_dict):
if id(ten.data) in param_to_id_map:
id_to_sharded_param_map[param_to_id_map[id(ten.data)]] = ten
else:
logger.debug(f'{ten} is not tracked by the optimizer')
if not id_to_sharded_param_map:
logger.warning(
"Sharded parameters mapping is empty. It means tensors in model state dict"
" do not correspond to tensors in optimizer parameters map."
" Make sure to call state_dict with `keep_vars=True`."
)
return id_to_sharded_param_map
def make_sharded_optimizer_tensor(
model_param: Union[ShardedTensor, ShardedTensorFactory], optim_param: torch.Tensor, prefix: str
) -> Union[ShardedTensor, ShardedTensorFactory]:
"""Build a ShardedTensor or ShardedTensorFactory for optimizer param based on model param
Args:
model_param (Union[ShardedTensor, ShardedTensorFactory]): model param
optim_param (torch.Tensor): corresponding optimizer param
prefix (str): optimizer prefix for the ShardedTensor or ShardedTensorFactory
Returns:
Union[ShardedTensor, ShardedTensorFactory]: wrapped optimizer parameter
"""
optim_param = to_local_if_dtensor(optim_param)
if isinstance(model_param, ShardedTensorFactory):
return replace(model_param, key=f'{prefix}.{model_param.key}', data=optim_param)
assert tuple(optim_param.shape) == model_param.local_shape, (
f'Optimizer shape ({tuple(optim_param.shape)} does not match model shape '
f'({model_param.local_shape})'
)
sh_ten = replace(
model_param, key=f'{prefix}.{model_param.key}', data=optim_param, dtype=optim_param.dtype
)
sh_ten.validate_metadata_integrity()
return sh_ten
def optim_state_to_sharding_state(
optim_state_dict: StateDict,
id_to_sharded_param_map: Dict[int, ShardedTensor],
exclude_keys: Tuple[str] = (),
):
"""Turn optimizer state dict to sharded state dict based on model state dict *in-place*.
Can be used to add sharding information to most common optimizer state dict.
Creates separate ShardedTensors for each key in `optim_state_dict['state']`
(e.g. for torch.optim.Adam there will be separate tensors for `exp_avg` and `exp_avg_sq`)
Args:
optim_state_dict (StateDict): optimizer state dict with
state parameters under `state` key and group hyperparameters under
`param_groups` -> `params` key.
id_to_sharded_param_map (Dict[int, ShardedTensor]): mapping from optimizer param ids
to model sharded tensors. Can be generated with `get_param_id_to_sharded_param_map`
function.
exclude_keys (Tuple[str]): optimizer state keys to exclude from the final state dict.
Returns:
None: state dict is modified in place
"""
sharded_state = {}
for param_id, param_state in optim_state_dict['state'].items():
sharded_state[param_id] = {}
for state_key, param in param_state.items():
if state_key in exclude_keys:
continue
if param_id in id_to_sharded_param_map:
sharded_state[param_id][state_key] = make_sharded_optimizer_tensor(
id_to_sharded_param_map[param_id], param, prefix=f'optimizer.state.{state_key}'
)
else:
raise ValueError(f'Param id {param_id} does not match any model sharded param')
optim_state_dict['param_groups'] = deepcopy(optim_state_dict['param_groups'])
for group in optim_state_dict['param_groups']:
group['params'] = LocalNonpersistentObject(group['params'])
optim_state_dict['state'] = sharded_state
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Entrypoints for saving and loading the distributed checkpoints.
Functions `load` and `save` are equivalents of `torch.load` and `torch.save`
but expect torch.Tensors to be wrapped with classes from the `mapping module`.
Additionally, `load` expects the sharded state dict argument as a guidance for
loading the sharded tensors.
"""
import logging
from pathlib import Path
from typing import Callable, Dict, Optional, Set, Tuple, Union
import torch
from . import ShardedTensor
from .core import CheckpointingConfig, save_config
from .dict_utils import extract_matching_values, merge
from .mapping import (
CheckpointingException,
CommonStateDict,
ShardedObject,
ShardedStateDict,
StateDict,
apply_factory_merges,
)
from .state_dict_transformation import load_preprocess, save_preprocess
from .strategies.async_utils import AsyncRequest
from .strategies.base import (
AsyncSaveShardedStrategy,
LoadCommonStrategy,
LoadShardedStrategy,
SaveCommonStrategy,
SaveShardedStrategy,
StrategyAction,
get_default_strategy,
)
from .utils import extract_sharded_base
from .validation import (
StrictHandling,
determine_global_metadata,
parse_strict_flag,
validate_integrity_and_strict_load,
validate_sharded_objects_handling,
verify_checkpoint_and_load_strategy,
)
logger = logging.getLogger(__name__)
# flat state dict with sharded objects without any data
CkptShardedMetadata = Dict[str, Union[ShardedTensor, ShardedObject]]
def load(
sharded_state_dict: ShardedStateDict,
checkpoint_dir: str,
sharded_strategy: Union[LoadShardedStrategy, Tuple[str, int], None] = None,
common_strategy: Union[LoadCommonStrategy, Tuple[str, int], None] = None,
validate_access_integrity: bool = True,
strict: Union[str, StrictHandling] = StrictHandling.ASSUME_OK_UNEXPECTED,
) -> Union[StateDict, Tuple[StateDict, Set[str], Set[str]]]:
"""Loading entrypoint.
In the steps below, the following verbs refer to corresponding objects:
- load = load from checkpoint
- extract = extract from sharded_state_dict
- add = add to the final state dict
Steps:
1. Load common state dict and form the base of the result state dict
2. Apply factories to sharded_state_dict
3. Extract LocalNonPersistentObject and add
4. (optional) Extract ShardedObjects, load and add
5. Extract ShardedBase, load, apply factory merges and add
Args:
sharded_state_dict (ShardedStateDict): state dict of the existing model
populated with ShardedTensors. Used as a mapping to determine which
parts of global tensors stored in the checkpoint should be loaded.
checkpoint_dir (str): directory with the checkpoint
sharded_strategy (LoadShardedStrategy, Tuple[str, int], optional):
configures loading behavior for sharded tensors
common_strategy (LoadCommonStrategy, Tuple[str, int], optional):
configures loading behavior for common data
validate_access_integrity (bool default = True): checks if each tensor shard is accessed
exactly once (as main replica) by some process
strict (StrictHandling, str, optional): determines the behavior in case of a mismatch
between the requested sharded state dict and the checkpoint. See `StrictHandling` docs
for more details. Some values affect the return value of this function
(missing and unexpected keys are returned).
Defaults to `True` (StrictHandling.ASSUME_OK_UNEXPECTED) which doesn't
incur any performance overhead. Other recommended values
are: `False` (StrictHandling.LOG_UNEXPECTED) which logs only unexpected keys
or `StrictHandling.RETURN_ALL` which returns all mismatch keys.
Returns:
StateDict or Tuple[StateDict, Set[str], Set[str]]: in most cases only
the loaded state dict is returned. If `strict` flag was set to
"""
sharded_strategy, common_strategy = verify_checkpoint_and_load_strategy(
checkpoint_dir, sharded_strategy, common_strategy
)
checkpoint_dir = Path(checkpoint_dir)
common_state_dict = common_strategy.load_common(checkpoint_dir)
sharded_state_dict, nonpersistent_state_dict, sh_ten_factories = load_preprocess(
sharded_state_dict
)
merge(common_state_dict, nonpersistent_state_dict)
# At this point we are only dealing with ShardedBase objects
sharded_state_dict, _ = extract_sharded_base(sharded_state_dict)
# Validation
ckpt_sharded_metadata = None
local_metadata, global_metadata = None, None
strict = parse_strict_flag(strict)
if StrictHandling.requires_explicit_ckpt_mismatch_check(strict):
ckpt_sharded_metadata = load_sharded_metadata(
str(checkpoint_dir), sharded_strategy, common_strategy
)
if validate_access_integrity or StrictHandling.requires_global_app_metadata(strict):
local_metadata, global_metadata = determine_global_metadata(sharded_state_dict)
sharded_state_dict, missing_keys, unexpected_keys = validate_integrity_and_strict_load(
sharded_state_dict,
strict,
validate_access_integrity,
local_metadata,
global_metadata,
ckpt_sharded_metadata,
)
# ShardedBase loading
if not sharded_strategy.can_handle_sharded_objects:
validate_sharded_objects_handling(sharded_strategy, common_strategy)
sharded_objects_state_dict, sharded_state_dict = extract_matching_values(
sharded_state_dict, lambda v: isinstance(v, ShardedObject)
)
sharded_objects = common_strategy.load_sharded_objects(
sharded_objects_state_dict, checkpoint_dir
)
merge(common_state_dict, sharded_objects)
loaded_state_dict = sharded_strategy.load(sharded_state_dict, checkpoint_dir)
merge(common_state_dict, loaded_state_dict)
loaded_state_dict = apply_factory_merges(common_state_dict, sh_ten_factories)
if StrictHandling.requires_returning_mismatch_keys(strict):
return common_state_dict, missing_keys, unexpected_keys
else:
return common_state_dict
def load_common_state_dict(checkpoint_dir: Path) -> StateDict:
"""Load common (non-sharded) objects state dict from the checkpoint.
Args:
checkpoint_dir (Path): checkpoint directory
Returns:
StateDict: state dict with non-sharded objects from the checkpoint
"""
sharded_strategy, common_strategy = verify_checkpoint_and_load_strategy(str(checkpoint_dir))
return common_strategy.load_common(checkpoint_dir)
def load_tensors_metadata(
checkpoint_dir: str, sharded_strategy: Union[LoadShardedStrategy, None] = None
) -> CkptShardedMetadata:
"""Load tensors metadata from the checkpoint.
Returns a dictionary similar to a sharded state dict, but note that
the dictionary keys are simply ShardedTensor keys (contrary to the
actual sharded state dicts where keys correspond to state dict keys).
Dict values are ShardedTensors without any sharding (so, the only useful
information is tensors global shape and dtype).
Concrete implementation depends on the loading strategy. If no strategy is
given, a default for a given backend is used.
Args:
checkpoint_dir (str): checkpoint directory to load from
sharded_strategy (LoadShardedStrategy, optional): sharded strategy to load metadata.
Defaults to None - in this case a default load strategy for a given checkpoint type
is used.
Returns:
CkptShardedMetadata: flat state dict without data describing ShardedTensors
in the checkpoint
"""
sharded_strategy, common_strategy = verify_checkpoint_and_load_strategy(
checkpoint_dir, sharded_strategy
)
return sharded_strategy.load_tensors_metadata(Path(checkpoint_dir))
def load_sharded_metadata(
checkpoint_dir: str,
sharded_strategy: Union[LoadShardedStrategy, None] = None,
common_strategy: Union[LoadCommonStrategy, None] = None,
) -> CkptShardedMetadata:
"""Load sharded metadata from the checkpoint.
Similar to `load_tensors_metadata`, but includes also ShardedObjects.
Returns a dictionary similar to a sharded state dict, but note that
the dictionary keys are simply ShardedTensor keys (contrary to the
actual sharded state dicts where keys correspond to state dict keys).
Dict values are ShardedTensors without any sharding (so, the only useful
information is tensors global shape and dtype).
Concrete implementation depends on the loading strategy. If no strategy is
given, a default for a given backend is used.
Args:
checkpoint_dir (str): checkpoint directory to load from
sharded_strategy (LoadShardedStrategy, optional): sharded strategy to load metadata.
Defaults to None - in this case a default load strategy for a given checkpoint type
is used.
common_strategy (LoadCommonStrategy, optional): common strategy to load metadata.
Defaults to None - in this case a default load strategy for a given checkpoint type is
used. This strategy won't be used unless `sharded_strategy` can't handle ShardedObjects
Returns:
CkptShardedMetadata: flat state dict without data describing ShardedTensors
and ShardedObjects in the checkpoint
"""
sharded_strategy, common_strategy = verify_checkpoint_and_load_strategy(
checkpoint_dir, sharded_strategy, common_strategy
)
sharded_metadata = sharded_strategy.load_sharded_metadata(Path(checkpoint_dir))
if not sharded_strategy.can_handle_sharded_objects:
validate_sharded_objects_handling(sharded_strategy, common_strategy)
common_metadata = common_strategy.load_sharded_metadata(Path(checkpoint_dir))
sharded_metadata = merge(sharded_metadata, common_metadata)
return sharded_metadata
def load_plain_tensors(checkpoint_dir: str) -> StateDict:
"""Load checkpoint tensors without any sharding and plain structure.
NOTE: common state dict is NOT included.
Args:
checkpoint_dir (str): checkpoint directory to load the tensors from.
Returns:
StateDict: checkpoint state dict containing only torch.Tensors.
"""
sharded_state_dict = load_tensors_metadata(checkpoint_dir)
# Don't validate integrity because shards will be overlapped
# if world_size > 1 (all processes load whole tensors)
return load(sharded_state_dict, checkpoint_dir, validate_access_integrity=False)
#
# def load_plain_tensors_and_objects(checkpoint_dir: str) -> StateDict:
# """Load checkpoint tensors and objects without any sharding and plain structure.
#
# NOTE: state dict structure might be different than the one used for checkpoint saving.
# NOTE: common state dict is NOT included.
#
# Args:
# checkpoint_dir (str): checkpoint directory to load the state dict from.
#
# Returns:
# StateDict: complete checkpoint state dict without any sharding.
# """
# sharded_state_dict = load_tensors_metadata(checkpoint_dir)
# # Don't validate integrity because shards will be overlapped
# # if world_size > 1 (all processes load whole tensors)
# return load(sharded_state_dict, checkpoint_dir, validate_access_integrity=False)
def remove_sharded_tensors(checkpoint_dir: str, key_prefix: str):
"""determine the appropriate sharding strategy and delegate removal to the sharded strategy"""
sharded_strategy, common_strategy = verify_checkpoint_and_load_strategy(checkpoint_dir)
sharded_strategy.remove_sharded_tensors(checkpoint_dir, key_prefix)
def save(
sharded_state_dict: ShardedStateDict,
checkpoint_dir: str,
sharded_strategy: Union[SaveShardedStrategy, Tuple[str, int], None] = None,
common_strategy: Union[SaveCommonStrategy, Tuple[str, int], None] = None,
validate_access_integrity: bool = True,
async_sharded_save: bool = False,
preprocess_common_before_consistancy_check: Callable[[CommonStateDict], StateDict] = None,
) -> Optional[AsyncRequest]:
"""Saving entrypoint.
Extracts ShardedTensors from the given state dict. Rank 0 saves the
"regular" part of the checkpoint to common torch file.
The ShardedTensors are saved according to a strategy specified by the
config.
Steps:
1. Apply factories
2. Extract and discard LocalNonPersistentObject
3. Extract all ShardedBase object
4. Save all other objects to common.pt
5. (optional) Extract and save ShardedObjects
6. Save all ShardedBase objects
7. Write metadata.json file with backend and version metadata.
Step (6) can be performed asynchronously (see `async_sharded_save`), in this
case the actual save is embodied in the returned async request and can be
scheduled by the external caller. For async request, step (7) is added as
one of the finalization functions, so that metadata.json is written only
if the checkpoint is complete.
Args:
sharded_state_dict (ShardedStateDict): state dict of the populated with
ShardedTensors. Used as a mapping to determine how local tensors
should be saved as global tensors in the checkpoint.
checkpoint_dir (str): directory to save the checkpoint to
sharded_strategy (SaveShardedStrategy, Tuple[str, int], optional):
configures sharded tensors saving behavior and backend
common_strategy (SaveCommonStrategy, Tuple[str, int], optional):
configures common data saving behavior and backend
validate_access_integrity (bool default = True): checks if each tensor shard is accessed
exactly once (as main replica) by some process.
It also makes sure the common state dict is consistant across all ranks
async_sharded_save (bool, optional): if True, for the sharded state dict part
an async save implementation will be called, with the AsyncRequest
being returned to the caller. Note that it is the caller responsibility to
actually schedule the async save. Defaults to False.
preprocess_common_before_consistancy_check (Callable[[CommonStateDict], StateDict], None):
A callable function that will preprocess the common state dict (i.e can be used to
remove keys that we expect to be different in the state dict). The function must not
modify the original state dict
Returns:
AsyncRequest (optional): if `async_sharded_save` is True, returns
async request that should be scheduled by the caller of this function.
None otherwise.
"""
checkpoint_dir = Path(checkpoint_dir)
if torch.distributed.get_rank() == 0:
if not checkpoint_dir.exists():
raise CheckpointingException(
f'Checkpoint destination directory does not exist: {checkpoint_dir}'
)
if next(checkpoint_dir.iterdir(), None) is not None:
raise CheckpointingException(
f'Checkpoint destination directory ({checkpoint_dir}) is not empty'
)
if common_strategy is not None:
raise NotImplementedError('The only supported common strategy is torch')
if sharded_strategy is None:
sharded_strategy = get_default_save_sharded_strategy()
if not isinstance(sharded_strategy, SaveShardedStrategy):
assert isinstance(sharded_strategy, tuple), type(sharded_strategy)
sharded_strategy = get_default_strategy(StrategyAction.SAVE_SHARDED, *sharded_strategy)
if common_strategy is None:
common_strategy = get_default_save_common_strategy()
if not isinstance(common_strategy, SaveCommonStrategy):
assert isinstance(common_strategy, tuple), type(common_strategy)
common_strategy = get_default_strategy(StrategyAction.SAVE_COMMON, *common_strategy)
sharded_state_dict, state_dict = save_preprocess(
sharded_state_dict, validate_access_integrity, preprocess_common_before_consistancy_check
)
common_strategy.save_common(state_dict, checkpoint_dir)
if not sharded_strategy.can_handle_sharded_objects:
validate_sharded_objects_handling(sharded_strategy, common_strategy)
sharded_objects_state_dict, sharded_state_dict = extract_matching_values(
sharded_state_dict, lambda v: isinstance(v, ShardedObject)
)
common_strategy.save_sharded_objects(sharded_objects_state_dict, checkpoint_dir)
def metadata_finalize_fn():
if torch.distributed.get_rank() == 0:
save_config(
CheckpointingConfig(sharded_strategy.backend, sharded_strategy.version),
checkpoint_dir,
)
torch.distributed.barrier()
if not async_sharded_save:
sharded_strategy.save(sharded_state_dict, checkpoint_dir)
metadata_finalize_fn()
return
if not isinstance(sharded_strategy, AsyncSaveShardedStrategy):
raise CheckpointingException(
f'Cannot apply async_save to non-async strategy {sharded_strategy}'
)
async_request = sharded_strategy.async_save(sharded_state_dict, checkpoint_dir)
async_request.finalize_fns.append(metadata_finalize_fn)
return async_request
def get_default_save_sharded_strategy(
backend: str = 'torch_dist', version: int = 1
) -> SaveShardedStrategy:
"""Get default save sharded strategy."""
return get_default_strategy(StrategyAction.SAVE_SHARDED, backend, version)
def get_default_save_common_strategy(
backend: str = 'torch', version: int = 1
) -> SaveCommonStrategy:
"""Get default save common strategy."""
return get_default_strategy(StrategyAction.SAVE_COMMON, backend, version)
def get_default_load_sharded_strategy(checkpoint_dir: str) -> LoadShardedStrategy:
"""Get default load sharded strategy."""
return verify_checkpoint_and_load_strategy(checkpoint_dir)[0]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment