Commit 523ec9cc authored by wangsen's avatar wangsen
Browse files

all

parents
Pipeline #1668 failed with stages
in 0 seconds
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Utilities for Retro preprocessing."""
import glob
import logging
import os
from collections import defaultdict
from types import SimpleNamespace
from typing import Any, Callable, Dict, List, Optional
import numpy as np
import torch
from tqdm import tqdm
from megatron.core import parallel_state
from megatron.core.datasets.retro.config import RetroPreprocessingConfig
from megatron.core.datasets.retro.query.multi_split_gpt_dataset import (
MultiSplitGPTDataset,
MultiSplitGPTDatasetConfig,
)
from megatron.core.utils import log_single_rank
from .external_libs import h5py
logger = logging.getLogger(__name__)
def log_retro_rank_0(message: str) -> None:
"""Log on rank 0.
Args:
message (str): Message to log.
"""
log_single_rank(logger, logging.INFO, "[RETRO] " + message)
def retro_makedir(config: RetroPreprocessingConfig, path: str) -> None:
"""Make a directory, conditional on not being in validation mode.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
path (str): Path to directory.
"""
if config.retro_task_validate is None:
os.makedirs(path, exist_ok=True)
def extract_data_config(config: RetroPreprocessingConfig) -> MultiSplitGPTDatasetConfig:
"""Extract data config from dataset.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
Returns:
The config object used to build the dataset.
"""
return config.retro_gpt_chunk_datasets.train["dataset"].sample_dataset.config
def get_num_chunks_per_sample(sample_length: int, chunk_length: int) -> int:
"""Compute seq_length // chunk_length.
Args:
sample_length (int): Alias of `sequence_length`.
chunk_length (int): Retro chunk length (e.g., 64).
Returns:
Number of chunks per sample (i.e., `sequence_length` / `chunk_length`).
"""
assert sample_length % chunk_length == 0
return sample_length // chunk_length
class GPTToTextDataset(torch.utils.data.Dataset):
"""Dataset to convert GPT tokens to text.
Args:
gpt_dataset (MultiSplitGPTDataset): GPT dataset, which outputs GPT token samples.
gpt_tokenizer (Any): GPT tokenizer.
"""
def __init__(self, gpt_dataset: MultiSplitGPTDataset, gpt_tokenizer: Any):
super().__init__()
self.gpt_dataset = gpt_dataset
self.gpt_tokenizer = gpt_tokenizer
def __len__(self) -> int:
"""Dataset length.
Returns:
Number of samples in the dataset.
"""
return len(self.gpt_dataset)
def __getitem__(self, idx: int) -> dict:
"""Get dataset sample.
Args:
idx (int): Index of sample.
Returns:
A dict containing attribute 'text' of type string.
"""
gpt_token_ids = self.gpt_dataset[idx]["text"].tolist()
text = self.gpt_tokenizer.detokenize(gpt_token_ids)
return {"text": text}
def get_blocks(
dirname: str, n_samples: int, block_size: int, validate: Callable = None,
) -> SimpleNamespace:
"""Divide range [0, num_samples) to sequence of block ranges.
This is a core method within the concept of block processing. The idea
is to divide a range (size n_samples) into a sequence of blocks. Each
block corresponds to a file within 'dirname' with name
'{start_idx}-{end_idx}.hdf5'. This method checks for the existence of
these files, and returns two lists, one for existing blocks and one for
missing blocks.
Args:
dirname (str): Path to directory containing block files.
n_samples (int): Ideal number of samples. The total number of saved block data is <=n_samples.
block_size (int): Max number of samples per block file (e.g., 100000).
validate (Callable): Method for validating each block file during load.
Returns:
A namespace consisting of 2 lists: existing blocks, and missing blocks. The total number of samples between the existing and missing blocks should equal n_samples above.
"""
assert os.path.isdir(dirname), "missing directory '%s.'" % dirname
# Block ranges.
block_start_idxs = list(range(0, n_samples, block_size))
block_end_idxs = [min(n_samples, i + block_size) for i in block_start_idxs]
block_ranges = list(zip(block_start_idxs, block_end_idxs))
# All block files (existing + missing).
n_digits = int(np.ceil(np.log(n_samples) / np.log(10)) + 1)
all_blocks = [
{
"range": r,
"path": os.path.join(
dirname, "%s-%s.hdf5" % tuple([str(i).zfill(n_digits) for i in r]),
),
}
for r in block_ranges
]
all_block_path_set = set(block["path"] for block in all_blocks)
# Validate function.
validate = (lambda f: None) if validate is None else validate
# Delete corrupt files.
if torch.distributed.get_rank() == 0:
existing_block_paths = [
block["path"] for block in all_blocks if os.path.exists(block["path"])
]
for index, path in enumerate(tqdm(existing_block_paths, "validating block.")):
assert path in all_block_path_set, "unexpected filename, '%s'." % path
try:
f = h5py.File(path, "r")
except:
os.remove(path)
continue
try:
validate(f)
except:
os.remove(path)
finally:
f.close()
# Wait for files to be deleted.
torch.distributed.barrier()
# Collect blocks.
blocks = SimpleNamespace(
existing=[b for b in all_blocks if os.path.exists(b["path"])],
missing=[b for b in all_blocks if not os.path.exists(b["path"])],
)
return blocks
def get_blocks_by_rank(
dirname: str,
n_samples: int,
block_size: int,
validate: Callable = None,
sample: Optional[float] = None,
) -> SimpleNamespace:
"""Divide existing and missing blocks evenly across all ranks.
See 'get_blocks()' above for description. The returned lists of existing and
missing blocks are split evenly across ranks via interleaving. This way,
each rank has a roughly equal number of blocks to process for a
downstream operation.
Args:
dirname (str): Path to directory containing block files.
n_samples (int): Ideal number of samples. The total number of saved block data is <=n_samples.
block_size (int): Max number of samples per block file (e.g., 100000).
validate (Callable): Method for validating each block file during load.
sample (Optional[float]): If provided, sample a random subset of the blocks. Used for validating preprocessing correctness.
Returns:
A namespace consisting of 2 lists: existing blocks, and missing blocks. Each of these two lists is potentially a sub-sample of the total set of existing and missing blocks, depending on whether sampling is used. Additionally, the attributes n_existing_world and n_missing_world are the total number of existing and missing blocks, independent of samples. Therefore, (n_existing_world + n_missing_world) * block_size == n_samples.
"""
# Get world blocks.
blocks = get_blocks(dirname, n_samples, block_size, validate)
# This rank's existing and missing files.
data_parallel_rank = parallel_state.get_data_parallel_rank()
data_parallel_world_size = parallel_state.get_data_parallel_world_size()
rank_existing_blocks = blocks.existing[
data_parallel_rank : len(blocks.existing) : data_parallel_world_size
]
rank_missing_blocks = blocks.missing[
data_parallel_rank : len(blocks.missing) : data_parallel_world_size
]
# Extend rank's existing and missing blocks (with None) such that all ranks
# have equal length lists. This allows for easier tracking of global progress.
def get_world_max(n: int) -> int:
"""Get max value across ranks.
Args:
n (int): Value on this rank.
Returns:
Max value across all ranks.
"""
n_tensor = torch.cuda.LongTensor([n])
torch.distributed.all_reduce(n_tensor, op=torch.distributed.ReduceOp.MAX)
return n_tensor.item()
max_n_existing = get_world_max(len(rank_existing_blocks))
max_n_missing = get_world_max(len(rank_missing_blocks))
rank_existing_blocks += [None] * (max_n_existing - len(rank_existing_blocks))
rank_missing_blocks += [None] * (max_n_missing - len(rank_missing_blocks))
# Collect blocks.
blocks = SimpleNamespace(
n_existing_world=len(blocks.existing),
n_missing_world=len(blocks.missing),
existing=rank_existing_blocks,
missing=rank_missing_blocks,
)
if sample is not None:
# Sample existing and missing blocks evenly across all ranks. The
# returned lists of blocks are randomly sampled (without replacement)
# to yield `sample * len(blocks)` number of blocks.
# Randomly sample blocks.
def sample_blocks(_blocks: List[Optional[Dict]]) -> List[Optional[Dict]]:
"""Sample a random subset of all blocks.
Args:
_blocks (List[Optional[Dict]]): List of all blocks.
Returns:
A random subset of the blocks.
"""
n_blocks_sample = int(np.ceil(sample * len(_blocks)))
sampled_blocks: List[Optional[Dict]] = [b for b in _blocks if b is not None]
np.random.seed(None)
np.random.shuffle(sampled_blocks)
sampled_blocks = sampled_blocks[:n_blocks_sample]
sampled_blocks += [None] * (n_blocks_sample - len(sampled_blocks))
return sampled_blocks
blocks.existing = sample_blocks(blocks.existing)
blocks.missing = sample_blocks(blocks.missing)
return blocks
class BlockPathMap:
"""Map an index to its containing block path.
The common use for this class is to have a directory of files containing
blocks of processed data, of uniform block size (e.g., 100k samples per
file). Each file must follow a naming convention of 'startIdx-endIdx.[ext]',
where 'endIdx' minus 'startIdx' must equal the block size, with the possible
exception of the final block. Given an input index, this class maps the
index to the containing block file.
Args:
block_paths (List[str]): List of paths to saved block files.
block_size (int): Max number of samples per block file (e.g., 100000).
"""
@classmethod
def from_dir(cls, dir: str, block_size: int, ext: str = "hdf5") -> Any:
"""Get list of block files, and create map.
Args:
dir (str): Path to directory containing saved block files.
block_size (int): Max number of samples per block file (e.g., 100000).
ext (str): Block file extension (e.g., 'hdf5').
Returns:
A mapping of sample index to block file path.
"""
assert os.path.isdir(dir), f"directory not found, '{dir}'."
return cls(sorted(glob.glob(dir + f"/*.{ext}")), block_size)
def __init__(self, block_paths: List[str], block_size: int):
self.max_idx = 0
self.block_path_map = {}
for block_path in block_paths:
name = os.path.splitext(os.path.basename(block_path))[0]
start_idx, end_idx = [int(i) for i in name.split("-")]
self.block_path_map[start_idx] = block_path
self.max_idx = max(self.max_idx, end_idx)
self.block_size = block_size
def __str__(self) -> str:
"""Stringify the mapping.
Returns:
A string representation of this block path map.
"""
return "%d paths" % len(self.block_path_map)
def __getitem__(self, idx: int) -> str:
"""Get block path from index.
Args:
idx (int): Index of sample.
Returns:
The path to the block file containing the sample index.
"""
block_start_idx = self.block_size * (idx // self.block_size)
block_path = self.block_path_map[block_start_idx]
return block_path
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from collections import deque
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Union
import numpy
from megatron.core.datasets.indexed_dataset import IndexedDataset
from megatron.core.datasets.masked_dataset import (
MaskedWordPieceDataset,
MaskedWordPieceDatasetConfig,
)
from megatron.core.datasets.utils import Split
@dataclass
class T5MaskedWordPieceDatasetConfig(MaskedWordPieceDatasetConfig):
"""Configuration object for Megatron Core T5 WordPiece datasets
NB: As a temporary holdover from Megatron-LM. The T5 tokenizer has an attribute which defines
a number of special sentinel tokens used during sampling. The assert in __post_init__ serves to
preserve compatibility with Megatron-LM until the T5 tokenizer is in Megatron Core.
"""
sequence_length_encoder: Optional[int] = field(init=False, default=None)
"""A sequence_length alias and the sequence length for the encoder"""
sequence_length_decoder: int = None
"""The sequence length for the decoder"""
def __post_init__(self) -> None:
"""Do asserts and set fields post init
"""
super().__post_init__()
self.sequence_length_encoder = self.sequence_length
assert self.sequence_length_encoder is not None
assert self.sequence_length_decoder is not None
assert len(self.tokenizer.additional_special_tokens_ids) > 0
class T5MaskedWordPieceDataset(MaskedWordPieceDataset):
"""The T5 dataset that assumes WordPiece tokenization
Args:
indexed_dataset (IndexedDataset): The IndexedDataset around which to build the MegatronDataset
dataset_path (str): The real path on disk to the dataset, for bookkeeping
indexed_indices (numpy.ndarray): The set of the documents indices to expose
num_samples (Optional[int]): The number of samples to draw from the indexed dataset. When None, build as many samples as correspond to one epoch.
index_split (Split): The indexed_indices Split
config (T5MaskedWordPieceDatasetConfig): The config
"""
def __init__(
self,
indexed_dataset: IndexedDataset,
dataset_path: str,
indexed_indices: numpy.ndarray,
num_samples: Optional[int],
index_split: Split,
config: T5MaskedWordPieceDatasetConfig,
) -> None:
super().__init__(
indexed_dataset, dataset_path, indexed_indices, num_samples, index_split, config
)
self.token_lookup = list(self.config.tokenizer.inv_vocab.keys())
# Account for the single <bos> and single <eos> token ids
self.sample_index = self._build_sample_index(self.config.sequence_length - 2, 1)
@staticmethod
def _key_config_attributes() -> List[str]:
"""Inherited method implementation
Returns:
List[str]: The key config attributes
"""
return super(
T5MaskedWordPieceDataset, T5MaskedWordPieceDataset
)._key_config_attributes() + ["sequence_length_decoder",]
def __getitem__(self, idx: int) -> Dict[str, Union[int, numpy.ndarray]]:
"""Abstract method implementation
Args:
idx (int): The index into the dataset
Returns:
Dict[str, Union[int, numpy.ndarray]]: The
"""
idx_beg, idx_end, target_sequence_length = self.sample_index[idx]
sample = [self.dataset[i] for i in range(idx_beg, idx_end)]
numpy_random_state = numpy.random.RandomState(
seed=(self.config.random_seed + idx) % 2 ** 32
)
assert target_sequence_length <= self.config.sequence_length
# Flatten the sample into a list of tokens
tokens = [token for sentence in sample for token in sentence]
# Truncate the list of tokens to a desired length
truncated = len(tokens) > target_sequence_length
tokens = tokens[:target_sequence_length]
# Masking
(tokens, _, _, _, masked_spans,) = self._create_masked_lm_predictions(
tokens, target_sequence_length, numpy_random_state
)
# Prepare the encoder input and decoder input and output
sentinels = deque(self.config.tokenizer.additional_special_tokens_ids)
encoder_input = []
decoder_input = [self.config.tokenizer.bos]
decoder_output = []
idx_beg = 0
for indices, labels in masked_spans:
sentinel = sentinels.popleft()
# set the end index
idx_end = indices[0]
encoder_input.extend(tokens[idx_beg:idx_end])
encoder_input.append(sentinel)
decoder_input.append(sentinel)
decoder_input.extend(labels)
decoder_output.append(sentinel)
decoder_output.extend(labels)
# set the start index
idx_beg = indices[-1] + 1
encoder_input.extend(tokens[idx_beg:])
decoder_output.append(self.config.tokenizer.eos)
# Pad the sequences and convert to NumPy
length_toks_encoder = len(encoder_input)
length_toks_decoder = len(decoder_input)
length_pads_encoder = self.config.sequence_length_encoder - length_toks_encoder
length_pads_decoder = self.config.sequence_length_decoder - length_toks_decoder
assert length_pads_encoder >= 0
assert length_pads_decoder >= 0
encoder_input = numpy.array(encoder_input, dtype=numpy.int64)
encoder_input = numpy.pad(
encoder_input, (0, length_pads_encoder), constant_values=self.config.tokenizer.pad
)
decoder_input = numpy.array(decoder_input, dtype=numpy.int64)
decoder_input = numpy.pad(
decoder_input, (0, length_pads_decoder), constant_values=self.config.tokenizer.pad
)
# Create attention and history masks
mask_encoder = self._make_attention_mask(encoder_input, encoder_input)
mask_encoder_decoder = self._make_attention_mask(decoder_input, encoder_input)
mask_decoder = self._make_attention_mask(decoder_input, decoder_input)
mask_decoder = mask_decoder * self._make_history_mask(decoder_input)
# Mask the labels
decoder_output = numpy.array(decoder_output, dtype=numpy.int64)
decoder_output = numpy.pad(decoder_output, (0, length_pads_decoder), constant_values=-1)
# Get the loss mask
loss_mask = numpy.zeros(self.config.sequence_length_decoder, dtype=numpy.int64)
loss_mask[:length_toks_decoder] = 1
return {
"text_enc": encoder_input,
"text_dec": decoder_input,
"labels": decoder_output,
"loss_mask": loss_mask,
"truncated": int(truncated),
"enc_mask": mask_encoder,
"dec_mask": mask_decoder,
"enc_dec_mask": mask_encoder_decoder,
}
@staticmethod
def _make_attention_mask(
source_block: numpy.ndarray, target_block: numpy.ndarray
) -> numpy.ndarray:
"""Return a 2-D attention mask
Args:
source_block (numpy.ndarray): A 1-D array
target_block (numpy.ndarray): A 1-D array
Returns:
numpy.ndarray: The 2-D attention mask
"""
mask = (target_block[None, :] >= 1) * (source_block[:, None] >= 1)
return mask.astype(numpy.int64)
@staticmethod
def _make_history_mask(block: numpy.ndarray) -> numpy.ndarray:
"""Return a 2-D history (lower-left-triangular) mask
Args:
block (numpy.ndarray): A 1-D array
Returns:
numpy.ndarray: The 2-D history (lower-left-triangular) mask
"""
arange = numpy.arange(block.shape[0])
mask = arange[None,] <= arange[:, None]
return mask.astype(numpy.int64)
def _get_token_mask(self, numpy_random_state: numpy.random.RandomState) -> int:
"""Abstract method implementation
100% of the time, replace the token id with mask token id.
Args:
numpy_random_state (RandomState): The NumPy random state
Returns:
int: The mask token id
"""
return self.config.tokenizer.mask
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import logging
from enum import Enum
from typing import List, Optional, Tuple
import numpy
import torch
from ..utils import log_single_rank
logger = logging.getLogger(__name__)
class Split(Enum):
train = 0
valid = 1
test = 2
def compile_helpers():
"""Compile C++ helper functions at runtime. Make sure this is invoked on a single process.
"""
import os
import subprocess
command = ["make", "-C", os.path.abspath(os.path.dirname(__file__))]
if subprocess.run(command).returncode != 0:
import sys
log_single_rank(logger, logging.ERROR, "Failed to compile the C++ dataset helper functions")
sys.exit(1)
def normalize(weights: List[float]) -> List[float]:
"""Do non-exponentiated normalization
Args:
weights (List[float]): The weights
Returns:
List[float]: The normalized weights
"""
w = numpy.array(weights, dtype=numpy.float64)
w_sum = numpy.sum(w)
w = (w / w_sum).tolist()
return w
def get_blend_from_list(
blend: Optional[List[str]],
) -> Optional[Tuple[List[str], Optional[List[float]]]]:
"""Get the megatron.core.datasets.blended_megatron_dataset_config.BlendedMegatronDatasetConfig blend from the blend list
Args:
blend (Optional[List[str]]): The blend list, which can be either (1) a list of prefixes, e.g. ["path/to/dataset_1_prefix", "path/to/dataset_2_prefix"], or (2) a flattened, zipped list of weights and prefixes, e.g. ["30", "path/to/dataset_1_prefix", "70", "path/to/dataset_2_prefix"]
Returns:
Optional[Tuple[List[str], Optional[List[float]]]]: The blend, consisting of a list of dataset prefixes and optionally a list of dataset weights, e.g. [["path/to/dataset_1_prefix", "path/to/dataset_2_prefix"], [30.0, 70.0]].
"""
if blend is None:
return None
if len(blend) % 2 == 1:
weight_per_dataset = None
raw_prefix_per_dataset = blend
else:
raw_weight_per_dataset, raw_prefix_per_dataset = zip(
*[(blend[i], blend[i + 1]) for i in range(0, len(blend), 2)]
)
weight_per_dataset = []
for rwpd in raw_weight_per_dataset:
try:
weight = float(rwpd)
except ValueError:
weight = None
weight_per_dataset.append(weight)
is_none = map(lambda _: _ is None, weight_per_dataset)
if any(is_none):
assert all(is_none)
weight_per_dataset = None
raw_prefix_per_dataset = blend
prefix_per_dataset = [rppd.strip() for rppd in raw_prefix_per_dataset]
return prefix_per_dataset, weight_per_dataset
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
from .core import check_is_distributed_checkpoint
from .mapping import LocalNonpersitentObject, ShardedTensor
from .serialization import (
load,
load_common_state_dict,
load_plain_tensors,
load_tensors_metadata,
save,
)
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Module for managing distributed checkpoints metadata. """
import json
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Optional
CONFIG_FNAME = 'metadata.json'
class CheckpointingException(Exception):
""" Base checkpointing related exception """
pass
@dataclass
class CheckpointingConfig:
""" Documents backends used in the checkpoint.
Checkpoint config keeps track of formats used for storing the sharded tensors
(sharded_backend) and other objects (common_backend).
Note that versioning is not for the checkpoint content (which is application specific),
but for the checkpoint format itself.
"""
sharded_backend: str
sharded_backend_version: int = 1
common_backend: str = 'torch'
common_backend_version: int = 1
def check_is_distributed_checkpoint(checkpoint_dir):
""" Checks if `metadata.json` exists in the checkpoint and is a valid config.
Args:
checkpoint_dir: checkpoint directory
Returns:
bool: True if `metadata.json` exists in the checkpoint and is a valid config.
"""
return maybe_load_config(checkpoint_dir) is not None
def maybe_load_config(checkpoint_dir: str) -> Optional[CheckpointingConfig]:
""" Returns checkpoint config if `checkpoint_dir` is a distributed checkpoint and None otherwise
Args:
checkpoint_dir: checkpoint directory
Returns:
CheckpointingConfig (optional): None if checkpoint is not a valid distributed checkpoint
"""
config_path = Path(checkpoint_dir, CONFIG_FNAME)
if not config_path.exists():
return None
with config_path.open() as f:
config_dict = json.load(f)
return CheckpointingConfig(**config_dict)
def save_config(config: CheckpointingConfig, checkpoint_dir: str):
""" Save given config to checkpoint directory.
Args:
config: checkpoint config
checkpoint_dir: checkpoint directory
Returns:
None
"""
config_path = Path(checkpoint_dir, CONFIG_FNAME)
with config_path.open('w') as f:
json.dump(asdict(config), f)
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Utilities for operating with dicts and lists.
All functions in this module handle nesting of dicts and lists.
Other objects (e.g. tuples) are treated as atomic leaf types that cannot be traversed.
"""
from collections import defaultdict
from typing import Any, Callable, Iterable, Optional, Tuple, Union
import torch
def extract_matching_values(
x: Union[dict, list], predicate: Callable[[Any], bool], return_lists_as_dicts: bool = False
) -> Tuple[Union[dict, list], Union[dict, list]]:
""" Return matching and nonmatching values. Keeps hierarchy.
Args:
x (Union[dict, list]) : state dict to process. Top-level argument must be a dict or list
predicate (object -> bool): determines matching values
return_lists_as_dicts (bool): if True, matching lists will be turned
into dicts, with keys indicating the indices of original elements.
Useful for reconstructing the original hierarchy.
"""
def _set_elem(target, k, v):
if return_lists_as_dicts:
target[k] = v
else:
target.append(v)
if isinstance(x, dict):
matching_vals = {}
nonmatching_vals = {}
for k, v in x.items():
if isinstance(v, (list, dict)):
match, nonmatch = extract_matching_values(v, predicate, return_lists_as_dicts)
if match:
matching_vals[k] = match
if nonmatch or not v:
nonmatching_vals[k] = nonmatch
elif predicate(v):
matching_vals[k] = v
else:
nonmatching_vals[k] = v
elif isinstance(x, list):
matching_vals = {} if return_lists_as_dicts else []
nonmatching_vals = {} if return_lists_as_dicts else []
for ind, v in enumerate(x):
if isinstance(v, (list, dict)) and v:
match, nonmatch = extract_matching_values(v, predicate, return_lists_as_dicts)
if match:
_set_elem(matching_vals, ind, match)
if nonmatch or not v:
_set_elem(nonmatching_vals, ind, nonmatch)
else:
target = matching_vals if predicate(v) else nonmatching_vals
_set_elem(target, ind, v)
else:
raise ValueError(f'Unexpected top-level object type: {type(x)}')
return matching_vals, nonmatching_vals
def diff(x1: Any, x2: Any, prefix: Tuple = ()) -> Tuple[list, list, list]:
""" Recursive diff of dicts.
Args:
x1 (object): left dict
x2 (object): right dict
prefix (tuple): tracks recursive calls. Used for reporting differing keys.
Returns:
Tuple[list, list, list]: tuple of:
- only_left: Prefixes present only in left dict
- only_right: Prefixes present only in right dict
- mismatch: values present in both dicts but not equal across dicts.
For tensors equality of all elems is checked.
Each element is a tuple (prefix, type of left value, type of right value).
"""
mismatch = []
if isinstance(x1, dict) and isinstance(x2, dict):
only_left = [prefix + (k,) for k in x1.keys() - x2.keys()]
only_right = [prefix + (k,) for k in x2.keys() - x1.keys()]
for k in x2.keys() & x1.keys():
_left, _right, _mismatch = diff(x1[k], x2[k], prefix + (k,))
only_left.extend(_left)
only_right.extend(_right)
mismatch.extend(_mismatch)
elif isinstance(x1, list) and isinstance(x2, list):
only_left = list(range(len(x1) - 1, len(x2) - 1, -1))
only_right = list(range(len(x1) - 1, len(x2) - 1, -1))
for i, (v1, v2) in enumerate(zip(x1, x2)):
_left, _right, _mismatch = diff(v1, v2, prefix + (i,))
only_left.extend(_left)
only_right.extend(_right)
mismatch.extend(_mismatch)
else:
only_left = []
only_right = []
if isinstance(x1, torch.Tensor) and isinstance(x2, torch.Tensor):
_is_mismatch = not torch.all(x1 == x2)
else:
try:
_is_mismatch = bool(x1 != x2)
except RuntimeError:
_is_mismatch = True
if _is_mismatch:
mismatch.append((prefix, type(x1), type(x2)))
return only_left, only_right, mismatch
def inspect_types(x: Any, prefix: Tuple = (), indent: int = 4):
""" Helper to print types of (nested) dict values. """
print_indent = lambda: print(' ' * indent * len(prefix), end='')
if isinstance(x, dict):
print()
for k, v in x.items():
print_indent()
print(f'> {k}: ', end='')
inspect_types(v, prefix + (k,), indent)
elif isinstance(x, list):
print()
for i, v in enumerate(x):
print_indent()
print(f'- {i}: ', end='')
inspect_types(v, prefix + (i,), indent)
else:
if isinstance(x, torch.Tensor):
print(f'Tensor of shape {x.shape}')
else:
try:
x_str = str(x)
except:
x_str = '<no string repr>'
if len(x_str) > 30:
x_str = x_str[:30] + '... (truncated)'
print(f'[{type(x)}]: {x_str}')
def nested_values(x: Union[dict, list]):
""" Returns iterator over (nested) values of a given dict or list. """
x_iter = x.values() if isinstance(x, dict) else x
for v in x_iter:
if isinstance(v, (dict, list)):
yield from nested_values(v)
else:
yield v
def nested_items_iter(x: Union[dict, list]):
""" Returns iterator over (nested) tuples (container, key, value) of a given dict or list. """
x_iter = x.items() if isinstance(x, dict) else enumerate(x)
for k, v in x_iter:
if isinstance(v, (dict, list)):
yield from nested_items_iter(v)
else:
yield x, k, v
def dict_map(f: Callable, d: dict):
""" `map` equivalent for dicts. """
for sub_d, k, v in nested_items_iter(d):
sub_d[k] = f(v)
def dict_map_with_key(f: Callable, d: dict):
""" `map` equivalent for dicts with a function that accepts tuple (key, value). """
for sub_d, k, v in nested_items_iter(d):
sub_d[k] = f(k, v)
def dict_list_map_inplace(f: Callable, x: Union[dict, list]):
""" Maps dicts and lists *in-place* with a given function. """
if isinstance(x, dict):
for k, v in x.items():
x[k] = dict_list_map_inplace(f, v)
elif isinstance(x, list):
x[:] = (dict_list_map_inplace(f, v) for v in x)
else:
return f(x)
return x
def dict_list_map_outplace(f: Callable, x: Union[dict, list]):
""" Maps dicts and lists *out-of-place* with a given function. """
if isinstance(x, dict):
return {k: dict_list_map_outplace(f, v) for k, v in x.items()}
elif isinstance(x, list):
return [dict_list_map_outplace(f, v) for v in x]
else:
return f(x)
def merge(x1: dict, x2: dict, key: Tuple[str, ...] = ()):
""" Merges dicts and lists recursively. """
if isinstance(x1, dict) and isinstance(x2, dict):
for k, v2 in x2.items():
if k not in x1:
x1[k] = v2
else:
x1[k] = merge(x1[k], v2, key=key + (k,))
elif isinstance(x1, list) and isinstance(x2, list):
if len(x1) != len(x2):
raise ValueError(
f'Cannot merge two lists with different lengths ({len(x1)} and {len(x2)}, encountered at level {key})'
)
for i, v2 in enumerate(x2):
x1[i] = merge(x1[i], v2, key=key + (i,))
else:
raise ValueError(
f'Duplicate non-dict and non-list values encountered: `{x1}` and `{x2}` (at level {key})'
)
return x1
def map_reduce(
xs: Iterable,
key_fn: Callable = lambda x: x,
value_fn: Callable = lambda x: x,
reduce_fn: Callable = lambda x: x,
) -> dict:
""" Simple map-reduce implementation following `more_itertools.map_reduce` interface. """
res = defaultdict(list)
for x in xs:
res[key_fn(x)].append(value_fn(x))
for k in res:
res[k] = reduce_fn(res[k])
return dict(res)
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Core library classes for representing sharding of tensors and objects.
The main expected usage is wrapping torch.Tensors in state dicts with
ShardedTensor class (mostly with the ShardedTensor.from_rank_offsets classmethod).
"""
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass, replace
from itertools import chain
from typing import Any, Callable, Dict, Optional, Tuple, Union
import numpy as np
import torch
from .core import CheckpointingException
from .dict_utils import dict_list_map_inplace, dict_list_map_outplace
logger = logging.getLogger(__name__)
# These type definitions are just hints to differentiate a plain model state
# dict (StateDict) from a state dict with tensors replaced with ShardedTensors
# (ShardedStateDict).
StateDict = Dict[str, Any]
ShardedStateDict = Dict[str, Any]
ReplicaId = Union[int, Tuple[int, ...]]
class ShardedBase(ABC):
key: str
data: object
replica_id: ReplicaId
@abstractmethod
def validate_metadata_integrity(self):
"""Codifies the constraints on metadata attributes."""
@dataclass
class ShardedTensor(ShardedBase):
"""Represents a mapping between a local tensor and a global tensor.
Global tensor is assumed to consist of many local tensors distributed
between different processes.
Args:
key: unique identifier of a global tensor
data: local tensor data. Can be None only for consistency validation
dtype: tensor dtype
local_shape: local tensor shape
global_shape: global tensor shape
global_offset: offset of a local tensor in a global tensor, specified in number of tensor elements
axis_fragmentations: global tensor fragmentation of each axis
replica_id: indicates given local tensor's replication wrt. local tensors in different processes
prepend_axis_num: number of axes prepended to the local tensor to reflect global tensor shape. The behavior is similar to unsqueezing the local tensor.
allow_shape_mismatch: if True, during loading, the global shape of a stored tensor does not have to match the expected global shape. Useful for representing tensors with flexible shape, e.g. padded.
flattened_range: specifies a slice that should be applied to a flattened tensor with `local_shape` in order to get the tensor stored as `data`
"""
key: str
data: Optional[torch.Tensor]
dtype: torch.dtype
local_shape: Tuple[int, ...]
global_shape: Tuple[int, ...]
global_offset: Tuple[int, ...]
axis_fragmentations: Optional[Tuple[int, ...]]
replica_id: ReplicaId = 0
prepend_axis_num: int = 0
allow_shape_mismatch: bool = False
flattened_range: Optional[slice] = None
def __post_init__(self):
self.validate_metadata_integrity()
def validate_metadata_integrity(self) -> None:
"""Codifies the constraints on metadata attributes.
Meeting those constraints is guaranteed when instantiating a ShardedTensor
class with `from_rank_offsets` or `from_rank_offsets_flat` constructors.
Returns:
None
"""
has_flattened_range = self.flattened_range is not None
if self.data is not None:
if self.data.dtype != self.dtype:
raise CheckpointingException(
f'Data dtype should match `dtype` attribute for {self}'
)
if not has_flattened_range and self.data.shape != self.local_shape:
raise CheckpointingException(
f'Data shape should match `local_shape` attribute for {self}'
)
if has_flattened_range:
if self.data.ndim != 1:
raise CheckpointingException(f'Data should be 1D for a flattened {self}')
real_data = self.data
try:
self.data = None
self.init_data(device='meta')
if self.data.shape != real_data.shape:
raise CheckpointingException(
f'Data shape doesnt match expected {self.data.shape} for {self}'
)
finally:
self.data = real_data
if len(self.global_shape) != len(self.global_offset):
raise CheckpointingException(
f'Global offset dimensions should be equal to global shape dimensions for {self}'
)
if len(self.local_shape) + self.prepend_axis_num != len(self.global_shape):
raise CheckpointingException(
f'Local shape together with `prepend_axis_num` dimensions should be equal to global shape dimensions for {self}'
)
for off, sh in zip(self.global_offset[self.prepend_axis_num :], self.local_shape):
if off % sh != 0:
raise CheckpointingException(
f'Global offset ({off}) must be divisible by local shape ({sh}) for {self}.'
)
if has_flattened_range and self.flattened_range.step is not None:
raise CheckpointingException(
f'`step` argument in the flattened range of a ShardedTensor is not supported.'
)
def global_slice(self) -> Tuple[Union[int, slice], ...]:
assert len(self.global_offset) == len(self.local_shape) + self.prepend_axis_num
return tuple(
chain(
(off for off in self.global_offset[: self.prepend_axis_num]),
(
slice(off, off + sh)
for off, sh in zip(
self.global_offset[self.prepend_axis_num :], self.local_shape
)
),
)
)
def global_coordinates(self) -> Tuple[np.ndarray, ...]:
if self.flattened_range is None:
raise CheckpointingException(
f'`global_coordinates` is undefined for'
f' {self.__class__.__name__} without `flattened_range`'
)
local_coords = self.local_coordinates()
assert len(local_coords) + self.prepend_axis_num == len(self.global_offset), (
len(local_coords),
self,
)
global_coords = tuple(
c + off
for c, off in zip((0,) * self.prepend_axis_num + local_coords, self.global_offset)
)
return global_coords
def local_coordinates(self) -> Tuple[np.ndarray, ...]:
if self.flattened_range is None:
raise CheckpointingException(
f'`local_coordinates` is undefined for'
f' {self.__class__.__name__} without `flattened_range`'
)
# TODO: np.unravel_index?
mask = np.zeros(np.product(self.local_shape), dtype=bool)
mask[self.flattened_range] = True
return np.nonzero(mask.reshape(self.local_shape))
def local_chunk_offset_in_global(self) -> Tuple[int, ...]:
"""Offset of a local chunk in a global array of chunks.
Returns:
Tuple[int, ...]: the offset of the whole local chunk in a global array of chunks.
"""
assert len(self.global_offset) == len(self.local_shape) + self.prepend_axis_num
chunk_offset = list(self.global_offset[: self.prepend_axis_num])
for off, sh in zip(self.global_offset[self.prepend_axis_num :], self.local_shape):
assert off % sh == 0, str(self)
chunk_offset.append(off // sh)
return tuple(chunk_offset)
def max_allowed_chunks(self) -> Tuple[int, ...]:
chunks = []
for axis_sh, axis_fragm in zip(self.global_shape, self.axis_fragmentations):
if not self.allow_shape_mismatch and axis_sh % axis_fragm != 0:
raise CheckpointingException(
f'Axis shape ({axis_sh}) not divisible by axis fragmentation ({axis_fragm}'
)
axis_chunk_size = axis_sh // axis_fragm
chunks.append(axis_chunk_size)
return tuple(chunks)
def without_data(self):
return replace(self, data=None)
@classmethod
def from_rank_offsets(
cls,
key: str,
data: torch.Tensor,
*rank_offsets: Tuple[int, int, int],
replica_id: ReplicaId = 0,
prepend_axis_num: int = 0,
flattened_range: None = None,
**init_kwargs,
):
"""Allows to construct the ShardedTensor given offset specified in process ranks.
Args:
key (str): unique key
data (torch.Tensor): local tensor data
rank_offsets (Tuple[int, int, int]): each tuple (axis, axis_rank_offset, axis_fragm) says that if global tensor is divided into `axis_fragm` fragment along `axis` axis, then local tensor data corresponds to the `axis_rank_offset` chunk.
replica_id (ReplicaId): see ShardedTensor
prepend_axis_num (int): see ShardedTensor
flattened_range (None): must be None when using this constructor
init_kwargs: passed to ShardedTensor.__init__
"""
if flattened_range is not None:
raise ValueError(
'Cannot instantiate a flat ShardedTensor with `from_rank_offsets` method.'
' Use `from_rank_offsets_flat` instead'
)
global_offset = [0] * (data.ndim + prepend_axis_num)
global_shape = ([1] * prepend_axis_num) + list(data.shape)
axis_fragmentations = [1] * (data.ndim + prepend_axis_num)
_seen_axis = set()
for axis, axis_rank_offset, axis_fragm in rank_offsets:
assert axis >= 0 and axis_rank_offset >= 0 and axis_fragm >= 0, (
axis,
axis_rank_offset,
axis_fragm,
)
assert (
axis_rank_offset < axis_fragm
), 'Rank offset must be lower than axis fragmentation'
if axis in _seen_axis:
raise CheckpointingException('Duplicated axis specified')
_seen_axis.add(axis)
local_axis_shape = 1 if axis < prepend_axis_num else data.shape[axis - prepend_axis_num]
global_shape[axis] = axis_fragm * local_axis_shape
global_offset[axis] = axis_rank_offset * local_axis_shape
axis_fragmentations[axis] = axis_fragm
return cls(
key,
data,
data.dtype,
tuple(data.shape),
tuple(global_shape),
tuple(global_offset),
tuple(axis_fragmentations),
replica_id,
prepend_axis_num,
flattened_range=flattened_range,
**init_kwargs,
)
@classmethod
def from_rank_offsets_flat(
cls,
key: str,
data: torch.Tensor,
non_flat_local_shape: Tuple[int, ...],
*args,
flattened_range: Optional[slice] = None,
**kwargs,
):
"""Allows to construct a *flattened* ShardedTensor given offset specified in process ranks.
Args:
key (str):
data (torch.Tensor): this should be a flattened data tensor
non_flat_local_shape (Tuple[int, ...]): expected local shape of a non-flat chunk
*args: passed unchanged to the `from_rank_offsets` constructor
flattened_range (slice): see ShardedTensor. Defaults to None, but must be set to
a non-None slice.
**kwargs:
Returns:
ShardedTensor: constructed ShardedTensor instance
"""
if flattened_range is None:
raise CheckpointingException(
'Cannot instantiate a non-flat ShardedTensor with `from_rank_offsets_flat` method.'
' Use `from_rank_offsets` instead'
)
if data.ndim != 1:
raise CheckpointingException(
f'Flattened ShardedTensor requires 1D data, got shape: {data.shape}'
)
if flattened_range.stop - flattened_range.start != data.numel():
raise CheckpointingException(
f'Flattened ShardedTensor data length ({data.numel()}) must meet the slice length: {flattened_range.stop - flattened_range.start}'
)
non_flat_data_meta = torch.empty(*non_flat_local_shape, dtype=data.dtype, device='meta')
sh_ten = cls.from_rank_offsets(key, non_flat_data_meta, *args, **kwargs)
instance = replace(sh_ten, data=data, flattened_range=flattened_range)
instance.validate_metadata_integrity()
return instance
def init_data(self, device: Union[str, torch.device], init_fn=torch.empty):
if self.data is not None:
return
self.data = init_fn(self.local_shape, dtype=self.dtype, device=device)
if self.flattened_range is not None:
self.data = self.data.flatten()[self.flattened_range.start : self.flattened_range.stop]
def __str__(self):
return f'{self.__class__.__name__}(key=\'{self.key}\')'
def is_main_replica(replica_id: ReplicaId):
""" Checks if given `replica_id` is considered as main.
"Main" replica is:
- integer 0
- or an iterable with all 0 elements
It is the application responsibility to set correct replicas for sharded tensors.
Args:
replica_id (Union[int, Tuple[int, ...]]): replica id
Returns:
(bool): True for a "main" replica
"""
if isinstance(replica_id, int):
return replica_id == 0
return all(r == 0 for r in replica_id)
class LocalNonpersitentObject:
"""Object that should not be stored in a checkpoint, but restored locally.
Wrapping any object inside the state dict with LocalNonpersitentObject
will result in:
- during saving, this object will *not* be stored in the checkpoint
- during loading, a local version of this object will be placed in a state dict
"""
def __init__(self, obj):
self.obj = obj
def unwrap(self):
return self.obj
@dataclass
class ShardedObject(ShardedBase):
"""Represents a mapping between a local object and a global object.
Global object is assumed to consist of many local objects distributed
between different processes.
NOTE: Contrary to ShardedTensor, it's impossible to change global object
sharding. Conceptually, ShardedObject is a fully-sharded ShardedTensor
with atomic arbitrary typed elements.
Args:
key: unique identifier of a global tensor
data: local object data. Can be None only for consistency validation
global_shape: global object shape
global_offset: offset of a local object in a global object, specified in number of shards
replica_id: indicates local object replication wrt. local objects in different processes
"""
key: str
data: object
global_shape: Tuple[int, ...]
global_offset: Tuple[int, ...]
replica_id: ReplicaId = 0
def __post_init__(self):
self.validate_metadata_integrity()
def validate_metadata_integrity(self):
if len(self.global_shape) != len(self.global_offset):
raise CheckpointingException(
f'Global offset dimensions should be equal to global shape dimensions for {self}'
)
def without_data(self):
return replace(self, data=None)
@property
def unique_key(self):
return f'{self.key}/shard_{".".join(map(str, self.global_offset))}_{".".join(map(str, self.global_shape))}'
def __str__(self):
return f'{self.__class__.__name__}(key=\'{self.key}\')'
@dataclass
class ShardedTensorFactory(ShardedBase):
""" Allows to apply transformations to tensors before/after serialization.
The essence of those transformations is that they can be applied to
optimizer states the same way they are applied to the model params.
The ultimate state dict with sharded tensors must depend functionally on
`build_fn` arguments (key, data, replica_id, flattened_range),
which will be provided by the optimizer.
Builder creates a sub-state-dict out of a tensor before saving, and merger
merges the corresponding state dict after loading.
Args:
key (str): unique identifier of the factory
data (torch.Tensor): original model parameter that will be further transformed by this factory
build_fn (callable): function that transforms the original tensor to a sharded state dict
merge_fn (callable): function that transforms loaded subtree back into a single tensor (inverse of `build_fn`)
replica_id (ReplicaId): indicates factory replication wrt. factories in different processes
flattened_range (slice, optional): indicates additional flattening applied to the ShardedTensors produced by the factory
"""
key: str
data: torch.Tensor
build_fn: Callable[[str, torch.Tensor, ReplicaId, Optional[slice]], ShardedStateDict]
merge_fn: Callable[[StateDict], torch.Tensor]
replica_id: ReplicaId = 0
flattened_range: Optional[slice] = None
def build(self):
return self.build_fn(self.key, self.data, self.replica_id, self.flattened_range)
def validate_metadata_integrity(self):
"""No reasonable checks can be applied"""
pass
def apply_factories(sharded_state_dict: ShardedStateDict):
""" Turn ShardedTensorFactories into ShardedTensors *in-place*.
Args:
sharded_state_dict (ShardedStateDict): state dict possibly containing ShardedTensorFactory objects
Returns:
None: state dict is modified in place
"""
def apply(x):
if isinstance(x, ShardedTensorFactory):
x = x.build()
return x
dict_list_map_inplace(apply, sharded_state_dict)
def apply_factory_merges(
x1: StateDict, x2: ShardedStateDict, key: Tuple[str, ...] = ()
) -> StateDict:
""" Apply merges defined by ShardedTensorFactories *in-place*.
Args:
x1 (StateDict): state dict loaded from the checkpoint
x2 (ShardedStateDict): subset of `x1` (in terms of dict keys) with ShardedTensorFactory
as (possibly nested) values that define how to merge objects from the `x1` state dict
key (Tuple[str, ...]): current key in a recursive call. Used only for reporting meaningful errors
Returns:
StateDict: `x1` modified in-place
"""
if isinstance(x2, ShardedTensorFactory):
return x2.merge_fn(x1)
# There rest is almost the same as the `merge` function from `dict_utils`
if isinstance(x1, dict) and isinstance(x2, dict):
for k, v2 in x2.items():
if k not in x1:
raise ValueError(
f'Different dict keys encountered in `apply_factory_merges` ({x1.keys()} vs {x2.keys()})'
)
else:
x1[k] = apply_factory_merges(x1[k], v2, key=key + (k,))
elif isinstance(x1, list) and isinstance(x2, list):
if len(x1) != len(x2):
err_msg = f'Cannot merge two lists with different lengths ({len(x1)} and {len(x2)}, encountered at key {key})'
logger.error(err_msg + f'\nx1: {x1}\nx2: {x2}')
raise ValueError(err_msg)
for i, v2 in enumerate(x2):
x1[i] = apply_factory_merges(x1[i], v2, key=key + (i,))
elif isinstance(x1, list) and isinstance(x2, dict):
for k, v2 in x2.items():
if not isinstance(k, int):
raise ValueError(
f'Invalid dict key {k} non-integer type encountered in a list-dict merge at level {key}'
)
if k >= len(x1):
raise ValueError(
f'Dict key {k} out of bound for list of length {len(x1)} (encountered at level {key})'
)
x1[k] = apply_factory_merges(x1[k], v2, key=key + (k,))
else:
raise ValueError(
f'Duplicate non-dict and non-list values encountered: `{x1}` and `{x2} (at key {key})`'
)
return x1
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Helpers for defining sharding for optimizer states based on existing sharding for model parameters. """
import logging
from copy import deepcopy
from dataclasses import replace
from itertools import chain
from typing import Dict, Iterable, List, Tuple, Union
logger = logging.getLogger(__name__)
import torch
from .dict_utils import nested_values
from .mapping import (
LocalNonpersitentObject,
ShardedStateDict,
ShardedTensor,
ShardedTensorFactory,
StateDict,
)
from .utils import extract_sharded_tensors_and_factories
def get_optim_param_to_id_map(optim_params_iter: Iterable[torch.nn.Parameter]) -> Dict[int, int]:
param_mappings = {}
for i, param in enumerate(optim_params_iter):
if id(param) not in param_mappings:
param_mappings[id(param)] = i
return param_mappings
def get_param_id_to_sharded_param_map(
model_sharded_state_dict: ShardedStateDict, optim_params_iter: Iterable[torch.nn.Parameter]
) -> Dict[int, Union[ShardedTensor, ShardedTensorFactory]]:
""" Generate mapping from optimizer state ids to model sharded parameters.
Args:
model_sharded_state_dict: sharded state dict with all model sharded tensors (can have any structure)
optim_params_iter: iterable which iterates over model parameters tracked by the optimizer.
The iteration must be in the same order as in the optimizer parameters.
Returns:
Dict[int, Union[ShardedTensor, ShardedTensorFactory]]: mapping from optimizer state ids
to model sharded parameters.
"""
model_sharded_state_dict, _ = extract_sharded_tensors_and_factories(model_sharded_state_dict)
id_to_sharded_param_map = {}
param_to_id_map = get_optim_param_to_id_map(optim_params_iter)
for ten in nested_values(model_sharded_state_dict):
if id(ten.data) in param_to_id_map:
id_to_sharded_param_map[param_to_id_map[id(ten.data)]] = ten
else:
logger.debug(f'{ten} is not tracked by the optimizer')
if not id_to_sharded_param_map:
logger.warning(
"Sharded parameters mapping is empty. It means tensors in model state dict"
" do not correspond to tensors in optimizer parameters map."
" Make sure to call state_dict with `keep_vars=True`."
)
return id_to_sharded_param_map
def make_sharded_optimizer_tensor(
model_param: Union[ShardedTensor, ShardedTensorFactory], optim_param: torch.Tensor, prefix: str
) -> Union[ShardedTensor, ShardedTensorFactory]:
""" Build a ShardedTensor or ShardedTensorFactory for optimizer param based on model param
Args:
model_param (Union[ShardedTensor, ShardedTensorFactory]): model param
optim_param (torch.Tensor): corresponding optimizer param
prefix (str): optimizer prefix for the ShardedTensor or ShardedTensorFactory
Returns:
Union[ShardedTensor, ShardedTensorFactory]: wrapped optimizer parameter
"""
if isinstance(model_param, ShardedTensorFactory):
return replace(model_param, key=f'{prefix}.{model_param.key}', data=optim_param)
assert (
tuple(optim_param.shape) == model_param.local_shape
), f'Optimizer shape ({tuple(optim_param.shape)} does not match model shape ({model_param.local_shape})'
sh_ten = replace(
model_param, key=f'{prefix}.{model_param.key}', data=optim_param, dtype=optim_param.dtype
)
sh_ten.validate_metadata_integrity()
return sh_ten
def optim_state_to_sharding_state(
optim_state_dict: StateDict,
id_to_sharded_param_map: Dict[int, ShardedTensor],
exclude_keys: Tuple[str] = (),
):
""" Turn optimizer state dict to sharded state dict based on model state dict *in-place*.
Can be used to add sharding information to most common optimizer state dict.
Creates separate ShardedTensors for each key in `optim_state_dict['state']`
(e.g. for torch.optim.Adam there will be separate tensors for `exp_avg` and `exp_avg_sq`)
Args:
optim_state_dict (StateDict): optimizer state dict with
state parameters under `state` key and group hyperparameters under `param_groups` -> `params` key.
id_to_sharded_param_map (Dict[int, ShardedTensor]): mapping from optimizer param ids to model sharded tensors.
Can be generated with `get_param_id_to_sharded_param_map` function
exclude_keys (Tuple[str]): optimizer state keys to exclude from the final state dict.
Returns:
None: state dict is modified in place
"""
sharded_state = {}
for param_id, param_state in optim_state_dict['state'].items():
sharded_state[param_id] = {}
for state_key, param in param_state.items():
if state_key in exclude_keys:
continue
if param_id in id_to_sharded_param_map:
sharded_state[param_id][state_key] = make_sharded_optimizer_tensor(
id_to_sharded_param_map[param_id], param, prefix=f'optimizer.state.{state_key}'
)
else:
raise ValueError(f'Param id {param_id} does not match any model sharded param')
optim_state_dict['param_groups'] = deepcopy(optim_state_dict['param_groups'])
for group in optim_state_dict['param_groups']:
group['params'] = LocalNonpersitentObject(group['params'])
optim_state_dict['state'] = sharded_state
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Entrypoints for saving and loading the distributed checkpoints.
Functions `load` and `save` are equivalents of `torch.load` and `torch.save`
but expect torch.Tensors to be wrapped with classes from the `mapping module`.
Additionally, `load` expects the sharded state dict argument as a guidance for loading the sharded tensors.
"""
import logging
import os
from collections import Counter, defaultdict
from itertools import chain
from pathlib import Path
from typing import Iterable, List, Optional, Tuple, Union
import numpy as np
import torch
from .core import CheckpointingConfig, maybe_load_config, save_config
from .dict_utils import (
dict_list_map_inplace,
diff,
extract_matching_values,
map_reduce,
merge,
nested_values,
)
from .mapping import (
CheckpointingException,
ShardedObject,
ShardedStateDict,
ShardedTensor,
ShardedTensorFactory,
StateDict,
apply_factories,
apply_factory_merges,
is_main_replica,
)
from .strategies.async_utils import AsyncRequest
from .strategies.base import (
AsyncSaveShardedStrategy,
LoadCommonStrategy,
LoadShardedStrategy,
SaveCommonStrategy,
SaveShardedStrategy,
StrategyAction,
get_default_strategy,
)
from .utils import (
extract_nonpersistent,
extract_sharded_base,
extract_sharded_tensors,
extract_sharded_tensors_or_nonpersistent,
)
COMMON_STATE_FNAME = 'common.pt'
logger = logging.getLogger(__name__)
def load(
sharded_state_dict: ShardedStateDict,
checkpoint_dir: str,
sharded_strategy: Union[LoadShardedStrategy, Tuple[str, int], None] = None,
common_strategy: Union[LoadCommonStrategy, Tuple[str, int], None] = None,
validate_access_integrity: bool = True,
) -> StateDict:
"""Loading entrypoint.
In the steps below, the following verbs refer to corresponding objects:
- load = load from checkpoint
- extract = extract from sharded_state_dict
- add = add to the final state dict
Steps:
1. Load common state dict and form the base of the result state dict
2. Apply factories to sharded_state_dict
3. Extract LocalNonPersistentObject and add
4. (optional) Extract ShardedObjects, load and add
5. Extract ShardedBase, load, apply factory merges and add
Args:
sharded_state_dict (ShardedStateDict): state dict of the existing model
populated with ShardedTensors. Used as a mapping to determine which
parts of global tensors stored in the checkpoint should be loaded.
checkpoint_dir (str): directory with the checkpoint
sharded_strategy (LoadShardedStrategy, Tuple[str, int], optional): configures loading behavior for sharded tensors
common_strategy (LoadCommonStrategy, Tuple[str, int], optional): configures loading behavior for common data
validate_access_integrity (bool default = True): checks if each tensor shard is accessed
exactly once (as main replica) by some process
"""
if common_strategy is not None:
raise NotImplementedError('The only supported common strategy is torch')
sharded_strategy = _verify_checkpoint_and_load_strategy(checkpoint_dir, sharded_strategy)
checkpoint_dir = Path(checkpoint_dir)
common_state_dict = load_common_state_dict(checkpoint_dir)
if not sharded_state_dict:
return common_state_dict
# Create a copy of sharded_state_dict as the passed in state dict may have
# references that prevent tensors from being deallocated
sharded_state_dict, _ = extract_matching_values(sharded_state_dict, lambda x: True)
sh_ten_factories, _ = extract_matching_values(
sharded_state_dict,
lambda x: isinstance(x, ShardedTensorFactory),
return_lists_as_dicts=True,
)
apply_factories(sharded_state_dict)
# Data inside sh_ten_factories no longer needed so delete them to reduce memory usage
def unlink_data(x):
x.data = None
return x
dict_list_map_inplace(unlink_data, sh_ten_factories)
# Non-persistent objects
nonpersistent_state_dict, sharded_state_dict = extract_nonpersistent(sharded_state_dict)
dict_list_map_inplace(lambda o: o.unwrap(), nonpersistent_state_dict)
merge(common_state_dict, nonpersistent_state_dict)
# Sharded base
if not sharded_strategy.can_handle_sharded_objects:
# TODO: implement is a part of common strategy
sharded_objects, sharded_state_dict = load_sharded_objects(
sharded_state_dict, checkpoint_dir
)
merge(common_state_dict, sharded_objects)
sharded_state_dict, _ = extract_sharded_base(sharded_state_dict)
if validate_access_integrity:
validate_sharding_integrity(nested_values(sharded_state_dict))
loaded_state_dict = sharded_strategy.load(sharded_state_dict, checkpoint_dir)
loaded_state_dict = apply_factory_merges(loaded_state_dict, sh_ten_factories)
merge(common_state_dict, loaded_state_dict)
return common_state_dict
def _verify_checkpoint_and_load_strategy(
checkpoint_dir: str, sharded_strategy: Union[LoadShardedStrategy, Tuple[str, int], None] = None,
) -> LoadShardedStrategy:
""" Verifies if checkpoint metadata exists and matches given strategy.
Args:
checkpoint_dir (str): checkpoint directory
sharded_strategy (LoadShardedStrategy, Tuple[str, int], optional): load strategy to be verified
if compatible with the checkpoint content. If None, the default load strategy
for the checkpoint backend will be returned.
"""
if not Path(checkpoint_dir).exists():
raise CheckpointingException(f'Checkpoint directory {checkpoint_dir} does not exist')
saved_config = maybe_load_config(checkpoint_dir)
if saved_config is None:
raise CheckpointingException(f'{checkpoint_dir} is not a distributed checkpoint')
if sharded_strategy is None:
sharded_strategy = get_default_strategy(
StrategyAction.LOAD_SHARDED,
saved_config.sharded_backend,
saved_config.sharded_backend_version,
)
elif isinstance(sharded_strategy, tuple):
sharded_strategy = get_default_strategy(StrategyAction.LOAD_SHARDED, *sharded_strategy)
# TODO: implement consistency checks here
return sharded_strategy
# TODO: implement it as common torch strategy
def load_common_state_dict(checkpoint_dir: Path) -> StateDict:
""" Load common (non-sharded) objects state dict from the checkpoint.
Args:
checkpoint_dir (Path): checkpoint directory
Returns:
StateDict: state dict with non-sharded objects from the checkpoint
"""
load_path = Path(checkpoint_dir) / COMMON_STATE_FNAME
try:
return torch.load(load_path, map_location='cpu')
except FileNotFoundError as e:
err_msg = f'Common file {load_path} does not exist'
ckpt_files = [f.name for f in checkpoint_dir.iterdir()]
logger.debug(f'{err_msg}. Checkpoint directory content: {ckpt_files}')
raise CheckpointingException(err_msg) from e
def load_sharded_objects(sharded_state_dict: ShardedStateDict, checkpoint_dir: Path):
""" Replaces all ShardedObject from a given state dict with values loaded from the checkpoint.
Args:
sharded_state_dict (ShardedStateDict): sharded state dict defining what objects should be loaded.
checkpoint_dir (Path): checkpoint directory
Returns:
None: state dict is modified in place
"""
sharded_objects, sharded_state_dict = extract_matching_values(
sharded_state_dict, lambda v: isinstance(v, ShardedObject)
)
def load_sharded_object(sh_obj: ShardedObject):
sh_obj.data = None
load_path = (checkpoint_dir / sh_obj.unique_key).with_suffix('.pt')
try:
loaded_obj = torch.load(load_path)
except FileNotFoundError as e:
err_msg = f'Object shard {load_path} not found'
obj_subdir = checkpoint_dir / sh_obj.key
if obj_subdir.exists():
obj_files = [f.name for f in obj_subdir.iterdir()]
logger.debug(f'{err_msg}. Object {sh_obj.key} directory content: {obj_files}')
else:
ckpt_files = [f.name for f in checkpoint_dir.iterdir()]
logger.debug(
f'{err_msg}. Object {sh_obj.key} directory does not exist. Checkpoint directory content: {ckpt_files}'
)
raise CheckpointingException(err_msg) from e
return loaded_obj
return dict_list_map_inplace(load_sharded_object, sharded_objects), sharded_state_dict
def load_tensors_metadata(
checkpoint_dir: str, sharded_strategy: Union[LoadShardedStrategy, None] = None
) -> ShardedStateDict:
"""Load tensors metadata from the checkpoint.
Returns a dictionary similar to a sharded state dict, but note that
the dictionary keys are simply ShardedTensor keys (contrary to the
actual sharded state dicts where keys correspond to state dict keys).
Dict values are ShardedTensors without any sharding (so, the only useful
information is tensors global shape and dtype).
Concrete implementation depends on the loading strategy. If no strategy is
given, a default for a given backend is used.
"""
sharded_strategy = _verify_checkpoint_and_load_strategy(checkpoint_dir, sharded_strategy)
return sharded_strategy.load_tensors_metadata(Path(checkpoint_dir))
def load_plain_tensors(checkpoint_dir: str):
"""Load checkpoint tensors without any sharding.
NOTE: common state dict is NOT included."""
sharded_state_dict = load_tensors_metadata(checkpoint_dir)
# Don't validate integrity because shards will be overlapped
# if world_size > 1 (all processes load whole tensors)
return load(sharded_state_dict, checkpoint_dir, validate_access_integrity=False)
def save(
sharded_state_dict: ShardedStateDict,
checkpoint_dir: str,
sharded_strategy: Union[SaveShardedStrategy, Tuple[str, int], None] = None,
common_strategy: Union[SaveCommonStrategy, Tuple[str, int], None] = None,
validate_access_integrity: bool = True,
async_sharded_save: bool = False,
) -> Optional[AsyncRequest]:
"""Saving entrypoint.
Extracts ShardedTensors from the given state dict. Rank 0 saves the
"regular" part of the checkpoint to common torch file.
The ShardedTensors are saved according to a strategy specified by the
config.
Steps:
1. Apply factories
2. Extract and discard LocalNonPersistentObject
3. Extract all ShardedBase object
4. Save all other objects to common.pt
5. (optional) Extract and save ShardedObjects
6. Save all ShardedBase objects
7. Write metadata.json file with backend and version metadata.
Step (6) can be performed asynchronously (see `async_sharded_save`), in this
case the actual save is embodied in the returned async request and can be
scheduled by the external caller. For async request, step (7) is added as
one of the finalization functions, so that metadata.json is written only
if the checkpoint is complete.
Args:
sharded_state_dict (ShardedStateDict): state dict of the populated with
ShardedTensors. Used as a mapping to determine how local tensors
should be saved as global tensors in the checkpoint.
checkpoint_dir (str): directory to save the checkpoint to
sharded_strategy (SaveShardedStrategy, Tuple[str, int], optional): configures sharded tensors saving behavior and backend
common_strategy (SaveCommonStrategy, Tuple[str, int], optional): configures common data saving behavior and backend
validate_access_integrity (bool default = True): checks if each tensor shard is accessed
exactly once (as main replica) by some process
async_sharded_save (bool, optional): if True, for the sharded state dict part
an async save implementation will be called, with the AsyncRequest
being returned to the caller. Note that it is the caller responsibility to
actually schedule the async save. Defaults to False.
Returns:
AsyncRequest (optional): if `async_sharded_save` is True, returns
async request that should be scheduled by the caller of this function.
None otherwise.
"""
checkpoint_dir = Path(checkpoint_dir)
if torch.distributed.get_rank() == 0:
if not checkpoint_dir.exists():
raise CheckpointingException(
f'Checkpoint destination directory does not exist: {checkpoint_dir}'
)
if next(checkpoint_dir.iterdir(), None) is not None:
raise CheckpointingException(
f'Checkpoint destination directory ({checkpoint_dir}) is not empty'
)
if common_strategy is not None:
raise NotImplementedError('The only supported common strategy is torch')
if sharded_strategy is None:
sharded_strategy = get_default_save_sharded_strategy()
if not isinstance(sharded_strategy, SaveShardedStrategy):
assert isinstance(sharded_strategy, tuple), type(sharded_strategy)
sharded_strategy = get_default_strategy(StrategyAction.SAVE_SHARDED, *sharded_strategy)
apply_factories(sharded_state_dict)
_, sharded_state_dict = extract_nonpersistent(sharded_state_dict)
sharded_state_dict, state_dict = extract_sharded_base(sharded_state_dict)
_save_common_dict(state_dict, checkpoint_dir, True)
if validate_access_integrity:
validate_sharding_integrity(list(nested_values(sharded_state_dict)))
if not sharded_strategy.can_handle_sharded_objects:
# TODO: implement is a part of common strategy
sharded_state_dict = _extract_and_save_sharded_objects(
sharded_state_dict, checkpoint_dir, validate_access_integrity
)
def metadata_finalize_fn():
if torch.distributed.get_rank() == 0:
save_config(
CheckpointingConfig(sharded_strategy.backend, sharded_strategy.version),
checkpoint_dir,
)
torch.distributed.barrier()
if not async_sharded_save:
sharded_strategy.save(sharded_state_dict, checkpoint_dir)
metadata_finalize_fn()
return
if not isinstance(sharded_strategy, AsyncSaveShardedStrategy):
raise CheckpointingException(
f'Cannot apply async_save to non-async strategy {sharded_strategy}'
)
async_request = sharded_strategy.async_save(sharded_state_dict, checkpoint_dir)
async_request.finalize_fns.append(metadata_finalize_fn)
return async_request
def get_default_save_sharded_strategy(
backend: str = 'torch_dist', version: int = 1
) -> SaveShardedStrategy:
return get_default_strategy(StrategyAction.SAVE_SHARDED, backend, version)
def get_default_load_sharded_strategy(checkpoint_dir: str) -> LoadShardedStrategy:
return _verify_checkpoint_and_load_strategy(checkpoint_dir)
# TODO: implement it as common torch strategy
def _save_common_dict(
state_dict: StateDict, checkpoint_dir: Path, validate_consistency: bool = False
):
if torch.distributed.get_rank() == 0:
torch.save(state_dict, checkpoint_dir / COMMON_STATE_FNAME)
if validate_consistency:
# TODO: implement checking consistency with rank 0 common dict on other ranks
pass
# torch.distributed.barrier()
# if not torch.distributed.get_rank() == 0:
# rank_0_state_dict = torch.load(checkpoint_dir / COMMON_STATE_FNAME)
# print(diff(common_state_dict, rank_0_state_dict))
def _extract_and_save_sharded_objects(
state_dict: StateDict, checkpoint_dir: Path, validate_consistency: bool = False
):
sharded_objects, state_dict = extract_matching_values(
state_dict, lambda v: isinstance(v, ShardedObject)
)
sharded_objects = list(nested_values(sharded_objects))
for sh_obj in sharded_objects:
if is_main_replica(sh_obj.replica_id):
save_path = (checkpoint_dir / sh_obj.unique_key).with_suffix('.pt')
os.makedirs(save_path.parent, exist_ok=True)
torch.save(sh_obj.data, save_path)
return state_dict
def validate_sharding_integrity(sharded_tensors: Iterable[ShardedTensor]):
""" Validate if the ShardedTensors from multiple processes define correct sharding of a global tensor.
Local ShardedTensors metadata is exchanged with `torch.distributed.all_gather_object`
and then process with global rank 0 checks if main replicas of the shards:
- cover the whole global tensors
- don't overlap
Args:
sharded_tensors (Iterable[ShardedTensor]): sharded tensors local to this process
Returns:
None
Raises:
CheckpointingException for invalid access pattern
"""
sharding = [ten.without_data() for ten in sharded_tensors]
all_sharding = [None] * torch.distributed.get_world_size()
torch.distributed.all_gather_object(all_sharding, sharding)
if torch.distributed.get_rank() != 0:
return
key_shardings = defaultdict(list)
for rank, rank_shardings in enumerate(all_sharding):
for sharding in rank_shardings:
key_shardings[sharding.key].append((rank, sharding))
for key, shardings in key_shardings.items():
if isinstance(shardings[0][1], ShardedObject):
_validate_objects_for_key(shardings)
else:
_validate_sharding_for_key(shardings)
def _validate_sharding_for_key(rank_sharding: List[Tuple[int, ShardedTensor]]):
some_rank_shard = rank_sharding[0][1]
global_shape = some_rank_shard.global_shape
local_shape = some_rank_shard.local_shape
dtype = some_rank_shard.dtype
has_flattened_range = some_rank_shard.flattened_range is not None
for rank, sharding in rank_sharding:
assert sharding.dtype == dtype, (sharding.dtype, dtype, some_rank_shard)
assert sharding.global_shape == global_shape, (
sharding.global_shape,
global_shape,
some_rank_shard,
)
assert sharding.local_shape == local_shape, (
sharding.local_shape,
local_shape,
some_rank_shard,
)
assert (sharding.flattened_range is not None) == has_flattened_range, (
(sharding.flattened_range is not None),
has_flattened_range,
some_rank_shard,
)
shard_access_cnt = _compute_shards_access(rank_sharding)
if has_flattened_range:
map_reduce(
rank_sharding,
lambda x: x[1].global_offset,
lambda x: x[1],
_validate_sharding_for_key_flattened,
)
else:
if not torch.all(shard_access_cnt == 1):
logger.error(f'Invalid access pattern for {rank_sharding[0][1]}: {shard_access_cnt}')
raise CheckpointingException(f'Invalid access pattern for {rank_sharding[0][1]}')
def _compute_shards_access(rank_sharding):
shard_access_cnt = torch.zeros(
rank_sharding[0][1].axis_fragmentations, dtype=torch.int, device='cpu'
)
for rank, sharding in rank_sharding:
if is_main_replica(sharding.replica_id):
shard_access_cnt[sharding.local_chunk_offset_in_global()] += 1
# TODO: consider validating different replicas too
return shard_access_cnt
def _validate_sharding_for_key_flattened(tensors_by_shard):
all_slices = []
local_shape = tensors_by_shard[0].local_shape
for sharding in tensors_by_shard:
assert sharding.local_shape == local_shape
sharding: ShardedTensor
if not is_main_replica(sharding.replica_id):
# TODO: this checks only saving (and loading replica_id=0) consistency
continue
all_slices.append((sharding.flattened_range.start, sharding.flattened_range.stop))
starts, stops = map(np.asarray, zip(*sorted(all_slices)))
if (
starts[0] != 0
or stops[-1] != np.product(local_shape)
or not np.all(starts[1:] == stops[:-1])
):
logger.error(
f'Flattened ranges dont cover the whole shard {tensors_by_shard[0]}. Ranges: {(starts, stops)}'
)
raise CheckpointingException(
f'Flattened ranges dont cover the whole shard {tensors_by_shard[0]}'
)
def _validate_objects_for_key(sharded_objects: List[ShardedObject]):
""" Ensure uniqueness of saved objects. """
unique_keys = [
sh_obj.unique_key for _, sh_obj in sharded_objects if is_main_replica(sh_obj.replica_id)
]
if len(unique_keys) != len(set(unique_keys)):
duplicates = {k: cnt for k, cnt in Counter(unique_keys).items() if cnt > 1}
logger.error(f'Duplicate ShardedObject keys and counts: {duplicates}')
raise CheckpointingException(f'Duplicate ShardedObject keys: {list(duplicates.keys())}')
expected_shard_num = np.prod(sharded_objects[0][1].global_shape)
if len(unique_keys) != expected_shard_num:
err_msg = f'Invalid access pattern: {expected_shard_num - len(unique_keys)} ShardedObject are missing.'
logger.error(f'{err_msg} Existing shards: {unique_keys}')
raise CheckpointingException(err_msg)
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Various loading and saving strategies """
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment