Commit d520d24f authored by silencealiang's avatar silencealiang
Browse files

Merge branch 'main' into 'main'

megatron升级v0.10

See merge request OpenDAS/megatron-lm!3
parents 3aca1415 481609bb
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Multi-process & multi-node version of Faiss's index.add(). """Multi-process & multi-node version of Faiss's index.add().
...@@ -8,122 +8,159 @@ FaissBaseIndex. This allows 'add()' to scale out to very large datasets, since ...@@ -8,122 +8,159 @@ FaissBaseIndex. This allows 'add()' to scale out to very large datasets, since
the vast majority of the computational effort is embarrassingly parallel. the vast majority of the computational effort is embarrassingly parallel.
""" """
import numpy as np
import os import os
import psutil
import shutil import shutil
from typing import Tuple
import numpy as np
import psutil
import torch import torch
from tqdm import tqdm from tqdm import tqdm
from megatron import get_retro_args, print_rank_0 from megatron.core.datasets.retro.config import Embedder, RetroPreprocessingConfig
from tools.bert_embedding import BertEmbedder from megatron.core.datasets.retro.external_libs import faiss, h5py
from tools.bert_embedding.utils import get_missing_blocks_by_rank from megatron.core.datasets.retro.index.utils import get_added_code_paths, get_added_codes_dir
from tools.retro.external_libs import faiss, h5py from megatron.core.datasets.retro.utils import (
from tools.retro.index.utils import get_added_codes_dir, get_added_code_paths GPTToTextDataset,
get_blocks_by_rank,
log_retro_rank_0,
retro_makedir,
)
from .faiss_base import FaissBaseIndex from .faiss_base import FaissBaseIndex
class FaissParallelAddIndex(FaissBaseIndex): class FaissParallelAddIndex(FaissBaseIndex):
"""
def encode_block(self, index, embedder, text_dataset, block): This class parallelizes both 1) encoding vectors, and 2) adding codes to the
'''Encode sub-dataset block, to be later added to index. index. This class is more performant than naive use of Faiss, because most
of the computational work is in encoding the vectors, which is an
embarassingly parallel operation.
"""
def encode_block(
self, index: faiss.Index, embedder: Embedder, text_dataset: GPTToTextDataset, block: dict
) -> Tuple[np.ndarray, np.ndarray]:
"""Encode sub-dataset block, to be later added to index.
Encode the data subset, generally in blocks of 1M vectors each. For Encode the data subset, generally in blocks of 1M vectors each. For
each block, the empty/trained index is loaded, codes are computed each block, the empty/trained index is loaded, codes are computed
via index.sa_encode(), and the resulting codes are saved to disk. via index.sa_encode(), and the resulting codes are saved to disk.
'''
args = get_retro_args() Args:
index (faiss.Index): Faiss index object.
embedder (Embedder): Embedder used to embed text dataset.
text_dataset (GPTToTextDataset): Text dataset to be embedded and encoded.
block (dict): Range information specifying start/end indices within text dataset.
Returns:
A tuple of (embeddings, encodings) for the given block subset of the text dataset.
"""
# Embed block. # Embed block.
embeddings = self.embed_text_dataset_block( embeddings = self.embed_text_dataset_block(embedder, text_dataset, block["range"])
embedder,
text_dataset,
block["range"],
)
# Encode block. # Encode block.
print_rank_0("encode.") log_retro_rank_0("encode.")
codes = index.sa_encode(embeddings) codes = index.sa_encode(embeddings)
# Return embeddings for validation purposes.
return embeddings, codes
def save_block(self, config: RetroPreprocessingConfig, block: dict, codes: np.ndarray) -> None:
"""Save block of codes to disk.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
block (dict): Range information specifying the start/end indices within the encoded text dataset. Here, the 'path' item is used for writing the encodings to storage.
codes (np.ndarray): Block of encodings to be saved to storage.
"""
# Save neighbors. # Save neighbors.
print_rank_0("save codes.") log_retro_rank_0("save codes.")
os.makedirs(os.path.dirname(block["path"]), exist_ok=True) retro_makedir(config, os.path.dirname(block["path"]))
with h5py.File(block["path"], "w") as f: with h5py.File(block["path"], "w") as f:
f.create_dataset("data", data=codes) f.create_dataset("data", data=codes)
def encode(self, text_dataset): def encode(self, config: RetroPreprocessingConfig, text_dataset: GPTToTextDataset) -> None:
'''Encode text dataset, to be later added to index.''' """Encode text dataset, to be later added to index.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
text_dataset (GPTToTextDataset): Text dataset to be encoded by the index.
"""
args = get_retro_args() codes_dir = get_added_codes_dir(config)
codes_dir = get_added_codes_dir() retro_makedir(config, codes_dir)
# Index. # Index.
index = self.get_empty_index() index = self.get_empty_index(config)
# Bert embedder. # Bert embedder.
embedder = BertEmbedder(args.retro_bert_batch_size, embedder = config.retro_bert_embedders.mem
args.retro_bert_max_chunk_length,
args.bert_embedder_type)
# Missing code blocks. # Missing code blocks.
def validate(f): def validate(f: h5py.File) -> None:
"""Validation method for validating loaded encodings.
Args:
f (h5py.File): File that contains encodings.
"""
assert len(f["data"].shape) == 2 assert len(f["data"].shape) == 2
n_missing_blocks, missing_code_blocks = get_missing_blocks_by_rank(
codes_dir, blocks = get_blocks_by_rank(
len(text_dataset), codes_dir, len(text_dataset), config.retro_block_size, validate=validate
args.retro_block_size,
validate=validate,
) )
# Encode each block. # Encode each block.
for block_index, block in enumerate(missing_code_blocks): for block_index, block in enumerate(blocks.missing):
if block is not None: if block is not None:
# Progress. # Progress.
print_rank_0("encode block %d / %d ... %s." % ( log_retro_rank_0(
block_index, "encode block %d / %d ... %s."
len(missing_code_blocks), % (block_index, len(blocks.missing), block["path"])
block["path"], )
))
# Query block neighbors. # Encode and save.
self.encode_block(index, embedder, text_dataset, block) _, codes = self.encode_block(index, embedder, text_dataset, block)
self.save_block(config, block, codes)
# Synchronize progress across all ranks. (for easier observation) # Synchronize progress across all ranks. (for easier observation)
print_rank_0(" > waiting for other ranks to finish block.") log_retro_rank_0(" > waiting for other ranks to finish block.")
torch.distributed.barrier() torch.distributed.barrier()
def add_codes(self): def add_codes(self, config: RetroPreprocessingConfig) -> None:
"""Read codes from disk, and add them to the index.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
"""
if torch.distributed.get_rank() != 0: if torch.distributed.get_rank() != 0:
return return
added_index_path = self.get_added_index_path() added_index_path = self.get_added_index_path(config)
if os.path.exists(added_index_path): if os.path.exists(added_index_path):
return return
args = get_retro_args()
# Index. # Index.
print_rank_0("read empty index.") log_retro_rank_0("read empty index.")
index = self.get_empty_index() index = self.get_empty_index(config)
index_ivf = faiss.extract_index_ivf(index) index_ivf = faiss.extract_index_ivf(index)
# Add codes. # Add codes.
print_rank_0("add codes.") log_retro_rank_0("add codes.")
code_paths = get_added_code_paths() code_paths = get_added_code_paths(config)
pbar = tqdm(code_paths) pbar = tqdm(code_paths)
for code_path in pbar: for code_path in pbar:
pbar.set_description("add codes, mem %.3f gb, %.1f%%" % ( pbar.set_description(
psutil.virtual_memory()[3] / 1024**3, "add codes, mem %.3f gb, %.1f%%"
psutil.virtual_memory()[2], % (psutil.virtual_memory()[3] / 1024**3, psutil.virtual_memory()[2])
)) )
with h5py.File(code_path) as f: with h5py.File(code_path) as f:
nload = int(args.retro_index_add_load_fraction*f["data"].shape[0]) nload = int(config.retro_index_add_load_fraction * f["data"].shape[0])
offset = int(os.path.basename(code_path).split("-")[0]) offset = int(os.path.basename(code_path).split("-")[0])
xids = np.arange(offset, offset + nload) xids = np.arange(offset, offset + nload)
codes = np.copy(f["data"][:nload]) codes = np.copy(f["data"][:nload])
...@@ -133,30 +170,39 @@ class FaissParallelAddIndex(FaissBaseIndex): ...@@ -133,30 +170,39 @@ class FaissParallelAddIndex(FaissBaseIndex):
index.ntotal = index_ivf.ntotal index.ntotal = index_ivf.ntotal
# Write index. # Write index.
print_rank_0("write added index.") log_retro_rank_0("write added index.")
faiss.write_index(index, added_index_path) faiss.write_index(index, added_index_path)
def remove_codes(self): def remove_codes(self, config: RetroPreprocessingConfig) -> None:
'''Remove added codes after adding to index.''' """Remove added codes after adding to index.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
"""
if torch.distributed.get_rank() != 0: if torch.distributed.get_rank() != 0:
return return
assert os.path.isfile(self.get_added_index_path()) assert os.path.isfile(self.get_added_index_path(config))
args = get_retro_args() if config.retro_index_delete_added_codes:
if args.retro_index_delete_added_codes:
raise Exception("remove?") raise Exception("remove?")
shutil.rmtree(get_added_codes_dir(), ignore_errors=True) shutil.rmtree(get_added_codes_dir(config), ignore_errors=True)
def add(self, config: RetroPreprocessingConfig, text_dataset: GPTToTextDataset) -> None:
"""Add vectors to index.
def add(self, text_dataset): Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
text_dataset (GPTToTextDataset): Text dataset that will be embedded and added to the index.
"""
# Encode chunks. # Encode chunks.
self.encode(text_dataset) self.encode(config, text_dataset)
# Add codes to index. # Add codes to index.
self.add_codes() self.add_codes(config)
# Wait for (single-process) adding to complete. # Wait for (single-process) adding to complete.
torch.distributed.barrier() torch.distributed.barrier()
# Remove codes. # Remove codes.
self.remove_codes() self.remove_codes(config)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Utilities for building an index."""
import glob
import os
from typing import List, Tuple
from megatron.core.datasets.retro.config import RetroPreprocessingConfig
from megatron.core.datasets.retro.utils import retro_makedir
def get_index_dir(config: RetroPreprocessingConfig) -> str:
"""Create sub-directory for this index.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
Returns:
Path to index sub-directory within Retro project.
"""
# Directory path.
index_dir_path = os.path.join(
config.retro_project_dir, "index", config.retro_index_type, config.retro_index_str
)
# Make directory.
retro_makedir(config, index_dir_path)
return index_dir_path
def num_samples_to_block_ranges(
config: RetroPreprocessingConfig, num_samples: int
) -> List[Tuple[int, int]]:
"""Split a range (length num_samples) into sequence of block ranges
of size block_size.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
num_samples (int): Split `num_samples` into consecutive block ranges, where each block is size `config.retro_block_size`.
Returns:
A list of tuples where each item is the (start, end) index for a given block.
"""
block_size = config.retro_block_size
start_idxs = list(range(0, num_samples, block_size))
end_idxs = [min(num_samples, s + block_size) for s in start_idxs]
ranges = list(zip(start_idxs, end_idxs))
return ranges
def get_training_data_root_dir(config: RetroPreprocessingConfig) -> str:
"""Get root directory for embeddings (blocks and merged data).
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
Returns:
Path to the training data directory, which contains both training embedding blocks and the final merged training embeddings.
"""
return os.path.join(config.retro_project_dir, "index", "train_emb")
def get_training_data_block_dir(config: RetroPreprocessingConfig) -> str:
"""Get directory for of saved embedding blocks.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
Returns:
Path to the directory containing the training embedding blocks, which will be later merged into a single embedding array.
"""
return os.path.join(get_training_data_root_dir(config), "blocks")
def get_training_data_block_paths(config: RetroPreprocessingConfig) -> List[str]:
"""Get paths to saved embedding blocks.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
Returns:
Paths of all training embedding blocks.
"""
return sorted(glob.glob(get_training_data_block_dir(config) + "/*.hdf5"))
def get_training_data_merged_path(config: RetroPreprocessingConfig) -> str:
"""Get path to merged training embeddings.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
Returns:
Path to the merged training embedding binary file.
"""
return os.path.join(
get_training_data_root_dir(config),
"train_%.3f.bin" % config.retro_index_train_load_fraction,
)
def get_added_codes_dir(config: RetroPreprocessingConfig) -> str:
"""Get directory of saved encodings.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
Returns:
Path to the directory containing the vector encodings for adding to the index.
"""
return os.path.join(get_index_dir(config), "add_codes")
def get_added_code_paths(config: RetroPreprocessingConfig) -> List[str]:
"""Get paths to all saved encodings.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
Returns:
Paths of all vector encoding blocks, for adding to the index.
"""
return sorted(glob.glob(get_added_codes_dir(config) + "/*.hdf5"))
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Validate an index's data.
This module contains functionality for checking for bitwise equality across code
changes. The training and adding steps of index construction can be validated
separately. The following high-level checks are supported:
- Training: Validate that saved training embeddings are bitwise equal with a
sample set of freshly computed embeddings. (*Note*:
`--no-retro-index-delete-training-embeddings` must be used.)
- Adding: Validate that the saved encodings are bitwise equal with a sample of
sample set of freshly computed encodings. (*Note*:
`--no-retro-index-delete-added-codes` must be used.)
"""
import typing
import numpy as np
import torch
from torch.utils.data import Subset
from megatron.core.datasets.retro.config import RetroPreprocessingConfig
from megatron.core.datasets.retro.external_libs import h5py
from megatron.core.datasets.retro.utils import (
GPTToTextDataset,
get_blocks_by_rank,
log_retro_rank_0,
)
from .build import get_text_dataset_for_adding, get_text_dataset_for_training
from .factory import IndexFactory
from .utils import get_added_codes_dir, get_training_data_block_dir
##################################################
# Validate trained index.
##################################################
def validate_training_embeddings(config: RetroPreprocessingConfig) -> None:
"""Validate training embeddings.
Steps:
- Randomly sample subset of text dataset blocks.
- Embed each block.
- Compare against saved embeddings.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
"""
# Training text dataset.
text_dataset = get_text_dataset_for_training(config)
# Sample existing blocks.
blocks = get_blocks_by_rank(
dirname=get_training_data_block_dir(config),
n_samples=len(text_dataset),
block_size=config.retro_block_size,
validate=None,
sample=config.retro_task_validate,
)
assert blocks.n_missing_world == 0
# Embed & validate blocks.
embedder = config.retro_bert_embedders.mem
for block_idx, block in enumerate(blocks.existing):
# Missing block lists are extended with None to have equal-length
# lists. Skip the Nones.
if block is not None:
# Progress. (*note*: move world progress to here.)
log_retro_rank_0(
"embed training block %d / %d ... %s."
% (block_idx, len(blocks.existing), block["path"])
)
# Load existing block embeddings.
with h5py.File(block["path"]) as f:
existing_embeddings = np.copy(f["data"])
# Embed block.
sub_dataset = Subset(text_dataset, range(*block["range"]))
embeddings = embedder.embed_text_dataset(sub_dataset, "train")
# Check equality.
log_retro_rank_0(" > validate.")
assert np.array_equal(existing_embeddings, embeddings)
# Synchronize progress across all ranks. (for easier observation)
log_retro_rank_0(" > waiting for other ranks to finish block.")
torch.distributed.barrier()
log_retro_rank_0(" > finished validating training embeddings.")
##################################################
# Validate filled index.
##################################################
def validate_added_encodings(config: RetroPreprocessingConfig) -> None:
"""Validate added encodings.
Steps:
- Randomly sample subset of text dataset blocks.
- Encode each block.
- Compare against saved encodings.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
"""
# Index.
index = IndexFactory.get_index(config.retro_index_type)
inner_index = index.get_empty_index(config)
# Text dataset.
text_dataset = get_text_dataset_for_adding(config)
# Sample existing blocks.
def validate(f: h5py.File) -> None:
"""Validation method for validating encoding blocks.
Args:
f (h5py.File): File with block of encodings.
"""
assert len(f["data"].shape) == 2
blocks = get_blocks_by_rank(
dirname=get_added_codes_dir(config),
n_samples=len(text_dataset),
block_size=config.retro_block_size,
validate=validate,
sample=config.retro_task_validate,
)
assert blocks.n_missing_world == 0
# Encode and validate blocks.
embedder = config.retro_bert_embedders.mem
for block_idx, block in enumerate(blocks.existing):
if block is not None:
# Progress.
log_retro_rank_0(
"encode block %d / %d ... %s." % (block_idx, len(blocks.existing), block["path"])
)
# Load existing codes.
with h5py.File(block["path"]) as f:
existing_codes = np.copy(f["data"])
# Encode block.
embeddings, codes = index.encode_block(inner_index, embedder, text_dataset, block)
# Check equality.
log_retro_rank_0(" > validate.")
assert np.array_equal(existing_codes, codes)
# Synchronize progress across all ranks. (for easier observation)
log_retro_rank_0(" > waiting for other ranks to finish block.")
torch.distributed.barrier()
log_retro_rank_0(" > finished validating added encodings.")
##################################################
# Validate index (trained + filled).
##################################################
def validate_index(config: RetroPreprocessingConfig) -> None:
"""Validate index.
Validating index involves sequentially running stages above:
- Validate trained index.
- Validate filled index.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
"""
# Validate training embeddings.
validate_training_embeddings(config)
# Validate added codes.
validate_added_encodings(config)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""
A GPTChunkDataset is a wrapper around a regular GPTDataset, that sequentially
chunks the sample tokens into `retro_chunk_length` sized smaller samples.
For example, if the GPTDataset has 100 samples and a sequence length of 2048, and
retro_chunk_length is 64, then the GPTChunkDataset will contain 100*(2048/64) =
3200 samples, each with length 64.
"""
import torch
from megatron.core.datasets.gpt_dataset import GPTDataset
from megatron.core.datasets.retro.utils import get_num_chunks_per_sample
from .utils import get_neighbor_dir
class GPTChunkDataset(torch.utils.data.Dataset):
"""Pretraining chunk dataset wraps a standard GPT dataset.
This dataset conceptually divides each sample (e.g., length 2048)
into chunks (e.g., length 64) and restructures them into a list of
chunks (e.g., length num_samples * num_chunks_per_sample).
Args:
sample_dataset (GPTDataset): Original GPT dataset, with `sequence_length` size samples.
sample_length (int): Alias for `sequence_length`.
chunk_length (int): Retro chunk length (e.g., 64).
"""
def __init__(self, sample_dataset: GPTDataset, sample_length: int, chunk_length: int):
super().__init__()
self.sample_dataset = sample_dataset
self.chunk_length = chunk_length
self.n_chunks_per_sample = get_num_chunks_per_sample(sample_length, chunk_length)
self.n_samples = len(sample_dataset)
self.n_chunks = self.n_samples * self.n_chunks_per_sample
def __len__(self) -> int:
"""Get dataset length.
Returns:
Dataset length.
"""
return self.n_chunks
def __getitem__(self, idx: int) -> dict:
"""Get sample, including represented document IDs.
Args:
idx (int): Sample index.
Returns:
A sample, which contains both the chunk-length token sample ('text') along with all document_ids ('doc_ids') contained withing the full `sequence_length` sample.
"""
# Convert global chunk index to global sample index & local chunk index.
sample_idx = idx // self.n_chunks_per_sample
chunk_idx = idx % self.n_chunks_per_sample
# Extract sample data.
sample = self.sample_dataset[sample_idx]
sample_token_ids = sample["text"]
sample_doc_ids = sample["document_ids"]
# Chunk start/end token idxs.
token_start_idx = chunk_idx * self.chunk_length
token_end_idx = token_start_idx + self.chunk_length
chunk_token_ids = sample_token_ids[token_start_idx:token_end_idx]
# Sample.
return {"doc_ids": sample_doc_ids, "text": chunk_token_ids}
def build_gpt_chunk_datasets_from_gpt_datasets(
project_dir: str, gpt_datasets: dict, sample_length: int, chunk_length: int
) -> dict:
"""Get train, valid, test GPT chunk datasets.
Args:
project_dir (str): Retro project dir.
gpt_datasets (dict): Mapping of 'train', 'valid', and 'test' GPT datasets (original, unchunked datasets).
sample_length (int): Alias of `sequence_length`.
chunk_length (int): Retro chunk length (e.g., 64).
Returns:
A <dict> ?
"""
# GPT chunk datasets.
chunk_datasets = {
key: (
{
"dataset": GPTChunkDataset(sample_ds, sample_length, chunk_length),
"neighbor_dir": get_neighbor_dir(project_dir, key, sample_ds),
"num_active_chunks": num_active_samples
* get_num_chunks_per_sample(sample_length, chunk_length),
}
if sample_ds
else None
)
for key, (sample_ds, num_active_samples) in gpt_datasets.items()
}
return chunk_datasets
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""A MultiSplitGPTDataset can handle multiple intersecting split strings, as well
as returning all of the document IDs of a sample."""
import logging
from dataclasses import dataclass
from typing import Dict, List
import numpy
from megatron.core.datasets.blended_megatron_dataset_config import (
convert_split_vector_to_split_matrix,
parse_and_normalize_split,
)
from megatron.core.datasets.gpt_dataset import GPTDataset, GPTDatasetConfig
from megatron.core.datasets.indexed_dataset import IndexedDataset
from megatron.core.datasets.utils import Split
from megatron.core.utils import log_single_rank
logger = logging.getLogger(__name__)
@dataclass
class MultiSplitGPTDatasetConfig(GPTDatasetConfig):
"""Configuration object for Megatron Core blended and Retro datasets.
Args:
return_document_ids (bool): Whether to return the document ids when querying the dataset. Turn this option on during preprocessing.
split_preprocessing (str): The Retro preprocessing split string. It follows the same pattern convention as 'split'. Not to be used with 'blend_per_split'.
"""
return_document_ids: bool = None
split_preprocessing: str = None
def __post_init__(self) -> None:
"""Validate config attributes."""
super().__post_init__()
assert self.split is not None, "the Retro data pipeline does not support 'blend_per_split'"
assert self.return_document_ids is not None, "this attribute must be user defined"
assert self.split_preprocessing is not None, "this attribute must be user defined"
split_vector = parse_and_normalize_split(self.split)
split_preprocessing_vector = parse_and_normalize_split(self.split_preprocessing)
if not numpy.allclose(split_vector, split_preprocessing_vector):
self.split_matrix = convert_split_vector_to_split_matrix(
split_vector, split_preprocessing_vector
)
log_single_rank(
logger,
logging.WARNING,
f"split =/= split_preprocessing. Let split_matrix = {self.split_matrix}",
)
class MultiSplitGPTDataset(GPTDataset):
"""Retro's customized GPT dataset.
Args:
indexed_dataset (IndexedDataset): The IndexedDataset around which to build the MegatronDataset.
dataset_path (str): The real path on disk to the dataset, for bookkeeping.
indexed_indices (numpy.ndarray): The set of the documents indices to expose.
num_samples (int): The number of samples to draw from the indexed dataset.
index_split (Split): The indexed_indices Split.
config (MultiSplitGPTDatasetConfig): The Retro-specific container for all config sourced parameters.
"""
def __init__(
self,
indexed_dataset: IndexedDataset,
dataset_path: str,
indexed_indices: numpy.ndarray,
num_samples: int,
index_split: Split,
config: MultiSplitGPTDatasetConfig,
) -> None:
super().__init__(
indexed_dataset, dataset_path, indexed_indices, num_samples, index_split, config
)
def __getitem__(self, idx: int) -> Dict[str, numpy.ndarray]:
"""Get dataset sample.
Args:
idx (int): The index into the dataset.
Returns:
Dict[str, numpy.ndarray]: The text ids and (optionally) the document ids wrapped in a dictionary.
"""
text, document_ids = self._query_document_sample_shuffle_indices(idx)
if self.config.return_document_ids:
return {"text": text, "document_ids": document_ids}
else:
return {"text": text}
@staticmethod
def _key_config_attributes() -> List[str]:
"""Add custom attributes for building unique dataset hash.
The preprocessing split used for preprocessing will constrain the samples available for pretraining.
Returns:
List[str]: The key config attributes.
"""
return super(MultiSplitGPTDataset, MultiSplitGPTDataset)._key_config_attributes() + [
"split_preprocessing"
]
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Entry point for querying an index using a GPTChunkDataset.
Querying involves:
- Iterate all chunks in the GPTChunkDataset.
- Query index for neighbor chunk IDs (i.e., chunks from the chunk database).
- Save neighbor chunk IDs to disk, for use in building a RetroDataset sample
during pretraining.
"""
import os
import time
import typing
import numpy as np
import psutil
import torch
from tqdm import tqdm
from megatron.core.datasets.retro.config import RetroPreprocessingConfig
from megatron.core.datasets.retro.db.dataset import DBDataset
from megatron.core.datasets.retro.db.utils import (
get_merged_train_dataset as get_db_merged_train_dataset,
)
from megatron.core.datasets.retro.external_libs import faiss, h5py
from megatron.core.datasets.retro.index.factory import IndexFactory
from megatron.core.datasets.retro.index.index import Index
from megatron.core.datasets.retro.index.utils import get_index_dir
from megatron.core.datasets.retro.query.gpt_chunk_dataset import GPTChunkDataset
from megatron.core.datasets.retro.utils import (
GPTToTextDataset,
get_blocks_by_rank,
log_retro_rank_0,
retro_makedir,
)
from .gpt_chunk_dataset import build_gpt_chunk_datasets_from_gpt_datasets
def get_index(config: RetroPreprocessingConfig, ondisk: bool = False) -> faiss.Index:
"""Read index from disk.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
ondisk (bool): If `ondisk = True`, memory map the index. (For debugging purposes only; very non-performant.)
Returns:
A Faiss index, loaded from storage.
"""
# Load index.
index_wrapper = IndexFactory.get_index(config.retro_index_type)
index_dir = get_index_dir(config)
added_index_path = index_wrapper.get_added_index_path(config)
if ondisk:
index = faiss.read_index(added_index_path, faiss.IO_FLAG_MMAP)
else:
index = faiss.read_index(added_index_path)
# Search parameters.
faiss.ParameterSpace().set_index_parameter(index, "efSearch", config.retro_query_ef_search)
faiss.ParameterSpace().set_index_parameter(index, "nprobe", config.retro_query_nprobe)
return index
def embed_block(
config: RetroPreprocessingConfig, gpt_dataset: GPTChunkDataset, block: dict
) -> np.ndarray:
"""Embed block of chunks.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
gpt_dataset (GPTChunkDataset): Chunk dataset to be embedded.
block (dict): Range information containing start/end indices of subset of chunk dataset.
Returns:
Embeddings array, with shape (len(block["range"]), dimension(embedder)).
"""
text_block_dataset = torch.utils.data.Subset(
GPTToTextDataset(gpt_dataset, config.retro_tokenizers.gpt), range(*block["range"])
)
return config.retro_bert_embedders.mem.embed_text_dataset(text_block_dataset)
def query_embeddings(
config: RetroPreprocessingConfig,
db_dataset: DBDataset,
index: Index,
embeddings: np.ndarray,
chunk_id_range: range,
sample_map: dict,
n_chunks_per_sample: int,
verbose: bool = True,
) -> typing.Tuple[np.ndarray, np.ndarray]:
"""Query neighbors of a block of embeddings.
Querying includes:
- Query index for neighbor chunk IDs.
- Filter chunk IDs that have the same document ID as the queried embedding.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
db_dataset (DBDataset): Dataset containing chunk database entries.
index (Index): Vector index populated with chunk database indices.
embeddings (np.ndarray): Embeddings from GPT chunk dataset.
chunk_id_range (range): Chunk ID range from GPT chunk dataset.
sample_map (dict): Mapping of sample_idx to dataset_idx and document_ids. Used for document filtering.
n_chunks_per_sample (int): Number of chunks per sample (e.g., sequence_length / chunk_length).
verbose (bool): Log querying progress.
Returns:
A tuple of original (unfiltered) neighbor IDs, and filtered (by document ID) neighbor IDs.
"""
# Query neighbor ids.
if verbose:
log_retro_rank_0("search.")
t = time.time()
assert index.ntotal > 0, "check we don't accidentally have an empty index."
_, query_neighbor_ids = index.search(embeddings, config.retro_query_num_neighbors_query)
if verbose:
log_retro_rank_0(" time : %.3f sec." % (time.time() - t))
# Filter banned neighbor ids.
if verbose:
log_retro_rank_0("filter banned neighbor ids.")
filtered_neighbor_ids = np.full(
shape=(len(query_neighbor_ids), config.retro_query_num_neighbors_save),
fill_value=-1,
dtype="int64",
)
min_chunk_id, max_chunk_id = chunk_id_range
for chunk_id in range(min_chunk_id, max_chunk_id):
sample_id = chunk_id // n_chunks_per_sample
sample = sample_map[sample_id]
sample_dataset_idx = sample["dataset_idx"].item()
sample_doc_ids = sample["doc_ids"].tolist()
sample_doc_tuples = [(sample_dataset_idx, d) for d in sample_doc_ids]
# Get valid neighbors (!= -1).
query_row = [i for i in query_neighbor_ids[chunk_id - min_chunk_id] if i >= 0]
# Filter row.
filtered_row = [
i
for i in query_row
if tuple(db_dataset.doc_tuples[i].tolist()) not in sample_doc_tuples
]
filtered_row = filtered_row[: config.retro_query_num_neighbors_save]
filtered_row += [-1] * (config.retro_query_num_neighbors_save - len(filtered_row))
filtered_neighbor_ids[chunk_id - min_chunk_id] = filtered_row
return query_neighbor_ids, filtered_neighbor_ids
def query_embedding_block(
config: RetroPreprocessingConfig,
db_dataset: DBDataset,
index: Index,
embeddings: np.ndarray,
chunk_id_range: range,
sample_map: dict,
n_chunks_per_sample: int,
) -> typing.Tuple[np.ndarray, np.ndarray]:
"""Query a block of embeddings.
The block is broken into smaller sub-blocks, for easier tracking of progress.
Both the raw neighbor IDs and the filtered neighbor IDs (i.e., chunks with the
same document ID are removed) are collected.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
db_dataset (DBDataset): Dataset containing chunk database entries.
index (Index): Vector index populated with chunk database indices.
embeddings (np.ndarray): Embeddings from GPT chunk dataset.
chunk_id_range (range): Chunk ID range from GPT chunk dataset.
sample_map (dict): Mapping of sample_idx to dataset_idx and document_ids. Used for document filtering.
n_chunks_per_sample (int): Number of chunks per sample (e.g., sequence_length / chunk_length).
Returns:
A tuple of original (unfiltered) neighbor IDs, and filtered (by document ID) neighbor IDs.
"""
query_neighbor_ids = []
filtered_neighbor_ids = []
# Query in sub-blocks.
partial_block_size = 1000
for partial_start_idx in tqdm(
range(0, len(embeddings), partial_block_size),
" search",
miniters=(len(embeddings) // partial_block_size) // 10,
disable=torch.distributed.get_rank() != 0,
):
partial_end_idx = min(len(embeddings), partial_start_idx + partial_block_size)
partial_embeddings = embeddings[partial_start_idx:partial_end_idx]
partial_chunk_id_range = (
chunk_id_range[0] + partial_start_idx,
chunk_id_range[0] + partial_end_idx,
)
partial_query_neighbor_ids, partial_filtered_neighbor_ids = query_embeddings(
config,
db_dataset,
index,
partial_embeddings,
partial_chunk_id_range,
sample_map,
n_chunks_per_sample,
verbose=False,
)
query_neighbor_ids.append(partial_query_neighbor_ids)
filtered_neighbor_ids.append(partial_filtered_neighbor_ids)
# Concatenate.
query_neighbor_ids = np.concatenate(query_neighbor_ids, axis=0)
filtered_neighbor_ids = np.concatenate(filtered_neighbor_ids, axis=0)
return query_neighbor_ids, filtered_neighbor_ids
def query_block_neighbors(
config: RetroPreprocessingConfig,
db_dataset: DBDataset,
query_dataset: GPTChunkDataset,
index: Index,
block: dict,
) -> None:
"""Query neighbors of a dataset block (i.e., range).
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
db_dataset (DBDataset): Dataset containing chunk database entries.
query_dataset (GPTChunkDataset): GPT chunk dataset to be queried.
index (Index): Vector index populated with chunk database indices.
block (dict): Range information containing start/end indices for querying GPT chunk dataset.
"""
n_chunks_per_sample = query_dataset.n_chunks_per_sample
# Sample map.
sample_ids = sorted(
list(set(chunk_id // n_chunks_per_sample for chunk_id in range(*block["range"])))
)
sample_map = {}
for i in sample_ids:
sample = query_dataset.sample_dataset[i]
sample_map[i] = {"dataset_idx": sample["dataset_id"], "doc_ids": sample["document_ids"]}
# Embed block.
embeddings = embed_block(config, query_dataset, block)
# Query embeddings.
_, filtered_neighbor_ids = query_embedding_block(
config, db_dataset, index, embeddings, block["range"], sample_map, n_chunks_per_sample
)
if config.retro_task_validate is None:
# Save neighbors.
log_retro_rank_0("save neighbors.")
retro_makedir(config, os.path.dirname(block["path"]))
f = h5py.File(block["path"], "w")
f.create_dataset("neighbors", data=filtered_neighbor_ids)
f.close()
else:
# Validate neighbors.
with h5py.File(block["path"]) as f:
existing_neighbor_ids = np.copy(f["neighbors"])
assert np.array_equal(existing_neighbor_ids, filtered_neighbor_ids)
def query_dataset_neighbors(
config: RetroPreprocessingConfig,
db_dataset: DBDataset,
query_dataset: GPTChunkDataset,
num_active_chunks: int,
prefix: str,
neighbor_dir: str,
index: Index,
) -> None:
"""Query neighbors of each chunk within a dataset.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
db_dataset (DBDataset): Dataset containing chunk database entries.
query_dataset (GPTChunkDataset): GPT chunk dataset to be queried.
num_active_chunks (int): The 'active' chunks are the subset of the GPT chunk dataset that aren't being queried. This argument is used when validating the correctness of a subset of the GPT chunk dataset.
prefix (str): Extra string for logging progress.
neighbor_dir (str): File path to directory for saving neighbor IDs.
index (Index): Vector index populated with chunk database indices.
"""
def validate(f: h5py.File) -> None:
"""Validation method for validating saved neighbor IDs.
Args:
f (h5py.File): File containing save neighbor IDs.
"""
assert (
f["neighbors"].shape[1] == config.retro_query_num_neighbors_save
), "neighbors.shape == %s; num_neighbors_target == %d." % (
str(f["neighbors"].shape),
config.retro_num_neighbors_target,
)
if config.retro_task_validate is None:
retro_makedir(config, neighbor_dir)
blocks = get_blocks_by_rank(
neighbor_dir, num_active_chunks, config.retro_block_size, validate=validate
)
active_blocks = blocks.missing
else:
blocks = get_blocks_by_rank(
neighbor_dir,
num_active_chunks,
config.retro_block_size,
validate=validate,
sample=config.retro_task_validate,
)
assert blocks.n_missing_world == 0
active_blocks = blocks.existing
# Query each block.
for block_index, block in enumerate(active_blocks):
if block is not None:
# Progress.
log_retro_rank_0(
"%squery '%s' block %d / %d ... %s ... mem %.3f gb, %.1f%%."
% (
"" if config.retro_task_validate is None else "[validate] ",
prefix,
block_index,
len(active_blocks),
os.path.basename(block["path"]),
psutil.virtual_memory()[3] / 1024**3,
psutil.virtual_memory()[2],
)
)
# Query block neighbors.
query_block_neighbors(config, db_dataset, query_dataset, index, block)
# Synchronize progress across all ranks. (for easier observation)
log_retro_rank_0(" > waiting for other ranks to finish block.")
torch.distributed.barrier()
def query_neighbors(config: RetroPreprocessingConfig) -> None:
"""Query pretraining datasets (train & valid).
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
"""
# Num threads.
faiss.omp_set_num_threads(64)
# Load chunk db dataset.
log_retro_rank_0("load chunk db dataset.")
db_dataset = get_db_merged_train_dataset(
project_dir=config.retro_project_dir,
chunk_length=config.retro_gpt_chunk_length,
eod_token_id=config.retro_tokenizers.gpt.eod,
)
db_dataset.load_doc_tuples()
# Load index.
log_retro_rank_0(" > get index.")
index = get_index(config)
# Query each (i.e., train, valid, test) dataset.
log_retro_rank_0(" > query.")
for prefix, info in vars(config.retro_gpt_chunk_datasets).items():
if info is None:
continue
log_retro_rank_0(
" > query '%s' dataset ... %d samples." % (prefix, info["num_active_chunks"])
)
query_dataset_neighbors(
config,
db_dataset,
info["dataset"],
info["num_active_chunks"],
prefix,
info["neighbor_dir"],
index,
)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""
A RetroDataset wraps both:
- A GPTDataset (which is nested as GPTChunkDataset -> MultiSplitGPTDataset ->
GPTDataset).
- Neighbor IDs of chunks in the chunk database, that were saved during
preprocessing.
Both the GPT sample data and the neighbor IDs are returned within a sample from
this dataset.
"""
import os
from typing import Any, Dict, Optional, Tuple
import numpy as np
import torch
from megatron.core.datasets.retro.db.dataset import DBDataset
from megatron.core.datasets.retro.db.utils import get_merged_train_dataset as get_db_dataset
from megatron.core.datasets.retro.external_libs import h5py
from megatron.core.datasets.retro.utils import BlockPathMap, log_retro_rank_0
from megatron.core.models.retro import RetroConfig
from .gpt_chunk_dataset import GPTChunkDataset, build_gpt_chunk_datasets_from_gpt_datasets
from .utils import get_query_dir
class RetroDataset(torch.utils.data.Dataset):
"""Dataset of retro samples.
Each sample contains the original GPT sample, along with the token IDs
of each neighbor of each chunk within the sequence. Neighbor array has
shape (num_chunks_per_sample, num_neighbors, num_retrieved_tokens).
** Note: chunk dataset wraps original GPT dataset (see gpt_chunk_dataset.py).
Args:
num_queried_samples (int): Total number of queried samples.
num_neighbors (int): Total number of saved neighbors.
num_retrieved_chunks (int): Number of retrieved chunks (e.g., 2 for neighbor + continuation).
block_size (int): Number of neighbor entries per file.
db_dataset (DBDataset): Chunk database used for retrieval.
chunk_dataset (GPTChunkDataset): GPT chunk dataset, which is a wrapper around a standard GPT dataset that breaks each sample into chunks.
neighbor_path_map (BlockPathMap): Mapping of neighbor ID to file path.
"""
def __init__(
self,
num_queried_samples: int,
num_neighbors: int,
num_retrieved_chunks: int,
block_size: int,
db_dataset: DBDataset,
chunk_dataset: GPTChunkDataset,
neighbor_path_map: BlockPathMap,
):
super().__init__()
self.num_queried_samples = num_queried_samples
self.num_neighbors = num_neighbors
self.num_retrieved_chunks = num_retrieved_chunks
self.block_size = block_size
self.db_dataset = db_dataset
self.chunk_dataset = chunk_dataset
self.neighbor_path_map = neighbor_path_map
def __len__(self) -> int:
"""Dataset length.
Returns:
Number of samples in dataset.
"""
return len(self.chunk_dataset.sample_dataset)
def __getitem__(self, sample_idx: int) -> dict:
"""Get dataset sample.
Args:
sample_idx (int): Index of sample in dataset.
Returns:
A dict consisting of GPT sample (attribute 'text') and corresponding neighbor chunk IDs ('neighbor_chunks', for indexing chunk database) and neighbor token IDs (corresponding chunk database GPT tokens).
"""
n_chunks_per_sample = self.chunk_dataset.n_chunks_per_sample
# Wrap sample idx around number of queried samples.
sample_idx = sample_idx % self.num_queried_samples
# Get standard sample.
sample = self.chunk_dataset.sample_dataset[sample_idx]
# Sample idx to chunk idxs.
chunk_idxs = list(
range(sample_idx * n_chunks_per_sample, (sample_idx + 1) * n_chunks_per_sample)
)
# Collect retrieved tokens.
all_retrieved_chunk_ids = []
all_retrieved_token_ids = []
for chunk_idx in chunk_idxs:
# Neighbor chunk ids.
neighbor_path = self.neighbor_path_map[chunk_idx]
with h5py.File(neighbor_path, "r") as f:
neighbor_chunk_ids = f["neighbors"][
chunk_idx % self.block_size, : self.num_neighbors
].tolist()
# Retrieved (neighbor + continuation) token ids.
retrieved_chunk_ids = []
retrieved_token_ids = []
for neighbor_chunk_id in neighbor_chunk_ids:
current_chunk_ids = [
i % len(self.db_dataset)
for i in range(neighbor_chunk_id, neighbor_chunk_id + self.num_retrieved_chunks)
]
current_token_ids = [self.db_dataset[ci]["text"] for ci in current_chunk_ids]
retrieved_chunk_ids.append(current_chunk_ids)
retrieved_token_ids.append(current_token_ids)
# Collect retrieved tokens.
all_retrieved_chunk_ids.append(retrieved_chunk_ids)
all_retrieved_token_ids.append(retrieved_token_ids)
# Reshape retrieved tokens.
all_retrieved_chunk_ids = np.array(all_retrieved_chunk_ids).reshape(
(n_chunks_per_sample, self.num_neighbors, -1)
)
all_retrieved_token_ids = np.array(all_retrieved_token_ids).reshape(
(n_chunks_per_sample, self.num_neighbors, -1)
)
# Sample.
sample: Dict[str, np.ndarray] = {
**sample,
"neighbor_chunks": all_retrieved_chunk_ids,
"neighbor_tokens": all_retrieved_token_ids,
}
return sample
def get_retro_datasets(
config: RetroConfig, gpt_datasets: dict, sample_length: int, eod_token_id: int
) -> Tuple[Optional[RetroDataset], Optional[RetroDataset], Optional[RetroDataset]]:
"""Get train, valid, test retro datasets.
Args:
config (RetroConfig): Retro preprocessing config.
gpt_datasets (dict): Mapping of data split key ('train', 'valid', or 'test') to the original sequence-length GPT dataset (i.e., not the chunk dataset).
sample_length (int): Alias to `sequence_length`.
eod_token_id (int): GPT EOD token ID.
Returns:
A tuple of 'train', 'valid', and 'test' `RetroDataset`s.
"""
# DB dataset.
db_dataset = get_db_dataset(
project_dir=config.retro_project_dir,
chunk_length=config.retro_chunk_length,
eod_token_id=eod_token_id,
)
# GPT chunk datasets.
chunk_ds_info_map = build_gpt_chunk_datasets_from_gpt_datasets(
project_dir=config.retro_project_dir,
gpt_datasets=gpt_datasets,
sample_length=sample_length,
chunk_length=config.retro_chunk_length,
)
# Retro datasets.
retro_dataset_map: Dict[str, Optional[RetroDataset]] = {}
query_dir = get_query_dir(config.retro_project_dir)
for data_key, chunk_ds_info in chunk_ds_info_map.items():
# Skip unused datasets.
if chunk_ds_info is None:
retro_dataset_map[data_key] = None
continue
# For consistency with preprocessing, the neighbor_dir is overwritten
# (from its setting in `build_gpt_chunk_datasets_from_gpt_datasets()`
# above). This is one piece -- along with setting data_path and
# train_samples from config.json -- of ensuring consistency between
# preprocessing and pretraining.
chunk_dataset = chunk_ds_info["dataset"]
chunk_ds_info["neighbor_dir"] = os.path.join(
query_dir, config.retro_neighbor_dirs[data_key]
)
neighbor_dir = chunk_ds_info["neighbor_dir"]
neighbor_path_map = BlockPathMap.from_dir(
dir=neighbor_dir, block_size=config.retro_block_size
)
# Verify num chunks.
n_active_chunks = chunk_ds_info["num_active_chunks"]
n_neighbor_chunks = neighbor_path_map.max_idx
if not os.path.isdir(neighbor_dir):
if torch.distributed.get_rank() == 0:
raise Exception(
"neighbor directory '%s' not found; please "
"compare --train-samples, --seq-length, --seed, "
"--eval-iters, and --eval-interval, with "
"retro preprocessing args." % neighbor_dir
)
torch.distributed.barrier()
exit()
if config.retro_verify_neighbor_count and n_active_chunks != n_neighbor_chunks:
if torch.distributed.get_rank() == 0:
log_retro_rank_0("neighbor_dir : %s" % neighbor_dir)
log_retro_rank_0("neighbor_path_map : %s" % neighbor_path_map)
raise Exception(
"num sampled chunks (%d) != num neighbor chunks "
"(%d); did you complete querying the entire "
"pretraining dataset?" % (n_active_chunks, n_neighbor_chunks)
)
torch.distributed.barrier()
exit()
# Retro dataset.
retro_dataset_map[data_key] = RetroDataset(
num_queried_samples=gpt_datasets[data_key][1],
num_neighbors=config.retro_num_neighbors,
num_retrieved_chunks=config.retro_num_retrieved_chunks,
block_size=config.retro_block_size,
db_dataset=db_dataset,
chunk_dataset=chunk_dataset,
neighbor_path_map=neighbor_path_map,
)
return (retro_dataset_map["train"], retro_dataset_map["valid"], retro_dataset_map["test"])
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Utilities for querying the pretraining dataset."""
import os
from megatron.core.datasets.megatron_dataset import MegatronDataset
def get_query_dir(project_dir: str) -> str:
"""Get root directory of all saved query data.
Args:
project_dir (str): Retro project dir.
Returns:
Path to query sub-directory in Retro project.
"""
return os.path.join(project_dir, "query")
def get_neighbor_dir(project_dir: str, key: str, dataset: MegatronDataset) -> str:
"""Get directory containing neighbor IDs for a dataset (i.e., train, valid, or test).
Args:
project_dir (str): Retro project dir.
key (str): Dataset split key; 'train', 'valid', or 'test'.
dataset (MegatronDataset): Dataset containing unique hash for finding corresponding neighbors.
Returns:
Path to directory containing this dataset's neighbors within Retro project.
"""
return os.path.join(
get_query_dir(project_dir), os.path.basename(f"{key}_{dataset.unique_description_hash}")
)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Utilities for Retro preprocessing."""
import glob
import logging
import os
from collections import defaultdict
from types import SimpleNamespace
from typing import Any, Callable, Dict, List, Optional
import numpy as np
import torch
from tqdm import tqdm
from megatron.core import parallel_state
from megatron.core.datasets.retro.config import RetroPreprocessingConfig
from megatron.core.datasets.retro.query.multi_split_gpt_dataset import (
MultiSplitGPTDataset,
MultiSplitGPTDatasetConfig,
)
from megatron.core.utils import log_single_rank
from .external_libs import h5py
logger = logging.getLogger(__name__)
def log_retro_rank_0(message: str) -> None:
"""Log on rank 0.
Args:
message (str): Message to log.
"""
log_single_rank(logger, logging.INFO, "[RETRO] " + message)
def retro_makedir(config: RetroPreprocessingConfig, path: str) -> None:
"""Make a directory, conditional on not being in validation mode.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
path (str): Path to directory.
"""
if config.retro_task_validate is None:
os.makedirs(path, exist_ok=True)
def extract_data_config(config: RetroPreprocessingConfig) -> MultiSplitGPTDatasetConfig:
"""Extract data config from dataset.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
Returns:
The config object used to build the dataset.
"""
return config.retro_gpt_chunk_datasets.train["dataset"].sample_dataset.config
def get_num_chunks_per_sample(sample_length: int, chunk_length: int) -> int:
"""Compute seq_length // chunk_length.
Args:
sample_length (int): Alias of `sequence_length`.
chunk_length (int): Retro chunk length (e.g., 64).
Returns:
Number of chunks per sample (i.e., `sequence_length` / `chunk_length`).
"""
assert sample_length % chunk_length == 0
return sample_length // chunk_length
class GPTToTextDataset(torch.utils.data.Dataset):
"""Dataset to convert GPT tokens to text.
Args:
gpt_dataset (MultiSplitGPTDataset): GPT dataset, which outputs GPT token samples.
gpt_tokenizer (Any): GPT tokenizer.
"""
def __init__(self, gpt_dataset: MultiSplitGPTDataset, gpt_tokenizer: Any):
super().__init__()
self.gpt_dataset = gpt_dataset
self.gpt_tokenizer = gpt_tokenizer
def __len__(self) -> int:
"""Dataset length.
Returns:
Number of samples in the dataset.
"""
return len(self.gpt_dataset)
def __getitem__(self, idx: int) -> dict:
"""Get dataset sample.
Args:
idx (int): Index of sample.
Returns:
A dict containing attribute 'text' of type string.
"""
gpt_token_ids = self.gpt_dataset[idx]["text"].tolist()
text = self.gpt_tokenizer.detokenize(gpt_token_ids)
return {"text": text}
def get_blocks(
dirname: str, n_samples: int, block_size: int, validate: Callable = None
) -> SimpleNamespace:
"""Divide range [0, num_samples) to sequence of block ranges.
This is a core method within the concept of block processing. The idea
is to divide a range (size n_samples) into a sequence of blocks. Each
block corresponds to a file within 'dirname' with name
'{start_idx}-{end_idx}.hdf5'. This method checks for the existence of
these files, and returns two lists, one for existing blocks and one for
missing blocks.
Args:
dirname (str): Path to directory containing block files.
n_samples (int): Ideal number of samples. The total number of saved block data is <=n_samples.
block_size (int): Max number of samples per block file (e.g., 100000).
validate (Callable): Method for validating each block file during load.
Returns:
A namespace consisting of 2 lists: existing blocks, and missing blocks. The total number of samples between the existing and missing blocks should equal n_samples above.
"""
assert os.path.isdir(dirname), "missing directory '%s.'" % dirname
# Block ranges.
block_start_idxs = list(range(0, n_samples, block_size))
block_end_idxs = [min(n_samples, i + block_size) for i in block_start_idxs]
block_ranges = list(zip(block_start_idxs, block_end_idxs))
# All block files (existing + missing).
n_digits = int(np.ceil(np.log(n_samples) / np.log(10)) + 1)
all_blocks = [
{
"range": r,
"path": os.path.join(
dirname, "%s-%s.hdf5" % tuple([str(i).zfill(n_digits) for i in r])
),
}
for r in block_ranges
]
all_block_path_set = set(block["path"] for block in all_blocks)
# Validate function.
validate = (lambda f: None) if validate is None else validate
# Delete corrupt files.
if torch.distributed.get_rank() == 0:
existing_block_paths = [
block["path"] for block in all_blocks if os.path.exists(block["path"])
]
for index, path in enumerate(tqdm(existing_block_paths, "validating block.")):
assert path in all_block_path_set, "unexpected filename, '%s'." % path
try:
f = h5py.File(path, "r")
except Exception:
os.remove(path)
continue
try:
validate(f)
except Exception:
os.remove(path)
finally:
f.close()
# Wait for files to be deleted.
torch.distributed.barrier()
# Collect blocks.
blocks = SimpleNamespace(
existing=[b for b in all_blocks if os.path.exists(b["path"])],
missing=[b for b in all_blocks if not os.path.exists(b["path"])],
)
return blocks
def get_blocks_by_rank(
dirname: str,
n_samples: int,
block_size: int,
validate: Callable = None,
sample: Optional[float] = None,
) -> SimpleNamespace:
"""Divide existing and missing blocks evenly across all ranks.
See 'get_blocks()' above for description. The returned lists of existing and
missing blocks are split evenly across ranks via interleaving. This way,
each rank has a roughly equal number of blocks to process for a
downstream operation.
Args:
dirname (str): Path to directory containing block files.
n_samples (int): Ideal number of samples. The total number of saved block data is <=n_samples.
block_size (int): Max number of samples per block file (e.g., 100000).
validate (Callable): Method for validating each block file during load.
sample (Optional[float]): If provided, sample a random subset of the blocks. Used for validating preprocessing correctness.
Returns:
A namespace consisting of 2 lists: existing blocks, and missing blocks. Each of these two lists is potentially a sub-sample of the total set of existing and missing blocks, depending on whether sampling is used. Additionally, the attributes n_existing_world and n_missing_world are the total number of existing and missing blocks, independent of samples. Therefore, (n_existing_world + n_missing_world) * block_size == n_samples.
"""
# Get world blocks.
blocks = get_blocks(dirname, n_samples, block_size, validate)
# This rank's existing and missing files.
data_parallel_rank = parallel_state.get_data_parallel_rank()
data_parallel_world_size = parallel_state.get_data_parallel_world_size()
rank_existing_blocks = blocks.existing[
data_parallel_rank : len(blocks.existing) : data_parallel_world_size
]
rank_missing_blocks = blocks.missing[
data_parallel_rank : len(blocks.missing) : data_parallel_world_size
]
# Extend rank's existing and missing blocks (with None) such that all ranks
# have equal length lists. This allows for easier tracking of global progress.
def get_world_max(n: int) -> int:
"""Get max value across ranks.
Args:
n (int): Value on this rank.
Returns:
Max value across all ranks.
"""
n_tensor = torch.cuda.LongTensor([n])
torch.distributed.all_reduce(n_tensor, op=torch.distributed.ReduceOp.MAX)
return n_tensor.item()
max_n_existing = get_world_max(len(rank_existing_blocks))
max_n_missing = get_world_max(len(rank_missing_blocks))
rank_existing_blocks += [None] * (max_n_existing - len(rank_existing_blocks))
rank_missing_blocks += [None] * (max_n_missing - len(rank_missing_blocks))
# Collect blocks.
blocks = SimpleNamespace(
n_existing_world=len(blocks.existing),
n_missing_world=len(blocks.missing),
existing=rank_existing_blocks,
missing=rank_missing_blocks,
)
if sample is not None:
# Sample existing and missing blocks evenly across all ranks. The
# returned lists of blocks are randomly sampled (without replacement)
# to yield `sample * len(blocks)` number of blocks.
# Randomly sample blocks.
def sample_blocks(_blocks: List[Optional[Dict]]) -> List[Optional[Dict]]:
"""Sample a random subset of all blocks.
Args:
_blocks (List[Optional[Dict]]): List of all blocks.
Returns:
A random subset of the blocks.
"""
n_blocks_sample = int(np.ceil(sample * len(_blocks)))
sampled_blocks: List[Optional[Dict]] = [b for b in _blocks if b is not None]
np.random.seed(None)
np.random.shuffle(sampled_blocks)
sampled_blocks = sampled_blocks[:n_blocks_sample]
sampled_blocks += [None] * (n_blocks_sample - len(sampled_blocks))
return sampled_blocks
blocks.existing = sample_blocks(blocks.existing)
blocks.missing = sample_blocks(blocks.missing)
return blocks
class BlockPathMap:
"""Map an index to its containing block path.
The common use for this class is to have a directory of files containing
blocks of processed data, of uniform block size (e.g., 100k samples per
file). Each file must follow a naming convention of 'startIdx-endIdx.[ext]',
where 'endIdx' minus 'startIdx' must equal the block size, with the possible
exception of the final block. Given an input index, this class maps the
index to the containing block file.
Args:
block_paths (List[str]): List of paths to saved block files.
block_size (int): Max number of samples per block file (e.g., 100000).
"""
@classmethod
def from_dir(cls, dir: str, block_size: int, ext: str = "hdf5") -> Any:
"""Get list of block files, and create map.
Args:
dir (str): Path to directory containing saved block files.
block_size (int): Max number of samples per block file (e.g., 100000).
ext (str): Block file extension (e.g., 'hdf5').
Returns:
A mapping of sample index to block file path.
"""
assert os.path.isdir(dir), f"directory not found, '{dir}'."
return cls(sorted(glob.glob(dir + f"/*.{ext}")), block_size)
def __init__(self, block_paths: List[str], block_size: int):
self.max_idx = 0
self.block_path_map = {}
for block_path in block_paths:
name = os.path.splitext(os.path.basename(block_path))[0]
start_idx, end_idx = [int(i) for i in name.split("-")]
self.block_path_map[start_idx] = block_path
self.max_idx = max(self.max_idx, end_idx)
self.block_size = block_size
def __str__(self) -> str:
"""Stringify the mapping.
Returns:
A string representation of this block path map.
"""
return "%d paths" % len(self.block_path_map)
def __getitem__(self, idx: int) -> str:
"""Get block path from index.
Args:
idx (int): Index of sample.
Returns:
The path to the block file containing the sample index.
"""
block_start_idx = self.block_size * (idx // self.block_size)
block_path = self.block_path_map[block_start_idx]
return block_path
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import os
from collections import deque
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Union
import numpy
import torch
from packaging.version import Version as PkgVersion
from megatron.core.datasets.indexed_dataset import IndexedDataset
from megatron.core.datasets.masked_dataset import (
MaskedWordPieceDataset,
MaskedWordPieceDatasetConfig,
)
from megatron.core.datasets.utils import Split
from megatron.core.utils import get_te_version
@dataclass
class T5MaskedWordPieceDatasetConfig(MaskedWordPieceDatasetConfig):
"""Configuration object for Megatron Core T5 WordPiece datasets
NB: As a temporary holdover from Megatron-LM. The T5 tokenizer has an attribute which defines
a number of special sentinel tokens used during sampling. The assert in __post_init__ serves to
preserve compatibility with Megatron-LM until the T5 tokenizer is in Megatron Core.
"""
sequence_length_encoder: Optional[int] = field(init=False, default=None)
"""A sequence_length alias and the sequence length for the encoder"""
sequence_length_decoder: int = None
"""The sequence length for the decoder"""
def __post_init__(self) -> None:
"""Do asserts and set fields post init"""
super().__post_init__()
self.sequence_length_encoder = self.sequence_length
assert self.sequence_length_encoder is not None
assert self.sequence_length_decoder is not None
assert len(self.tokenizer.additional_special_tokens_ids) > 0
class T5MaskedWordPieceDataset(MaskedWordPieceDataset):
"""The T5 dataset that assumes WordPiece tokenization
Args:
indexed_dataset (IndexedDataset): The IndexedDataset around
which to build the MegatronDataset
dataset_path (str): The real path on disk to the dataset, for bookkeeping
indexed_indices (numpy.ndarray): The set of the documents indices to expose
num_samples (Optional[int]): The number of samples to draw from the indexed
dataset. When None, build as many samples as correspond to one epoch.
index_split (Split): The indexed_indices Split
config (T5MaskedWordPieceDatasetConfig): The config
"""
def __init__(
self,
indexed_dataset: IndexedDataset,
dataset_path: str,
indexed_indices: numpy.ndarray,
num_samples: Optional[int],
index_split: Split,
config: T5MaskedWordPieceDatasetConfig,
) -> None:
super().__init__(
indexed_dataset, dataset_path, indexed_indices, num_samples, index_split, config
)
self.token_lookup = list(self.config.tokenizer.inv_vocab.keys())
# Account for the single <bos> and single <eos> token ids
self.sample_index = self._build_sample_index(self.config.sequence_length - 2, 1)
@staticmethod
def _key_config_attributes() -> List[str]:
"""Inherited method implementation
Returns:
List[str]: The key config attributes
"""
return super(
T5MaskedWordPieceDataset, T5MaskedWordPieceDataset
)._key_config_attributes() + ["sequence_length_decoder"]
@staticmethod
def _build_b1ss_attention_mask(
source_block: torch.tensor, target_block: torch.tensor, make_history_mask: bool = False
) -> torch.tensor:
"""Build an attention-mask having shape (bs, 1, q_len, kv_len)
from source_block and target_block
Args:
source_block (torch.tensor): A 2-D array of tokens (bs, q_len)
target_block (torch.tensor): A 2-D array of tokens (bs, kv_len)
make_history_mask (bool): Whether to turn mask into causal mask
Returns:
torch.tensor: The 4-D attention mask (bs, 1, q_len, kv_len)
"""
batch_size = source_block.shape[0]
attention_mask = []
for i in range(batch_size):
source_sample = source_block[i]
target_sample = target_block[i]
mask = (target_sample[None, :] >= 1) * (source_sample[:, None] >= 1)
if make_history_mask:
arange = numpy.arange(source_sample.shape[0])
history_mask = arange[None,] <= arange[:, None]
history_mask = torch.tensor(history_mask).to(mask.device)
mask = mask * history_mask
mask = ~(mask) # flip True to False
attention_mask.append(mask)
attention_mask = torch.stack(attention_mask)
attention_mask = attention_mask.unsqueeze(1)
return attention_mask
@staticmethod
def config_attention_mask(
encoder_tokens: torch.tensor,
decoder_tokens: torch.tensor,
encoder_mask: torch.tensor,
decoder_mask: torch.tensor,
use_local: bool = False,
test_te_version: str = None,
) -> torch.tensor:
"""Config attention-mask for encoder_mask, decoder_mask, encoder_decoder_mask
conditioned on transformer-implementation (e.g. TE vs local), TE versions,
and TE backends
Args:
encoder_tokens (torch.tensor): A 2-D array of tokens (bs, kv_len)
decoder_tokens (torch.tensor): A 2-D array of tokens (bs, q_len)
encoder_mask (torch.tensor): A 2-D array of tokens (bs, kv_len)
decoder_mask (torch.tensor): A 2-D array of tokens (bs, q_len)
use_local (bool): Whether the current T5 model uses local (vs TE)
transformer implmentation
Returns:
Configured encoder_mask, decoder_mask, encoder_decoder_mask
torch.tensor: configured encoder attention mask
torch.tensor: configured decoder attention mask
torch.tensor: configured encoder-decoder attention mask
"""
# If using local transformer implementation (not transformer_engine):
# re-organize all attention masks, because local and transformer_engine
# backbones use different masks shapes. E.g.:
# (local: b1ss - transformer_engine: b11s)
if use_local:
encoder_mask = T5MaskedWordPieceDataset._build_b1ss_attention_mask(
encoder_tokens, encoder_tokens
)
decoder_mask = T5MaskedWordPieceDataset._build_b1ss_attention_mask(
decoder_tokens, decoder_tokens, make_history_mask=True
)
encoder_decoder_mask = T5MaskedWordPieceDataset._build_b1ss_attention_mask(
decoder_tokens, encoder_tokens
)
else:
# If using transformer_engine transformer implementation:
# 1. For TE version >= 1.10, across all 3 backends,
# The padding mask is configued as
# [bs, 1, 1, seq_len] for self-attention and
# ([bs, 1, 1, q_len], [bs, 1, 1, kv_len]) for cross-attention
# 2. For TE version >=1.7 and <1.10, when using Non-fused backend,
# The padding mask is configued as
# [bs, 1, q_len, kv_len] for both self-attention and for cross-attention
# 3. For TE version <1.7, only support Non-fused backend
# The padding mask is configued as
# [bs, 1, q_len, kv_len] for both self-attention and for cross-attention
# Process for Flash/Fused
encoder_mask = encoder_mask.unsqueeze(1).unsqueeze(1)
decoder_mask = decoder_mask.unsqueeze(1).unsqueeze(1)
encoder_decoder_mask = (decoder_mask, encoder_mask)
# set decoder_mask to None because decoder uses AttnMaskType.causal
decoder_mask = None
# get TE version, using test TE version if not None
if test_te_version is not None:
te_version = PkgVersion(test_te_version)
else:
te_version = get_te_version()
# Check for older TE version than 1.10, adjust attention mask accordingly
flash_attention_enabled = os.getenv('NVTE_FLASH_ATTN') == '1'
fused_attention_enabled = os.getenv('NVTE_FUSED_ATTN') == '1'
if (te_version < PkgVersion("1.10.0")) and (te_version >= PkgVersion("1.7.0")):
if not (flash_attention_enabled) and not (fused_attention_enabled):
encoder_mask = T5MaskedWordPieceDataset._build_b1ss_attention_mask(
encoder_tokens, encoder_tokens
)
encoder_decoder_mask = T5MaskedWordPieceDataset._build_b1ss_attention_mask(
decoder_tokens, encoder_tokens
)
else:
pass
elif te_version < PkgVersion("1.7.0"):
if not (flash_attention_enabled) and not (fused_attention_enabled):
encoder_mask = T5MaskedWordPieceDataset._build_b1ss_attention_mask(
encoder_tokens, encoder_tokens
)
encoder_decoder_mask = T5MaskedWordPieceDataset._build_b1ss_attention_mask(
decoder_tokens, encoder_tokens
)
else:
assert not flash_attention_enabled and not fused_attention_enabled, (
"Flash and fused attention is not supported with transformer "
"engine version < 1.7. Set NVTE_FLASH_ATTN=0 and NVTE_FUSED_ATTN=0"
"or upgrade transformer engine >= 1.7"
)
return encoder_mask, decoder_mask, encoder_decoder_mask
def __getitem__(self, idx: int) -> Dict[str, Union[int, numpy.ndarray]]:
"""Abstract method implementation
Args:
idx (int): The index into the dataset
Returns:
Dict[str, Union[int, numpy.ndarray]]: The
"""
idx_beg, idx_end, target_sequence_length = self.sample_index[idx]
sample = [self.dataset[i] for i in range(idx_beg, idx_end)]
numpy_random_state = numpy.random.RandomState(seed=(self.config.random_seed + idx) % 2**32)
assert target_sequence_length <= self.config.sequence_length
# Flatten the sample into a list of tokens
tokens = [token for sentence in sample for token in sentence]
# Truncate the list of tokens to a desired length
truncated = len(tokens) > target_sequence_length
tokens = tokens[:target_sequence_length]
# Masking
(tokens, _, _, _, masked_spans) = self._create_masked_lm_predictions(
tokens, target_sequence_length, numpy_random_state
)
# Prepare the encoder input and decoder input and output
sentinels = deque(self.config.tokenizer.additional_special_tokens_ids)
encoder_input = []
decoder_input = [self.config.tokenizer.bos]
decoder_output = []
idx_beg = 0
for indices, labels in masked_spans:
sentinel = sentinels.popleft()
# set the end index
idx_end = indices[0]
encoder_input.extend(tokens[idx_beg:idx_end])
encoder_input.append(sentinel)
decoder_input.append(sentinel)
decoder_input.extend(labels)
decoder_output.append(sentinel)
decoder_output.extend(labels)
# set the start index
idx_beg = indices[-1] + 1
encoder_input.extend(tokens[idx_beg:])
decoder_output.append(self.config.tokenizer.eos)
# Pad the sequences and convert to NumPy
length_toks_encoder = len(encoder_input)
length_toks_decoder = len(decoder_input)
length_pads_encoder = self.config.sequence_length_encoder - length_toks_encoder
length_pads_decoder = self.config.sequence_length_decoder - length_toks_decoder
assert length_pads_encoder >= 0
assert length_pads_decoder >= 0
encoder_input = numpy.array(encoder_input, dtype=numpy.int64)
encoder_input = numpy.pad(
encoder_input, (0, length_pads_encoder), constant_values=self.config.tokenizer.pad
)
decoder_input = numpy.array(decoder_input, dtype=numpy.int64)
decoder_input = numpy.pad(
decoder_input, (0, length_pads_decoder), constant_values=self.config.tokenizer.pad
)
# Create attention and history masks
mask_encoder = numpy.array([1] * length_toks_encoder + [0] * length_pads_encoder)
mask_decoder = numpy.array([1] * length_toks_decoder + [0] * length_pads_decoder)
mask_encoder_decoder = None
# Mask the labels
decoder_output = numpy.array(decoder_output, dtype=numpy.int64)
decoder_output = numpy.pad(decoder_output, (0, length_pads_decoder), constant_values=-1)
# Get the loss mask
loss_mask = numpy.zeros(self.config.sequence_length_decoder, dtype=numpy.int64)
loss_mask[:length_toks_decoder] = 1
return {
"text_enc": encoder_input,
"text_dec": decoder_input,
"labels": decoder_output,
"loss_mask": loss_mask,
"truncated": int(truncated),
"enc_mask": mask_encoder,
"dec_mask": mask_decoder,
}
def _get_token_mask(self, numpy_random_state: numpy.random.RandomState) -> int:
"""Abstract method implementation
100% of the time, replace the token id with mask token id.
Args:
numpy_random_state (RandomState): The NumPy random state
Returns:
int: The mask token id
"""
return self.config.tokenizer.mask
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import logging
from enum import Enum
from typing import List, Optional, Tuple
import numpy
import torch
from ..utils import log_single_rank
logger = logging.getLogger(__name__)
class Split(Enum):
train = 0
valid = 1
test = 2
def compile_helpers():
"""Compile C++ helper functions at runtime. Make sure this is invoked on a single process."""
import os
import subprocess
command = ["make", "-C", os.path.abspath(os.path.dirname(__file__))]
if subprocess.run(command).returncode != 0:
import sys
log_single_rank(logger, logging.ERROR, "Failed to compile the C++ dataset helper functions")
sys.exit(1)
def normalize(weights: List[float]) -> List[float]:
"""Do non-exponentiated normalization
Args:
weights (List[float]): The weights
Returns:
List[float]: The normalized weights
"""
w = numpy.array(weights, dtype=numpy.float64)
w_sum = numpy.sum(w)
w = (w / w_sum).tolist()
return w
def get_blend_from_list(
blend: Optional[List[str]],
) -> Optional[Tuple[List[str], Optional[List[float]]]]:
"""Get the megatron.core.datasets.blended_megatron_dataset_config.BlendedMegatronDatasetConfig blend from the blend list
Args:
blend (Optional[List[str]]): The blend list, which can be either (1) a list of prefixes, e.g. ["path/to/dataset_1_prefix", "path/to/dataset_2_prefix"], or (2) a flattened, zipped list of weights and prefixes, e.g. ["30", "path/to/dataset_1_prefix", "70", "path/to/dataset_2_prefix"]
Returns:
Optional[Tuple[List[str], Optional[List[float]]]]: The blend, consisting of a list of dataset prefixes and optionally a list of dataset weights, e.g. [["path/to/dataset_1_prefix", "path/to/dataset_2_prefix"], [30.0, 70.0]].
"""
if blend is None:
return None
if len(blend) % 2 == 1:
weight_per_dataset = None
raw_prefix_per_dataset = blend
else:
raw_weight_per_dataset, raw_prefix_per_dataset = zip(
*[(blend[i], blend[i + 1]) for i in range(0, len(blend), 2)]
)
weight_per_dataset = []
for rwpd in raw_weight_per_dataset:
try:
weight = float(rwpd)
except ValueError:
weight = None
weight_per_dataset.append(weight)
is_none = map(lambda _: _ is None, weight_per_dataset)
if any(is_none):
assert all(is_none)
weight_per_dataset = None
raw_prefix_per_dataset = blend
prefix_per_dataset = [rppd.strip() for rppd in raw_prefix_per_dataset]
return prefix_per_dataset, weight_per_dataset
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import os
from typing import Any, Dict, NamedTuple, Protocol, Tuple
import torch
try:
import boto3
import botocore.exceptions as exceptions
except ModuleNotFoundError:
pass
S3_PREFIX = "s3://"
class S3Config(NamedTuple):
"""Config when the data (.bin) file and the index (.idx) file are in S3
TODO: These parameters are few and can be consolidated with parameters specific to bin reader
classes - @jkamalu
Attributes:
path_to_idx_cache (str): The local directory where we will store the index (.idx) file
bin_chunk_nbytes (int): If the number of bytes is too small, then we send a request to S3 at each call of the `read` method in _S3BinReader, which is slow, because each request has a fixed cost independent of the size of the byte range requested. If the number of bytes is too large, then we only rarely have to send requests to S3, but it takes a lot of time to complete the request when we do, which can block training. We've found that 256 * 1024 * 1024 (i.e., 256 MiB) has worked well (though we have not put that much effort into tuning it), so we default to it.
"""
path_to_idx_cache: str
bin_chunk_nbytes: int = 256 * 1024 * 1024
class S3Client(Protocol):
"""The protocol which all s3 clients should abide by"""
def download_file(self, Bucket: str, Key: str, Filename: str) -> None: ...
def upload_file(self, Filename: str, Bucket: str, Key: str) -> None: ...
def head_object(self, Bucket: str, Key: str) -> Dict[str, Any]: ...
def get_object(self, Bucket: str, Key: str, Range: str) -> Dict[str, Any]: ...
def close(self) -> None: ...
def is_s3_path(path: str) -> bool:
"""Ascertain whether a path is in S3
Args:
path (str): The path
Returns:
bool: True if the path is in S3, False otherwise
"""
return path.startswith(S3_PREFIX)
def parse_s3_path(path: str) -> Tuple[str, str]:
"""Parses the given S3 path returning correspsonding bucket and key.
Args:
path (str): The S3 path
Returns:
Tuple[str, str]: A (bucket, key) tuple
"""
assert is_s3_path(path)
parts = path.replace(S3_PREFIX, "").split("/")
bucket = parts[0]
if len(parts) > 1:
key = "/".join(parts[1:])
assert S3_PREFIX + bucket + "/" + key == path
else:
key = ""
return bucket, key
def object_exists(client: S3Client, path: str) -> bool:
"""Ascertain whether the object at the given S3 path exists in S3
Args:
client (S3Client): The S3 client
path (str): The S3 path
Raises:
botocore.exceptions.ClientError: The error code is 404
Returns:
bool: True if the object exists in S3, False otherwise
"""
parsed_s3_path = parse_s3_path(path)
try:
response = client.head_object(bucket=parsed_s3_path[0], key=parsed_s3_path[1])
except exceptions.ClientError as e:
if e.response["Error"]["Code"] != "404":
raise e
return True
def _download_file(client: S3Client, s3_path: str, local_path: str) -> None:
"""Download the object at the given S3 path to the given local file system path
Args:
client (S3Client): The S3 client
s3_path (str): The S3 source path
local_path (str): The local destination path
"""
dirname = os.path.dirname(local_path)
os.makedirs(dirname, exist_ok=True)
parsed_s3_path = parse_s3_path(s3_path)
client.download_file(parsed_s3_path[0], parsed_s3_path[1], local_path)
def maybe_download_file(s3_path: str, local_path: str) -> None:
"""Download the object at the given S3 path to the given local file system path
In a distributed setting, downloading the S3 object proceeds in stages in order
to try to have the minimum number of processes download the object in order for
all the ranks to have access to the downloaded object.
Args:
s3_path (str): The S3 source path
local_path (str): The local destination path
"""
if torch.distributed.is_initialized():
rank = torch.distributed.get_rank()
local_rank = rank % torch.cuda.device_count()
else:
rank = 0
local_rank = 0
s3_client = boto3.client("s3")
if (not os.path.exists(local_path)) and (rank == 0):
_download_file(s3_client, s3_path, local_path)
if torch.distributed.is_initialized():
torch.distributed.barrier()
# If the `local_path` is in a file system that is not
# shared across all the ranks, then we assume it's in the
# host file system and each host needs to download the file.
if (not os.path.exists(local_path)) and (local_rank == 0):
_download_file(s3_client, s3_path, local_path)
if torch.distributed.is_initialized():
torch.distributed.barrier()
# If the `local_path` still does not exist, then we assume
# each rank is saving to a separate location.
if not os.path.exists(local_path):
_download_file(s3_client, s3_path, local_path)
if torch.distributed.is_initialized():
torch.distributed.barrier()
assert os.path.exists(local_path)
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
from .core import check_is_distributed_checkpoint from .core import check_is_distributed_checkpoint
from .mapping import LocalNonpersitentObject, ShardedTensor from .mapping import LocalNonpersistentObject, LocalNonpersitentObject, ShardedTensor
from .serialization import load, load_common_state_dict, save from .serialization import (
load,
load_common_state_dict,
load_plain_tensors,
load_tensors_metadata,
remove_sharded_tensors,
save,
)
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Module for managing distributed checkpoints metadata. """
import json import json
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from pathlib import Path from pathlib import Path
...@@ -9,12 +11,21 @@ CONFIG_FNAME = 'metadata.json' ...@@ -9,12 +11,21 @@ CONFIG_FNAME = 'metadata.json'
class CheckpointingException(Exception): class CheckpointingException(Exception):
"""Base checkpointing related exception"""
pass pass
@dataclass @dataclass
class CheckpointingConfig: class CheckpointingConfig:
""" Documents backends used in the checkpoint. """ """Documents backends used in the checkpoint.
Checkpoint config keeps track of formats used for storing the sharded tensors
(sharded_backend) and other objects (common_backend).
Note that versioning is not for the checkpoint content (which is application specific),
but for the checkpoint format itself.
"""
sharded_backend: str sharded_backend: str
sharded_backend_version: int = 1 sharded_backend_version: int = 1
...@@ -23,10 +34,26 @@ class CheckpointingConfig: ...@@ -23,10 +34,26 @@ class CheckpointingConfig:
def check_is_distributed_checkpoint(checkpoint_dir): def check_is_distributed_checkpoint(checkpoint_dir):
"""Checks if `metadata.json` exists in the checkpoint and is a valid config.
Args:
checkpoint_dir: checkpoint directory
Returns:
bool: True if `metadata.json` exists in the checkpoint and is a valid config.
"""
return maybe_load_config(checkpoint_dir) is not None return maybe_load_config(checkpoint_dir) is not None
def maybe_load_config(checkpoint_dir: str) -> Optional[CheckpointingConfig]: def maybe_load_config(checkpoint_dir: str) -> Optional[CheckpointingConfig]:
"""Returns checkpoint config if `checkpoint_dir` is a distributed checkpoint and None otherwise
Args:
checkpoint_dir: checkpoint directory
Returns:
CheckpointingConfig (optional): None if checkpoint is not a valid distributed checkpoint
"""
config_path = Path(checkpoint_dir, CONFIG_FNAME) config_path = Path(checkpoint_dir, CONFIG_FNAME)
if not config_path.exists(): if not config_path.exists():
return None return None
...@@ -36,6 +63,15 @@ def maybe_load_config(checkpoint_dir: str) -> Optional[CheckpointingConfig]: ...@@ -36,6 +63,15 @@ def maybe_load_config(checkpoint_dir: str) -> Optional[CheckpointingConfig]:
def save_config(config: CheckpointingConfig, checkpoint_dir: str): def save_config(config: CheckpointingConfig, checkpoint_dir: str):
"""Save given config to checkpoint directory.
Args:
config: checkpoint config
checkpoint_dir: checkpoint directory
Returns:
None
"""
config_path = Path(checkpoint_dir, CONFIG_FNAME) config_path = Path(checkpoint_dir, CONFIG_FNAME)
with config_path.open('w') as f: with config_path.open('w') as f:
json.dump(asdict(config), f) json.dump(asdict(config), f)
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Utilities for operating with dicts and lists. """ """ Utilities for operating with dicts and lists.
All functions in this module handle nesting of dicts and lists.
Other objects (e.g. tuples) are treated as atomic leaf types that cannot be traversed.
"""
from collections import defaultdict from collections import defaultdict
from typing import Any, Callable, Iterable, Optional, Tuple, Union from typing import Any, Callable, Dict, Iterable, List, Tuple, TypeVar, Union
import numpy as np
import torch import torch
U, V = TypeVar("U"), TypeVar("V")
def extract_matching_values( def extract_matching_values(
x: Union[dict, list], predicate: Callable x: Union[dict, list], predicate: Callable[[Any], bool], return_lists_as_dicts: bool = False
) -> Tuple[Union[dict, list], Union[dict, list]]: ) -> Tuple[Union[dict, list], Union[dict, list]]:
""" Return matching and nonmatching values. Keeps hierarchy. """ """Return matching and nonmatching values. Keeps hierarchy.
Args:
x (Union[dict, list]) : state dict to process. Top-level argument must be a dict or list
predicate (object -> bool): determines matching values
return_lists_as_dicts (bool): if True, matching lists will be turned
into dicts, with keys indicating the indices of original elements.
Useful for reconstructing the original hierarchy.
"""
def _set_elem(target, k, v):
if return_lists_as_dicts:
target[k] = v
else:
target.append(v)
if isinstance(x, dict): if isinstance(x, dict):
matching_vals = {} matching_vals = {}
nonmatching_vals = {} nonmatching_vals = {}
for k, v in x.items(): for k, v in x.items():
if isinstance(v, (list, dict)): if isinstance(v, (list, dict)):
match, nonmatch = extract_matching_values(v, predicate) match, nonmatch = extract_matching_values(v, predicate, return_lists_as_dicts)
if match: if match:
matching_vals[k] = match matching_vals[k] = match
if nonmatch or not v: if nonmatch or not v:
...@@ -26,25 +48,40 @@ def extract_matching_values( ...@@ -26,25 +48,40 @@ def extract_matching_values(
matching_vals[k] = v matching_vals[k] = v
else: else:
nonmatching_vals[k] = v nonmatching_vals[k] = v
else: elif isinstance(x, list): # type: ignore
assert isinstance(x, list) matching_vals = {} if return_lists_as_dicts else []
matching_vals = [] nonmatching_vals = {} if return_lists_as_dicts else []
nonmatching_vals = [] for ind, v in enumerate(x):
for v in x:
if isinstance(v, (list, dict)) and v: if isinstance(v, (list, dict)) and v:
match, nonmatch = extract_matching_values(v, predicate) match, nonmatch = extract_matching_values(v, predicate, return_lists_as_dicts)
if match: if match:
matching_vals.append(match) _set_elem(matching_vals, ind, match)
if nonmatch or not v: if nonmatch or not v:
nonmatching_vals.append(nonmatch) _set_elem(nonmatching_vals, ind, nonmatch)
elif predicate(v):
matching_vals.append(v)
else: else:
nonmatching_vals.append(v) target = matching_vals if predicate(v) else nonmatching_vals
_set_elem(target, ind, v)
else:
raise ValueError(f'Unexpected top-level object type: {type(x)}')
return matching_vals, nonmatching_vals return matching_vals, nonmatching_vals
def diff(x1: Any, x2: Any, prefix: Tuple = ()) -> Tuple[list, list, list]: def diff(x1: Any, x2: Any, prefix: Tuple = ()) -> Tuple[list, list, list]:
"""Recursive diff of dicts.
Args:
x1 (object): left dict
x2 (object): right dict
prefix (tuple): tracks recursive calls. Used for reporting differing keys.
Returns:
Tuple[list, list, list]: tuple of:
- only_left: Prefixes present only in left dict
- only_right: Prefixes present only in right dict
- mismatch: values present in both dicts but not equal across dicts.
For tensors equality of all elems is checked.
Each element is a tuple (prefix, type of left value, type of right value).
"""
mismatch = [] mismatch = []
if isinstance(x1, dict) and isinstance(x2, dict): if isinstance(x1, dict) and isinstance(x2, dict):
only_left = [prefix + (k,) for k in x1.keys() - x2.keys()] only_left = [prefix + (k,) for k in x1.keys() - x2.keys()]
...@@ -54,7 +91,8 @@ def diff(x1: Any, x2: Any, prefix: Tuple = ()) -> Tuple[list, list, list]: ...@@ -54,7 +91,8 @@ def diff(x1: Any, x2: Any, prefix: Tuple = ()) -> Tuple[list, list, list]:
only_left.extend(_left) only_left.extend(_left)
only_right.extend(_right) only_right.extend(_right)
mismatch.extend(_mismatch) mismatch.extend(_mismatch)
elif isinstance(x1, list) and isinstance(x2, list): elif isinstance(x1, list) or isinstance(x1, tuple) or isinstance(x1, np.ndarray):
assert type(x1) == type(x2)
only_left = list(range(len(x1) - 1, len(x2) - 1, -1)) only_left = list(range(len(x1) - 1, len(x2) - 1, -1))
only_right = list(range(len(x1) - 1, len(x2) - 1, -1)) only_right = list(range(len(x1) - 1, len(x2) - 1, -1))
for i, (v1, v2) in enumerate(zip(x1, x2)): for i, (v1, v2) in enumerate(zip(x1, x2)):
...@@ -66,7 +104,17 @@ def diff(x1: Any, x2: Any, prefix: Tuple = ()) -> Tuple[list, list, list]: ...@@ -66,7 +104,17 @@ def diff(x1: Any, x2: Any, prefix: Tuple = ()) -> Tuple[list, list, list]:
only_left = [] only_left = []
only_right = [] only_right = []
if isinstance(x1, torch.Tensor) and isinstance(x2, torch.Tensor): if isinstance(x1, torch.Tensor) and isinstance(x2, torch.Tensor):
_is_mismatch = not torch.all(x1 == x2) if x1.device != x2.device:
_is_mismatch = not torch.all(x1.cpu() == x2.cpu())
else:
_is_mismatch = not torch.all(x1 == x2)
# TODO: change with concrete type that has both replica_id and data attrs
elif hasattr(x1, 'replica_id') and hasattr(x2, 'replica_id'):
assert type(x1) == type(x2)
only_left, only_right, mismatch = diff(
x1.data, x2.data, prefix + (type(x1),)
) # type: ignore
_is_mismatch = False
else: else:
try: try:
_is_mismatch = bool(x1 != x2) _is_mismatch = bool(x1 != x2)
...@@ -79,22 +127,8 @@ def diff(x1: Any, x2: Any, prefix: Tuple = ()) -> Tuple[list, list, list]: ...@@ -79,22 +127,8 @@ def diff(x1: Any, x2: Any, prefix: Tuple = ()) -> Tuple[list, list, list]:
return only_left, only_right, mismatch return only_left, only_right, mismatch
def inspect_keys_types(d: dict, prefix: Tuple = (), indent: int = 4):
print_indent = lambda: print(' ' * indent * len(prefix), end='')
for k, v in d.items():
if isinstance(v, dict):
print_indent()
print(f'> {k}:')
inspect_keys_types(v, prefix + (k,), indent)
else:
print_indent()
if isinstance(v, torch.Tensor):
print(f'> {k}: {type(v)} of shape {v.shape}')
else:
print(f'> {k}: {type(v)}')
def inspect_types(x: Any, prefix: Tuple = (), indent: int = 4): def inspect_types(x: Any, prefix: Tuple = (), indent: int = 4):
"""Helper to print types of (nested) dict values."""
print_indent = lambda: print(' ' * indent * len(prefix), end='') print_indent = lambda: print(' ' * indent * len(prefix), end='')
if isinstance(x, dict): if isinstance(x, dict):
print() print()
...@@ -122,6 +156,7 @@ def inspect_types(x: Any, prefix: Tuple = (), indent: int = 4): ...@@ -122,6 +156,7 @@ def inspect_types(x: Any, prefix: Tuple = (), indent: int = 4):
def nested_values(x: Union[dict, list]): def nested_values(x: Union[dict, list]):
"""Returns iterator over (nested) values of a given dict or list."""
x_iter = x.values() if isinstance(x, dict) else x x_iter = x.values() if isinstance(x, dict) else x
for v in x_iter: for v in x_iter:
if isinstance(v, (dict, list)): if isinstance(v, (dict, list)):
...@@ -131,6 +166,7 @@ def nested_values(x: Union[dict, list]): ...@@ -131,6 +166,7 @@ def nested_values(x: Union[dict, list]):
def nested_items_iter(x: Union[dict, list]): def nested_items_iter(x: Union[dict, list]):
"""Returns iterator over (nested) tuples (container, key, value) of a given dict or list."""
x_iter = x.items() if isinstance(x, dict) else enumerate(x) x_iter = x.items() if isinstance(x, dict) else enumerate(x)
for k, v in x_iter: for k, v in x_iter:
if isinstance(v, (dict, list)): if isinstance(v, (dict, list)):
...@@ -140,16 +176,19 @@ def nested_items_iter(x: Union[dict, list]): ...@@ -140,16 +176,19 @@ def nested_items_iter(x: Union[dict, list]):
def dict_map(f: Callable, d: dict): def dict_map(f: Callable, d: dict):
"""`map` equivalent for dicts."""
for sub_d, k, v in nested_items_iter(d): for sub_d, k, v in nested_items_iter(d):
sub_d[k] = f(v) sub_d[k] = f(v)
def dict_map_with_key(f: Callable, d: dict): def dict_map_with_key(f: Callable, d: dict):
"""`map` equivalent for dicts with a function that accepts tuple (key, value)."""
for sub_d, k, v in nested_items_iter(d): for sub_d, k, v in nested_items_iter(d):
sub_d[k] = f(k, v) sub_d[k] = f(k, v)
def dict_list_map_inplace(f: Callable, x: Union[dict, list]): def dict_list_map_inplace(f: Callable[[U], V], x: Union[Dict, List, U]):
"""Maps dicts and lists *in-place* with a given function."""
if isinstance(x, dict): if isinstance(x, dict):
for k, v in x.items(): for k, v in x.items():
x[k] = dict_list_map_inplace(f, v) x[k] = dict_list_map_inplace(f, v)
...@@ -160,7 +199,8 @@ def dict_list_map_inplace(f: Callable, x: Union[dict, list]): ...@@ -160,7 +199,8 @@ def dict_list_map_inplace(f: Callable, x: Union[dict, list]):
return x return x
def dict_list_map_outplace(f: Callable, x: Union[dict, list]): def dict_list_map_outplace(f: Callable[[U], V], x: Union[Dict, List, U]) -> Union[Dict, List, V]:
"""Maps dicts and lists *out-of-place* with a given function."""
if isinstance(x, dict): if isinstance(x, dict):
return {k: dict_list_map_outplace(f, v) for k, v in x.items()} return {k: dict_list_map_outplace(f, v) for k, v in x.items()}
elif isinstance(x, list): elif isinstance(x, list):
...@@ -169,20 +209,27 @@ def dict_list_map_outplace(f: Callable, x: Union[dict, list]): ...@@ -169,20 +209,27 @@ def dict_list_map_outplace(f: Callable, x: Union[dict, list]):
return f(x) return f(x)
def merge(x1: dict, x2: dict): def merge(x1: Union[dict, list], x2: Union[dict, list], key: Tuple[Union[str, int], ...] = ()):
"""Merges dicts and lists recursively."""
if isinstance(x1, dict) and isinstance(x2, dict): if isinstance(x1, dict) and isinstance(x2, dict):
for k, v2 in x2.items(): for k, v2 in x2.items():
if k not in x1: if k not in x1:
x1[k] = v2 x1[k] = v2
else: else:
x1[k] = merge(x1[k], v2) x1[k] = merge(x1[k], v2, key=key + (k,))
elif isinstance(x1, list) and isinstance(x2, list): elif isinstance(x1, list) and isinstance(x2, list):
if len(x1) != len(x2): if len(x1) != len(x2):
raise ValueError('Cannot merge two lists with different lengths') raise ValueError(
f'Cannot merge two lists with different lengths ({len(x1)} and {len(x2)}, '
f'encountered at level {key})'
)
for i, v2 in enumerate(x2): for i, v2 in enumerate(x2):
x1[i] = merge(x1[i], v2) x1[i] = merge(x1[i], v2, key=key + (i,))
else: else:
raise ValueError(f'Duplicate non-dict and non-list values encountered: `{x1}` and `{x2}`') raise ValueError(
f'Duplicate non-dict and non-list values encountered: `{x1}` and `{x2}` '
f'(at level {key})'
)
return x1 return x1
...@@ -192,6 +239,7 @@ def map_reduce( ...@@ -192,6 +239,7 @@ def map_reduce(
value_fn: Callable = lambda x: x, value_fn: Callable = lambda x: x,
reduce_fn: Callable = lambda x: x, reduce_fn: Callable = lambda x: x,
) -> dict: ) -> dict:
"""Simple map-reduce implementation following `more_itertools.map_reduce` interface."""
res = defaultdict(list) res = defaultdict(list)
for x in xs: for x in xs:
res[key_fn(x)].append(value_fn(x)) res[key_fn(x)].append(value_fn(x))
......
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
"""Utilities for exchanging data between ranks."""
import logging
from collections import defaultdict
from functools import reduce
from itertools import zip_longest
from time import time
from typing import Dict, List, NamedTuple, Optional, Set, Tuple, TypeVar, cast
import numpy as np
import torch
from .core import CheckpointingException
from .dict_utils import nested_values
from .mapping import ShardedStateDict, ShardedTensor, is_main_replica
from .utils import _sharded_tensor_shard_id, _ShardId
# TODO: remove TE references once the TE bug is fixed
# Check if Transformer Engine has Float8Tensor class
HAVE_TE_FLOAT8TENSOR = False
try:
from transformer_engine.pytorch.float8_tensor import Float8Tensor
HAVE_TE_FLOAT8TENSOR = True
except (ImportError, ModuleNotFoundError):
# Float8Tensor not found
pass
def is_float8tensor(tensor: torch.Tensor) -> bool:
"""Check if a tensor is a Transformer Engine Float8Tensor"""
return HAVE_TE_FLOAT8TENSOR and isinstance(tensor, Float8Tensor)
logger = logging.getLogger(__name__)
class ShardDistribution(NamedTuple):
"""Represents a distribution of ShardedTensors.
Given distribution is valid only for a specific parallelization group,
which is implicit here (not referenced by this class).
Args:
main_rank_for_shard (Dict[_ShardId, int]): specifies which rank should hold
the main replica for a given shard
shards_in_this_group (Set[_ShardId]): which shards have a main replica
in this parallelization group
shard_to_metadata (Dict[_ShardId, ShardedTensor]): maps ShardedTensor
identifier to the original ShardedTensor
all_ranks_for_shard (Dict[_ShardId, List[int]]): specifies which ranks
need a given shard in a given parallelization group
"""
main_rank_for_shard: Dict[_ShardId, int]
shards_in_this_group: Set[_ShardId]
shard_to_metadata: Dict[_ShardId, ShardedTensor]
all_ranks_for_shard: Dict[_ShardId, List[int]]
def _shard_size(sh_ten: ShardedTensor):
"""Returns size in bytes of a given sharded tensor."""
if sh_ten.flattened_range is None:
numel = np.product(sh_ten.local_shape)
else:
numel = sh_ten.flattened_range.stop - sh_ten.flattened_range.start
return numel * torch._utils._element_size(sh_ten.dtype)
def _get_empty_tensor_for_exchange(
shard_id: _ShardId,
needed_shards: Dict[_ShardId, ShardedTensor],
unneeded_shards: Dict[_ShardId, ShardedTensor],
loaded_tensors: Dict[_ShardId, torch.Tensor],
) -> Tuple[torch.Tensor, Optional[torch.device]]:
"""Determines the empty tensor to use for exchange.
If shard_id is needed by this rank, it will be in the `unloaded_shards`.
Otherwise, the metadata for this tensor can be found in `shard_to_metadata`
Args:
shard_id (_ShardId): shard_id that will be exchanged
needed_shards (Dict[_ShardId, ShardedTensor]): mapping from shard ids
to metadata for shards needed by this rank
unneeded_shards (Dict[_ShardId, ShardedTensor]): mapping from shard ids
to metadata for shards that can be discarded after exchange
loaded_tensors (Dict[_ShardId, torch.Tensor]): mapping where useful tensors
are placed in
Returns:
Tuple[torch.Tensor, Optional[torch.device]]: empty CUDA tensor to be exchanged,
and the device of the original state dict tensor (if there was any)
"""
local_unloaded_sh_ten = needed_shards.get(shard_id)
if local_unloaded_sh_ten is None:
orig_device = None # this tensor will be discarded anyway
sh_ten = unneeded_shards[shard_id]
if sh_ten.data is None:
sh_ten.init_data('cuda')
tensor = sh_ten.data
sh_ten.data = None # won't be used. free memory
else:
tensor = sh_ten.data
if tensor.device.type == 'cpu':
tensor = torch.empty_like(tensor, device='cuda')
else:
local_unloaded_sh_ten.init_data('cuda')
orig_device = local_unloaded_sh_ten.data.device
tensor = local_unloaded_sh_ten.data
if tensor.device.type == 'cpu':
tensor = torch.empty_like(tensor, device='cuda')
loaded_tensors[shard_id] = tensor
return tensor, orig_device
T = TypeVar('T')
def distribute_shards_to_ranks(
shard_to_ranks: Dict[T, List[int]], shard_to_size: Dict[T, int], num_ranks: int
) -> Dict[T, int]:
"""Computes uniform distribution of workload across ranks, based on sizes.
Currently, the assignment is greedy, based on:
1. Firstly, the coverage of each shard
(how many ranks the shard is available on; lower coverage is assigned first)
2. Secondly, the size of each shard (larger size is assigned first)
3. Finally, shard id for differentiation.
Third step is added because we rely on the fact that
the assignment is deterministic on all ranks.
Args:
shard_to_ranks (Dict[T, List[int]]): mapping of rank access to shards
shard_to_size (Dict[T, int]): sizes of each shard
num_ranks (int): number of ranks in the parallelization group
Returns (Dict[T, int]): assignment of shard to rank (which rank should do the work
to achieve maximal uniformity)
"""
shard_to_ranks = {k: tuple(v) for k, v in shard_to_ranks.items()}
shard_to_saving_rank = {}
rank_sizes = [(0, rank) for rank in range(num_ranks)]
# start from tensors of lowest coverage, then go by tensor size from largest (hence minus size)
for shard_id, shard_ranks in sorted(
shard_to_ranks.items(),
key=lambda sh_id_ranks: (
len(sh_id_ranks[1]),
-shard_to_size[sh_id_ranks[0]],
sh_id_ranks[0],
),
):
# assign greedily to the least occupied rank
size, rank = min((size, rank) for size, rank in rank_sizes if rank in shard_ranks)
shard_to_saving_rank[shard_id] = rank
rank_sizes[rank] = (size + shard_to_size[shard_id], rank)
logger.debug(f'distribute_shards_to_ranks distribution: {rank_sizes}')
return shard_to_saving_rank
def determine_main_replica_uniform_distribution(
sharded_state_dict: ShardedStateDict,
parallelization_group: torch.distributed.ProcessGroup,
ignore_groups: bool = False,
) -> Optional[ShardDistribution]:
"""Computes the save distribution.
Should be used in conjunction with `distribute_main_replicas_with_precomputed_distribution`
which applies the computed save distribution.
We rely on the fact that the assignment algorithm is deterministic on all ranks,
so there is no extra communication needed after metadata exchange.
Args:
sharded_state_dict (ShardedStateDict): state dict to compute the distribution of
parallelization_group (ProcessGroup): distribution will be computed
within this process group
ignore_groups (bool, optional): whether the distribution defines groups.
This option is primarily used during loading, as it ensures that all replicas,
including non-main ones, are loaded by this parallelization group
Defaults to False.
Returns (ShardDistribution, optional): distribution that can be used to apply the
parallelization. Returns None if the process_group is trivial (1 rank)
"""
group_size = torch.distributed.get_world_size(group=parallelization_group)
if group_size <= 1:
return
local_shards = list(
sh_base
for sh_base in nested_values(sharded_state_dict)
if isinstance(sh_base, ShardedTensor)
)
local_shards_no_data = [ten.without_data() for ten in local_shards]
all_shards = [None] * torch.distributed.get_world_size(group=parallelization_group)
torch.distributed.all_gather_object(
all_shards, local_shards_no_data, group=parallelization_group
)
shard_to_ranks = defaultdict(list)
shard_to_size = {}
shard_to_metadata = {}
shards_in_this_parallelization_group: Set[_ShardId] = set()
for rank, rank_shards in enumerate(all_shards):
for sh_ten in rank_shards:
shard_id = _sharded_tensor_shard_id(sh_ten)
shard_to_ranks[shard_id].append(rank)
if shard_id not in shard_to_size:
shard_to_size[shard_id] = _shard_size(sh_ten)
shard_to_metadata[shard_id] = sh_ten
if is_main_replica(sh_ten.replica_id) or ignore_groups:
shards_in_this_parallelization_group.add(shard_id)
shard_to_ranks = {
k: v for k, v in shard_to_ranks.items() if k in shards_in_this_parallelization_group
}
shard_to_saving_rank = distribute_shards_to_ranks(
shard_to_ranks, shard_to_size, len(all_shards)
)
return ShardDistribution(
shard_to_saving_rank,
shards_in_this_parallelization_group,
shard_to_metadata,
shard_to_ranks,
)
@torch.no_grad()
def exchange_loaded_tensors_gather_rounds(
loaded_tensors: Dict[_ShardId, torch.Tensor],
unloaded_shards: Dict[_ShardId, ShardedTensor],
shard_distribution: ShardDistribution = None,
parallelization_group: Optional[torch.distributed.ProcessGroup] = None,
) -> Dict[_ShardId, torch.Tensor]:
"""Exchange the tensors loaded by different ranks with several all_gather calls.
Groups tensors by dtype, divide tensors that will be exchanged into rounds
and execute all_gather for tensors from each round.
Note: the loading is distributed across ranks based on total loaded size
in bytes, so there is no guarantee that number of rounds needed for each
rank will be similar, which might result in a lot of almost empty
all_gathers. The solution would be to group all tensors into a one
bytes tensor and do a single all_gather (with similarly sized messages).
Args:
loaded_tensors (Dict[_ShardId, torch.Tensor]): mapping from ShardedTensor
shard ids to tensors already loaded by this rank.
unloaded_shards (Dict[_ShardId, torch.Tensor]): mapping from ShardedTensor
shard ids to ShardedTensors that aren't loaded yet.
shard_distribution (ShardDistribution): distribution of all shards
parallelization_group (ProcessGroup, optional): process group used for load
distribution. Tensors will be exchanged within this group
Returns:
Dict[_ShardId, torch.Tensor]: dictionary mapping shard ids to tensors
needed by this rank to load a given state dict. Includes
previously loaded tensors (from `loaded_tensors` input)
"""
main_rank_for_shard, _, shard_to_metadata, all_ranks_for_shard = shard_distribution
local_rank = torch.distributed.get_rank(group=parallelization_group)
all_loaded_tensors = dict(loaded_tensors)
# Group by dtype so that we all_gather tensors of the same dtype
for dtype in sorted(set(map(lambda sh_ten: sh_ten.dtype, shard_to_metadata.values())), key=str):
start = time()
# shards_by_rank maps rank to tensors loaded by this rank
shards_by_rank: List[List[torch.Tensor]] = [
[] for _ in range(torch.distributed.get_world_size(group=parallelization_group))
]
for shard_id, rank in main_rank_for_shard.items():
if len(all_ranks_for_shard[shard_id]) == 1:
assert all_ranks_for_shard[shard_id][0] == main_rank_for_shard[shard_id], (
f'When there is only 1 ranks that needs a given shard,'
f' it should be the loading rank.'
f' Got: needs [{all_ranks_for_shard[shard_id][0]}]'
f' vs loads [{main_rank_for_shard[shard_id]}]'
)
# Skipping the exchange since only the loading rank needs this tensor
# TODO: we can employ some optimizations even for `len(shard_to_ranks) > 1`
# case, e.g. P2P exchange. Currently handling this case saves most of the
# work though.
continue
if shard_to_metadata[shard_id].dtype == dtype:
shards_by_rank[rank].append(shard_id)
# Transpose `shards_by_rank` to form exchange rounds
shards_by_round = zip_longest(*shards_by_rank, fillvalue=None)
for round_idx, round_shard_ids in enumerate(shards_by_round):
round_tensors = []
orig_devices = {}
for rank, shard_id in enumerate(round_shard_ids):
if shard_id is None:
# if no more useful data, the given rank will exchange empty tensor
local_ten = torch.empty(0, dtype=dtype, device='cuda')
orig_device = None
else:
assert isinstance(shard_id, tuple), type(shard_id)
if rank == local_rank:
assert shard_id in all_loaded_tensors, (shard_id, all_loaded_tensors.keys())
orig_device = all_loaded_tensors[shard_id]
all_loaded_tensors[shard_id] = all_loaded_tensors[shard_id].cuda()
local_ten = all_loaded_tensors[shard_id]
else:
local_ten, orig_device = _get_empty_tensor_for_exchange(
shard_id, unloaded_shards, shard_to_metadata, all_loaded_tensors
)
# Because of a TE bug, we have to exchange a nominal dtype instead of FP8
# It's ok to keep the nominal dtype after exchange, because TE will handle
# this during state dict load.
# TODO: remove it once the bug is fixed
if is_float8tensor(local_ten):
local_ten = local_ten.from_float8()
all_loaded_tensors[shard_id] = local_ten
round_tensors.append(local_ten)
if orig_device is not None:
orig_devices[shard_id] = orig_device
torch.distributed.all_gather(
list(round_tensors),
round_tensors[local_rank],
group=parallelization_group,
async_op=False,
)
# Move tensors back to CPU if originally was on CPU
for shard_id, orig_device in orig_devices.items():
all_loaded_tensors[shard_id] = all_loaded_tensors[shard_id].to(orig_device)
del round_tensors # remove tensor references
end = time()
if torch.distributed.get_rank() == 0:
logger.debug(f'{dtype} exchange rounds all_gather schedule took {end - start}s')
return all_loaded_tensors
def exchange_loaded_tensors_gather_object(
loaded_tensors: Dict[_ShardId, torch.Tensor],
unloaded_shards: Dict[_ShardId, ShardedTensor],
shard_distribution: ShardDistribution,
parallelization_group: Optional[torch.distributed.ProcessGroup] = None,
) -> Dict[_ShardId, torch.Tensor]:
"""Exchange the tensors loaded by different ranks with a simple all_gather_object call.
This version can be used for debugging purposes do to its simplistic
implementation. Shouldn't be used if performance is important.
Args:
loaded_tensors (Dict[_ShardId, torch.Tensor]): mapping from ShardedTensor
shard ids to tensors already loaded by this rank.
unloaded_shards (Dict[_ShardId, torch.Tensor]): mapping from ShardedTensor
shard ids to ShardedTensors that aren't loaded yet.
shard_distribution (ShardDistribution): distribution of all shards
parallelization_group (ProcessGroup, optional): process group used for load
distribution. Tensors will be exchanged within this group
Returns:
Dict[_ShardId, torch.Tensor]: dictionary mapping shard ids to tensors
needed by this rank to load a given state dict. Includes
previously loaded tensors (from `loaded_tensors` input)
"""
all_loaded_tensors_list = [None] * torch.distributed.get_world_size(group=parallelization_group)
torch.distributed.all_gather_object(
all_loaded_tensors_list, loaded_tensors, group=parallelization_group
)
all_loaded_tensors_list = cast(List[Dict[_ShardId, torch.Tensor]], all_loaded_tensors_list)
all_loaded_tensors = reduce(lambda x, y: {**x, **y}, all_loaded_tensors_list)
# Error checks
if len(all_loaded_tensors) != sum(map(len, all_loaded_tensors_list)):
err_msg = 'Duplicate shard ids loaded by different ranks'
if torch.distributed.get_rank() == 0:
logger.error(
f'{err_msg}. Shards ids by rank:'
f' {[lt.keys() for lt in all_loaded_tensors_list]}'
)
raise CheckpointingException(err_msg)
return all_loaded_tensors
@torch.no_grad()
def exchange_loaded_tensors_broadcast(
loaded_tensors: Dict[_ShardId, torch.Tensor],
unloaded_shards: Dict[_ShardId, ShardedTensor],
shard_distribution: ShardDistribution,
parallelization_group: Optional[torch.distributed.ProcessGroup] = None,
) -> Dict[_ShardId, torch.Tensor]:
"""Exchange the tensors loaded by different ranks by a series of broadcasts.
For each rank for each loaded tensor do a broadcast to the whole group.
A reasonable tradeoff in terms of performance and simplicity.
Args:
loaded_tensors (Dict[_ShardId, torch.Tensor]): mapping from ShardedTensor
shard ids to tensors already loaded by this rank.
unloaded_shards (Dict[_ShardId, ShardedTensor]): mapping from ShardedTensor
shard ids to ShardedTensors that aren't loaded yet.
shard_distribution (ShardDistribution): distribution of all shards
parallelization_group (ProcessGroup, optional): process group used for load
distribution. Tensors will be exchanged within this group
Returns:
Dict[_ShardId, torch.Tensor]: dictionary mapping shard ids to tensors
needed by this rank to load a given state dict. Includes
previously loaded tensors (from `loaded_tensors` input)
"""
main_rank_for_shard, _, shard_to_metadata, all_ranks_for_shard = shard_distribution
local_rank = torch.distributed.get_rank(group=parallelization_group)
all_loaded_tensors = dict(loaded_tensors)
start = time()
for idx, (shard_id, rank) in enumerate(main_rank_for_shard.items()):
if len(all_ranks_for_shard[shard_id]) == 1:
assert all_ranks_for_shard[shard_id][0] == main_rank_for_shard[shard_id], (
f'When there is only 1 ranks that needs a given shard,'
f' it should be the loading rank.'
f'Got: needs [{all_ranks_for_shard[shard_id][0]}]'
f' vs loads [{main_rank_for_shard[shard_id]}]'
)
# Skipping the exchange since only the loading rank needs this tensor
# TODO: we can employ some optimizations even for `len(shard_to_ranks) > 1` case,
# e.g. P2P exchange. Currently handling this case saves most of the work though.
continue
if rank == local_rank:
assert shard_id in all_loaded_tensors, (shard_id, all_loaded_tensors.keys())
orig_device = all_loaded_tensors[shard_id].device
local_ten = all_loaded_tensors[shard_id].cuda()
else:
local_ten, orig_device = _get_empty_tensor_for_exchange(
shard_id, unloaded_shards, shard_to_metadata, all_loaded_tensors
)
# Because of a TE bug, we have to exchange a nominal dtype instead of FP8
# It's ok to keep the nominal dtype after exchange, because TE will handle
# this during state dict load.
# TODO: remove it once the bug is fixed
if is_float8tensor(local_ten):
local_ten = local_ten.from_float8()
all_loaded_tensors[shard_id] = local_ten
global_src_rank = (
rank
if parallelization_group == None
else torch.distributed.get_global_rank(parallelization_group, rank)
)
# We can do async_op=True only if there is no CPU-copy follow-up
torch.distributed.broadcast(
local_ten,
src=global_src_rank,
group=parallelization_group,
async_op=orig_device is None,
)
# Move tensor back to CPU if originally was on CPU
if orig_device is not None:
all_loaded_tensors[shard_id] = local_ten.to(orig_device)
del local_ten
end = time()
if torch.distributed.get_rank() == 0:
logger.debug(f'exchange broadcast schedule took {end - start}s')
return all_loaded_tensors
def exchange_by_distribution(
loaded_tensors: Dict[_ShardId, torch.Tensor],
unloaded_shards: Dict[_ShardId, ShardedTensor],
shard_distribution: ShardDistribution = None,
parallelization_group: Optional[torch.distributed.ProcessGroup] = None,
exchange_algo='broadcast',
) -> Dict[_ShardId, torch.Tensor]:
"""Exchange tensors loaded by different ranks using the specified exchange_algo.
Args:
loaded_tensors (Dict[_ShardId, torch.Tensor]): mapping from ShardedTensor
shard ids to tensors already loaded by this rank.
unloaded_shards (Dict[_ShardId, ShardedTensor]): mapping from ShardedTensor
shard ids to ShardedTensors that aren't loaded yet.
shard_distribution (ShardDistribution): distribution of all shards
parallelization_group (ProcessGroup, optional): process group used for load
distribution. Tensors will be exchanged within this group
exchange_algo (str): The algorithm used for performing exchanges.
Defaults to 'broadcast'.
Returns:
Dict[_ShardId, torch.Tensor]: dictionary mapping shard ids to tensors
needed by this rank to load a given state dict. Includes
previously loaded tensors (from `loaded_tensors` input)
"""
if exchange_algo == 'gather_object':
exchange_fn = exchange_loaded_tensors_gather_object
elif exchange_algo == 'gather_rounds':
exchange_fn = exchange_loaded_tensors_gather_rounds
elif exchange_algo == 'broadcast':
exchange_fn = exchange_loaded_tensors_broadcast
else:
raise NotImplementedError(f'Unrecognized gather algorithm: {exchange_algo}')
return exchange_fn(loaded_tensors, unloaded_shards, shard_distribution, parallelization_group)
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Core library classes. """ """ Core library classes for representing sharding of tensors and objects.
from dataclasses import dataclass, replace The main expected usage is wrapping torch.Tensors in state dicts with
ShardedTensor class (mostly with the ShardedTensor.from_rank_offsets classmethod).
"""
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass, field, replace
from itertools import chain from itertools import chain
from typing import Any, Callable, Dict, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
from .core import CheckpointingException from .core import CheckpointingException
from .dict_utils import dict_list_map_inplace, dict_list_map_outplace from .dict_utils import dict_list_map_inplace
logger = logging.getLogger(__name__)
# These type definitions are just hints to differentiate a plain model state # These type definitions are just hints to differentiate a plain model state
# dict (StateDict) from a state dict with tensors replaced with ShardedTensors # dict (StateDict) from a state dict with tensors replaced with ShardedTensors
# (ShardedStateDict). # (ShardedStateDict).
StateDict = Dict[str, Any] StateDict = Dict[str, Any]
CommonStateDict = Dict[str, Any]
ShardedStateDict = Dict[str, Any] ShardedStateDict = Dict[str, Any]
ReplicaId = Union[int, Tuple[int, ...]] ReplicaId = Union[int, Tuple[int, ...]]
class ShardedBase(ABC):
"""Base class for ShardedTensor and ShardedStateDict."""
key: str
data: object
replica_id: ReplicaId
@abstractmethod
def validate_metadata_integrity(self):
"""Codifies the constraints on metadata attributes."""
@abstractmethod
def without_data(self) -> 'ShardedBase':
"""Returns a new ShardedBase instance with data=None."""
raise NotImplementedError
@dataclass @dataclass
class ShardedTensor: class ShardedTensor(ShardedBase):
"""Represents a mapping between a local tensor and a global tensor. """Represents a mapping between a local tensor and a global tensor.
Global tensor is assumed to consist of many local tensors distributed Global tensor is assumed to consist of many local tensors distributed
between different processes. between different processes.
Attributes: Args:
key: unique identifier of a global tensor key: unique identifier of a global tensor
data: local tensor data. Can be None only for consistency validation data: local tensor data. Can be None only for consistency validation
dtype: tensor dtype dtype: tensor dtype
local_shape: local tensor shape local_shape: local tensor shape
global_shape: global tensor shape global_shape: global tensor shape
global_offset: offset of a local tensor in a global tensor, specified global_offset: offset of a local tensor in a global tensor,
in number of tensor elements specified in number of tensor elements
axis_fragmentations: global tensor fragmentation of each axis axis_fragmentations: global tensor fragmentation of each axis
replica_id: indicates given local tensor's replication wrt. local replica_id: indicates given local tensor's replication wrt.
tensors in different processes local tensors in different processes
prepend_axis_num: number of axes prepended to the local tensor prepend_axis_num: number of axes prepended to the local tensor to
to reflect global tensor shape. reflect global tensor shape. The behavior is similar to
The behavior is similar to unsqueezing the local tensor. unsqueezing the local tensor.
allow_shape_mismatch: if True, during loading, the global shape of a allow_shape_mismatch: if True, during loading, the global shape of
stored tensor does not have to match the expected global shape. a stored tensor does not have to match the expected global shape.
Useful for representing tensors with flexible shape, e.g. padded. Useful for representing tensors with flexible shape,
flattened_range: specifies a slice that should be applied to a flattened e.g. padded.
tensor with `local_shape` in order to get the tensor stored as `data` flattened_range: specifies a slice that should be applied to a
flattened tensor with `local_shape` in order to get
the tensor stored as `data`
""" """
key: str key: str
data: Optional[torch.Tensor] data: Optional[torch.Tensor] = field(repr=False)
dtype: torch.dtype dtype: torch.dtype
local_shape: Tuple[int, ...] local_shape: Tuple[int, ...]
global_shape: Tuple[int, ...] global_shape: Tuple[int, ...]
...@@ -60,7 +88,68 @@ class ShardedTensor: ...@@ -60,7 +88,68 @@ class ShardedTensor:
allow_shape_mismatch: bool = False allow_shape_mismatch: bool = False
flattened_range: Optional[slice] = None flattened_range: Optional[slice] = None
def __post_init__(self):
self.validate_metadata_integrity()
def validate_metadata_integrity(self) -> None:
"""Codifies the constraints on metadata attributes.
Meeting those constraints is guaranteed when instantiating a ShardedTensor
class with `from_rank_offsets` or `from_rank_offsets_flat` constructors.
Returns:
None
"""
has_flattened_range = self.flattened_range is not None
if self.data is not None:
if self.data.dtype != self.dtype:
raise CheckpointingException(
f'Data dtype should match `dtype` attribute for {self}'
)
if not has_flattened_range and self.data.shape != self.local_shape:
raise CheckpointingException(
f'Data shape should match `local_shape` attribute for {self}'
)
if has_flattened_range:
if self.data.ndim != 1:
raise CheckpointingException(f'Data should be 1D for a flattened {self}')
real_data = self.data
try:
self.data = None
self.init_data(device='meta')
if self.data.shape != real_data.shape:
raise CheckpointingException(
f'Data shape doesnt match expected {self.data.shape} for {self}'
)
finally:
self.data = real_data
if len(self.global_shape) != len(self.global_offset):
raise CheckpointingException(
f'Global offset dimensions should be equal to global shape dimensions for {self}'
)
if len(self.local_shape) + self.prepend_axis_num != len(self.global_shape):
raise CheckpointingException(
f'Local shape together with `prepend_axis_num` dimensions should be '
f'equal to global shape dimensions for {self}'
)
for off, sh in zip(self.global_offset[self.prepend_axis_num :], self.local_shape):
if off % sh != 0:
raise CheckpointingException(
f'Global offset ({off}) must be divisible by local shape ({sh}) for {self}.'
)
if has_flattened_range and self.flattened_range.step is not None:
raise CheckpointingException(
f'`step` argument in the flattened range of a ShardedTensor is not supported.'
)
def global_slice(self) -> Tuple[Union[int, slice], ...]: def global_slice(self) -> Tuple[Union[int, slice], ...]:
"""
Returns a tuple of int and slice objects representing a slice of the
global tensor that this ShardedTensor corresponds to.
"""
assert len(self.global_offset) == len(self.local_shape) + self.prepend_axis_num assert len(self.global_offset) == len(self.local_shape) + self.prepend_axis_num
return tuple( return tuple(
chain( chain(
...@@ -75,6 +164,10 @@ class ShardedTensor: ...@@ -75,6 +164,10 @@ class ShardedTensor:
) )
def global_coordinates(self) -> Tuple[np.ndarray, ...]: def global_coordinates(self) -> Tuple[np.ndarray, ...]:
"""
Returns a tuple of np.ndarrays representing the coordinates of the global tensor
that this ShardedTensor corresponds to.
"""
if self.flattened_range is None: if self.flattened_range is None:
raise CheckpointingException( raise CheckpointingException(
f'`global_coordinates` is undefined for' f'`global_coordinates` is undefined for'
...@@ -93,6 +186,10 @@ class ShardedTensor: ...@@ -93,6 +186,10 @@ class ShardedTensor:
return global_coords return global_coords
def local_coordinates(self) -> Tuple[np.ndarray, ...]: def local_coordinates(self) -> Tuple[np.ndarray, ...]:
"""
Returns a tuple of np.ndarrays representing the coordinates of the local tensor
that this ShardedTensor corresponds to.
"""
if self.flattened_range is None: if self.flattened_range is None:
raise CheckpointingException( raise CheckpointingException(
f'`local_coordinates` is undefined for' f'`local_coordinates` is undefined for'
...@@ -104,12 +201,28 @@ class ShardedTensor: ...@@ -104,12 +201,28 @@ class ShardedTensor:
mask[self.flattened_range] = True mask[self.flattened_range] = True
return np.nonzero(mask.reshape(self.local_shape)) return np.nonzero(mask.reshape(self.local_shape))
def local_chunk_offset_in_global(self) -> Tuple[int, ...]:
"""Offset of a local chunk in a global array of chunks.
Returns:
Tuple[int, ...]: the offset of the whole local chunk in a global array of chunks.
"""
assert len(self.global_offset) == len(self.local_shape) + self.prepend_axis_num
chunk_offset = list(self.global_offset[: self.prepend_axis_num])
for off, sh in zip(self.global_offset[self.prepend_axis_num :], self.local_shape):
assert off % sh == 0, str(self)
chunk_offset.append(off // sh)
return tuple(chunk_offset)
def max_allowed_chunks(self) -> Tuple[int, ...]: def max_allowed_chunks(self) -> Tuple[int, ...]:
"""
Returns the maximum allowed chunks for this ShardedTensor.
"""
chunks = [] chunks = []
for axis_sh, axis_fragm in zip(self.global_shape, self.axis_fragmentations): for axis_sh, axis_fragm in zip(self.global_shape, self.axis_fragmentations):
if not self.allow_shape_mismatch and axis_sh % axis_fragm != 0: if not self.allow_shape_mismatch and axis_sh % axis_fragm != 0:
raise CheckpointingException( raise CheckpointingException(
f'Axis shape ({axis_sh}) not divisible' f' by axis fragmentation ({axis_fragm}' f'Axis shape ({axis_sh}) not divisible by axis fragmentation ({axis_fragm}'
) )
axis_chunk_size = axis_sh // axis_fragm axis_chunk_size = axis_sh // axis_fragm
chunks.append(axis_chunk_size) chunks.append(axis_chunk_size)
...@@ -126,35 +239,35 @@ class ShardedTensor: ...@@ -126,35 +239,35 @@ class ShardedTensor:
*rank_offsets: Tuple[int, int, int], *rank_offsets: Tuple[int, int, int],
replica_id: ReplicaId = 0, replica_id: ReplicaId = 0,
prepend_axis_num: int = 0, prepend_axis_num: int = 0,
allow_shape_mismatch: bool = False, flattened_range: None = None,
**init_kwargs,
): ):
"""Allows to construct the ShardedTensor given offset specified in process ranks. """Allows to construct the ShardedTensor given offset specified in process ranks.
Arguments:
key: unique key Args:
data: local tensor data key (str): unique key
rank_offsets: each tuple (axis, axis_rank_offset, axis_fragm) data (torch.Tensor): local tensor data
says that if global tensor is divided into `axis_fragm` rank_offsets (Tuple[int, int, int]): each tuple
fragment along `axis` axis, then local tensor data (axis, axis_rank_offset, axis_fragm) says that if
corresponds to the `axis_rank_offset` chunk. global tensor is divided into `axis_fragm` fragment along `axis`
replica_id: see ShardedTensor axis, then local tensor data corresponds to the `axis_rank_offset` chunk.
prepend_axis_num: see ShardedTensor replica_id (ReplicaId): see ShardedTensor
allow_shape_mismatch: see ShardedTensor prepend_axis_num (int): see ShardedTensor
flattened_range (None): must be None when using this constructor
init_kwargs: passed to ShardedTensor.__init__
""" """
if flattened_range is not None:
raise ValueError(
'Cannot instantiate a flat ShardedTensor with `from_rank_offsets` method.'
' Use `from_rank_offsets_flat` instead'
)
global_offset = [0] * (data.ndim + prepend_axis_num) global_offset = [0] * (data.ndim + prepend_axis_num)
global_shape = ([1] * prepend_axis_num) + list(data.shape) global_shape = ([1] * prepend_axis_num) + list(data.shape)
axis_fragmentations = [1] * (data.ndim + prepend_axis_num) axis_fragmentations = [1] * (data.ndim + prepend_axis_num)
_seen_axis = set() _seen_axis = set()
for axis, axis_rank_offset, axis_fragm in rank_offsets: for axis, axis_rank_offset, axis_fragm in rank_offsets:
assert axis >= 0 and axis_rank_offset >= 0 and axis_fragm >= 0, ( if axis < 0 or axis_rank_offset < 0 or axis_fragm < 1 or axis_rank_offset >= axis_fragm:
axis, raise CheckpointingException(f'Invalid rank offsets: {rank_offsets} for key {key}.')
axis_rank_offset,
axis_fragm,
)
assert (
axis_rank_offset < axis_fragm
), 'Rank offset must be lower than axis fragmentation'
if axis in _seen_axis:
raise CheckpointingException('Duplicated axis specified')
_seen_axis.add(axis) _seen_axis.add(axis)
local_axis_shape = 1 if axis < prepend_axis_num else data.shape[axis - prepend_axis_num] local_axis_shape = 1 if axis < prepend_axis_num else data.shape[axis - prepend_axis_num]
...@@ -172,23 +285,223 @@ class ShardedTensor: ...@@ -172,23 +285,223 @@ class ShardedTensor:
tuple(axis_fragmentations), tuple(axis_fragmentations),
replica_id, replica_id,
prepend_axis_num, prepend_axis_num,
allow_shape_mismatch, flattened_range=flattened_range,
**init_kwargs,
) )
def __str__(self): @classmethod
return f'{self.__class__.__name__}(key=\'{self.key}\')' def from_rank_offsets_flat(
cls,
key: str,
data: torch.Tensor,
non_flat_local_shape: Tuple[int, ...],
*args,
flattened_range: Optional[slice] = None,
**kwargs,
):
"""Allows to construct a *flattened* ShardedTensor given offset specified in process ranks.
Args:
key (str):
data (torch.Tensor): this should be a flattened data tensor
non_flat_local_shape (Tuple[int, ...]): expected local shape of a non-flat chunk
*args: passed unchanged to the `from_rank_offsets` constructor
flattened_range (slice): see ShardedTensor. Defaults to None, but must be set to
a non-None slice.
**kwargs:
Returns:
ShardedTensor: constructed ShardedTensor instance
"""
if flattened_range is None:
raise CheckpointingException(
'Cannot instantiate a non-flat ShardedTensor with `from_rank_offsets_flat` method.'
' Use `from_rank_offsets` instead'
)
if data.ndim != 1:
raise CheckpointingException(
f'Flattened ShardedTensor requires 1D data, got shape: {data.shape}'
)
if flattened_range.stop - flattened_range.start != data.numel():
raise CheckpointingException(
f'Flattened ShardedTensor data length ({data.numel()}) must meet the '
f'slice length: {flattened_range.stop - flattened_range.start}'
)
non_flat_data_meta = torch.empty(*non_flat_local_shape, dtype=data.dtype, device='meta')
sh_ten = cls.from_rank_offsets(key, non_flat_data_meta, *args, **kwargs)
instance = replace(sh_ten, data=data, flattened_range=flattened_range)
instance.validate_metadata_integrity()
return instance
def is_main_replica(replica_id): def init_data(self, device: Union[str, torch.device], init_fn=torch.empty):
"""
Initialize the tensor data of this ShardedTensor.
Only called if `data` attribute is None.
Args:
device (Union[str, torch.device]): device to place the tensor on
init_fn (Callable, optional): function to use to initialize the tensor.
Defaults to `torch.empty`.
"""
if self.data is not None:
return
self.data = init_fn(self.local_shape, dtype=self.dtype, device=device)
if self.flattened_range is not None:
self.data = self.data.flatten()[self.flattened_range.start : self.flattened_range.stop]
def narrow(self, dim: int, start: int, length: int) -> List['ShardedTensor']:
"""This is an analogue of torch.narrow for ShardedTensors.
Narrowing assumes that we narrow a local tensor on each rank.
This has consequences on local_shape, global_shape, global_offset, etc.
Args:
dim (int): dimension to narrow. Doesn't include prepended axes.
start (int): start element
length (int): length of the slice
Returns:
List[ShardedTensor]: narrowed ShardedTensors. For non-flat tensors,
the list will always have 1 element. For flat ShardedTensors the number of
elements varies depending on `dim` and on overlap, because flat
tensors must be contiguous. In particular the list can be empty.
"""
prepended_dim = dim + self.prepend_axis_num
local_length_along_dim = self.local_shape[dim]
def _update_tuple(x, ind, val):
x = list(x)
x[ind] = val
return tuple(x)
def _safe_div(x, y):
assert x % y == 0, (x, y)
return x // y
# Decrease global shape and global offset by `length / local_length_along_dim`
assert (
self.global_shape[prepended_dim] % local_length_along_dim == 0
), f'Only regular grid of local tensors is supported for narrowing, got: {self}'
assert (
self.global_offset[prepended_dim] % local_length_along_dim == 0
), f'Only regular grid of local tensors is supported for narrowing, got: {self}'
global_shape = _update_tuple(
self.global_shape,
prepended_dim,
_safe_div(self.global_shape[prepended_dim] * length, local_length_along_dim),
)
global_offset = _update_tuple(
self.global_offset,
prepended_dim,
_safe_div(self.global_offset[prepended_dim] * length, local_length_along_dim),
)
if self.flattened_range is None:
new_data = self.data.narrow(dim, start, length)
# always a single result tensor
return [
replace(
self,
data=new_data,
local_shape=new_data.shape,
global_shape=global_shape,
global_offset=global_offset,
)
]
else:
if dim != 0:
raise CheckpointingException(
f'Narrowing along the first axis is supported for now only, got dim={dim}'
)
# If dim=0, we will always get 0 or 1 resulting tensor.
# If dim>1, in general there can be more result tensors (e.g. max 3 for dim=1)
# For on original flat ShardedTensor of local shape [3, 4] and
# flattened_range=slice(5, 10),
# the X signs mark the actual (flat) data in `self.data`
# notice 12 (3*4) total "virtual" elements, out of which 5 is actual data.
# flat original: [.....XXXXX..]
# If we narrow to start=1, length=1 in the original local shape dimensions,
# the overlapping flat slice would be:
# narrow to: [....XXXX....]
# flat overlap: [.....XXX....]
# Now `data` is flattened and sliced, so we must compute local_shape manually
local_shape = _update_tuple(self.local_shape, dim, length)
other_dims_volume = np.prod(
_update_tuple(local_shape, dim, 1)
) # 4 in the example above
volume_before_split = other_dims_volume * start # 4 in the example above
volume_of_split = other_dims_volume * length # 4 in the example above
flat_slice_start_shifted = (
self.flattened_range.start - volume_before_split
) # 5 - 4 = 1 in the example above
flat_slice_stop_shifted = (
self.flattened_range.stop - volume_before_split
) # 10 - 4 = 6 in the example above
# Find an intersection of
# (flat_slice_start_shifted, flat_slice_stop_shifted) vs (0, volume_of_split)
if flat_slice_stop_shifted <= 0 or flat_slice_start_shifted >= volume_of_split:
return [] # no intersection
# new_flattened_range = slice(1, 4) in the example above
new_flattened_range = slice(
max(flat_slice_start_shifted, 0), min(flat_slice_stop_shifted, volume_of_split)
)
# Apply the intersection to the flattened data tensor.
# Compute start and slice appropriate length
intersection_slice_start = (
new_flattened_range.start - flat_slice_start_shifted
) # 0 in the example above
new_data = self.data[
intersection_slice_start : intersection_slice_start
+ new_flattened_range.stop
- new_flattened_range.start
]
return [
replace(
self,
data=new_data,
local_shape=local_shape,
global_shape=global_shape,
global_offset=global_offset,
flattened_range=new_flattened_range,
)
]
def is_main_replica(replica_id: ReplicaId):
"""Checks if given `replica_id` is considered as main.
"Main" replica is:
- integer 0
- or an iterable with all 0 elements
It is the application responsibility to set correct replicas for sharded tensors.
Args:
replica_id (Union[int, Tuple[int, ...]]): replica id
Returns:
(bool): True for a "main" replica
"""
if isinstance(replica_id, int): if isinstance(replica_id, int):
return replica_id == 0 return replica_id == 0
return all(r == 0 for r in replica_id) return all(r == 0 for r in replica_id)
class LocalNonpersitentObject: class LocalNonpersistentObject:
"""Object that should not be stored in a checkpoint, but restored locally. """Object that should not be stored in a checkpoint, but restored locally.
Wrapping any object inside the state dict with LocalNonpersitentObject Wrapping any object inside the state dict with LocalNonpersistentObject
will result in: will result in:
- during saving, this object will *not* be stored in the checkpoint - during saving, this object will *not* be stored in the checkpoint
- during loading, a local version of this object will be placed in a state dict - during loading, a local version of this object will be placed in a state dict
...@@ -198,11 +511,16 @@ class LocalNonpersitentObject: ...@@ -198,11 +511,16 @@ class LocalNonpersitentObject:
self.obj = obj self.obj = obj
def unwrap(self): def unwrap(self):
"""Returns the original object."""
return self.obj return self.obj
# TODO: Delete once NeMo fixes typo.
LocalNonpersitentObject = LocalNonpersistentObject
@dataclass @dataclass
class ShardedObject: class ShardedObject(ShardedBase):
"""Represents a mapping between a local object and a global object. """Represents a mapping between a local object and a global object.
Global object is assumed to consist of many local objects distributed Global object is assumed to consist of many local objects distributed
...@@ -212,14 +530,12 @@ class ShardedObject: ...@@ -212,14 +530,12 @@ class ShardedObject:
sharding. Conceptually, ShardedObject is a fully-sharded ShardedTensor sharding. Conceptually, ShardedObject is a fully-sharded ShardedTensor
with atomic arbitrary typed elements. with atomic arbitrary typed elements.
Attributes: Args:
key: unique identifier of a global tensor key: unique identifier of a global tensor
data: local object data. Can be None only for consistency validation data: local object data. Can be None only for consistency validation
global_shape: global object shape global_shape: global object shape
global_offset: offset of a local object in a global object, specified global_offset: offset of a local object in a global object, specified in number of shards
in number of shards replica_id: indicates local object replication wrt. local objects in different processes
replica_id: indicates local object replication wrt. local
objects in different processes
""" """
key: str key: str
...@@ -228,38 +544,116 @@ class ShardedObject: ...@@ -228,38 +544,116 @@ class ShardedObject:
global_offset: Tuple[int, ...] global_offset: Tuple[int, ...]
replica_id: ReplicaId = 0 replica_id: ReplicaId = 0
def __post_init__(self):
self.validate_metadata_integrity()
def validate_metadata_integrity(self):
if len(self.global_shape) != len(self.global_offset):
raise CheckpointingException(
f'Global offset dimensions should be equal to global shape dimensions for {self}'
)
def without_data(self): def without_data(self):
return replace(self, data=None) return replace(self, data=None)
@property @property
def unique_key(self): def unique_key(self):
return f'{self.key}/shard_{".".join(map(str, self.global_offset))}_{".".join(map(str, self.global_shape))}' """returns a unique key for this object"""
return (
f'{self.key}/shard_'
f'{".".join(map(str, self.global_offset))}_'
f'{".".join(map(str, self.global_shape))}'
)
def __str__(self): def __str__(self):
return f'{self.__class__.__name__}(key=\'{self.key}\')' return f'{self.__class__.__name__}(key=\'{self.key}\')'
@classmethod
def empty_from_unique_key(cls, unique_key, replica_id: ReplicaId = 0) -> 'ShardedObject':
"""Instantiates a ShardedObject from a unique key.
Args:
unique_key: a string of the form
<key>/shard_<global_offset>_<global_shape>
replica_id: indicates local object replication wrt.
local objects in different processes
Returns:
a ShardedObject with data=None
"""
key, shard_key = unique_key.split('/')
shard_str, offset, shape = shard_key.split('_')
assert shard_str == 'shard'
offset = tuple(map(int, offset.split('.')))
shape = tuple(map(int, shape.split('.')))
if len(shape) + 1 == len(offset):
# This is a backward-compatible fix. We don't know the last
# element of global shape so set it to -1.
shape += (-1,)
return cls(key, None, shape, offset, replica_id)
FactoryBuildFn = Callable[[str, torch.Tensor, ReplicaId, Optional[slice]], ShardedStateDict]
FactoryMergeFn = Callable[[StateDict], torch.Tensor]
@dataclass @dataclass
class ShardedTensorFactory: class ShardedTensorFactory(ShardedBase):
""" Allows to apply transformations to tensors before/after serialization. """Allows to apply transformations to tensors before/after serialization.
The essence of those transformations is that they can be applied to The essence of those transformations is that they can be applied to
optimizer states the same way they are applied to the model params. optimizer states the same way they are applied to the model params.
The ultimate state dict with sharded tensors must depend functionally on
`build_fn` arguments (key, data, replica_id, flattened_range),
which will be provided by the optimizer.
Builder creates a sub-state-dict out of a tensor before saving, and merger Builder creates a sub-state-dict out of a tensor before saving, and merger
merges the corresponding state dict after loading. merges the corresponding state dict after loading.
Args:
key (str): unique identifier of the factory
data (torch.Tensor): original model parameter that will be further
transformed by this factory
build_fn (callable): function that transforms the original tensor
to a sharded state dict
merge_fn (callable): function that transforms loaded subtree back
into a single tensor (inverse of `build_fn`)
replica_id (ReplicaId): indicates factory replication wrt.
factories in different processes
flattened_range (slice, optional): indicates additional flattening
applied to the ShardedTensors produced by the factory
""" """
key: str key: str
data: torch.Tensor data: torch.Tensor
build_fn: Callable[[str, torch.Tensor], ShardedStateDict] build_fn: FactoryBuildFn
merge_fn: Callable[[StateDict], torch.Tensor] merge_fn: FactoryMergeFn
replica_id: ReplicaId = 0
flattened_range: Optional[slice] = None
def build(self): def build(self):
return self.build_fn(self.key, self.data) """Builds a ShardedStateDict from the original tensor"""
return self.build_fn(self.key, self.data, self.replica_id, self.flattened_range)
def validate_metadata_integrity(self):
"""No reasonable checks can be applied"""
pass
def without_data(self):
return replace(self, data=None)
def apply_factories(sharded_state_dict: ShardedStateDict): def apply_factories(sharded_state_dict: ShardedStateDict):
"""Turn ShardedTensorFactories into ShardedTensors *in-place*.
Args:
sharded_state_dict (ShardedStateDict): state dict possibly
containing ShardedTensorFactory objects
Returns:
None: state dict is modified in place
"""
def apply(x): def apply(x):
if isinstance(x, ShardedTensorFactory): if isinstance(x, ShardedTensorFactory):
x = x.build() x = x.build()
...@@ -268,7 +662,23 @@ def apply_factories(sharded_state_dict: ShardedStateDict): ...@@ -268,7 +662,23 @@ def apply_factories(sharded_state_dict: ShardedStateDict):
dict_list_map_inplace(apply, sharded_state_dict) dict_list_map_inplace(apply, sharded_state_dict)
def apply_factory_merges(x1: StateDict, x2: ShardedStateDict): def apply_factory_merges(
x1: StateDict, x2: ShardedStateDict, key: Tuple[str, ...] = ()
) -> StateDict:
"""Apply merges defined by ShardedTensorFactories *in-place*.
Args:
x1 (StateDict): state dict loaded from the checkpoint
x2 (ShardedStateDict): subset of `x1` (in terms of dict keys)
with ShardedTensorFactory
as (possibly nested) values that define how to
merge objects from the `x1` state dict
key (Tuple[str, ...]): current key in a recursive call.
Used only for reporting meaningful errors
Returns:
StateDict: `x1` modified in-place
"""
if isinstance(x2, ShardedTensorFactory): if isinstance(x2, ShardedTensorFactory):
return x2.merge_fn(x1) return x2.merge_fn(x1)
...@@ -276,14 +686,37 @@ def apply_factory_merges(x1: StateDict, x2: ShardedStateDict): ...@@ -276,14 +686,37 @@ def apply_factory_merges(x1: StateDict, x2: ShardedStateDict):
if isinstance(x1, dict) and isinstance(x2, dict): if isinstance(x1, dict) and isinstance(x2, dict):
for k, v2 in x2.items(): for k, v2 in x2.items():
if k not in x1: if k not in x1:
raise ValueError('Different dict keys encountered in `apply_factory_merges`') raise ValueError(
f'Different dict keys encountered in `apply_factory_merges` '
f'({x1.keys()} vs {x2.keys()})'
)
else: else:
x1[k] = apply_factory_merges(x1[k], v2) x1[k] = apply_factory_merges(x1[k], v2, key=key + (k,))
elif isinstance(x1, list) and isinstance(x2, list): elif isinstance(x1, list) and isinstance(x2, list):
if len(x1) != len(x2): if len(x1) != len(x2):
raise ValueError('Cannot merge two lists with different lengths') err_msg = (
f'Cannot merge two lists with different lengths '
f'({len(x1)} and {len(x2)}, encountered at key {key})'
)
logger.error(err_msg + f'\nx1: {x1}\nx2: {x2}')
raise ValueError(err_msg)
for i, v2 in enumerate(x2): for i, v2 in enumerate(x2):
x1[i] = apply_factory_merges(x1[i], v2) x1[i] = apply_factory_merges(x1[i], v2, key=key + (i,))
elif isinstance(x1, list) and isinstance(x2, dict):
for k, v2 in x2.items():
if not isinstance(k, int):
raise ValueError(
f'Invalid dict key {k} non-integer type encountered '
f'in a list-dict merge at level {key}'
)
if k >= len(x1):
raise ValueError(
f'Dict key {k} out of bound for list of length'
f'{len(x1)} (encountered at level {key})'
)
x1[k] = apply_factory_merges(x1[k], v2, key=key + (k,))
else: else:
raise ValueError(f'Duplicate non-dict and non-list values encountered: `{x1}` and `{x2}`') raise ValueError(
f'Duplicate non-dict and non-list values encountered: `{x1}` and `{x2} (at key {key})`'
)
return x1 return x1
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Optimizer related helpers. """ """ Helpers for defining sharding for optimizer states based on existing sharding
for model parameters.
"""
import logging import logging
from copy import deepcopy from copy import deepcopy
from dataclasses import replace from dataclasses import replace
from itertools import chain from typing import Dict, Iterable, Tuple, Union
from typing import Dict, Iterable, List, Union
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
import torch import torch
from megatron.core.utils import to_local_if_dtensor
from .dict_utils import nested_values from .dict_utils import nested_values
from .mapping import ( from .mapping import (
LocalNonpersitentObject, LocalNonpersistentObject,
ShardedStateDict, ShardedStateDict,
ShardedTensor, ShardedTensor,
ShardedTensorFactory, ShardedTensorFactory,
StateDict, StateDict,
) )
from .utils import extract_sharded_tensors, extract_sharded_tensors_and_factories from .utils import extract_sharded_tensors_and_factories
def get_optim_param_to_id_map(optim_params_iter: Iterable[torch.nn.Parameter]) -> Dict[int, int]: def get_optim_param_to_id_map(optim_params_iter: Iterable[torch.nn.Parameter]) -> Dict[int, int]:
"""Generate mapping from optimizer param to optimizer state id."""
param_mappings = {} param_mappings = {}
for i, param in enumerate(optim_params_iter): for i, param in enumerate(optim_params_iter):
param = to_local_if_dtensor(param)
if id(param) not in param_mappings: if id(param) not in param_mappings:
param_mappings[id(param)] = i param_mappings[id(param)] = i
return param_mappings return param_mappings
...@@ -34,9 +39,24 @@ def get_optim_param_to_id_map(optim_params_iter: Iterable[torch.nn.Parameter]) - ...@@ -34,9 +39,24 @@ def get_optim_param_to_id_map(optim_params_iter: Iterable[torch.nn.Parameter]) -
def get_param_id_to_sharded_param_map( def get_param_id_to_sharded_param_map(
model_sharded_state_dict: ShardedStateDict, optim_params_iter: Iterable[torch.nn.Parameter] model_sharded_state_dict: ShardedStateDict, optim_params_iter: Iterable[torch.nn.Parameter]
) -> Dict[int, Union[ShardedTensor, ShardedTensorFactory]]: ) -> Dict[int, Union[ShardedTensor, ShardedTensorFactory]]:
"""Generate mapping from optimizer state ids to model sharded parameters.
Args:
model_sharded_state_dict: sharded state dict with all model sharded tensors
(can have any structure)
optim_params_iter: iterable which iterates over model parameters tracked by the optimizer.
The iteration must be in the same order as in the optimizer parameters.
Returns:
Dict[int, Union[ShardedTensor, ShardedTensorFactory]]: mapping from optimizer state ids
to model sharded parameters.
"""
model_sharded_state_dict, _ = extract_sharded_tensors_and_factories(model_sharded_state_dict) model_sharded_state_dict, _ = extract_sharded_tensors_and_factories(model_sharded_state_dict)
id_to_sharded_param_map = {} id_to_sharded_param_map = {}
param_to_id_map = get_optim_param_to_id_map(optim_params_iter) param_to_id_map = get_optim_param_to_id_map(optim_params_iter)
# If using PyTorch FSDP2 the values in model_sharded_state_dict would
# have been converted to local tensors during initialization.
# See the make_(tp)_sharded_tensor_for_checkpoint functions.
for ten in nested_values(model_sharded_state_dict): for ten in nested_values(model_sharded_state_dict):
if id(ten.data) in param_to_id_map: if id(ten.data) in param_to_id_map:
id_to_sharded_param_map[param_to_id_map[id(ten.data)]] = ten id_to_sharded_param_map[param_to_id_map[id(ten.data)]] = ten
...@@ -55,24 +75,60 @@ def get_param_id_to_sharded_param_map( ...@@ -55,24 +75,60 @@ def get_param_id_to_sharded_param_map(
def make_sharded_optimizer_tensor( def make_sharded_optimizer_tensor(
model_param: Union[ShardedTensor, ShardedTensorFactory], optim_param: torch.Tensor, prefix: str model_param: Union[ShardedTensor, ShardedTensorFactory], optim_param: torch.Tensor, prefix: str
) -> Union[ShardedTensor, ShardedTensorFactory]: ) -> Union[ShardedTensor, ShardedTensorFactory]:
"""Build a ShardedTensor or ShardedTensorFactory for optimizer param based on model param
Args:
model_param (Union[ShardedTensor, ShardedTensorFactory]): model param
optim_param (torch.Tensor): corresponding optimizer param
prefix (str): optimizer prefix for the ShardedTensor or ShardedTensorFactory
Returns:
Union[ShardedTensor, ShardedTensorFactory]: wrapped optimizer parameter
"""
optim_param = to_local_if_dtensor(optim_param)
if isinstance(model_param, ShardedTensorFactory): if isinstance(model_param, ShardedTensorFactory):
return replace(model_param, key=f'{prefix}.{model_param.key}', data=optim_param) return replace(model_param, key=f'{prefix}.{model_param.key}', data=optim_param)
assert ( assert tuple(optim_param.shape) == model_param.local_shape, (
tuple(optim_param.shape) == model_param.local_shape f'Optimizer shape ({tuple(optim_param.shape)} does not match model shape '
), f'Optimizer shape ({tuple(optim_param.shape)} does not match model shape ({model_param.local_shape})' f'({model_param.local_shape})'
return replace( )
sh_ten = replace(
model_param, key=f'{prefix}.{model_param.key}', data=optim_param, dtype=optim_param.dtype model_param, key=f'{prefix}.{model_param.key}', data=optim_param, dtype=optim_param.dtype
) )
sh_ten.validate_metadata_integrity()
return sh_ten
def optim_state_to_sharding_state( def optim_state_to_sharding_state(
optim_state_dict: StateDict, id_to_sharded_param_map: Dict[int, ShardedTensor] optim_state_dict: StateDict,
id_to_sharded_param_map: Dict[int, ShardedTensor],
exclude_keys: Tuple[str] = (),
): ):
"""Turn optimizer state dict to sharded state dict based on model state dict *in-place*.
Can be used to add sharding information to most common optimizer state dict.
Creates separate ShardedTensors for each key in `optim_state_dict['state']`
(e.g. for torch.optim.Adam there will be separate tensors for `exp_avg` and `exp_avg_sq`)
Args:
optim_state_dict (StateDict): optimizer state dict with
state parameters under `state` key and group hyperparameters under
`param_groups` -> `params` key.
id_to_sharded_param_map (Dict[int, ShardedTensor]): mapping from optimizer param ids
to model sharded tensors. Can be generated with `get_param_id_to_sharded_param_map`
function.
exclude_keys (Tuple[str]): optimizer state keys to exclude from the final state dict.
Returns:
None: state dict is modified in place
"""
sharded_state = {} sharded_state = {}
for param_id, param_state in optim_state_dict['state'].items(): for param_id, param_state in optim_state_dict['state'].items():
sharded_state[param_id] = {} sharded_state[param_id] = {}
for state_key, param in param_state.items(): for state_key, param in param_state.items():
if state_key in exclude_keys:
continue
if param_id in id_to_sharded_param_map: if param_id in id_to_sharded_param_map:
sharded_state[param_id][state_key] = make_sharded_optimizer_tensor( sharded_state[param_id][state_key] = make_sharded_optimizer_tensor(
id_to_sharded_param_map[param_id], param, prefix=f'optimizer.state.{state_key}' id_to_sharded_param_map[param_id], param, prefix=f'optimizer.state.{state_key}'
...@@ -82,5 +138,5 @@ def optim_state_to_sharding_state( ...@@ -82,5 +138,5 @@ def optim_state_to_sharding_state(
optim_state_dict['param_groups'] = deepcopy(optim_state_dict['param_groups']) optim_state_dict['param_groups'] = deepcopy(optim_state_dict['param_groups'])
for group in optim_state_dict['param_groups']: for group in optim_state_dict['param_groups']:
group['params'] = LocalNonpersitentObject(group['params']) group['params'] = LocalNonpersistentObject(group['params'])
optim_state_dict['state'] = sharded_state optim_state_dict['state'] = sharded_state
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Entrypoints for saving and loading the distributed checkpoints.
Functions `load` and `save` are equivalents of `torch.load` and `torch.save`
but expect torch.Tensors to be wrapped with classes from the `mapping module`.
Additionally, `load` expects the sharded state dict argument as a guidance for
loading the sharded tensors.
"""
import logging import logging
import os
from collections import Counter, defaultdict
from itertools import chain
from pathlib import Path from pathlib import Path
from typing import Iterable, List, Tuple, Union from typing import Callable, Dict, Optional, Set, Tuple, Union
import numpy as np
import torch import torch
from .core import CheckpointingConfig, maybe_load_config, save_config from . import ShardedTensor
from .dict_utils import ( from .core import CheckpointingConfig, save_config
dict_list_map_inplace, from .dict_utils import extract_matching_values, merge
diff,
extract_matching_values,
map_reduce,
merge,
nested_values,
)
from .mapping import ( from .mapping import (
CheckpointingException, CheckpointingException,
CommonStateDict,
ShardedObject, ShardedObject,
ShardedStateDict, ShardedStateDict,
ShardedTensor,
ShardedTensorFactory,
StateDict, StateDict,
apply_factories,
apply_factory_merges, apply_factory_merges,
is_main_replica,
) )
from .state_dict_transformation import load_preprocess, save_preprocess
from .strategies.async_utils import AsyncRequest
from .strategies.base import ( from .strategies.base import (
AsyncSaveShardedStrategy,
LoadCommonStrategy, LoadCommonStrategy,
LoadShardedStrategy, LoadShardedStrategy,
SaveCommonStrategy, SaveCommonStrategy,
...@@ -38,97 +36,266 @@ from .strategies.base import ( ...@@ -38,97 +36,266 @@ from .strategies.base import (
StrategyAction, StrategyAction,
get_default_strategy, get_default_strategy,
) )
from .utils import extract_sharded_tensors, extract_sharded_tensors_or_nonpersistent from .utils import extract_sharded_base
from .validation import (
COMMON_STATE_FNAME = 'common.pt' StrictHandling,
determine_global_metadata,
parse_strict_flag,
validate_integrity_and_strict_load,
validate_sharded_objects_handling,
verify_checkpoint_and_load_strategy,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# flat state dict with sharded objects without any data
CkptShardedMetadata = Dict[str, Union[ShardedTensor, ShardedObject]]
def load( def load(
sharded_state_dict: ShardedStateDict, sharded_state_dict: ShardedStateDict,
checkpoint_dir: str, checkpoint_dir: str,
sharded_strategy: Union[LoadShardedStrategy, None] = None, sharded_strategy: Union[LoadShardedStrategy, Tuple[str, int], None] = None,
common_strategy: Union[LoadCommonStrategy, None] = None, common_strategy: Union[LoadCommonStrategy, Tuple[str, int], None] = None,
) -> StateDict: validate_access_integrity: bool = True,
strict: Union[str, StrictHandling] = StrictHandling.ASSUME_OK_UNEXPECTED,
) -> Union[StateDict, Tuple[StateDict, Set[str], Set[str]]]:
"""Loading entrypoint. """Loading entrypoint.
Arguments: In the steps below, the following verbs refer to corresponding objects:
sharded_state_dict: state dict of the existing model populated with - load = load from checkpoint
ShardedTensors. Used as a mapping to determine which parts of - extract = extract from sharded_state_dict
global tensors stored in the checkpoint should be loaded. - add = add to the final state dict
checkpoint_dir: directory with the checkpoint Steps:
sharded_strategy: configures loading behavior for sharded tensors 1. Load common state dict and form the base of the result state dict
common_strategy: configures loading behavior for common data 2. Apply factories to sharded_state_dict
3. Extract LocalNonPersistentObject and add
4. (optional) Extract ShardedObjects, load and add
5. Extract ShardedBase, load, apply factory merges and add
Args:
sharded_state_dict (ShardedStateDict): state dict of the existing model
populated with ShardedTensors. Used as a mapping to determine which
parts of global tensors stored in the checkpoint should be loaded.
checkpoint_dir (str): directory with the checkpoint
sharded_strategy (LoadShardedStrategy, Tuple[str, int], optional):
configures loading behavior for sharded tensors
common_strategy (LoadCommonStrategy, Tuple[str, int], optional):
configures loading behavior for common data
validate_access_integrity (bool default = True): checks if each tensor shard is accessed
exactly once (as main replica) by some process
strict (StrictHandling, str, optional): determines the behavior in case of a mismatch
between the requested sharded state dict and the checkpoint. See `StrictHandling` docs
for more details. Some values affect the return value of this function
(missing and unexpected keys are returned).
Defaults to `True` (StrictHandling.ASSUME_OK_UNEXPECTED) which doesn't
incur any performance overhead. Other recommended values
are: `False` (StrictHandling.LOG_UNEXPECTED) which logs only unexpected keys
or `StrictHandling.RETURN_ALL` which returns all mismatch keys.
Returns:
StateDict or Tuple[StateDict, Set[str], Set[str]]: in most cases only
the loaded state dict is returned. If `strict` flag was set to
""" """
if common_strategy is not None: sharded_strategy, common_strategy = verify_checkpoint_and_load_strategy(
raise NotImplementedError('The only supported common strategy is torch') checkpoint_dir, sharded_strategy, common_strategy
)
checkpoint_dir = Path(checkpoint_dir) checkpoint_dir = Path(checkpoint_dir)
common_state_dict = load_common_state_dict(checkpoint_dir) common_state_dict = common_strategy.load_common(checkpoint_dir)
if not sharded_state_dict: if not sharded_state_dict:
return common_state_dict return common_state_dict
sharded_objects, sharded_state_dict = load_sharded_objects(sharded_state_dict, checkpoint_dir) sharded_state_dict, nonpersistent_state_dict, sh_ten_factories = load_preprocess(
merge(common_state_dict, sharded_objects) sharded_state_dict
saved_config = maybe_load_config(checkpoint_dir)
if saved_config is None:
raise CheckpointingException(f'{checkpoint_dir} is not a distributed checkpoint')
sh_ten_factories, _ = extract_matching_values(
sharded_state_dict, lambda x: isinstance(x, ShardedTensorFactory)
) )
apply_factories(sharded_state_dict)
sharded_state_dict, _ = extract_sharded_tensors_or_nonpersistent(sharded_state_dict)
sharded_state_dict, nonpersistent_state_dict = extract_sharded_tensors(sharded_state_dict)
dict_list_map_inplace(lambda o: o.unwrap(), nonpersistent_state_dict)
merge(common_state_dict, nonpersistent_state_dict) merge(common_state_dict, nonpersistent_state_dict)
validate_sharding_integrity(nested_values(sharded_state_dict)) # At this point we are only dealing with ShardedBase objects
sharded_state_dict, _ = extract_sharded_base(sharded_state_dict)
if sharded_strategy is None: # Validation
sharded_strategy = get_default_strategy( ckpt_sharded_metadata = None
StrategyAction.LOAD_SHARDED, local_metadata, global_metadata = None, None
saved_config.sharded_backend, strict = parse_strict_flag(strict)
saved_config.sharded_backend_version, if StrictHandling.requires_explicit_ckpt_mismatch_check(strict):
ckpt_sharded_metadata = load_sharded_metadata(
str(checkpoint_dir), sharded_strategy, common_strategy
) )
else: if validate_access_integrity or StrictHandling.requires_global_app_metadata(strict):
# TODO: implement consistency checks here local_metadata, global_metadata = determine_global_metadata(sharded_state_dict)
pass
loaded_state_dict = sharded_strategy.load(sharded_state_dict, checkpoint_dir) sharded_state_dict, missing_keys, unexpected_keys = validate_integrity_and_strict_load(
sharded_state_dict,
strict,
validate_access_integrity,
local_metadata,
global_metadata,
ckpt_sharded_metadata,
)
loaded_state_dict = apply_factory_merges(loaded_state_dict, sh_ten_factories) # ShardedBase loading
if not sharded_strategy.can_handle_sharded_objects:
validate_sharded_objects_handling(sharded_strategy, common_strategy)
sharded_objects_state_dict, sharded_state_dict = extract_matching_values(
sharded_state_dict, lambda v: isinstance(v, ShardedObject)
)
sharded_objects = common_strategy.load_sharded_objects(
sharded_objects_state_dict, checkpoint_dir
)
merge(common_state_dict, sharded_objects)
loaded_state_dict = sharded_strategy.load(sharded_state_dict, checkpoint_dir)
merge(common_state_dict, loaded_state_dict) merge(common_state_dict, loaded_state_dict)
return common_state_dict
loaded_state_dict = apply_factory_merges(common_state_dict, sh_ten_factories)
if StrictHandling.requires_returning_mismatch_keys(strict):
return common_state_dict, missing_keys, unexpected_keys
else:
return common_state_dict
def load_common_state_dict(checkpoint_dir: Path) -> StateDict:
"""Load common (non-sharded) objects state dict from the checkpoint.
Args:
checkpoint_dir (Path): checkpoint directory
Returns:
StateDict: state dict with non-sharded objects from the checkpoint
"""
sharded_strategy, common_strategy = verify_checkpoint_and_load_strategy(str(checkpoint_dir))
return common_strategy.load_common(checkpoint_dir)
def load_tensors_metadata(
checkpoint_dir: str, sharded_strategy: Union[LoadShardedStrategy, None] = None
) -> CkptShardedMetadata:
"""Load tensors metadata from the checkpoint.
Returns a dictionary similar to a sharded state dict, but note that
the dictionary keys are simply ShardedTensor keys (contrary to the
actual sharded state dicts where keys correspond to state dict keys).
Dict values are ShardedTensors without any sharding (so, the only useful
information is tensors global shape and dtype).
Concrete implementation depends on the loading strategy. If no strategy is
given, a default for a given backend is used.
Args:
checkpoint_dir (str): checkpoint directory to load from
sharded_strategy (LoadShardedStrategy, optional): sharded strategy to load metadata.
Defaults to None - in this case a default load strategy for a given checkpoint type
is used.
Returns:
CkptShardedMetadata: flat state dict without data describing ShardedTensors
in the checkpoint
"""
sharded_strategy, common_strategy = verify_checkpoint_and_load_strategy(
checkpoint_dir, sharded_strategy
)
return sharded_strategy.load_tensors_metadata(Path(checkpoint_dir))
def load_sharded_metadata(
checkpoint_dir: str,
sharded_strategy: Union[LoadShardedStrategy, None] = None,
common_strategy: Union[LoadCommonStrategy, None] = None,
) -> CkptShardedMetadata:
"""Load sharded metadata from the checkpoint.
Similar to `load_tensors_metadata`, but includes also ShardedObjects.
Returns a dictionary similar to a sharded state dict, but note that
the dictionary keys are simply ShardedTensor keys (contrary to the
actual sharded state dicts where keys correspond to state dict keys).
Dict values are ShardedTensors without any sharding (so, the only useful
information is tensors global shape and dtype).
# TODO: implement it as common torch strategy Concrete implementation depends on the loading strategy. If no strategy is
def load_common_state_dict(checkpoint_dir: Path): given, a default for a given backend is used.
return torch.load(Path(checkpoint_dir) / COMMON_STATE_FNAME, map_location='cpu')
Args:
checkpoint_dir (str): checkpoint directory to load from
sharded_strategy (LoadShardedStrategy, optional): sharded strategy to load metadata.
Defaults to None - in this case a default load strategy for a given checkpoint type
is used.
common_strategy (LoadCommonStrategy, optional): common strategy to load metadata.
Defaults to None - in this case a default load strategy for a given checkpoint type is
used. This strategy won't be used unless `sharded_strategy` can't handle ShardedObjects
def load_sharded_objects(sharded_state_dict: ShardedStateDict, checkpoint_dir: Path): Returns:
sharded_objects, sharded_state_dict = extract_matching_values( CkptShardedMetadata: flat state dict without data describing ShardedTensors
sharded_state_dict, lambda v: isinstance(v, ShardedObject) and ShardedObjects in the checkpoint
"""
sharded_strategy, common_strategy = verify_checkpoint_and_load_strategy(
checkpoint_dir, sharded_strategy, common_strategy
) )
sharded_metadata = sharded_strategy.load_sharded_metadata(Path(checkpoint_dir))
if not sharded_strategy.can_handle_sharded_objects:
validate_sharded_objects_handling(sharded_strategy, common_strategy)
common_metadata = common_strategy.load_sharded_metadata(Path(checkpoint_dir))
sharded_metadata = merge(sharded_metadata, common_metadata)
return sharded_metadata
def load_plain_tensors(checkpoint_dir: str) -> StateDict:
"""Load checkpoint tensors without any sharding and plain structure.
def load_sharded_object(sh_obj: ShardedObject): NOTE: common state dict is NOT included.
sh_obj.data = None
load_path = (checkpoint_dir / sh_obj.unique_key).with_suffix('.pt')
loaded_obj = torch.load(load_path)
return loaded_obj
return dict_list_map_inplace(load_sharded_object, sharded_objects), sharded_state_dict Args:
checkpoint_dir (str): checkpoint directory to load the tensors from.
Returns:
StateDict: checkpoint state dict containing only torch.Tensors.
"""
sharded_state_dict = load_tensors_metadata(checkpoint_dir)
# Don't validate integrity because shards will be overlapped
# if world_size > 1 (all processes load whole tensors)
return load(sharded_state_dict, checkpoint_dir, validate_access_integrity=False)
#
# def load_plain_tensors_and_objects(checkpoint_dir: str) -> StateDict:
# """Load checkpoint tensors and objects without any sharding and plain structure.
#
# NOTE: state dict structure might be different than the one used for checkpoint saving.
# NOTE: common state dict is NOT included.
#
# Args:
# checkpoint_dir (str): checkpoint directory to load the state dict from.
#
# Returns:
# StateDict: complete checkpoint state dict without any sharding.
# """
# sharded_state_dict = load_tensors_metadata(checkpoint_dir)
# # Don't validate integrity because shards will be overlapped
# # if world_size > 1 (all processes load whole tensors)
# return load(sharded_state_dict, checkpoint_dir, validate_access_integrity=False)
def remove_sharded_tensors(checkpoint_dir: str, key_prefix: str):
"""determine the appropriate sharding strategy and delegate removal to the sharded strategy"""
sharded_strategy, common_strategy = verify_checkpoint_and_load_strategy(checkpoint_dir)
sharded_strategy.remove_sharded_tensors(checkpoint_dir, key_prefix)
def save( def save(
sharded_state_dict: ShardedStateDict, sharded_state_dict: ShardedStateDict,
checkpoint_dir: str, checkpoint_dir: str,
sharded_strategy: Union[SaveShardedStrategy, None] = None, sharded_strategy: Union[SaveShardedStrategy, Tuple[str, int], None] = None,
common_strategy: Union[SaveCommonStrategy, None] = None, common_strategy: Union[SaveCommonStrategy, Tuple[str, int], None] = None,
): validate_access_integrity: bool = True,
async_sharded_save: bool = False,
preprocess_common_before_consistancy_check: Callable[[CommonStateDict], StateDict] = None,
) -> Optional[AsyncRequest]:
"""Saving entrypoint. """Saving entrypoint.
Extracts ShardedTensors from the given state dict. Rank 0 saves the Extracts ShardedTensors from the given state dict. Rank 0 saves the
...@@ -136,13 +303,46 @@ def save( ...@@ -136,13 +303,46 @@ def save(
The ShardedTensors are saved according to a strategy specified by the The ShardedTensors are saved according to a strategy specified by the
config. config.
Arguments: Steps:
sharded_state_dict: state dict of the populated with 1. Apply factories
2. Extract and discard LocalNonPersistentObject
3. Extract all ShardedBase object
4. Save all other objects to common.pt
5. (optional) Extract and save ShardedObjects
6. Save all ShardedBase objects
7. Write metadata.json file with backend and version metadata.
Step (6) can be performed asynchronously (see `async_sharded_save`), in this
case the actual save is embodied in the returned async request and can be
scheduled by the external caller. For async request, step (7) is added as
one of the finalization functions, so that metadata.json is written only
if the checkpoint is complete.
Args:
sharded_state_dict (ShardedStateDict): state dict of the populated with
ShardedTensors. Used as a mapping to determine how local tensors ShardedTensors. Used as a mapping to determine how local tensors
should be saved as global tensors in the checkpoint. should be saved as global tensors in the checkpoint.
checkpoint_dir: directory to save the checkpoint to checkpoint_dir (str): directory to save the checkpoint to
sharded_strategy: configures sharded tensors saving behavior and backend sharded_strategy (SaveShardedStrategy, Tuple[str, int], optional):
common_strategy: configures common data saving behavior and backend configures sharded tensors saving behavior and backend
common_strategy (SaveCommonStrategy, Tuple[str, int], optional):
configures common data saving behavior and backend
validate_access_integrity (bool default = True): checks if each tensor shard is accessed
exactly once (as main replica) by some process.
It also makes sure the common state dict is consistant across all ranks
async_sharded_save (bool, optional): if True, for the sharded state dict part
an async save implementation will be called, with the AsyncRequest
being returned to the caller. Note that it is the caller responsibility to
actually schedule the async save. Defaults to False.
preprocess_common_before_consistancy_check (Callable[[CommonStateDict], StateDict], None):
A callable function that will preprocess the common state dict (i.e can be used to
remove keys that we expect to be different in the state dict). The function must not
modify the original state dict
Returns:
AsyncRequest (optional): if `async_sharded_save` is True, returns
async request that should be scheduled by the caller of this function.
None otherwise.
""" """
checkpoint_dir = Path(checkpoint_dir) checkpoint_dir = Path(checkpoint_dir)
...@@ -161,174 +361,66 @@ def save( ...@@ -161,174 +361,66 @@ def save(
raise NotImplementedError('The only supported common strategy is torch') raise NotImplementedError('The only supported common strategy is torch')
if sharded_strategy is None: if sharded_strategy is None:
sharded_strategy = get_default_strategy(StrategyAction.SAVE_SHARDED, 'zarr', 1) sharded_strategy = get_default_save_sharded_strategy()
if not isinstance(sharded_strategy, SaveShardedStrategy):
apply_factories(sharded_state_dict) assert isinstance(sharded_strategy, tuple), type(sharded_strategy)
sharded_state_dict, state_dict = extract_sharded_tensors_or_nonpersistent(sharded_state_dict) sharded_strategy = get_default_strategy(StrategyAction.SAVE_SHARDED, *sharded_strategy)
sharded_state_dict, _ = extract_sharded_tensors(sharded_state_dict)
sharded_tensors = list(nested_values(sharded_state_dict)) if common_strategy is None:
validate_sharding_integrity(sharded_tensors) common_strategy = get_default_save_common_strategy()
if not isinstance(common_strategy, SaveCommonStrategy):
_save_common_dict(state_dict, checkpoint_dir, True) assert isinstance(common_strategy, tuple), type(common_strategy)
common_strategy = get_default_strategy(StrategyAction.SAVE_COMMON, *common_strategy)
sharded_strategy.save(sharded_tensors, checkpoint_dir)
save_config( sharded_state_dict, state_dict = save_preprocess(
CheckpointingConfig(sharded_strategy.backend, sharded_strategy.version), checkpoint_dir sharded_state_dict, validate_access_integrity, preprocess_common_before_consistancy_check
) )
common_strategy.save_common(state_dict, checkpoint_dir)
# TODO: implement it as common torch strategy if not sharded_strategy.can_handle_sharded_objects:
def _save_common_dict( validate_sharded_objects_handling(sharded_strategy, common_strategy)
state_dict: StateDict, checkpoint_dir: Path, validate_consistency: bool = False sharded_objects_state_dict, sharded_state_dict = extract_matching_values(
): sharded_state_dict, lambda v: isinstance(v, ShardedObject)
common_state_dict = _extract_and_save_sharded_objects(
state_dict, checkpoint_dir, validate_consistency
)
if torch.distributed.get_rank() == 0:
torch.save(common_state_dict, checkpoint_dir / COMMON_STATE_FNAME)
if validate_consistency:
# TODO: implement checking consistency with rank 0 common dict on other ranks
pass
# torch.distributed.barrier()
# if not torch.distributed.get_rank() == 0:
# rank_0_state_dict = torch.load(checkpoint_dir / COMMON_STATE_FNAME)
# print(diff(common_state_dict, rank_0_state_dict))
def _extract_and_save_sharded_objects(
state_dict: StateDict, checkpoint_dir: Path, validate_consistency: bool = False
):
sharded_objects, state_dict = extract_matching_values(
state_dict, lambda v: isinstance(v, ShardedObject)
)
sharded_objects = list(nested_values(sharded_objects))
if validate_consistency:
validate_objects_sharding_integrity(sharded_objects)
for sh_obj in sharded_objects:
if is_main_replica(sh_obj.replica_id):
save_path = (checkpoint_dir / sh_obj.unique_key).with_suffix('.pt')
os.makedirs(save_path.parent, exist_ok=True)
torch.save(sh_obj.data, save_path)
return state_dict
def validate_sharding_integrity(sharded_tensors: Iterable[ShardedTensor]):
sharding = [ten.without_data() for ten in sharded_tensors]
all_sharding = [None] * torch.distributed.get_world_size()
torch.distributed.all_gather_object(all_sharding, sharding)
if torch.distributed.get_rank() != 0:
return
key_shardings = defaultdict(list)
for rank, rank_shardings in enumerate(all_sharding):
for sharding in rank_shardings:
key_shardings[sharding.key].append((rank, sharding))
for key, shardings in key_shardings.items():
_validate_sharding_for_key(shardings)
def _validate_sharding_for_key(rank_sharding: List[Tuple[int, ShardedTensor]]):
some_rank_shard = rank_sharding[0][1]
global_shape = some_rank_shard.global_shape
local_shape = some_rank_shard.local_shape
dtype = some_rank_shard.dtype
has_flattened_range = some_rank_shard.flattened_range is not None
for rank, sharding in rank_sharding:
assert sharding.dtype == dtype, (sharding.dtype, dtype, some_rank_shard)
assert sharding.global_shape == global_shape, (
sharding.global_shape,
global_shape,
some_rank_shard,
)
assert sharding.local_shape == local_shape, (
sharding.local_shape,
local_shape,
some_rank_shard,
)
assert (sharding.flattened_range is not None) == has_flattened_range, (
(sharding.flattened_range is not None),
has_flattened_range,
some_rank_shard,
) )
common_strategy.save_sharded_objects(sharded_objects_state_dict, checkpoint_dir)
shard_access_cnt = _compute_shards_access(rank_sharding) def metadata_finalize_fn():
if has_flattened_range: if torch.distributed.get_rank() == 0:
map_reduce( save_config(
rank_sharding, CheckpointingConfig(sharded_strategy.backend, sharded_strategy.version),
lambda x: x[1].global_offset, checkpoint_dir,
lambda x: x[1],
_validate_sharding_for_key_flattened,
)
else:
if not torch.all(shard_access_cnt == 1):
logger.error(f'Invalid access pattern for {rank_sharding[0][1]}: {shard_access_cnt}')
raise CheckpointingException(f'Invalid access pattern for {rank_sharding[0][1]}')
def _compute_shards_access(rank_sharding):
def chunk_offset(sharding):
assert len(sharding.global_offset) == len(sharding.local_shape) + sharding.prepend_axis_num
return tuple(
chain(
(off for off in sharding.global_offset[: sharding.prepend_axis_num]),
(
off // sh
for off, sh in zip(
sharding.global_offset[sharding.prepend_axis_num :], sharding.local_shape
)
),
) )
) torch.distributed.barrier()
shard_access_cnt = torch.zeros( if not async_sharded_save:
rank_sharding[0][1].axis_fragmentations, dtype=torch.int, device='cpu' sharded_strategy.save(sharded_state_dict, checkpoint_dir)
) metadata_finalize_fn()
for rank, sharding in rank_sharding: return
if is_main_replica(sharding.replica_id):
shard_access_cnt[chunk_offset(sharding)] += 1 if not isinstance(sharded_strategy, AsyncSaveShardedStrategy):
# TODO: consider validating different replicas too
return shard_access_cnt
def _validate_sharding_for_key_flattened(tensors_by_shard):
all_slices = []
local_shape = tensors_by_shard[0].local_shape
for sharding in tensors_by_shard:
assert sharding.local_shape == local_shape
sharding: ShardedTensor
if not is_main_replica(sharding.replica_id):
# TODO: this checks only saving (and loading replica_id=0) consistency
continue
all_slices.append((sharding.flattened_range.start, sharding.flattened_range.stop))
starts, stops = map(np.asarray, zip(*sorted(all_slices)))
if (
starts[0] != 0
or stops[-1] != np.product(local_shape)
or not np.all(starts[1:] == stops[:-1])
):
logger.error(
f'Flattened ranges dont cover the whole shard {tensors_by_shard[0]}. Ranges: {(starts, stops)}'
)
raise CheckpointingException( raise CheckpointingException(
f'Flattened ranges dont cover the whole shard {tensors_by_shard[0]}' f'Cannot apply async_save to non-async strategy {sharded_strategy}'
) )
async_request = sharded_strategy.async_save(sharded_state_dict, checkpoint_dir)
async_request.finalize_fns.append(metadata_finalize_fn)
return async_request
def validate_objects_sharding_integrity(sharded_objects: List[ShardedObject]): def get_default_save_sharded_strategy(
""" Ensure uniqueness of saved objects. """ backend: str = 'torch_dist', version: int = 1
local_sh_objs = [sh_obj.without_data() for sh_obj in sharded_objects] ) -> SaveShardedStrategy:
all_sh_objs = [None] * torch.distributed.get_world_size() """Get default save sharded strategy."""
torch.distributed.all_gather_object(all_sh_objs, local_sh_objs) return get_default_strategy(StrategyAction.SAVE_SHARDED, backend, version)
if torch.distributed.get_rank() != 0:
return
unique_keys = [ def get_default_save_common_strategy(
sh_obj.unique_key backend: str = 'torch', version: int = 1
for sh_obj in chain.from_iterable(all_sh_objs) ) -> SaveCommonStrategy:
if is_main_replica(sh_obj.replica_id) """Get default save common strategy."""
] return get_default_strategy(StrategyAction.SAVE_COMMON, backend, version)
if len(unique_keys) != len(set(unique_keys)):
duplicates = {k: cnt for k, cnt in Counter(unique_keys).items() if cnt > 1}
logger.error(f'Duplicate ShardedObject keys and counts: {duplicates}') def get_default_load_sharded_strategy(checkpoint_dir: str) -> LoadShardedStrategy:
raise CheckpointingException(f'Duplicate ShardedObject keys: {list(duplicates.keys())}') """Get default load sharded strategy."""
return verify_checkpoint_and_load_strategy(checkpoint_dir)[0]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment