Commit 1106877d authored by jerrrrry's avatar jerrrrry
Browse files

“13.0”

parents
Pipeline #2934 failed with stages
in 0 seconds
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""
Exports:
- FaissBaseIndex: Unoptimized Faiss index wrapper
- FaissParallelAddIndex: Optimized index.add() for Faiss index.
"""
from .faiss_base import FaissBaseIndex
from .faiss_par_add import FaissParallelAddIndex
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""
This class implements a simple, un-optimized wrapper around a Faiss index, that
implements the Index interface (see ..index.py). While this class is
instantiable, it is meant to be extended with optimizations in classes that
inherit from this class (see FaissParAddIndex, for an example).
"""
import os
import numpy as np
import torch
from megatron.core.datasets.retro.config import RetroPreprocessingConfig
from megatron.core.datasets.retro.index.index import Index
from megatron.core.datasets.retro.index.utils import (
get_training_data_merged_path,
num_samples_to_block_ranges,
)
from megatron.core.datasets.retro.utils import GPTToTextDataset, log_retro_rank_0
try:
import faiss
HAVE_FAISS = True
except ImportError:
HAVE_FAISS = False
try:
from tqdm import tqdm
HAVE_TQDM = True
except ImportError:
HAVE_TQDM = False
class FaissBaseIndex(Index):
"""Base class for Faiss-base indexes.
This class wraps a Faiss index, and adds additional functionality for training
and adding codes. This base class performs a naive sequential code adding,
while the optimized FaissParallelAddIndex class performs a parallel
index.add().
"""
def _train(self, config: RetroPreprocessingConfig) -> None:
"""Train index (rank 0's method).
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
"""
if not HAVE_FAISS:
raise ImportError(
"faiss is required to use the FaissBaseIndex class. Please install faiss."
)
assert torch.distributed.get_rank() == 0
# Set num threads (torch.distributed reset it to 1).
faiss.omp_set_num_threads(64)
empty_index_path = self.get_empty_index_path(config)
# Index already exists? -> return.
if os.path.isfile(empty_index_path):
return
# Load data.
merged_path = get_training_data_merged_path(config)
inp = np.memmap(merged_path, dtype="f4", mode="r").reshape((-1, config.hidden_size))
# Init index.
index = faiss.index_factory(config.hidden_size, config.retro_index_str)
# Move to GPU.
log_retro_rank_0("> move faiss index to gpu.")
index_ivf = faiss.extract_index_ivf(index)
clustering_index = faiss.index_cpu_to_all_gpus(faiss.IndexFlatL2(index_ivf.d))
index_ivf.clustering_index = clustering_index
log_retro_rank_0("> finished moving to gpu.")
self.make_object_verbose(index, True)
self.make_object_verbose(index_ivf, True)
self.make_object_verbose(index_ivf.quantizer, True)
self.make_object_verbose(index_ivf.clustering_index, True)
# Train index.
index.train(inp)
# Save index.
faiss.write_index(index, empty_index_path)
def train(self, config: RetroPreprocessingConfig) -> None:
"""Train index.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
"""
# Single process only.
if torch.distributed.get_rank() == 0:
self._train(config)
torch.distributed.barrier()
def _add(self, config: RetroPreprocessingConfig, text_dataset: GPTToTextDataset) -> None:
"""Add to index (rank 0's method).
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
text_dataset (GPTToTextDataset): Text dataset that will be embedded
and added to the index.
"""
if not HAVE_FAISS:
raise ImportError(
"faiss is required to use the FaissBaseIndex class. Please install faiss."
)
if not HAVE_TQDM:
raise ImportError(
"tqdm is required to use the FaissBaseIndex class. Please install tqdm."
)
assert torch.distributed.get_rank() == 0
dataset_sample_ranges = num_samples_to_block_ranges(len(text_dataset))
# Set num threads (torch.distributed reset it to 1).
faiss.omp_set_num_threads(64)
# Bert embedder.
embedder = config.bert_embedders.mem
# Empty/added index paths.
empty_index_path = self.get_empty_index_path()
added_index_path = self.get_added_index_path()
# Skip adding, if index exists.
if os.path.isfile(added_index_path):
return
# Read trained index.
index = faiss.read_index(empty_index_path)
# Iterate data blocks & add.
for sample_range in tqdm(dataset_sample_ranges, "faiss_base.add"):
# Embed text.
embeds = self.embed_text_dataset_block(embedder, text_dataset, sample_range)
# Add to index.
index.add(embeds)
# Write index.
faiss.write_index(index, added_index_path)
def add(self, config: RetroPreprocessingConfig, text_dataset: GPTToTextDataset) -> str:
"""Add to index.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
text_dataset (GPTToTextDataset): Text dataset that will be embedded
and added to the index.
Returns:
File path to the populated index.
"""
# Single process only.
if torch.distributed.get_rank() == 0:
self._add(config, text_dataset)
# Wait for rank 0.
torch.distributed.barrier()
# Get output index path, for return.
return self.get_added_index_path(config)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Multi-process & multi-node version of Faiss's index.add().
This class inherits from FaissBaseIndex, and optimizes the 'add()' method by
making it multi-node and multi-process, with bit-wise equivalence to
FaissBaseIndex. This allows 'add()' to scale out to very large datasets, since
the vast majority of the computational effort is embarrassingly parallel.
"""
import os
import shutil
from typing import Tuple
import numpy as np
import torch
from megatron.core.datasets.retro.config import Embedder, RetroPreprocessingConfig
from megatron.core.datasets.retro.index.utils import get_added_code_paths, get_added_codes_dir
from megatron.core.datasets.retro.utils import (
GPTToTextDataset,
get_blocks_by_rank,
log_retro_rank_0,
retro_makedir,
)
from .faiss_base import FaissBaseIndex
try:
import psutil
HAVE_PSUTIL = True
except ImportError:
HAVE_PSUTIL = False
try:
from tqdm import tqdm
HAVE_TQDM = True
except ImportError:
HAVE_TQDM = False
try:
import h5py
HAVE_H5PY = True
except ImportError:
HAVE_H5PY = False
try:
import faiss
HAVE_FAISS = True
except ImportError:
HAVE_FAISS = False
class FaissParallelAddIndex(FaissBaseIndex):
"""
This class parallelizes both 1) encoding vectors, and 2) adding codes to the
index. This class is more performant than naive use of Faiss, because most
of the computational work is in encoding the vectors, which is an
embarassingly parallel operation.
"""
def encode_block(
self, index: "faiss.Index", embedder: Embedder, text_dataset: GPTToTextDataset, block: dict
) -> Tuple[np.ndarray, np.ndarray]:
"""Encode sub-dataset block, to be later added to index.
Encode the data subset, generally in blocks of 1M vectors each. For
each block, the empty/trained index is loaded, codes are computed
via index.sa_encode(), and the resulting codes are saved to disk.
Args:
index (faiss.Index): Faiss index object.
embedder (Embedder): Embedder used to embed text dataset.
text_dataset (GPTToTextDataset): Text dataset to be embedded and encoded.
block (dict): Range information specifying start/end indices within text dataset.
Returns:
A tuple of (embeddings, encodings) for the given block subset of the text dataset.
"""
# Embed block.
embeddings = self.embed_text_dataset_block(embedder, text_dataset, block["range"])
# Encode block.
log_retro_rank_0("encode.")
codes = index.sa_encode(embeddings)
# Return embeddings for validation purposes.
return embeddings, codes
def save_block(self, config: RetroPreprocessingConfig, block: dict, codes: np.ndarray) -> None:
"""Save block of codes to disk.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
block (dict): Range information specifying the start/end indices within
the encoded text dataset. Here, the 'path' item is used for writing
the encodings to storage.
codes (np.ndarray): Block of encodings to be saved to storage.
"""
# Save neighbors.
log_retro_rank_0("save codes.")
retro_makedir(config, os.path.dirname(block["path"]))
with h5py.File(block["path"], "w") as f:
f.create_dataset("data", data=codes)
def encode(self, config: RetroPreprocessingConfig, text_dataset: GPTToTextDataset) -> None:
"""Encode text dataset, to be later added to index.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
text_dataset (GPTToTextDataset): Text dataset to be encoded by the index.
"""
codes_dir = get_added_codes_dir(config)
retro_makedir(config, codes_dir)
# Index.
index = self.get_empty_index(config)
# Bert embedder.
embedder = config.retro_bert_embedders.mem
# Missing code blocks.
def validate(f: h5py.File) -> None:
"""Validation method for validating loaded encodings.
Args:
f (h5py.File): File that contains encodings.
"""
assert len(f["data"].shape) == 2
blocks = get_blocks_by_rank(
codes_dir, len(text_dataset), config.retro_block_size, validate=validate
)
# Encode each block.
for block_index, block in enumerate(blocks.missing):
if block is not None:
# Progress.
log_retro_rank_0(
"encode block %d / %d ... %s."
% (block_index, len(blocks.missing), block["path"])
)
# Encode and save.
_, codes = self.encode_block(index, embedder, text_dataset, block)
self.save_block(config, block, codes)
# Synchronize progress across all ranks. (for easier observation)
log_retro_rank_0(" > waiting for other ranks to finish block.")
torch.distributed.barrier()
def add_codes(self, config: RetroPreprocessingConfig) -> None:
"""Read codes from disk, and add them to the index.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
"""
if not HAVE_PSUTIL:
raise ImportError(
"psutil is required to use the FaissParallelAddIndex class. Please install psutil."
)
if not HAVE_TQDM:
raise ImportError(
"tqdm is required to use the FaissParallelAddIndex class. Please install tqdm."
)
if not HAVE_FAISS:
raise ImportError(
"faiss is required to use the FaissParallelAddIndex class. Please install faiss."
)
if not HAVE_H5PY:
raise ImportError(
"h5py is required to use the FaissParallelAddIndex class. Please install h5py."
)
if torch.distributed.get_rank() != 0:
return
added_index_path = self.get_added_index_path(config)
if os.path.exists(added_index_path):
return
# Index.
log_retro_rank_0("read empty index.")
index = self.get_empty_index(config)
index_ivf = faiss.extract_index_ivf(index)
# Add codes.
log_retro_rank_0("add codes.")
code_paths = get_added_code_paths(config)
pbar = tqdm(code_paths)
for code_path in pbar:
pbar.set_description(
"add codes, mem %.3f gb, %.1f%%"
% (psutil.virtual_memory()[3] / 1024**3, psutil.virtual_memory()[2])
)
with h5py.File(code_path) as f:
nload = int(config.retro_index_add_load_fraction * f["data"].shape[0])
offset = int(os.path.basename(code_path).split("-")[0])
xids = np.arange(offset, offset + nload)
codes = np.copy(f["data"][:nload])
index_ivf.add_sa_codes(codes, xids)
# Update index's ntotal.
index.ntotal = index_ivf.ntotal
# Write index.
log_retro_rank_0("write added index.")
faiss.write_index(index, added_index_path)
def remove_codes(self, config: RetroPreprocessingConfig) -> None:
"""Remove added codes after adding to index.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
"""
if torch.distributed.get_rank() != 0:
return
assert os.path.isfile(self.get_added_index_path(config))
if config.retro_index_delete_added_codes:
raise Exception("remove?")
shutil.rmtree(get_added_codes_dir(config), ignore_errors=True)
def add(self, config: RetroPreprocessingConfig, text_dataset: GPTToTextDataset) -> None:
"""Add vectors to index.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
text_dataset (GPTToTextDataset): Text dataset that will be embedded
and added to the index.
"""
# Encode chunks.
self.encode(config, text_dataset)
# Add codes to index.
self.add_codes(config)
# Wait for (single-process) adding to complete.
torch.distributed.barrier()
# Remove codes.
self.remove_codes(config)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Utilities for building an index."""
import glob
import os
from typing import List, Tuple
from megatron.core.datasets.retro.config import RetroPreprocessingConfig
from megatron.core.datasets.retro.utils import retro_makedir
def get_index_dir(config: RetroPreprocessingConfig) -> str:
"""Create sub-directory for this index.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
Returns:
Path to index sub-directory within Retro project.
"""
# Directory path.
index_dir_path = os.path.join(
config.retro_project_dir, "index", config.retro_index_type, config.retro_index_str
)
# Make directory.
retro_makedir(config, index_dir_path)
return index_dir_path
def num_samples_to_block_ranges(
config: RetroPreprocessingConfig, num_samples: int
) -> List[Tuple[int, int]]:
"""Split a range (length num_samples) into sequence of block ranges
of size block_size.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
num_samples (int): Split `num_samples` into consecutive block ranges, where each block is size `config.retro_block_size`.
Returns:
A list of tuples where each item is the (start, end) index for a given block.
"""
block_size = config.retro_block_size
start_idxs = list(range(0, num_samples, block_size))
end_idxs = [min(num_samples, s + block_size) for s in start_idxs]
ranges = list(zip(start_idxs, end_idxs))
return ranges
def get_training_data_root_dir(config: RetroPreprocessingConfig) -> str:
"""Get root directory for embeddings (blocks and merged data).
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
Returns:
Path to the training data directory, which contains both training embedding blocks and the final merged training embeddings.
"""
return os.path.join(config.retro_project_dir, "index", "train_emb")
def get_training_data_block_dir(config: RetroPreprocessingConfig) -> str:
"""Get directory for of saved embedding blocks.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
Returns:
Path to the directory containing the training embedding blocks, which will be later merged into a single embedding array.
"""
return os.path.join(get_training_data_root_dir(config), "blocks")
def get_training_data_block_paths(config: RetroPreprocessingConfig) -> List[str]:
"""Get paths to saved embedding blocks.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
Returns:
Paths of all training embedding blocks.
"""
return sorted(glob.glob(get_training_data_block_dir(config) + "/*.hdf5"))
def get_training_data_merged_path(config: RetroPreprocessingConfig) -> str:
"""Get path to merged training embeddings.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
Returns:
Path to the merged training embedding binary file.
"""
return os.path.join(
get_training_data_root_dir(config),
"train_%.3f.bin" % config.retro_index_train_load_fraction,
)
def get_added_codes_dir(config: RetroPreprocessingConfig) -> str:
"""Get directory of saved encodings.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
Returns:
Path to the directory containing the vector encodings for adding to the index.
"""
return os.path.join(get_index_dir(config), "add_codes")
def get_added_code_paths(config: RetroPreprocessingConfig) -> List[str]:
"""Get paths to all saved encodings.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
Returns:
Paths of all vector encoding blocks, for adding to the index.
"""
return sorted(glob.glob(get_added_codes_dir(config) + "/*.hdf5"))
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Validate an index's data.
This module contains functionality for checking for bitwise equality across code
changes. The training and adding steps of index construction can be validated
separately. The following high-level checks are supported:
- Training: Validate that saved training embeddings are bitwise equal with a
sample set of freshly computed embeddings. (*Note*:
`--no-retro-index-delete-training-embeddings` must be used.)
- Adding: Validate that the saved encodings are bitwise equal with a sample of
sample set of freshly computed encodings. (*Note*:
`--no-retro-index-delete-added-codes` must be used.)
"""
import numpy as np
import torch
from torch.utils.data import Subset
from megatron.core.datasets.retro.config import RetroPreprocessingConfig
from megatron.core.datasets.retro.utils import get_blocks_by_rank, log_retro_rank_0
from .build import get_text_dataset_for_adding, get_text_dataset_for_training
from .factory import IndexFactory
from .utils import get_added_codes_dir, get_training_data_block_dir
try:
import h5py
HAVE_H5PY = True
except ImportError:
HAVE_H5PY = False
##################################################
# Validate trained index.
##################################################
def validate_training_embeddings(config: RetroPreprocessingConfig) -> None:
"""Validate training embeddings.
Steps:
- Randomly sample subset of text dataset blocks.
- Embed each block.
- Compare against saved embeddings.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
"""
if not HAVE_H5PY:
raise ImportError(
"h5py is required to use the validate_training_embeddings function. "
"Please install h5py."
)
# Training text dataset.
text_dataset = get_text_dataset_for_training(config)
# Sample existing blocks.
blocks = get_blocks_by_rank(
dirname=get_training_data_block_dir(config),
n_samples=len(text_dataset),
block_size=config.retro_block_size,
validate=None,
sample=config.retro_task_validate,
)
assert blocks.n_missing_world == 0
# Embed & validate blocks.
embedder = config.retro_bert_embedders.mem
for block_idx, block in enumerate(blocks.existing):
# Missing block lists are extended with None to have equal-length
# lists. Skip the Nones.
if block is not None:
# Progress. (*note*: move world progress to here.)
log_retro_rank_0(
"embed training block %d / %d ... %s."
% (block_idx, len(blocks.existing), block["path"])
)
# Load existing block embeddings.
with h5py.File(block["path"]) as f:
existing_embeddings = np.copy(f["data"])
# Embed block.
sub_dataset = Subset(text_dataset, range(*block["range"]))
embeddings = embedder.embed_text_dataset(sub_dataset, "train")
# Check equality.
log_retro_rank_0(" > validate.")
assert np.array_equal(existing_embeddings, embeddings)
# Synchronize progress across all ranks. (for easier observation)
log_retro_rank_0(" > waiting for other ranks to finish block.")
torch.distributed.barrier()
log_retro_rank_0(" > finished validating training embeddings.")
##################################################
# Validate filled index.
##################################################
def validate_added_encodings(config: RetroPreprocessingConfig) -> None:
"""Validate added encodings.
Steps:
- Randomly sample subset of text dataset blocks.
- Encode each block.
- Compare against saved encodings.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
"""
# Index.
index = IndexFactory.get_index(config.retro_index_type)
inner_index = index.get_empty_index(config)
# Text dataset.
text_dataset = get_text_dataset_for_adding(config)
# Sample existing blocks.
def validate(f: h5py.File) -> None:
"""Validation method for validating encoding blocks.
Args:
f (h5py.File): File with block of encodings.
"""
assert len(f["data"].shape) == 2
blocks = get_blocks_by_rank(
dirname=get_added_codes_dir(config),
n_samples=len(text_dataset),
block_size=config.retro_block_size,
validate=validate,
sample=config.retro_task_validate,
)
assert blocks.n_missing_world == 0
# Encode and validate blocks.
embedder = config.retro_bert_embedders.mem
for block_idx, block in enumerate(blocks.existing):
if block is not None:
# Progress.
log_retro_rank_0(
"encode block %d / %d ... %s." % (block_idx, len(blocks.existing), block["path"])
)
# Load existing codes.
with h5py.File(block["path"]) as f:
existing_codes = np.copy(f["data"])
# Encode block.
embeddings, codes = index.encode_block(inner_index, embedder, text_dataset, block)
# Check equality.
log_retro_rank_0(" > validate.")
assert np.array_equal(existing_codes, codes)
# Synchronize progress across all ranks. (for easier observation)
log_retro_rank_0(" > waiting for other ranks to finish block.")
torch.distributed.barrier()
log_retro_rank_0(" > finished validating added encodings.")
##################################################
# Validate index (trained + filled).
##################################################
def validate_index(config: RetroPreprocessingConfig) -> None:
"""Validate index.
Validating index involves sequentially running stages above:
- Validate trained index.
- Validate filled index.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
"""
# Validate training embeddings.
validate_training_embeddings(config)
# Validate added codes.
validate_added_encodings(config)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""
A GPTChunkDataset is a wrapper around a regular GPTDataset, that sequentially
chunks the sample tokens into `retro_chunk_length` sized smaller samples.
For example, if the GPTDataset has 100 samples and a sequence length of 2048, and
retro_chunk_length is 64, then the GPTChunkDataset will contain 100*(2048/64) =
3200 samples, each with length 64.
"""
import torch
from megatron.core.datasets.gpt_dataset import GPTDataset
from megatron.core.datasets.retro.utils import get_num_chunks_per_sample
from .utils import get_neighbor_dir
class GPTChunkDataset(torch.utils.data.Dataset):
"""Pretraining chunk dataset wraps a standard GPT dataset.
This dataset conceptually divides each sample (e.g., length 2048)
into chunks (e.g., length 64) and restructures them into a list of
chunks (e.g., length num_samples * num_chunks_per_sample).
Args:
sample_dataset (GPTDataset): Original GPT dataset, with `sequence_length` size samples.
sample_length (int): Alias for `sequence_length`.
chunk_length (int): Retro chunk length (e.g., 64).
"""
def __init__(self, sample_dataset: GPTDataset, sample_length: int, chunk_length: int):
super().__init__()
self.sample_dataset = sample_dataset
self.chunk_length = chunk_length
self.n_chunks_per_sample = get_num_chunks_per_sample(sample_length, chunk_length)
self.n_samples = len(sample_dataset)
self.n_chunks = self.n_samples * self.n_chunks_per_sample
def __len__(self) -> int:
"""Get dataset length.
Returns:
Dataset length.
"""
return self.n_chunks
def __getitem__(self, idx: int) -> dict:
"""Get sample, including represented document IDs.
Args:
idx (int): Sample index.
Returns:
A sample, which contains both the chunk-length token sample ('text') along with all document_ids ('doc_ids') contained withing the full `sequence_length` sample.
"""
# Convert global chunk index to global sample index & local chunk index.
sample_idx = idx // self.n_chunks_per_sample
chunk_idx = idx % self.n_chunks_per_sample
# Extract sample data.
sample = self.sample_dataset[sample_idx]
sample_token_ids = sample["text"]
sample_doc_ids = sample["document_ids"]
# Chunk start/end token idxs.
token_start_idx = chunk_idx * self.chunk_length
token_end_idx = token_start_idx + self.chunk_length
chunk_token_ids = sample_token_ids[token_start_idx:token_end_idx]
# Sample.
return {"doc_ids": sample_doc_ids, "text": chunk_token_ids}
def build_gpt_chunk_datasets_from_gpt_datasets(
project_dir: str, gpt_datasets: dict, sample_length: int, chunk_length: int
) -> dict:
"""Get train, valid, test GPT chunk datasets.
Args:
project_dir (str): Retro project dir.
gpt_datasets (dict): Mapping of 'train', 'valid', and 'test' GPT datasets (original, unchunked datasets).
sample_length (int): Alias of `sequence_length`.
chunk_length (int): Retro chunk length (e.g., 64).
Returns:
A <dict> ?
"""
# GPT chunk datasets.
chunk_datasets = {
key: (
{
"dataset": GPTChunkDataset(sample_ds, sample_length, chunk_length),
"neighbor_dir": get_neighbor_dir(project_dir, key, sample_ds),
"num_active_chunks": num_active_samples
* get_num_chunks_per_sample(sample_length, chunk_length),
}
if sample_ds
else None
)
for key, (sample_ds, num_active_samples) in gpt_datasets.items()
}
return chunk_datasets
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""A MultiSplitGPTDataset can handle multiple intersecting split strings, as well
as returning all of the document IDs of a sample."""
import logging
from dataclasses import dataclass
from typing import Dict, List
import numpy
from megatron.core.datasets.blended_megatron_dataset_config import (
convert_split_vector_to_split_matrix,
parse_and_normalize_split,
)
from megatron.core.datasets.gpt_dataset import GPTDataset, GPTDatasetConfig
from megatron.core.datasets.indexed_dataset import IndexedDataset
from megatron.core.datasets.utils import Split
from megatron.core.utils import log_single_rank
logger = logging.getLogger(__name__)
@dataclass
class MultiSplitGPTDatasetConfig(GPTDatasetConfig):
"""Configuration object for Megatron Core blended and Retro datasets.
Args:
return_document_ids (bool): Whether to return the document ids when querying the dataset.
Turn this option on during preprocessing.
split_preprocessing (str): The Retro preprocessing split string.
It follows the same pattern convention as 'split'.
Not to be used with 'blend_per_split'.
"""
return_document_ids: bool = None
split_preprocessing: str = None
def __post_init__(self) -> None:
"""Validate config attributes."""
super().__post_init__()
assert self.split is not None, "the Retro data pipeline does not support 'blend_per_split'"
assert self.return_document_ids is not None, "this attribute must be user defined"
assert self.split_preprocessing is not None, "this attribute must be user defined"
split_vector = parse_and_normalize_split(self.split)
split_preprocessing_vector = parse_and_normalize_split(self.split_preprocessing)
if not numpy.allclose(split_vector, split_preprocessing_vector):
self.split_matrix = convert_split_vector_to_split_matrix(
split_vector, split_preprocessing_vector
)
log_single_rank(
logger,
logging.WARNING,
f"split =/= split_preprocessing. Let split_matrix = {self.split_matrix}",
)
class MultiSplitGPTDataset(GPTDataset):
"""Retro's customized GPT dataset.
Args:
indexed_dataset (IndexedDataset): The IndexedDataset around which
to build the MegatronDataset.
dataset_path (str): The real path on disk to the dataset, for bookkeeping.
indexed_indices (numpy.ndarray): The set of the documents indices to expose.
num_samples (int): The number of samples to draw from the indexed dataset.
index_split (Split): The indexed_indices Split.
config (MultiSplitGPTDatasetConfig): The Retro-specific container for all
config sourced parameters.
"""
def __init__(
self,
indexed_dataset: IndexedDataset,
dataset_path: str,
indexed_indices: numpy.ndarray,
num_samples: int,
index_split: Split,
config: MultiSplitGPTDatasetConfig,
) -> None:
super().__init__(
indexed_dataset, dataset_path, indexed_indices, num_samples, index_split, config
)
def __getitem__(self, idx: int) -> Dict[str, numpy.ndarray]:
"""Get dataset sample.
Args:
idx (int): The index into the dataset.
Returns:
Dict[str, numpy.ndarray]: The text ids and (optionally)
the document ids wrapped in a dictionary.
"""
text, document_ids = self._query_document_sample_shuffle_indices(idx)
if self.config.return_document_ids:
return {"text": text, "document_ids": document_ids}
else:
return {"text": text}
@staticmethod
def _key_config_attributes() -> List[str]:
"""Add custom attributes for building unique dataset hash.
The preprocessing split used for preprocessing will constrain
the samples available for pretraining.
Returns:
List[str]: The key config attributes.
"""
return super(MultiSplitGPTDataset, MultiSplitGPTDataset)._key_config_attributes() + [
"split_preprocessing"
]
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Entry point for querying an index using a GPTChunkDataset.
Querying involves:
- Iterate all chunks in the GPTChunkDataset.
- Query index for neighbor chunk IDs (i.e., chunks from the chunk database).
- Save neighbor chunk IDs to disk, for use in building a RetroDataset sample
during pretraining.
"""
import os
import time
import typing
import numpy as np
import torch
from megatron.core.datasets.retro.config import RetroPreprocessingConfig
from megatron.core.datasets.retro.db.dataset import DBDataset
from megatron.core.datasets.retro.db.utils import (
get_merged_train_dataset as get_db_merged_train_dataset,
)
from megatron.core.datasets.retro.index.factory import IndexFactory
from megatron.core.datasets.retro.index.index import Index
from megatron.core.datasets.retro.index.utils import get_index_dir
from megatron.core.datasets.retro.query.gpt_chunk_dataset import GPTChunkDataset
from megatron.core.datasets.retro.utils import (
GPTToTextDataset,
get_blocks_by_rank,
log_retro_rank_0,
retro_makedir,
)
try:
import psutil
HAVE_PSUTIL = True
except ImportError:
HAVE_PSUTIL = False
try:
from tqdm import tqdm
HAVE_TQDM = True
except ImportError:
HAVE_TQDM = False
try:
import h5py
HAVE_H5PY = True
except ImportError:
HAVE_H5PY = False
try:
import faiss
HAVE_FAISS = True
except ImportError:
HAVE_FAISS = False
def get_index(config: RetroPreprocessingConfig, ondisk: bool = False) -> "faiss.Index":
"""Read index from disk.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
ondisk (bool): If `ondisk = True`, memory map the index.
(For debugging purposes only; very non-performant.)
Returns:
A Faiss index, loaded from storage.
"""
if not HAVE_FAISS:
raise ImportError(
"faiss is required to use the query_neighbors function. " "Please install faiss."
)
# Load index.
index_wrapper = IndexFactory.get_index(config.retro_index_type)
index_dir = get_index_dir(config)
added_index_path = index_wrapper.get_added_index_path(config)
if ondisk:
index = faiss.read_index(added_index_path, faiss.IO_FLAG_MMAP)
else:
index = faiss.read_index(added_index_path)
# Search parameters.
faiss.ParameterSpace().set_index_parameter(index, "efSearch", config.retro_query_ef_search)
faiss.ParameterSpace().set_index_parameter(index, "nprobe", config.retro_query_nprobe)
return index
def embed_block(
config: RetroPreprocessingConfig, gpt_dataset: GPTChunkDataset, block: dict
) -> np.ndarray:
"""Embed block of chunks.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
gpt_dataset (GPTChunkDataset): Chunk dataset to be embedded.
block (dict): Range information containing start/end indices of subset of chunk dataset.
Returns:
Embeddings array, with shape (len(block["range"]), dimension(embedder)).
"""
text_block_dataset = torch.utils.data.Subset(
GPTToTextDataset(gpt_dataset, config.retro_tokenizers.gpt), range(*block["range"])
)
return config.retro_bert_embedders.mem.embed_text_dataset(text_block_dataset)
def query_embeddings(
config: RetroPreprocessingConfig,
db_dataset: DBDataset,
index: Index,
embeddings: np.ndarray,
chunk_id_range: range,
sample_map: dict,
n_chunks_per_sample: int,
verbose: bool = True,
) -> typing.Tuple[np.ndarray, np.ndarray]:
"""Query neighbors of a block of embeddings.
Querying includes:
- Query index for neighbor chunk IDs.
- Filter chunk IDs that have the same document ID as the queried embedding.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
db_dataset (DBDataset): Dataset containing chunk database entries.
index (Index): Vector index populated with chunk database indices.
embeddings (np.ndarray): Embeddings from GPT chunk dataset.
chunk_id_range (range): Chunk ID range from GPT chunk dataset.
sample_map (dict): Mapping of sample_idx to dataset_idx and document_ids.
Used for document filtering.
n_chunks_per_sample (int): Number of chunks per sample
(e.g., sequence_length / chunk_length).
verbose (bool): Log querying progress.
Returns:
A tuple of original (unfiltered) neighbor IDs, and filtered (by document ID) neighbor IDs.
"""
# Query neighbor ids.
if verbose:
log_retro_rank_0("search.")
t = time.time()
assert index.ntotal > 0, "check we don't accidentally have an empty index."
_, query_neighbor_ids = index.search(embeddings, config.retro_query_num_neighbors_query)
if verbose:
log_retro_rank_0(" time : %.3f sec." % (time.time() - t))
# Filter banned neighbor ids.
if verbose:
log_retro_rank_0("filter banned neighbor ids.")
filtered_neighbor_ids = np.full(
shape=(len(query_neighbor_ids), config.retro_query_num_neighbors_save),
fill_value=-1,
dtype="int64",
)
min_chunk_id, max_chunk_id = chunk_id_range
for chunk_id in range(min_chunk_id, max_chunk_id):
sample_id = chunk_id // n_chunks_per_sample
sample = sample_map[sample_id]
sample_dataset_idx = sample["dataset_idx"].item()
sample_doc_ids = sample["doc_ids"].tolist()
sample_doc_tuples = [(sample_dataset_idx, d) for d in sample_doc_ids]
# Get valid neighbors (!= -1).
query_row = [i for i in query_neighbor_ids[chunk_id - min_chunk_id] if i >= 0]
# Filter row.
filtered_row = [
i
for i in query_row
if tuple(db_dataset.doc_tuples[i].tolist()) not in sample_doc_tuples
]
filtered_row = filtered_row[: config.retro_query_num_neighbors_save]
filtered_row += [-1] * (config.retro_query_num_neighbors_save - len(filtered_row))
filtered_neighbor_ids[chunk_id - min_chunk_id] = filtered_row
return query_neighbor_ids, filtered_neighbor_ids
def query_embedding_block(
config: RetroPreprocessingConfig,
db_dataset: DBDataset,
index: Index,
embeddings: np.ndarray,
chunk_id_range: range,
sample_map: dict,
n_chunks_per_sample: int,
) -> typing.Tuple[np.ndarray, np.ndarray]:
"""Query a block of embeddings.
The block is broken into smaller sub-blocks, for easier tracking of progress.
Both the raw neighbor IDs and the filtered neighbor IDs (i.e., chunks with the
same document ID are removed) are collected.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
db_dataset (DBDataset): Dataset containing chunk database entries.
index (Index): Vector index populated with chunk database indices.
embeddings (np.ndarray): Embeddings from GPT chunk dataset.
chunk_id_range (range): Chunk ID range from GPT chunk dataset.
sample_map (dict): Mapping of sample_idx to dataset_idx and document_ids.
Used for document filtering.
n_chunks_per_sample (int): Number of chunks per sample
(e.g., sequence_length / chunk_length).
Returns:
A tuple of original (unfiltered) neighbor IDs, and filtered (by document ID) neighbor IDs.
"""
if not HAVE_TQDM:
raise ImportError(
"tqdm is required to use the query_embeddings function. Please install tqdm."
)
query_neighbor_ids = []
filtered_neighbor_ids = []
# Query in sub-blocks.
partial_block_size = 1000
for partial_start_idx in tqdm(
range(0, len(embeddings), partial_block_size),
" search",
miniters=(len(embeddings) // partial_block_size) // 10,
disable=torch.distributed.get_rank() != 0,
):
partial_end_idx = min(len(embeddings), partial_start_idx + partial_block_size)
partial_embeddings = embeddings[partial_start_idx:partial_end_idx]
partial_chunk_id_range = (
chunk_id_range[0] + partial_start_idx,
chunk_id_range[0] + partial_end_idx,
)
partial_query_neighbor_ids, partial_filtered_neighbor_ids = query_embeddings(
config,
db_dataset,
index,
partial_embeddings,
partial_chunk_id_range,
sample_map,
n_chunks_per_sample,
verbose=False,
)
query_neighbor_ids.append(partial_query_neighbor_ids)
filtered_neighbor_ids.append(partial_filtered_neighbor_ids)
# Concatenate.
query_neighbor_ids = np.concatenate(query_neighbor_ids, axis=0)
filtered_neighbor_ids = np.concatenate(filtered_neighbor_ids, axis=0)
return query_neighbor_ids, filtered_neighbor_ids
def query_block_neighbors(
config: RetroPreprocessingConfig,
db_dataset: DBDataset,
query_dataset: GPTChunkDataset,
index: Index,
block: dict,
) -> None:
"""Query neighbors of a dataset block (i.e., range).
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
db_dataset (DBDataset): Dataset containing chunk database entries.
query_dataset (GPTChunkDataset): GPT chunk dataset to be queried.
index (Index): Vector index populated with chunk database indices.
block (dict): Range information containing start/end indices
for querying GPT chunk dataset.
"""
if not HAVE_H5PY:
raise ImportError(
"h5py is required to use the query_block_neighbors function. Please install h5py."
)
n_chunks_per_sample = query_dataset.n_chunks_per_sample
# Sample map.
sample_ids = sorted(
list(set(chunk_id // n_chunks_per_sample for chunk_id in range(*block["range"])))
)
sample_map = {}
for i in sample_ids:
sample = query_dataset.sample_dataset[i]
sample_map[i] = {"dataset_idx": sample["dataset_id"], "doc_ids": sample["document_ids"]}
# Embed block.
embeddings = embed_block(config, query_dataset, block)
# Query embeddings.
_, filtered_neighbor_ids = query_embedding_block(
config, db_dataset, index, embeddings, block["range"], sample_map, n_chunks_per_sample
)
if config.retro_task_validate is None:
# Save neighbors.
log_retro_rank_0("save neighbors.")
retro_makedir(config, os.path.dirname(block["path"]))
f = h5py.File(block["path"], "w")
f.create_dataset("neighbors", data=filtered_neighbor_ids)
f.close()
else:
# Validate neighbors.
with h5py.File(block["path"]) as f:
existing_neighbor_ids = np.copy(f["neighbors"])
assert np.array_equal(existing_neighbor_ids, filtered_neighbor_ids)
def query_dataset_neighbors(
config: RetroPreprocessingConfig,
db_dataset: DBDataset,
query_dataset: GPTChunkDataset,
num_active_chunks: int,
prefix: str,
neighbor_dir: str,
index: Index,
) -> None:
"""Query neighbors of each chunk within a dataset.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
db_dataset (DBDataset): Dataset containing chunk database entries.
query_dataset (GPTChunkDataset): GPT chunk dataset to be queried.
num_active_chunks (int): The 'active' chunks are the subset of the GPT chunk dataset
that aren't being queried. This argument is used when validating the correctness
of a subset of the GPT chunk dataset.
prefix (str): Extra string for logging progress.
neighbor_dir (str): File path to directory for saving neighbor IDs.
index (Index): Vector index populated with chunk database indices.
"""
if not HAVE_H5PY:
raise ImportError(
"h5py is required to use the query_dataset_neighbors function. Please install h5py."
)
def validate(f: h5py.File) -> None:
"""Validation method for validating saved neighbor IDs.
Args:
f (h5py.File): File containing save neighbor IDs.
"""
assert (
f["neighbors"].shape[1] == config.retro_query_num_neighbors_save
), "neighbors.shape == %s; num_neighbors_target == %d." % (
str(f["neighbors"].shape),
config.retro_num_neighbors_target,
)
if config.retro_task_validate is None:
retro_makedir(config, neighbor_dir)
blocks = get_blocks_by_rank(
neighbor_dir, num_active_chunks, config.retro_block_size, validate=validate
)
active_blocks = blocks.missing
else:
blocks = get_blocks_by_rank(
neighbor_dir,
num_active_chunks,
config.retro_block_size,
validate=validate,
sample=config.retro_task_validate,
)
assert blocks.n_missing_world == 0
active_blocks = blocks.existing
if not HAVE_PSUTIL:
raise ImportError(
"psutil is required to use the query_dataset_neighbors function. Please install psutil."
)
# Query each block.
for block_index, block in enumerate(active_blocks):
if block is not None:
# Progress.
log_retro_rank_0(
"%squery '%s' block %d / %d ... %s ... mem %.3f gb, %.1f%%."
% (
"" if config.retro_task_validate is None else "[validate] ",
prefix,
block_index,
len(active_blocks),
os.path.basename(block["path"]),
psutil.virtual_memory()[3] / 1024**3,
psutil.virtual_memory()[2],
)
)
# Query block neighbors.
query_block_neighbors(config, db_dataset, query_dataset, index, block)
# Synchronize progress across all ranks. (for easier observation)
log_retro_rank_0(" > waiting for other ranks to finish block.")
torch.distributed.barrier()
def query_neighbors(config: RetroPreprocessingConfig) -> None:
"""Query pretraining datasets (train & valid).
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
"""
if not HAVE_FAISS:
raise ImportError(
"faiss is required to use the query_neighbors function. Please install faiss."
)
# Num threads.
faiss.omp_set_num_threads(64)
# Load chunk db dataset.
log_retro_rank_0("load chunk db dataset.")
db_dataset = get_db_merged_train_dataset(
project_dir=config.retro_project_dir,
chunk_length=config.retro_gpt_chunk_length,
eod_token_id=config.retro_tokenizers.gpt.eod,
)
db_dataset.load_doc_tuples()
# Load index.
log_retro_rank_0(" > get index.")
index = get_index(config)
# Query each (i.e., train, valid, test) dataset.
log_retro_rank_0(" > query.")
for prefix, info in vars(config.retro_gpt_chunk_datasets).items():
if info is None:
continue
log_retro_rank_0(
" > query '%s' dataset ... %d samples." % (prefix, info["num_active_chunks"])
)
query_dataset_neighbors(
config,
db_dataset,
info["dataset"],
info["num_active_chunks"],
prefix,
info["neighbor_dir"],
index,
)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""
A RetroDataset wraps both:
- A GPTDataset (which is nested as GPTChunkDataset -> MultiSplitGPTDataset ->
GPTDataset).
- Neighbor IDs of chunks in the chunk database, that were saved during
preprocessing.
Both the GPT sample data and the neighbor IDs are returned within a sample from
this dataset.
"""
import os
from typing import Dict, Optional, Tuple
import numpy as np
import torch
from megatron.core.datasets.retro.db.dataset import DBDataset
from megatron.core.datasets.retro.db.utils import get_merged_train_dataset as get_db_dataset
from megatron.core.datasets.retro.utils import BlockPathMap, log_retro_rank_0
from megatron.core.models.retro import RetroConfig
from .gpt_chunk_dataset import GPTChunkDataset, build_gpt_chunk_datasets_from_gpt_datasets
from .utils import get_query_dir
try:
import h5py
HAVE_H5PY = True
except ImportError:
HAVE_H5PY = False
class RetroDataset(torch.utils.data.Dataset):
"""Dataset of retro samples.
Each sample contains the original GPT sample, along with the token IDs
of each neighbor of each chunk within the sequence. Neighbor array has
shape (num_chunks_per_sample, num_neighbors, num_retrieved_tokens).
** Note: chunk dataset wraps original GPT dataset (see gpt_chunk_dataset.py).
Args:
num_queried_samples (int): Total number of queried samples.
num_neighbors (int): Total number of saved neighbors.
num_retrieved_chunks (int): Number of retrieved chunks
(e.g., 2 for neighbor + continuation).
block_size (int): Number of neighbor entries per file.
db_dataset (DBDataset): Chunk database used for retrieval.
chunk_dataset (GPTChunkDataset): GPT chunk dataset, which is a wrapper
around a standard GPT dataset that breaks each sample into chunks.
neighbor_path_map (BlockPathMap): Mapping of neighbor ID to file path.
"""
def __init__(
self,
num_queried_samples: int,
num_neighbors: int,
num_retrieved_chunks: int,
block_size: int,
db_dataset: DBDataset,
chunk_dataset: GPTChunkDataset,
neighbor_path_map: BlockPathMap,
):
super().__init__()
self.num_queried_samples = num_queried_samples
self.num_neighbors = num_neighbors
self.num_retrieved_chunks = num_retrieved_chunks
self.block_size = block_size
self.db_dataset = db_dataset
self.chunk_dataset = chunk_dataset
self.neighbor_path_map = neighbor_path_map
def __len__(self) -> int:
"""Dataset length.
Returns:
Number of samples in dataset.
"""
return len(self.chunk_dataset.sample_dataset)
def __getitem__(self, sample_idx: int) -> dict:
"""Get dataset sample.
Args:
sample_idx (int): Index of sample in dataset.
Returns:
A dict consisting of GPT sample (attribute 'text') and corresponding neighbor chunk IDs
('neighbor_chunks', for indexing chunk database) and neighbor token IDs
(corresponding chunk database GPT tokens).
"""
if not HAVE_H5PY:
raise ImportError("h5py is required to use the RetroDataset. Please install h5py.")
n_chunks_per_sample = self.chunk_dataset.n_chunks_per_sample
# Wrap sample idx around number of queried samples.
sample_idx = sample_idx % self.num_queried_samples
# Get standard sample.
sample = self.chunk_dataset.sample_dataset[sample_idx]
# Sample idx to chunk idxs.
chunk_idxs = list(
range(sample_idx * n_chunks_per_sample, (sample_idx + 1) * n_chunks_per_sample)
)
# Collect retrieved tokens.
all_retrieved_chunk_ids = []
all_retrieved_token_ids = []
for chunk_idx in chunk_idxs:
# Neighbor chunk ids.
neighbor_path = self.neighbor_path_map[chunk_idx]
with h5py.File(neighbor_path, "r") as f:
neighbor_chunk_ids = f["neighbors"][
chunk_idx % self.block_size, : self.num_neighbors
].tolist()
# Retrieved (neighbor + continuation) token ids.
retrieved_chunk_ids = []
retrieved_token_ids = []
for neighbor_chunk_id in neighbor_chunk_ids:
current_chunk_ids = [
i % len(self.db_dataset)
for i in range(neighbor_chunk_id, neighbor_chunk_id + self.num_retrieved_chunks)
]
current_token_ids = [self.db_dataset[ci]["text"] for ci in current_chunk_ids]
retrieved_chunk_ids.append(current_chunk_ids)
retrieved_token_ids.append(current_token_ids)
# Collect retrieved tokens.
all_retrieved_chunk_ids.append(retrieved_chunk_ids)
all_retrieved_token_ids.append(retrieved_token_ids)
# Reshape retrieved tokens.
all_retrieved_chunk_ids = np.array(all_retrieved_chunk_ids).reshape(
(n_chunks_per_sample, self.num_neighbors, -1)
)
all_retrieved_token_ids = np.array(all_retrieved_token_ids).reshape(
(n_chunks_per_sample, self.num_neighbors, -1)
)
# Sample.
sample: Dict[str, np.ndarray] = {
**sample,
"neighbor_chunks": all_retrieved_chunk_ids,
"neighbor_tokens": all_retrieved_token_ids,
}
return sample
def get_retro_datasets(
config: RetroConfig, gpt_datasets: dict, sample_length: int, eod_token_id: int
) -> Tuple[Optional[RetroDataset], Optional[RetroDataset], Optional[RetroDataset]]:
"""Get train, valid, test retro datasets.
Args:
config (RetroConfig): Retro preprocessing config.
gpt_datasets (dict): Mapping of data split key
('train', 'valid', or 'test') to the original sequence-length
GPT dataset (i.e., not the chunk dataset).
sample_length (int): Alias to `sequence_length`.
eod_token_id (int): GPT EOD token ID.
Returns:
A tuple of 'train', 'valid', and 'test' `RetroDataset`s.
"""
# DB dataset.
db_dataset = get_db_dataset(
project_dir=config.retro_project_dir,
chunk_length=config.retro_chunk_length,
eod_token_id=eod_token_id,
)
# GPT chunk datasets.
chunk_ds_info_map = build_gpt_chunk_datasets_from_gpt_datasets(
project_dir=config.retro_project_dir,
gpt_datasets=gpt_datasets,
sample_length=sample_length,
chunk_length=config.retro_chunk_length,
)
# Retro datasets.
retro_dataset_map: Dict[str, Optional[RetroDataset]] = {}
query_dir = get_query_dir(config.retro_project_dir)
for data_key, chunk_ds_info in chunk_ds_info_map.items():
# Skip unused datasets.
if chunk_ds_info is None:
retro_dataset_map[data_key] = None
continue
# For consistency with preprocessing, the neighbor_dir is overwritten
# (from its setting in `build_gpt_chunk_datasets_from_gpt_datasets()`
# above). This is one piece -- along with setting data_path and
# train_samples from config.json -- of ensuring consistency between
# preprocessing and pretraining.
chunk_dataset = chunk_ds_info["dataset"]
chunk_ds_info["neighbor_dir"] = os.path.join(
query_dir, config.retro_neighbor_dirs[data_key]
)
neighbor_dir = chunk_ds_info["neighbor_dir"]
neighbor_path_map = BlockPathMap.from_dir(
dir=neighbor_dir, block_size=config.retro_block_size
)
# Verify num chunks.
n_active_chunks = chunk_ds_info["num_active_chunks"]
n_neighbor_chunks = neighbor_path_map.max_idx
if not os.path.isdir(neighbor_dir):
if torch.distributed.get_rank() == 0:
raise Exception(
"neighbor directory '%s' not found; please "
"compare --train-samples, --seq-length, --seed, "
"--eval-iters, and --eval-interval, with "
"retro preprocessing args." % neighbor_dir
)
torch.distributed.barrier()
exit()
if config.retro_verify_neighbor_count and n_active_chunks != n_neighbor_chunks:
if torch.distributed.get_rank() == 0:
log_retro_rank_0("neighbor_dir : %s" % neighbor_dir)
log_retro_rank_0("neighbor_path_map : %s" % neighbor_path_map)
raise Exception(
"num sampled chunks (%d) != num neighbor chunks "
"(%d); did you complete querying the entire "
"pretraining dataset?" % (n_active_chunks, n_neighbor_chunks)
)
torch.distributed.barrier()
exit()
# Retro dataset.
retro_dataset_map[data_key] = RetroDataset(
num_queried_samples=gpt_datasets[data_key][1],
num_neighbors=config.retro_num_neighbors,
num_retrieved_chunks=config.retro_num_retrieved_chunks,
block_size=config.retro_block_size,
db_dataset=db_dataset,
chunk_dataset=chunk_dataset,
neighbor_path_map=neighbor_path_map,
)
return (retro_dataset_map["train"], retro_dataset_map["valid"], retro_dataset_map["test"])
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Utilities for querying the pretraining dataset."""
import os
from megatron.core.datasets.megatron_dataset import MegatronDataset
def get_query_dir(project_dir: str) -> str:
"""Get root directory of all saved query data.
Args:
project_dir (str): Retro project dir.
Returns:
Path to query sub-directory in Retro project.
"""
return os.path.join(project_dir, "query")
def get_neighbor_dir(project_dir: str, key: str, dataset: MegatronDataset) -> str:
"""Get directory containing neighbor IDs for a dataset (i.e., train, valid, or test).
Args:
project_dir (str): Retro project dir.
key (str): Dataset split key; 'train', 'valid', or 'test'.
dataset (MegatronDataset): Dataset containing unique hash for finding corresponding neighbors.
Returns:
Path to directory containing this dataset's neighbors within Retro project.
"""
return os.path.join(
get_query_dir(project_dir), os.path.basename(f"{key}_{dataset.unique_description_hash}")
)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Utilities for Retro preprocessing."""
import glob
import logging
import os
from types import SimpleNamespace
from typing import Any, Callable, Dict, List, Optional, Tuple, TypedDict
import numpy as np
import torch
from torch.distributed import ProcessGroup
from megatron.core import parallel_state
from megatron.core.datasets.retro.config import RetroPreprocessingConfig
from megatron.core.datasets.retro.query.multi_split_gpt_dataset import (
MultiSplitGPTDataset,
MultiSplitGPTDatasetConfig,
)
from megatron.core.utils import log_single_rank
logger = logging.getLogger(__name__)
try:
from tqdm import tqdm
HAVE_TQDM = True
except ImportError:
HAVE_TQDM = False
try:
import h5py
HAVE_H5PY = True
except ImportError:
HAVE_H5PY = False
class Block(TypedDict):
"""Specific block arg type to mute mypy."""
range: Tuple[int, int]
path: str
def log_retro_rank_0(message: str) -> None:
"""Log on rank 0.
Args:
message (str): Message to log.
"""
log_single_rank(logger, logging.INFO, "[RETRO] " + message)
def retro_makedir(config: RetroPreprocessingConfig, path: str) -> None:
"""Make a directory, conditional on not being in validation mode.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
path (str): Path to directory.
"""
if config.retro_task_validate is None:
os.makedirs(path, exist_ok=True)
def extract_data_config(config: RetroPreprocessingConfig) -> MultiSplitGPTDatasetConfig:
"""Extract data config from dataset.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
Returns:
The config object used to build the dataset.
"""
return config.retro_gpt_chunk_datasets.train["dataset"].sample_dataset.config
def get_num_chunks_per_sample(sample_length: int, chunk_length: int) -> int:
"""Compute seq_length // chunk_length.
Args:
sample_length (int): Alias of `sequence_length`.
chunk_length (int): Retro chunk length (e.g., 64).
Returns:
Number of chunks per sample (i.e., `sequence_length` / `chunk_length`).
"""
assert sample_length % chunk_length == 0
return sample_length // chunk_length
class GPTToTextDataset(torch.utils.data.Dataset):
"""Dataset to convert GPT tokens to text.
Args:
gpt_dataset (MultiSplitGPTDataset): GPT dataset, which outputs GPT token samples.
gpt_tokenizer (Any): GPT tokenizer.
"""
def __init__(self, gpt_dataset: MultiSplitGPTDataset, gpt_tokenizer: Any):
super().__init__()
self.gpt_dataset = gpt_dataset
self.gpt_tokenizer = gpt_tokenizer
def __len__(self) -> int:
"""Dataset length.
Returns:
Number of samples in the dataset.
"""
return len(self.gpt_dataset)
def __getitem__(self, idx: int) -> dict:
"""Get dataset sample.
Args:
idx (int): Index of sample.
Returns:
A dict containing attribute 'text' of type string.
"""
gpt_token_ids = self.gpt_dataset[idx]["text"].tolist()
text = self.gpt_tokenizer.detokenize(gpt_token_ids)
return {"text": text}
def get_blocks(
dirname: str, n_samples: int, block_size: int, validate: Optional[Callable] = None
) -> SimpleNamespace:
"""Divide range [0, num_samples) to sequence of block ranges.
This is a core method within the concept of block processing. The idea
is to divide a range (size n_samples) into a sequence of blocks. Each
block corresponds to a file within 'dirname' with name
'{start_idx}-{end_idx}.hdf5'. This method checks for the existence of
these files, and returns two lists, one for existing blocks and one for
missing blocks.
Args:
dirname (str): Path to directory containing block files.
n_samples (int): Ideal number of samples.
The total number of saved block data is <=n_samples.
block_size (int): Max number of samples per block file (e.g., 100000).
validate (Callable): Method for validating each block file during load.
Returns:
A namespace consisting of 2 lists: existing blocks, and missing blocks.
The total number of samples between the existing and missing blocks should
equal n_samples above.
"""
if not HAVE_TQDM:
raise ImportError("tqdm is required to use the RetroDataset. Please install tqdm.")
if not HAVE_H5PY:
raise ImportError("h5py is required to use the RetroDataset. Please install h5py.")
assert os.path.isdir(dirname), "missing directory '%s.'" % dirname
# Block ranges.
block_start_idxs = list(range(0, n_samples, block_size))
block_end_idxs = [min(n_samples, i + block_size) for i in block_start_idxs]
block_ranges = list(zip(block_start_idxs, block_end_idxs))
# All block files (existing + missing).
n_digits = int(np.ceil(np.log(n_samples) / np.log(10)) + 1)
all_blocks: List[Block] = [
{
"range": r,
"path": os.path.join(
dirname, "%s-%s.hdf5" % tuple([str(i).zfill(n_digits) for i in r])
),
}
for r in block_ranges
]
all_block_path_set = set(block["path"] for block in all_blocks)
# Validate function.
validate = (lambda f: None) if validate is None else validate
# Delete corrupt files.
if torch.distributed.get_rank() == 0:
existing_block_paths = [
block["path"] for block in all_blocks if os.path.exists(block["path"])
]
for index, path in enumerate(tqdm(existing_block_paths, "validating block.")):
assert path in all_block_path_set, "unexpected filename, '%s'." % path
try:
f = h5py.File(path, "r")
except Exception:
os.remove(path)
continue
try:
validate(f)
except Exception:
os.remove(path)
finally:
f.close()
# Wait for files to be deleted.
torch.distributed.barrier()
# Collect blocks.
blocks = SimpleNamespace(
existing=[b for b in all_blocks if os.path.exists(b["path"])],
missing=[b for b in all_blocks if not os.path.exists(b["path"])],
)
return blocks
def get_blocks_by_rank(
dirname: str,
n_samples: int,
block_size: int,
validate: Optional[Callable] = None,
sample: Optional[float] = None,
process_group: Optional[ProcessGroup] = None,
) -> SimpleNamespace:
"""Divide existing and missing blocks evenly across all ranks.
See 'get_blocks()' above for description. The returned lists of existing and
missing blocks are split evenly across ranks via interleaving. This way,
each rank has a roughly equal number of blocks to process for a
downstream operation.
Args:
dirname (str): Path to directory containing block files.
n_samples (int): Ideal number of samples. The total number of saved block data
is <=n_samples.
block_size (int): Max number of samples per block file (e.g., 100000).
validate (Callable): Method for validating each block file during load.
sample (Optional[float]): If provided, sample a random subset of the blocks.
Used for validating preprocessing correctness.
process_group (Optional[ProcessGroup]): Process group for distributed operations.
If None, uses data parallel group.
Returns:
A namespace consisting of 2 lists: existing blocks, and missing blocks.
Each of these two lists is potentially a sub-sample of the total set of
existing and missing blocks, depending on whether sampling is used.
Additionally, the attributes n_existing_world and n_missing_world are the
total number of existing and missing blocks, independent of samples.
Therefore, (n_existing_world + n_missing_world) * block_size == n_samples.
"""
if process_group is None:
process_group = parallel_state.get_data_parallel_group()
# Get world blocks.
blocks = get_blocks(dirname, n_samples, block_size, validate)
# This rank's existing and missing files.
rank_existing_blocks = blocks.existing[
process_group.rank() : len(blocks.existing) : process_group.size()
]
rank_missing_blocks = blocks.missing[
process_group.rank() : len(blocks.missing) : process_group.size()
]
# Extend rank's existing and missing blocks (with None) such that all ranks
# have equal length lists. This allows for easier tracking of global progress.
def get_world_max(n: int) -> int:
"""Get max value across ranks.
Args:
n (int): Value on this rank.
Returns:
Max value across all ranks.
"""
n_tensor = torch.cuda.LongTensor([n])
torch.distributed.all_reduce(n_tensor, op=torch.distributed.ReduceOp.MAX)
return n_tensor.item()
max_n_existing = get_world_max(len(rank_existing_blocks))
max_n_missing = get_world_max(len(rank_missing_blocks))
rank_existing_blocks += [None] * (max_n_existing - len(rank_existing_blocks))
rank_missing_blocks += [None] * (max_n_missing - len(rank_missing_blocks))
# Collect blocks.
blocks = SimpleNamespace(
n_existing_world=len(blocks.existing),
n_missing_world=len(blocks.missing),
existing=rank_existing_blocks,
missing=rank_missing_blocks,
)
if sample is not None:
# Sample existing and missing blocks evenly across all ranks. The
# returned lists of blocks are randomly sampled (without replacement)
# to yield `sample * len(blocks)` number of blocks.
# Randomly sample blocks.
def sample_blocks(_blocks: List[Optional[Dict]]) -> List[Optional[Dict]]:
"""Sample a random subset of all blocks.
Args:
_blocks (List[Optional[Dict]]): List of all blocks.
Returns:
A random subset of the blocks.
"""
n_blocks_sample = int(np.ceil(sample * len(_blocks)))
sampled_blocks: List[Optional[Dict]] = [b for b in _blocks if b is not None]
np.random.seed(None)
np.random.shuffle(sampled_blocks)
sampled_blocks = sampled_blocks[:n_blocks_sample]
sampled_blocks += [None] * (n_blocks_sample - len(sampled_blocks))
return sampled_blocks
blocks.existing = sample_blocks(blocks.existing)
blocks.missing = sample_blocks(blocks.missing)
return blocks
class BlockPathMap:
"""Map an index to its containing block path.
The common use for this class is to have a directory of files containing
blocks of processed data, of uniform block size (e.g., 100k samples per
file). Each file must follow a naming convention of 'startIdx-endIdx.[ext]',
where 'endIdx' minus 'startIdx' must equal the block size, with the possible
exception of the final block. Given an input index, this class maps the
index to the containing block file.
Args:
block_paths (List[str]): List of paths to saved block files.
block_size (int): Max number of samples per block file (e.g., 100000).
"""
@classmethod
def from_dir(cls, dir: str, block_size: int, ext: str = "hdf5") -> Any:
"""Get list of block files, and create map.
Args:
dir (str): Path to directory containing saved block files.
block_size (int): Max number of samples per block file (e.g., 100000).
ext (str): Block file extension (e.g., 'hdf5').
Returns:
A mapping of sample index to block file path.
"""
assert os.path.isdir(dir), f"directory not found, '{dir}'."
return cls(sorted(glob.glob(dir + f"/*.{ext}")), block_size)
def __init__(self, block_paths: List[str], block_size: int):
self.max_idx = 0
self.block_path_map = {}
for block_path in block_paths:
name = os.path.splitext(os.path.basename(block_path))[0]
start_idx, end_idx = [int(i) for i in name.split("-")]
self.block_path_map[start_idx] = block_path
self.max_idx = max(self.max_idx, end_idx)
self.block_size = block_size
def __str__(self) -> str:
"""Stringify the mapping.
Returns:
A string representation of this block path map.
"""
return "%d paths" % len(self.block_path_map)
def __getitem__(self, idx: int) -> str:
"""Get block path from index.
Args:
idx (int): Index of sample.
Returns:
The path to the block file containing the sample index.
"""
block_start_idx = self.block_size * (idx // self.block_size)
block_path = self.block_path_map[block_start_idx]
return block_path
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import os
from collections import deque
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Union
import numpy
import torch
from packaging.version import Version as PkgVersion
from megatron.core.datasets.indexed_dataset import IndexedDataset
from megatron.core.datasets.masked_dataset import (
MaskedWordPieceDataset,
MaskedWordPieceDatasetConfig,
)
from megatron.core.datasets.utils import Split
from megatron.core.utils import get_te_version
@dataclass
class T5MaskedWordPieceDatasetConfig(MaskedWordPieceDatasetConfig):
"""Configuration object for Megatron Core T5 WordPiece datasets
NB: As a temporary holdover from Megatron-LM. The T5 tokenizer has an attribute which defines
a number of special sentinel tokens used during sampling. The assert in __post_init__ serves to
preserve compatibility with Megatron-LM until the T5 tokenizer is in Megatron Core.
"""
sequence_length_encoder: Optional[int] = field(init=False, default=None)
"""A sequence_length alias and the sequence length for the encoder"""
sequence_length_decoder: int = None
"""The sequence length for the decoder"""
def __post_init__(self) -> None:
"""Do asserts and set fields post init"""
super().__post_init__()
self.sequence_length_encoder = self.sequence_length
assert self.sequence_length_encoder is not None
assert self.sequence_length_decoder is not None
assert len(self.tokenizer.additional_special_tokens_ids) > 0
class T5MaskedWordPieceDataset(MaskedWordPieceDataset):
"""The T5 dataset that assumes WordPiece tokenization
Args:
indexed_dataset (IndexedDataset): The IndexedDataset around
which to build the MegatronDataset
dataset_path (str): The real path on disk to the dataset, for bookkeeping
indexed_indices (numpy.ndarray): The set of the documents indices to expose
num_samples (Optional[int]): The number of samples to draw from the indexed
dataset. When None, build as many samples as correspond to one epoch.
index_split (Split): The indexed_indices Split
config (T5MaskedWordPieceDatasetConfig): The config
"""
def __init__(
self,
indexed_dataset: IndexedDataset,
dataset_path: str,
indexed_indices: numpy.ndarray,
num_samples: Optional[int],
index_split: Split,
config: T5MaskedWordPieceDatasetConfig,
) -> None:
super().__init__(
indexed_dataset, dataset_path, indexed_indices, num_samples, index_split, config
)
self.token_lookup = list(self.config.tokenizer.inv_vocab.keys())
# Account for the single <bos> and single <eos> token ids
self.sample_index = self._build_sample_index(self.config.sequence_length - 2, 1)
@staticmethod
def _key_config_attributes() -> List[str]:
"""Inherited method implementation
Returns:
List[str]: The key config attributes
"""
return super(
T5MaskedWordPieceDataset, T5MaskedWordPieceDataset
)._key_config_attributes() + ["sequence_length_decoder"]
@staticmethod
def _build_b1ss_attention_mask(
source_block: torch.tensor, target_block: torch.tensor, make_history_mask: bool = False
) -> torch.tensor:
"""Build an attention-mask having shape (bs, 1, q_len, kv_len)
from source_block and target_block
Args:
source_block (torch.tensor): A 2-D array of tokens (bs, q_len)
target_block (torch.tensor): A 2-D array of tokens (bs, kv_len)
make_history_mask (bool): Whether to turn mask into causal mask
Returns:
torch.tensor: The 4-D attention mask (bs, 1, q_len, kv_len)
"""
batch_size = source_block.shape[0]
attention_mask = []
for i in range(batch_size):
source_sample = source_block[i]
target_sample = target_block[i]
mask = (target_sample[None, :] >= 1) * (source_sample[:, None] >= 1)
if make_history_mask:
arange = numpy.arange(source_sample.shape[0])
history_mask = arange[None,] <= arange[:, None]
history_mask = torch.tensor(history_mask).to(mask.device)
mask = mask * history_mask
mask = ~(mask) # flip True to False
attention_mask.append(mask)
attention_mask = torch.stack(attention_mask)
attention_mask = attention_mask.unsqueeze(1)
return attention_mask
@staticmethod
def config_attention_mask(
encoder_tokens: torch.tensor,
decoder_tokens: torch.tensor,
encoder_mask: torch.tensor,
decoder_mask: torch.tensor,
use_local: bool = False,
test_te_version: str = None,
) -> torch.tensor:
"""Config attention-mask for encoder_mask, decoder_mask, encoder_decoder_mask
conditioned on transformer-implementation (e.g. TE vs local), TE versions,
and TE backends
Args:
encoder_tokens (torch.tensor): A 2-D array of tokens (bs, kv_len)
decoder_tokens (torch.tensor): A 2-D array of tokens (bs, q_len)
encoder_mask (torch.tensor): A 2-D array of tokens (bs, kv_len)
decoder_mask (torch.tensor): A 2-D array of tokens (bs, q_len)
use_local (bool): Whether the current T5 model uses local (vs TE)
transformer implmentation
Returns:
Configured encoder_mask, decoder_mask, encoder_decoder_mask
torch.tensor: configured encoder attention mask
torch.tensor: configured decoder attention mask
torch.tensor: configured encoder-decoder attention mask
"""
# If using local transformer implementation (not transformer_engine):
# re-organize all attention masks, because local and transformer_engine
# backbones use different masks shapes. E.g.:
# (local: b1ss - transformer_engine: b11s)
if use_local:
encoder_mask = T5MaskedWordPieceDataset._build_b1ss_attention_mask(
encoder_tokens, encoder_tokens
)
decoder_mask = T5MaskedWordPieceDataset._build_b1ss_attention_mask(
decoder_tokens, decoder_tokens, make_history_mask=True
)
encoder_decoder_mask = T5MaskedWordPieceDataset._build_b1ss_attention_mask(
decoder_tokens, encoder_tokens
)
else:
# If using transformer_engine transformer implementation:
# 1. For TE version >= 1.10, across all 3 backends,
# The padding mask is configued as
# [bs, 1, 1, seq_len] for self-attention and
# ([bs, 1, 1, q_len], [bs, 1, 1, kv_len]) for cross-attention
# 2. For TE version >=1.7 and <1.10, when using Non-fused backend,
# The padding mask is configued as
# [bs, 1, q_len, kv_len] for both self-attention and for cross-attention
# 3. For TE version <1.7, only support Non-fused backend
# The padding mask is configued as
# [bs, 1, q_len, kv_len] for both self-attention and for cross-attention
# Process for Flash/Fused
encoder_mask = encoder_mask.unsqueeze(1).unsqueeze(1)
decoder_mask = decoder_mask.unsqueeze(1).unsqueeze(1)
encoder_decoder_mask = (decoder_mask, encoder_mask)
# set decoder_mask to None because decoder uses AttnMaskType.causal
decoder_mask = None
# get TE version, using test TE version if not None
if test_te_version is not None:
te_version = PkgVersion(test_te_version)
else:
te_version = get_te_version()
# Check for older TE version than 1.10, adjust attention mask accordingly
flash_attention_enabled = os.getenv("NVTE_FLASH_ATTN") == "1"
fused_attention_enabled = os.getenv("NVTE_FUSED_ATTN") == "1"
if (te_version < PkgVersion("1.10.0")) and (te_version >= PkgVersion("1.7.0")):
if not (flash_attention_enabled) and not (fused_attention_enabled):
encoder_mask = T5MaskedWordPieceDataset._build_b1ss_attention_mask(
encoder_tokens, encoder_tokens
)
encoder_decoder_mask = T5MaskedWordPieceDataset._build_b1ss_attention_mask(
decoder_tokens, encoder_tokens
)
else:
pass
elif te_version < PkgVersion("1.7.0"):
if not (flash_attention_enabled) and not (fused_attention_enabled):
encoder_mask = T5MaskedWordPieceDataset._build_b1ss_attention_mask(
encoder_tokens, encoder_tokens
)
encoder_decoder_mask = T5MaskedWordPieceDataset._build_b1ss_attention_mask(
decoder_tokens, encoder_tokens
)
else:
assert not flash_attention_enabled and not fused_attention_enabled, (
"Flash and fused attention is not supported with transformer "
"engine version < 1.7. Set NVTE_FLASH_ATTN=0 and NVTE_FUSED_ATTN=0"
"or upgrade transformer engine >= 1.7"
)
return encoder_mask, decoder_mask, encoder_decoder_mask
def __getitem__(self, idx: int) -> Dict[str, Union[int, numpy.ndarray]]:
"""Abstract method implementation
Args:
idx (int): The index into the dataset
Returns:
Dict[str, Union[int, numpy.ndarray]]: The
"""
idx_beg, idx_end, target_sequence_length = self.sample_index[idx]
sample = [self.dataset[i] for i in range(idx_beg, idx_end)]
numpy_random_state = numpy.random.RandomState(seed=(self.config.random_seed + idx) % 2**32)
assert target_sequence_length <= self.config.sequence_length
# Flatten the sample into a list of tokens
tokens = [token for sentence in sample for token in sentence]
# Truncate the list of tokens to a desired length
truncated = len(tokens) > target_sequence_length
tokens = tokens[:target_sequence_length]
# Masking
(tokens, _, _, _, masked_spans) = self._create_masked_lm_predictions(
tokens, target_sequence_length, numpy_random_state
)
# Prepare the encoder input and decoder input and output
sentinels = deque(self.config.tokenizer.additional_special_tokens_ids)
encoder_input = []
decoder_input = [self.config.tokenizer.bos]
decoder_output = []
idx_beg = 0
for indices, labels in masked_spans:
sentinel = sentinels.popleft()
# set the end index
idx_end = indices[0]
encoder_input.extend(tokens[idx_beg:idx_end])
encoder_input.append(sentinel)
decoder_input.append(sentinel)
decoder_input.extend(labels)
decoder_output.append(sentinel)
decoder_output.extend(labels)
# set the start index
idx_beg = indices[-1] + 1
encoder_input.extend(tokens[idx_beg:])
decoder_output.append(self.config.tokenizer.eos)
# Pad the sequences and convert to NumPy
length_toks_encoder = len(encoder_input)
length_toks_decoder = len(decoder_input)
length_pads_encoder = self.config.sequence_length_encoder - length_toks_encoder
length_pads_decoder = self.config.sequence_length_decoder - length_toks_decoder
assert length_pads_encoder >= 0
assert length_pads_decoder >= 0
encoder_input = numpy.array(encoder_input, dtype=numpy.int64)
encoder_input = numpy.pad(
encoder_input, (0, length_pads_encoder), constant_values=self.config.tokenizer.pad
)
decoder_input = numpy.array(decoder_input, dtype=numpy.int64)
decoder_input = numpy.pad(
decoder_input, (0, length_pads_decoder), constant_values=self.config.tokenizer.pad
)
# Create attention and history masks
mask_encoder = numpy.array([1] * length_toks_encoder + [0] * length_pads_encoder)
mask_decoder = numpy.array([1] * length_toks_decoder + [0] * length_pads_decoder)
mask_encoder_decoder = None
# Mask the labels
decoder_output = numpy.array(decoder_output, dtype=numpy.int64)
decoder_output = numpy.pad(decoder_output, (0, length_pads_decoder), constant_values=-1)
# Get the loss mask
loss_mask = numpy.zeros(self.config.sequence_length_decoder, dtype=numpy.int64)
loss_mask[:length_toks_decoder] = 1
return {
"text_enc": encoder_input,
"text_dec": decoder_input,
"labels": decoder_output,
"loss_mask": loss_mask,
"truncated": int(truncated),
"enc_mask": mask_encoder,
"dec_mask": mask_decoder,
}
def _get_token_mask(self, numpy_random_state: numpy.random.RandomState) -> int:
"""Abstract method implementation
100% of the time, replace the token id with mask token id.
Args:
numpy_random_state (RandomState): The NumPy random state
Returns:
int: The mask token id
"""
return self.config.tokenizer.mask
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import logging
from enum import Enum
from typing import List, Optional, Tuple
import numpy
from ..utils import log_single_rank
logger = logging.getLogger(__name__)
class Split(Enum):
train = 0
valid = 1
test = 2
def compile_helpers():
"""Compile C++ helper functions at runtime. Make sure this is invoked on a single process."""
import os
import subprocess
command = ["make", "-C", os.path.abspath(os.path.dirname(__file__))]
if subprocess.run(command).returncode != 0:
import sys
log_single_rank(logger, logging.ERROR, "Failed to compile the C++ dataset helper functions")
sys.exit(1)
def normalize(weights: List[float]) -> List[float]:
"""Do non-exponentiated normalization
Args:
weights (List[float]): The weights
Returns:
List[float]: The normalized weights
"""
w = numpy.array(weights, dtype=numpy.float64)
w_sum = numpy.sum(w)
w = (w / w_sum).tolist()
return w
def get_blend_from_list(
blend: Optional[List[str]],
) -> Optional[Tuple[List[str], Optional[List[float]]]]:
# pylint: disable=line-too-long
"""Get the blended_megatron_dataset_config.BlendedMegatronDatasetConfig blend
from the blend list
Args:
blend (Optional[List[str]]): The blend list, which can be either
(1) a list of prefixes, e.g. ["path/to/dataset_1_prefix", "path/to/dataset_2_prefix"], or
(2) a flattened, zipped list of weights and prefixes, e.g. ["30", "path/to/dataset_1_prefix", "70", "path/to/dataset_2_prefix"]
Returns:
Optional[Tuple[List[str], Optional[List[float]]]]: The blend, consisting of a list of dataset prefixes and optionally a list of dataset weights, e.g. [["path/to/dataset_1_prefix", "path/to/dataset_2_prefix"], [30.0, 70.0]].
"""
# pylint: enable=line-too-long
if blend is None:
return None
if len(blend) % 2 == 1:
weight_per_dataset = None
raw_prefix_per_dataset = blend
else:
raw_weight_per_dataset, raw_prefix_per_dataset = zip(
*[(blend[i], blend[i + 1]) for i in range(0, len(blend), 2)]
)
weight_per_dataset = []
for rwpd in raw_weight_per_dataset:
try:
weight = float(rwpd)
except ValueError:
weight = None
weight_per_dataset.append(weight)
is_none = map(lambda _: _ is None, weight_per_dataset)
if any(is_none):
assert all(is_none)
weight_per_dataset = None
raw_prefix_per_dataset = blend
prefix_per_dataset = [rppd.strip() for rppd in raw_prefix_per_dataset]
return prefix_per_dataset, weight_per_dataset
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import os
from dataclasses import dataclass
from typing import Any, Dict, Protocol, Tuple
import torch
try:
import boto3
import botocore.exceptions as exceptions
except ModuleNotFoundError:
pass
from megatron.core.msc_utils import MultiStorageClientFeature
S3_PREFIX = "s3://"
MSC_PREFIX = "msc://"
@dataclass
class ObjectStorageConfig:
"""Config when the data (.bin) file and the index (.idx) file are in object storage
Attributes:
path_to_idx_cache (str): The local directory where we will store the index (.idx) file
bin_chunk_nbytes (int): If the number of bytes is too small, then we send a request to S3
at each call of the `read` method in _S3BinReader, which is slow, because each request
has a fixed cost independent of the size of the byte range requested. If the number of
bytes is too large, then we only rarely have to send requests to S3, but it takes a lot
of time to complete the request when we do, which can block training. We've found that
256 * 1024 * 1024 (i.e., 256 MiB) has worked well (though we have not put that much
effort into tuning it), so we default to it.
"""
path_to_idx_cache: str
bin_chunk_nbytes: int = 256 * 1024 * 1024
class S3Client(Protocol):
"""The protocol which all s3 clients should abide by"""
def download_file(self, Bucket: str, Key: str, Filename: str) -> None:
"""Download the file from S3 to the local file system"""
...
def upload_file(self, Filename: str, Bucket: str, Key: str) -> None:
"""Upload the file to S3"""
...
def head_object(self, Bucket: str, Key: str) -> Dict[str, Any]:
"""Get the metadata of the file in S3"""
...
def get_object(self, Bucket: str, Key: str, Range: str) -> Dict[str, Any]:
"""Get the file from S3"""
...
def close(self) -> None:
"""Close the S3 client"""
...
def _remove_s3_prefix(path: str) -> str:
"""Remove the S3 prefix from a path
Args:
path (str): The path
Returns:
str: The path without the S3 prefix
"""
return path.removeprefix(S3_PREFIX)
def _is_s3_path(path: str) -> bool:
"""Ascertain whether a path is in S3
Args:
path (str): The path
Returns:
bool: True if the path is in S3, False otherwise
"""
return path.startswith(S3_PREFIX)
def _remove_msc_prefix(path: str) -> str:
"""
Remove the MSC prefix from a path
Args:
path (str): The path
Returns:
str: The path without the MSC prefix
"""
return path.removeprefix(MSC_PREFIX)
def _is_msc_path(path: str) -> bool:
"""Checks whether a path is in MSC path (msc://profile/path/to/file)
Args:
path (str): The path
Returns:
bool: True if the path is in MSC path, False otherwise
"""
return path.startswith(MSC_PREFIX)
def _s3_download_file(client: S3Client, s3_path: str, local_path: str) -> None:
"""Download the object at the given S3 path to the given local file system path
Args:
client (S3Client): The S3 client
s3_path (str): The S3 source path
local_path (str): The local destination path
"""
dirname = os.path.dirname(local_path)
os.makedirs(dirname, exist_ok=True)
parsed_s3_path = parse_s3_path(s3_path)
client.download_file(parsed_s3_path[0], parsed_s3_path[1], local_path)
def _s3_object_exists(client: S3Client, path: str) -> bool:
"""Ascertain whether the object at the given S3 path exists in S3
Args:
client (S3Client): The S3 client
path (str): The S3 path
Raises:
botocore.exceptions.ClientError: The error code is 404
Returns:
bool: True if the object exists in S3, False otherwise
"""
parsed_s3_path = parse_s3_path(path)
try:
_ = client.head_object(bucket=parsed_s3_path[0], key=parsed_s3_path[1])
except exceptions.ClientError as e:
if e.response["Error"]["Code"] != "404":
raise e
return True
def is_object_storage_path(path: str) -> bool:
"""Ascertain whether a path is in object storage
Args:
path (str): The path
Returns:
bool: True if the path is in object storage (s3:// or msc://), False otherwise
"""
return _is_s3_path(path) or _is_msc_path(path)
def get_index_cache_path(idx_path: str, object_storage_config: ObjectStorageConfig) -> str:
"""Get the index cache path for the given path
Args:
idx_path (str): The path to the index file
object_storage_config (ObjectStorageConfig): The object storage config
Returns:
str: The index cache path
"""
if _is_s3_path(idx_path):
cache_idx_path = os.path.join(
object_storage_config.path_to_idx_cache, _remove_s3_prefix(idx_path)
)
elif _is_msc_path(idx_path):
cache_idx_path = os.path.join(
object_storage_config.path_to_idx_cache, _remove_msc_prefix(idx_path)
)
else:
raise ValueError(f"Invalid path: {idx_path}")
return cache_idx_path
def parse_s3_path(path: str) -> Tuple[str, str]:
"""Parses the given S3 path returning correspsonding bucket and key.
Args:
path (str): The S3 path
Returns:
Tuple[str, str]: A (bucket, key) tuple
"""
assert _is_s3_path(path)
parts = path.replace(S3_PREFIX, "").split("/")
bucket = parts[0]
if len(parts) > 1:
key = "/".join(parts[1:])
assert S3_PREFIX + bucket + "/" + key == path
else:
key = ""
return bucket, key
def get_object_storage_access(path: str) -> str:
"""Get the object storage access"""
return "s3" if _is_s3_path(path) else "msc"
def dataset_exists(path_prefix: str, idx_path: str, bin_path: str) -> bool:
"""Check if the dataset exists on object storage
Args:
path_prefix (str): The prefix to the index (.idx) and data (.bin) files
idx_path (str): The path to the index file
bin_path (str): The path to the data file
Returns:
bool: True if the dataset exists on object storage, False otherwise
"""
if _is_s3_path(path_prefix):
s3_client = boto3.client("s3")
return _s3_object_exists(s3_client, idx_path) and _s3_object_exists(s3_client, bin_path)
elif _is_msc_path(path_prefix):
msc = MultiStorageClientFeature.import_package()
return msc.exists(idx_path) and msc.exists(bin_path)
else:
raise ValueError(f"Invalid path: {path_prefix}")
def cache_index_file(remote_path: str, local_path: str) -> None:
"""Download a file from object storage to a local path with distributed training support.
The download only happens on Rank 0, and other ranks will wait for the file to be available.
Note that this function does not include any barrier synchronization. The caller (typically
in blended_megatron_dataset_builder.py) is responsible for ensuring proper synchronization
between ranks using torch.distributed.barrier() after this function returns.
Args:
remote_path (str): The URL of the file to download (e.g., s3://bucket/path/file.idx
or msc://profile/path/file.idx)
local_path (str): The local destination path where the file should be saved
Raises:
ValueError: If the remote_path is not a valid S3 or MSC path
"""
torch_dist_enabled = torch.distributed.is_initialized()
if torch_dist_enabled:
rank = torch.distributed.get_rank()
else:
rank = 0
if _is_s3_path(remote_path):
s3_client = boto3.client("s3")
if not torch_dist_enabled or rank == 0:
_s3_download_file(s3_client, remote_path, local_path)
assert os.path.exists(local_path)
elif _is_msc_path(remote_path):
msc = MultiStorageClientFeature.import_package()
if not torch_dist_enabled or rank == 0:
msc.download_file(remote_path, local_path)
assert os.path.exists(local_path)
else:
raise ValueError(f"Invalid path: {remote_path}")
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from megatron.core.datasets.object_storage_utils import ( # pylint: disable=unused-import
S3_PREFIX,
S3Client,
)
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
from .core import check_is_distributed_checkpoint
from .mapping import LocalNonpersistentObject, ShardedObject, ShardedTensor
from .serialization import (
load,
load_common_state_dict,
load_content_metadata,
load_plain_tensors,
load_tensors_metadata,
remove_sharded_tensors,
save,
)
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
""" Module for managing distributed checkpoints metadata. """
import json
import os
from dataclasses import asdict, dataclass
from typing import Optional
from megatron.core.msc_utils import MultiStorageClientFeature
CONFIG_FNAME = 'metadata.json'
class CheckpointingException(Exception):
"""Base checkpointing related exception"""
pass
@dataclass
class CheckpointingConfig:
"""Documents backends used in the checkpoint.
Checkpoint config keeps track of formats used for storing the sharded tensors
(sharded_backend) and other objects (common_backend).
Note that versioning is not for the checkpoint content (which is application specific),
but for the checkpoint format itself.
"""
sharded_backend: str
sharded_backend_version: int = 1
common_backend: str = 'torch'
common_backend_version: int = 1
def check_is_distributed_checkpoint(checkpoint_dir):
"""Checks if `metadata.json` exists in the checkpoint and is a valid config.
Args:
checkpoint_dir: checkpoint directory
Returns:
bool: True if `metadata.json` exists in the checkpoint and is a valid config.
"""
return maybe_load_config(checkpoint_dir) is not None
def maybe_load_config(checkpoint_dir: str) -> Optional[CheckpointingConfig]:
"""Returns checkpoint config if `checkpoint_dir` is a distributed checkpoint and None otherwise
Args:
checkpoint_dir: checkpoint directory
Returns:
CheckpointingConfig (optional): None if checkpoint is not a valid distributed checkpoint
"""
config_path = os.path.join(checkpoint_dir, CONFIG_FNAME)
if checkpoint_dir:
if MultiStorageClientFeature.is_enabled():
msc = MultiStorageClientFeature.import_package()
if not msc.os.path.exists(config_path):
return None
with msc.open(config_path) as f:
config_dict = json.load(f)
else:
if not os.path.exists(config_path):
return None
with open(config_path) as f:
config_dict = json.load(f)
return CheckpointingConfig(**config_dict)
return None
def save_config(config: CheckpointingConfig, checkpoint_dir: str):
"""Save given config to checkpoint directory.
Args:
config: checkpoint config
checkpoint_dir: checkpoint directory
Returns:
None
"""
config_path = os.path.join(checkpoint_dir, CONFIG_FNAME)
if MultiStorageClientFeature.is_enabled():
msc = MultiStorageClientFeature.import_package()
with msc.open(config_path, 'w') as f:
json.dump(asdict(config), f)
else:
with open(config_path, 'w') as f:
json.dump(asdict(config), f)
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
"""Utilities for operating with dicts and lists.
All functions in this module handle nesting of dicts and lists.
Other objects (e.g. tuples) are treated as atomic leaf types that cannot be traversed.
"""
from collections import defaultdict
from typing import Any, Callable, Dict, Iterable, List, Tuple, TypeVar, Union
import numpy as np
import torch
U, V = TypeVar("U"), TypeVar("V")
def extract_matching_values(
x: Union[dict, list], predicate: Callable[[Any], bool], return_lists_as_dicts: bool = False
) -> Tuple[Union[dict, list], Union[dict, list]]:
"""Return matching and nonmatching values. Keeps hierarchy.
Args:
x (Union[dict, list]) : state dict to process. Top-level argument must be a dict or list
predicate (object -> bool): determines matching values
return_lists_as_dicts (bool): if True, matching lists will be turned
into dicts, with keys indicating the indices of original elements.
Useful for reconstructing the original hierarchy.
"""
def _set_elem(target, k, v):
if return_lists_as_dicts:
target[k] = v
else:
target.append(v)
if isinstance(x, dict):
matching_vals = {}
nonmatching_vals = {}
for k, v in x.items():
if isinstance(v, (list, dict)):
match, nonmatch = extract_matching_values(v, predicate, return_lists_as_dicts)
if match:
matching_vals[k] = match
if nonmatch or not v:
nonmatching_vals[k] = nonmatch
elif predicate(v):
matching_vals[k] = v
else:
nonmatching_vals[k] = v
elif isinstance(x, list): # type: ignore
matching_vals = {} if return_lists_as_dicts else []
nonmatching_vals = {} if return_lists_as_dicts else []
for ind, v in enumerate(x):
if isinstance(v, (list, dict)) and v:
match, nonmatch = extract_matching_values(v, predicate, return_lists_as_dicts)
if match:
_set_elem(matching_vals, ind, match)
if nonmatch or not v:
_set_elem(nonmatching_vals, ind, nonmatch)
else:
target = matching_vals if predicate(v) else nonmatching_vals
_set_elem(target, ind, v)
else:
raise ValueError(f"Unexpected top-level object type: {type(x)}")
return matching_vals, nonmatching_vals
def diff(x1: Any, x2: Any, prefix: Tuple = ()) -> Tuple[list, list, list]:
"""Recursive diff of dicts.
Args:
x1 (object): left dict
x2 (object): right dict
prefix (tuple): tracks recursive calls. Used for reporting differing keys.
Returns:
Tuple[list, list, list]: tuple of:
- only_left: Prefixes present only in left dict
- only_right: Prefixes present only in right dict
- mismatch: values present in both dicts but not equal across dicts.
For tensors equality of all elems is checked.
Each element is a tuple (prefix, type of left value, type of right value).
"""
mismatch = []
if isinstance(x1, dict) and isinstance(x2, dict):
only_left = [prefix + (k,) for k in x1.keys() - x2.keys()]
only_right = [prefix + (k,) for k in x2.keys() - x1.keys()]
for k in x2.keys() & x1.keys():
_left, _right, _mismatch = diff(x1[k], x2[k], prefix + (k,))
only_left.extend(_left)
only_right.extend(_right)
mismatch.extend(_mismatch)
elif isinstance(x1, list) or isinstance(x1, tuple) or isinstance(x1, np.ndarray):
assert type(x1) == type(x2)
only_left = list(range(len(x1) - 1, len(x2) - 1, -1))
only_right = list(range(len(x2) - 1, len(x1) - 1, -1))
for i, (v1, v2) in enumerate(zip(x1, x2)):
_left, _right, _mismatch = diff(v1, v2, prefix + (i,))
only_left.extend(_left)
only_right.extend(_right)
mismatch.extend(_mismatch)
else:
only_left = []
only_right = []
if isinstance(x1, torch.Tensor) and isinstance(x2, torch.Tensor):
if x1.device != x2.device:
_is_mismatch = not torch.all(x1.cpu() == x2.cpu())
else:
_is_mismatch = not torch.all(x1 == x2)
# TODO: change with concrete type that has both replica_id and data attrs
elif hasattr(x1, "replica_id") and hasattr(x2, "replica_id"):
assert type(x1) == type(x2)
only_left, only_right, mismatch = diff(
x1.data, x2.data, prefix + (type(x1),)
) # type: ignore
_is_mismatch = False
else:
try:
_is_mismatch = bool(x1 != x2)
except RuntimeError:
_is_mismatch = True
if _is_mismatch:
mismatch.append((prefix, type(x1), type(x2)))
return only_left, only_right, mismatch
def inspect_types(x: Any, prefix: Tuple = (), indent: int = 4):
"""Helper to print types of (nested) dict values."""
print_indent = lambda: print(" " * indent * len(prefix), end="")
if isinstance(x, dict):
print()
for k, v in x.items():
print_indent()
print(f"> {k}: ", end="")
inspect_types(v, prefix + (k,), indent)
elif isinstance(x, list):
print()
for i, v in enumerate(x):
print_indent()
print(f"- {i}: ", end="")
inspect_types(v, prefix + (i,), indent)
else:
if isinstance(x, torch.Tensor):
print(f"Tensor of shape {x.shape}")
else:
try:
x_str = str(x)
except:
x_str = "<no string repr>"
if len(x_str) > 30:
x_str = x_str[:30] + "... (truncated)"
print(f"[{type(x)}]: {x_str}")
def nested_values(x: Union[dict, list]):
"""Returns iterator over (nested) values of a given dict or list."""
x_iter = x.values() if isinstance(x, dict) else x
for v in x_iter:
if isinstance(v, (dict, list)):
yield from nested_values(v)
else:
yield v
def nested_items_iter(x: Union[dict, list]):
"""Returns iterator over (nested) tuples (container, key, value) of a given dict or list."""
x_iter = x.items() if isinstance(x, dict) else enumerate(x)
for k, v in x_iter:
if isinstance(v, (dict, list)):
yield from nested_items_iter(v)
else:
yield x, k, v
def dict_map(f: Callable, d: dict):
"""`map` equivalent for dicts."""
for sub_d, k, v in nested_items_iter(d):
sub_d[k] = f(v)
def dict_map_with_key(f: Callable, d: dict):
"""`map` equivalent for dicts with a function that accepts tuple (key, value)."""
for sub_d, k, v in nested_items_iter(d):
sub_d[k] = f(k, v)
def dict_list_map_inplace(f: Callable[[U], V], x: Union[Dict, List, U]):
"""Maps dicts and lists *in-place* with a given function."""
if isinstance(x, dict):
for k, v in x.items():
x[k] = dict_list_map_inplace(f, v)
elif isinstance(x, list):
x[:] = (dict_list_map_inplace(f, v) for v in x)
else:
return f(x)
return x
def dict_list_map_outplace(f: Callable[[U], V], x: Union[Dict, List, U]) -> Union[Dict, List, V]:
"""Maps dicts and lists *out-of-place* with a given function."""
if isinstance(x, dict):
return {k: dict_list_map_outplace(f, v) for k, v in x.items()}
elif isinstance(x, list):
return [dict_list_map_outplace(f, v) for v in x]
else:
return f(x)
def merge(x1: Union[dict, list], x2: Union[dict, list], key: Tuple[Union[str, int], ...] = ()):
"""Merges dicts and lists recursively."""
if isinstance(x1, dict) and isinstance(x2, dict):
for k, v2 in x2.items():
if k not in x1:
x1[k] = v2
else:
x1[k] = merge(x1[k], v2, key=key + (k,))
elif isinstance(x1, list) and isinstance(x2, list):
if len(x1) != len(x2):
raise ValueError(
f"Cannot merge two lists with different lengths ({len(x1)} and {len(x2)}, "
f"encountered at level {key})"
)
for i, v2 in enumerate(x2):
x1[i] = merge(x1[i], v2, key=key + (i,))
else:
raise ValueError(
f"Duplicate non-dict and non-list values encountered: `{x1}` and `{x2}` "
f"(at level {key})"
)
return x1
def map_reduce(
xs: Iterable,
key_fn: Callable = lambda x: x,
value_fn: Callable = lambda x: x,
reduce_fn: Callable = lambda x: x,
) -> dict:
"""Simple map-reduce implementation following `more_itertools.map_reduce` interface."""
res = defaultdict(list)
for x in xs:
res[key_fn(x)].append(value_fn(x))
for k in res:
res[k] = reduce_fn(res[k])
return dict(res)
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
"""Utilities for exchanging data between ranks."""
import logging
from collections import defaultdict
from functools import reduce
from itertools import zip_longest
from typing import Any, Dict, List, NamedTuple, Optional, Set, Tuple, TypeVar, cast
import numpy as np
import torch
from ..utils import get_pg_rank, get_pg_size
from .core import CheckpointingException
from .dict_utils import nested_values
from .mapping import ShardedStateDict, ShardedTensor, is_main_replica
from .utils import _sharded_tensor_shard_id, _ShardId, debug_time
# TODO: remove TE references once the TE bug is fixed
# Check if Transformer Engine has Float8Tensor class
try:
from transformer_engine.pytorch.float8_tensor import Float8Tensor
HAVE_TE_FLOAT8TENSOR = True
except (ImportError, ModuleNotFoundError):
# Float8Tensor not found
HAVE_TE_FLOAT8TENSOR = False
def is_float8tensor(tensor: torch.Tensor) -> bool:
"""Check if a tensor is a Transformer Engine Float8Tensor"""
return HAVE_TE_FLOAT8TENSOR and isinstance(tensor, Float8Tensor)
logger = logging.getLogger(__name__)
class ShardDistribution(NamedTuple):
"""Represents a distribution of ShardedTensors.
Given distribution is valid only for a specific parallelization group,
which is implicit here (not referenced by this class).
Args:
main_rank_for_shard (Dict[_ShardId, int]): specifies which rank should hold
the main replica for a given shard
shards_in_this_group (Set[_ShardId]): which shards have a main replica
in this parallelization group
shard_to_metadata (Dict[_ShardId, ShardedTensor]): maps ShardedTensor
identifier to the original ShardedTensor
all_ranks_for_shard (Dict[_ShardId, List[int]]): specifies which ranks
need a given shard in a given parallelization group
"""
main_rank_for_shard: Dict[_ShardId, int]
shards_in_this_group: Set[_ShardId]
shard_to_metadata: Dict[_ShardId, ShardedTensor]
all_ranks_for_shard: Dict[_ShardId, List[int]]
def _shard_size(sh_ten: ShardedTensor):
"""Returns size in bytes of a given sharded tensor."""
if sh_ten.flattened_range is None:
numel = np.product(sh_ten.local_shape)
else:
numel = sh_ten.flattened_range.stop - sh_ten.flattened_range.start
return numel * torch._utils._element_size(sh_ten.dtype)
def _get_empty_tensor_for_exchange(
shard_id: _ShardId,
needed_shards: Dict[_ShardId, ShardedTensor],
unneeded_shards: Dict[_ShardId, ShardedTensor],
loaded_tensors: Dict[_ShardId, torch.Tensor],
) -> Tuple[torch.Tensor, Optional[torch.device]]:
"""Determines the empty tensor to use for exchange.
If shard_id is needed by this rank, it will be in the `unloaded_shards`.
Otherwise, the metadata for this tensor can be found in `shard_to_metadata`
Args:
shard_id (_ShardId): shard_id that will be exchanged
needed_shards (Dict[_ShardId, ShardedTensor]): mapping from shard ids
to metadata for shards needed by this rank
unneeded_shards (Dict[_ShardId, ShardedTensor]): mapping from shard ids
to metadata for shards that can be discarded after exchange
loaded_tensors (Dict[_ShardId, torch.Tensor]): mapping where useful tensors
are placed in
Returns:
Tuple[torch.Tensor, Optional[torch.device]]: empty CUDA tensor to be exchanged,
and the device of the original state dict tensor (if there was any)
"""
local_unloaded_sh_ten = needed_shards.get(shard_id)
if local_unloaded_sh_ten is None:
orig_device = None # this tensor will be discarded anyway
sh_ten = unneeded_shards[shard_id]
if sh_ten.data is None:
sh_ten.init_data("cuda")
tensor = sh_ten.data
sh_ten.data = None # won't be used. free memory
else:
tensor = sh_ten.data
if tensor.device.type == "cpu":
tensor = torch.empty_like(tensor, device="cuda")
else:
local_unloaded_sh_ten.init_data("cuda")
orig_device = local_unloaded_sh_ten.data.device
tensor = local_unloaded_sh_ten.data
if tensor.device.type == "cpu":
tensor = torch.empty_like(tensor, device="cuda")
loaded_tensors[shard_id] = tensor
return tensor, orig_device
T = TypeVar("T")
def distribute_shards_to_ranks(
shard_to_ranks: Dict[T, List[int]],
shard_to_size: Dict[T, int],
num_ranks: int,
cross_parallelization_group_loads: Set[T],
) -> Dict[T, int]:
"""Computes uniform distribution of workload across ranks, based on sizes.
Currently, the assignment is greedy, based on:
1. Cross-parallelization group dependencies (shards with main rank in another group
are assigned at the end to make sure the distribution for load and save
is as similar as possible).
2. Secondly, the coverage of each shard
(how many ranks the shard is available on; lower coverage is assigned first)
3. Then, the size of each shard (larger size is assigned first)
4. Finally, shard id for differentiation.
Last step is added because we rely on the fact that
the assignment is deterministic on all ranks.
Args:
shard_to_ranks (Dict[T, List[int]]): mapping of rank access to shards
shard_to_size (Dict[T, int]): sizes of each shard
num_ranks (int): number of ranks in the parallelization group
cross_parallelization_group_loads (Set[T]): Shards to load that are not in the main replica
Returns (Dict[T, int]): assignment of shard to rank (which rank should do the work
to achieve maximal uniformity)
"""
shard_to_ranks = {k: tuple(v) for k, v in shard_to_ranks.items()}
shard_to_saving_rank = {}
rank_sizes = [(0, rank) for rank in range(num_ranks)]
# start from tensors of lowest coverage, then go by tensor size from largest (hence minus size)
for shard_id, shard_ranks in sorted(
shard_to_ranks.items(),
key=lambda sh_id_ranks: (
# 0 if rank is not in cross_parallelization_group_loads
# which means it has higher priority
int(sh_id_ranks[0] in cross_parallelization_group_loads),
len(sh_id_ranks[1]),
-shard_to_size[sh_id_ranks[0]],
sh_id_ranks[0],
),
):
# assign greedily to the least occupied rank
size, rank = min((size, rank) for size, rank in rank_sizes if rank in shard_ranks)
shard_to_saving_rank[shard_id] = rank
rank_sizes[rank] = (size + shard_to_size[shard_id], rank)
logger.debug(f"distribute_shards_to_ranks distribution: {rank_sizes}")
return shard_to_saving_rank
def determine_main_replica_uniform_distribution(
sharded_state_dict: ShardedStateDict,
parallelization_group: torch.distributed.ProcessGroup,
ignore_groups: bool = False,
) -> Optional[ShardDistribution]:
"""Computes the save distribution.
Should be used in conjunction with `distribute_main_replicas_with_precomputed_distribution`
which applies the computed save distribution.
We rely on the fact that the assignment algorithm is deterministic on all ranks,
so there is no extra communication needed after metadata exchange.
Args:
sharded_state_dict (ShardedStateDict): state dict to compute the distribution of
parallelization_group (ProcessGroup): distribution will be computed
within this process group
ignore_groups (bool, optional): whether the distribution defines groups.
This option is primarily used during loading, as it ensures that all replicas,
including non-main ones, are loaded by this parallelization group
Defaults to False.
Returns (ShardDistribution, optional): distribution that can be used to apply the
parallelization. Returns None if the process_group is trivial (1 rank)
"""
if parallelization_group is None:
parallelization_group = torch.distributed.group.WORLD
group_size = get_pg_size(group=parallelization_group)
if group_size <= 1:
return
local_shards = list(
sh_base
for sh_base in nested_values(sharded_state_dict)
if isinstance(sh_base, ShardedTensor)
)
local_shards_no_data = [ten.without_data() for ten in local_shards]
all_shards = [None] * get_pg_size(group=parallelization_group)
torch.distributed.all_gather_object(
all_shards, local_shards_no_data, group=parallelization_group
)
shard_to_ranks = defaultdict(list)
shard_to_size = {}
shard_to_metadata = {}
group_has_main_replica: Set[_ShardId] = set()
group_has_non_main_replica: Set[_ShardId] = set()
for rank, rank_shards in enumerate(all_shards):
for sh_ten in rank_shards:
shard_id = _sharded_tensor_shard_id(sh_ten)
shard_to_ranks[shard_id].append(rank)
if shard_id not in shard_to_size:
shard_to_size[shard_id] = _shard_size(sh_ten)
shard_to_metadata[shard_id] = sh_ten
if is_main_replica(sh_ten.replica_id):
group_has_main_replica.add(shard_id)
else:
group_has_non_main_replica.add(shard_id)
# we always include all main replicas, and non-main only if `ignore_groups`
shards_in_this_group: Set[_ShardId] = group_has_main_replica
if ignore_groups:
shards_in_this_group = shards_in_this_group | group_has_non_main_replica
# cross-parallel-group references are empty if `not ignore_groups`,
# otherwise it's `group_has_non_main_replica - group_has_main_replica`
cross_parallelization_group_loads = shards_in_this_group - group_has_main_replica
# Filter out shards that don't belong to this group
shard_to_ranks = {k: v for k, v in shard_to_ranks.items() if k in shards_in_this_group}
shard_to_saving_rank = distribute_shards_to_ranks(
shard_to_ranks, shard_to_size, len(all_shards), cross_parallelization_group_loads
)
return ShardDistribution(
shard_to_saving_rank, shards_in_this_group, shard_to_metadata, shard_to_ranks
)
@torch.no_grad()
@debug_time(f"exchange_loaded_tensors_gather_rounds", logger)
def exchange_loaded_tensors_gather_rounds(
loaded_tensors: Dict[_ShardId, torch.Tensor],
unloaded_shards: Dict[_ShardId, ShardedTensor],
shard_distribution: ShardDistribution = None,
parallelization_group: Optional[torch.distributed.ProcessGroup] = None,
) -> Dict[_ShardId, torch.Tensor]:
"""Exchange the tensors loaded by different ranks with several all_gather calls.
Groups tensors by dtype, divide tensors that will be exchanged into rounds
and execute all_gather for tensors from each round.
Note: the loading is distributed across ranks based on total loaded size
in bytes, so there is no guarantee that number of rounds needed for each
rank will be similar, which might result in a lot of almost empty
all_gathers. The solution would be to group all tensors into a one
bytes tensor and do a single all_gather (with similarly sized messages).
Args:
loaded_tensors (Dict[_ShardId, torch.Tensor]): mapping from ShardedTensor
shard ids to tensors already loaded by this rank.
unloaded_shards (Dict[_ShardId, torch.Tensor]): mapping from ShardedTensor
shard ids to ShardedTensors that aren't loaded yet.
shard_distribution (ShardDistribution): distribution of all shards
parallelization_group (ProcessGroup, optional): process group used for load
distribution. Tensors will be exchanged within this group
Returns:
Dict[_ShardId, torch.Tensor]: dictionary mapping shard ids to tensors
needed by this rank to load a given state dict. Includes
previously loaded tensors (from `loaded_tensors` input)
"""
if parallelization_group is None:
parallelization_group = torch.distributed.group.WORLD
main_rank_for_shard, _, shard_to_metadata, all_ranks_for_shard = shard_distribution
local_rank = get_pg_rank(group=parallelization_group)
all_loaded_tensors = dict(loaded_tensors)
# Group by dtype so that we all_gather tensors of the same dtype
for dtype in sorted(set(map(lambda sh_ten: sh_ten.dtype, shard_to_metadata.values())), key=str):
with debug_time(f"dtype_{dtype}"):
# shards_by_rank maps rank to tensors loaded by this rank
shards_by_rank: List[List[torch.Tensor]] = [
[] for _ in range(get_pg_size(group=parallelization_group))
]
for shard_id, rank in main_rank_for_shard.items():
if len(all_ranks_for_shard[shard_id]) == 1:
assert all_ranks_for_shard[shard_id][0] == main_rank_for_shard[shard_id], (
f"When there is only 1 ranks that needs a given shard,"
f" it should be the loading rank."
f" Got: needs [{all_ranks_for_shard[shard_id][0]}]"
f" vs loads [{main_rank_for_shard[shard_id]}]"
)
# Skipping the exchange since only the loading rank needs this tensor
# TODO: we can employ some optimizations even for `len(shard_to_ranks) > 1`
# case, e.g. P2P exchange. Currently handling this case saves most of the
# work though.
continue
if shard_to_metadata[shard_id].dtype == dtype:
shards_by_rank[rank].append(shard_id)
# Transpose `shards_by_rank` to form exchange rounds
shards_by_round = zip_longest(*shards_by_rank, fillvalue=None)
for round_idx, round_shard_ids in enumerate(shards_by_round):
round_tensors = []
orig_devices = {}
for rank, shard_id in enumerate(round_shard_ids):
if shard_id is None:
# if no more useful data, the given rank will exchange empty tensor
local_ten = torch.empty(0, dtype=dtype, device="cuda")
orig_device = None
else:
assert isinstance(shard_id, tuple), type(shard_id)
if rank == local_rank:
assert shard_id in all_loaded_tensors, (
shard_id,
all_loaded_tensors.keys(),
)
orig_device = all_loaded_tensors[shard_id]
all_loaded_tensors[shard_id] = all_loaded_tensors[shard_id].cuda()
local_ten = all_loaded_tensors[shard_id]
else:
local_ten, orig_device = _get_empty_tensor_for_exchange(
shard_id, unloaded_shards, shard_to_metadata, all_loaded_tensors
)
# Because of a TE bug, we have to exchange a nominal dtype instead of FP8
# It's ok to keep the nominal dtype after exchange, because TE will handle
# this during state dict load.
# TODO: remove it once the bug is fixed
if is_float8tensor(local_ten):
try:
local_ten = local_ten.from_float8()
except Exception as e:
local_ten = local_ten.dequantize()
all_loaded_tensors[shard_id] = local_ten
round_tensors.append(local_ten)
if orig_device is not None:
orig_devices[shard_id] = orig_device
torch.distributed.all_gather(
list(round_tensors),
round_tensors[local_rank],
group=parallelization_group,
async_op=False,
)
# Move tensors back to CPU if originally was on CPU
for shard_id, orig_device in orig_devices.items():
all_loaded_tensors[shard_id] = all_loaded_tensors[shard_id].to(orig_device)
del round_tensors # remove tensor references
return all_loaded_tensors
def exchange_loaded_tensors_gather_object(
loaded_tensors: Dict[_ShardId, torch.Tensor],
unloaded_shards: Dict[_ShardId, ShardedTensor],
shard_distribution: ShardDistribution,
parallelization_group: Optional[torch.distributed.ProcessGroup] = None,
) -> Dict[_ShardId, torch.Tensor]:
"""Exchange the tensors loaded by different ranks with a simple all_gather_object call.
This version can be used for debugging purposes do to its simplistic
implementation. Shouldn't be used if performance is important.
Args:
loaded_tensors (Dict[_ShardId, torch.Tensor]): mapping from ShardedTensor
shard ids to tensors already loaded by this rank.
unloaded_shards (Dict[_ShardId, torch.Tensor]): mapping from ShardedTensor
shard ids to ShardedTensors that aren't loaded yet.
shard_distribution (ShardDistribution): distribution of all shards
parallelization_group (ProcessGroup, optional): process group used for load
distribution. Tensors will be exchanged within this group
Returns:
Dict[_ShardId, torch.Tensor]: dictionary mapping shard ids to tensors
needed by this rank to load a given state dict. Includes
previously loaded tensors (from `loaded_tensors` input)
"""
all_loaded_tensors_list = [None] * torch.distributed.get_world_size(group=parallelization_group)
torch.distributed.all_gather_object(
all_loaded_tensors_list, loaded_tensors, group=parallelization_group
)
all_loaded_tensors_list = cast(List[Dict[_ShardId, torch.Tensor]], all_loaded_tensors_list)
all_loaded_tensors = reduce(lambda x, y: {**x, **y}, all_loaded_tensors_list)
# Error checks
if len(all_loaded_tensors) != sum(map(len, all_loaded_tensors_list)):
err_msg = "Duplicate shard ids loaded by different ranks"
if torch.distributed.get_rank() == 0:
logger.error(
f"{err_msg}. Shards ids by rank:"
f" {[lt.keys() for lt in all_loaded_tensors_list]}"
)
raise CheckpointingException(err_msg)
return all_loaded_tensors
def exchange_loaded_objects_gather_object(
loaded_objects: Dict[_ShardId, Any]
) -> Dict[_ShardId, Any]:
"""Exchange the objects loaded by different ranks with a simple all_gather_object call.
Args:
loaded_objects (Dict[_ShardId, Any]): mapping from shard ids to objects
already loaded by this rank.
Returns:
Dict[_ShardId, Any]: dictionary mapping shard ids to objects needed by this rank to
load a given state dict.
"""
all_loaded_objects_list = [None] * torch.distributed.get_world_size()
torch.distributed.all_gather_object(all_loaded_objects_list, loaded_objects, group=None)
all_loaded_objects_list = cast(List[Dict[_ShardId, Any]], all_loaded_objects_list)
all_loaded_objects = reduce(lambda x, y: {**x, **y}, all_loaded_objects_list)
# Error checks
if len(all_loaded_objects) != sum(map(len, all_loaded_objects_list)):
err_msg = "Duplicate shard ids loaded by different ranks"
if torch.distributed.get_rank() == 0:
logger.error(
f"{err_msg}. Shards ids by rank:"
f" {[lt.keys() for lt in all_loaded_objects_list]}"
)
raise CheckpointingException(err_msg)
return all_loaded_objects
@torch.no_grad()
@debug_time("exchange_loaded_tensors_broadcast", logger)
def exchange_loaded_tensors_broadcast(
loaded_tensors: Dict[_ShardId, torch.Tensor],
unloaded_shards: Dict[_ShardId, ShardedTensor],
shard_distribution: ShardDistribution,
parallelization_group: Optional[torch.distributed.ProcessGroup] = None,
) -> Dict[_ShardId, torch.Tensor]:
"""Exchange the tensors loaded by different ranks by a series of broadcasts.
For each rank for each loaded tensor do a broadcast to the whole group.
A reasonable tradeoff in terms of performance and simplicity.
Args:
loaded_tensors (Dict[_ShardId, torch.Tensor]): mapping from ShardedTensor
shard ids to tensors already loaded by this rank.
unloaded_shards (Dict[_ShardId, ShardedTensor]): mapping from ShardedTensor
shard ids to ShardedTensors that aren't loaded yet.
shard_distribution (ShardDistribution): distribution of all shards
parallelization_group (ProcessGroup, optional): process group used for load
distribution. Tensors will be exchanged within this group
Returns:
Dict[_ShardId, torch.Tensor]: dictionary mapping shard ids to tensors
needed by this rank to load a given state dict. Includes
previously loaded tensors (from `loaded_tensors` input)
"""
main_rank_for_shard, _, shard_to_metadata, all_ranks_for_shard = shard_distribution
local_rank = torch.distributed.get_rank(group=parallelization_group)
all_loaded_tensors = dict(loaded_tensors)
for idx, (shard_id, rank) in enumerate(main_rank_for_shard.items()):
if len(all_ranks_for_shard[shard_id]) == 1:
assert all_ranks_for_shard[shard_id][0] == main_rank_for_shard[shard_id], (
f"When there is only 1 ranks that needs a given shard,"
f" it should be the loading rank."
f"Got: needs [{all_ranks_for_shard[shard_id][0]}]"
f" vs loads [{main_rank_for_shard[shard_id]}]"
)
# Skipping the exchange since only the loading rank needs this tensor
# TODO: we can employ some optimizations even for `len(shard_to_ranks) > 1` case,
# e.g. P2P exchange. Currently handling this case saves most of the work though.
continue
if rank == local_rank:
assert shard_id in all_loaded_tensors, (shard_id, all_loaded_tensors.keys())
orig_device = all_loaded_tensors[shard_id].device
local_ten = all_loaded_tensors[shard_id].cuda()
else:
local_ten, orig_device = _get_empty_tensor_for_exchange(
shard_id, unloaded_shards, shard_to_metadata, all_loaded_tensors
)
# Because of a TE bug, we have to exchange a nominal dtype instead of FP8
# It's ok to keep the nominal dtype after exchange, because TE will handle
# this during state dict load.
# TODO: remove it once the bug is fixed
if is_float8tensor(local_ten):
try:
local_ten = local_ten.from_float8()
except Exception as e:
local_ten = local_ten.dequantize()
all_loaded_tensors[shard_id] = local_ten
global_src_rank = (
rank
if parallelization_group == None
else torch.distributed.get_global_rank(parallelization_group, rank)
)
# We can do async_op=True only if there is no CPU-copy follow-up
torch.distributed.broadcast(
local_ten,
src=global_src_rank,
group=parallelization_group,
async_op=orig_device is None,
)
# Move tensor back to CPU if originally was on CPU
if orig_device is not None:
all_loaded_tensors[shard_id] = local_ten.to(orig_device)
del local_ten
return all_loaded_tensors
def exchange_by_distribution(
loaded_tensors: Dict[_ShardId, torch.Tensor],
unloaded_shards: Dict[_ShardId, ShardedTensor],
shard_distribution: ShardDistribution,
parallelization_group: Optional[torch.distributed.ProcessGroup] = None,
exchange_algo="broadcast",
) -> Dict[_ShardId, torch.Tensor]:
"""Exchange tensors loaded by different ranks using the specified exchange_algo.
Args:
loaded_tensors (Dict[_ShardId, torch.Tensor]): mapping from ShardedTensor
shard ids to tensors already loaded by this rank.
unloaded_shards (Dict[_ShardId, ShardedTensor]): mapping from ShardedTensor
shard ids to ShardedTensors that aren't loaded yet.
shard_distribution (ShardDistribution): distribution of all shards
parallelization_group (ProcessGroup, optional): process group used for load
distribution. Tensors will be exchanged within this group
exchange_algo (str): The algorithm used for performing exchanges.
Defaults to 'broadcast'.
Returns:
Dict[_ShardId, torch.Tensor]: dictionary mapping shard ids to tensors
needed by this rank to load a given state dict. Includes
previously loaded tensors (from `loaded_tensors` input)
"""
assert shard_distribution is not None, "Expecting distribution to perform exchange"
if exchange_algo == "gather_object":
exchange_fn = exchange_loaded_tensors_gather_object
elif exchange_algo == "gather_rounds":
exchange_fn = exchange_loaded_tensors_gather_rounds
elif exchange_algo == "broadcast":
exchange_fn = exchange_loaded_tensors_broadcast
else:
raise NotImplementedError(f"Unrecognized gather algorithm: {exchange_algo}")
return exchange_fn(loaded_tensors, unloaded_shards, shard_distribution, parallelization_group)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment