"sgl-router/git@developer.sourcefind.cn:change/sglang.git" did not exist on "7590f5224b69b73d7b75480fe8553fe9381b82d0"
Commit 7c19b3a8 authored by wangsen's avatar wangsen
Browse files

Initial commit

parents
Pipeline #1721 failed with stages
in 0 seconds
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Construct an index.
Constructing an index generally happens in two phases:
- index.train(): Train an index on a representative set of vectors.
- index.add(): Add vectors to an index, to be available for retrieval.
"""
import os
import shutil
import numpy as np
import torch
from tqdm import tqdm
from megatron.core.datasets.retro.config import RetroPreprocessingConfig
from megatron.core.datasets.retro.db.utils import (
get_merged_sampled_dataset,
get_merged_train_dataset,
)
from megatron.core.datasets.retro.external_libs import h5py
from megatron.core.datasets.retro.utils import GPTToTextDataset
from .factory import IndexFactory
from .utils import (
get_training_data_block_dir,
get_training_data_block_paths,
get_training_data_merged_path,
get_training_data_root_dir,
)
##################################################
# Train index.
##################################################
def get_empty_index_path(config: RetroPreprocessingConfig) -> str:
"""Path of empty index.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
Returns:
Path to the empty (trained, but without added samples) vector index.
"""
index = IndexFactory.get_index(config.retro_index_type)
empty_index_path = index.get_empty_index_path(config)
return empty_index_path
def get_block_nload(block_path: str, load_fraction: float) -> int:
"""Compute number of blocks to load.
This is computed by multiplying the total number of available blocks with the
fraction of blocks to load.
Args:
block_path (str): Path to HDF5 file containing block of data. File must contain key 'data'.
load_fraction (float): Fraction (0 < load_fraction <= 1) of block samples to load.
Returns:
Number of block samples to load.
"""
with h5py.File(block_path) as fi:
return int(load_fraction * fi["data"].shape[0])
def merge_embedding_blocks(config: RetroPreprocessingConfig) -> None:
"""Merge individual embedding blocks into a single binary mmap file.
The embeddings are initially stored in block-sized (e.g., ~100k embeddings per
block) HDF5 files. These individual block files must be merged into a single
file before training, to be based as a numpy mmap array to the index.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
"""
if torch.distributed.get_rank() != 0:
return
# Get block, merged paths.
load_fraction = config.retro_index_train_load_fraction
block_paths = get_training_data_block_paths(config)
bin_path = get_training_data_merged_path(config)
# Skip, if already built.
if os.path.exists(bin_path):
return
# Merge blocks.
with open(bin_path, "wb") as fo:
byte_offset = 0
for block_idx, block_path in enumerate(
tqdm(
block_paths,
"merge train embeddings",
miniters=len(block_paths) // 10,
disable=torch.distributed.get_rank() != 0,
)
):
with h5py.File(block_path) as fi:
nload = get_block_nload(block_path, load_fraction)
block = np.array(fi["data"][:nload], copy=False)
fo.write(block.tobytes())
byte_offset += block.size * block.itemsize
fo.seek(byte_offset)
def get_text_dataset_for_training(config: RetroPreprocessingConfig) -> GPTToTextDataset:
"""Convert GPT token chunk dataset to a text dataset for passing to the
embedder.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
Returns:
The text dataset consisting of tokens converted from sampled chunk database.
"""
gpt_dataset = get_merged_sampled_dataset(
project_dir=config.retro_project_dir,
chunk_length=config.retro_gpt_chunk_length,
eod_token_id=config.retro_tokenizers.gpt.eod,
)
text_dataset = GPTToTextDataset(gpt_dataset, config.retro_tokenizers.gpt)
return text_dataset
def embed_training_chunks(config: RetroPreprocessingConfig) -> None:
"""Embed DB chunks.
Store chunks in blocks on disk. These blocks will later be merged into
a single dataset for training the index.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
"""
merged_train_data_path = get_training_data_merged_path(config)
if os.path.exists(merged_train_data_path):
return
# Get training text dataset.
text_dataset = get_text_dataset_for_training(config)
# Embed dataset.
embedder = config.retro_bert_embedders.disk
embedder.embed_text_dataset("index", get_training_data_block_dir(config), text_dataset)
# Merge embeddings.
merge_embedding_blocks(config)
def train_on_embeddings(config: RetroPreprocessingConfig) -> None:
"""Train index on embedded DB chunks.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
"""
index = IndexFactory.get_index(config.retro_index_type)
index.train(config)
def remove_embeddings(config: RetroPreprocessingConfig) -> None:
"""Remove embeddings after training.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
"""
torch.distributed.barrier()
if torch.distributed.get_rank() != 0:
return
empty_index_path = get_empty_index_path(config)
assert os.path.isfile(empty_index_path)
shutil.rmtree(get_training_data_root_dir(config), ignore_errors=True)
def _train_index(config: RetroPreprocessingConfig) -> None:
"""Train index on DB chunks.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
"""
# Check if trained index already exists.
if not os.path.isfile(get_empty_index_path(config)):
# Embed training chunks.
embed_training_chunks(config)
# Train index on embeddings.
train_on_embeddings(config)
# Wait for (single-process) training to complete.
torch.distributed.barrier()
# Remove embeddings.
if config.retro_index_delete_training_embeddings:
remove_embeddings(config)
def train_index(config: RetroPreprocessingConfig) -> None:
"""Entry point for training the index.
We select whether to train a new index, or validate an existing index.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
"""
# Train new index.
if config.retro_task_validate is None:
_train_index(config)
# Validate existing trained index.
else:
from .validate import validate_training_embeddings
validate_training_embeddings(config)
##################################################
# Add to index.
##################################################
def get_text_dataset_for_adding(config: RetroPreprocessingConfig) -> GPTToTextDataset:
"""Convert GPT token chunk dataset to a text dataset for passing to the
embedder.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
Returns:
The text dataset that consists of tokens converted from the 'train' chunk database. These are the chunks used for retrieval by the pretraining 'train' dataset.
"""
gpt_dataset = get_merged_train_dataset(
project_dir=config.retro_project_dir,
chunk_length=config.retro_gpt_chunk_length,
eod_token_id=config.retro_tokenizers.gpt.eod,
)
text_dataset = GPTToTextDataset(gpt_dataset, config.retro_tokenizers.gpt)
return text_dataset
def _add_to_index(config: RetroPreprocessingConfig) -> str:
"""Add DB chunks to index.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
Returns:
Path to the populated index.
"""
# Get index.
index = IndexFactory.get_index(config.retro_index_type)
# Get text dataset.
text_dataset = get_text_dataset_for_adding(config)
# Add to index.
output_index_path = index.add(config, text_dataset)
return output_index_path
def add_to_index(config: RetroPreprocessingConfig) -> None:
"""Entry point for adding to the index.
We select whether to add to a new index, or validate an existing index.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
"""
# Add to new index.
if config.retro_task_validate is None:
_add_to_index(config)
# Validate existing encodings.
else:
from .validate import validate_added_encodings
validate_added_encodings(config)
##################################################
# Build index (train + add).
##################################################
def build_index(config: RetroPreprocessingConfig) -> None:
"""Build index.
Building index involves sequentially running stages above:
- Train index (on sampled training chunks).
- Add to index (on all training chunks).
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
"""
# Train index.
train_index(config)
# Add to index.
add_to_index(config)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""The IndexFactory constructs an index from an index type string."""
from megatron.core.datasets.retro.index.index import Index
from .indexes import FaissBaseIndex, FaissParallelAddIndex
class IndexFactory:
"""Get index.
Index type generally read from argument '--retro-index-ty'.
"""
@classmethod
def get_index_class(cls, index_type: str) -> type:
"""Get an index class, given a type string.
Args:
index_type (str): One of 'faiss-base' (naive Faiss index wrapper) or 'faiss-par-add' (Faiss index wrapper with near embarrassingly parallel index.add().
Returns:
An `Index` sub-type corresponding to the `index_type`.
"""
return {"faiss-base": FaissBaseIndex, "faiss-par-add": FaissParallelAddIndex,}[index_type]
@classmethod
def get_index(cls, index_type: str) -> Index:
"""Construct an index from an index type string.
Args:
index_type (str): One of 'faiss-base' (naive Faiss index wrapper) or 'faiss-par-add' (Faiss index wrapper with near embarrassingly parallel index.add().
Returns:
An `Index` instance corresponding to the `index_type`.
"""
index_class = cls.get_index_class(index_type)
index = index_class()
return index
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Base class for all vector indexes.
A vector index is a type of retrieval database that is queried using vectors,
and returns vectors that are 'similar' (e.g., by cosine distance) to the query
vector. The construction and usage of an index generally has the following
pattern:
- Train the index on representative vectors.
- Add vectors to the index (i.e., vectors available for retrieval)
- Query index with new vector, to retrieve similar vector indexes.
"""
import abc
import os
from typing import List, Tuple
import numpy as np
import torch
from megatron.core.datasets.retro.config import Embedder, RetroPreprocessingConfig
from megatron.core.datasets.retro.external_libs import faiss
from megatron.core.datasets.retro.utils import GPTToTextDataset
from .utils import get_index_dir
class Index(abc.ABC):
"""Abstract base class for indexes.
*Note* : While currently only Faiss-based classes are implemented, in the
future, this class will be extended with other types of indexes that have
different performance-accuracy trade-offs.
The primary methods to override are:
- train() : Train index on the sampled training chunks.
- add() : Add all training chunks to index.
"""
@classmethod
def make_object_verbose(cls, index: faiss.Index, verbose: bool) -> None:
"""Make index object verbose.
Args:
index (faiss.Index): Faiss object to set verbose.
verbose (bool): Sets whether index should log status updates during training and adding.
"""
assert isinstance(verbose, bool)
faiss.ParameterSpace().set_index_parameter(index, "verbose", verbose)
def get_empty_index_path(self, config: RetroPreprocessingConfig) -> str:
"""Get file path to empty index (i.e., trained, but unpopulated).
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
Returns:
File path to empty index (i.e., this index has had index.train() called, but not yet index.add()).
"""
return os.path.join(
get_index_dir(config), "empty_%.3f.faissindex" % config.retro_index_train_load_fraction,
)
def get_empty_index(self, config: RetroPreprocessingConfig) -> faiss.Index:
"""Get empty index (i.e., trained, but unpopulated).
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
Returns:
Empty Faiss index, loaded from storage.
"""
return faiss.read_index(self.get_empty_index_path(config))
def get_added_index_path(self, config: RetroPreprocessingConfig) -> str:
"""Get file path to index that has been populated with vectors.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
Returns:
File path to added index (i.e., this index has had both index.train() and index.add() called).
"""
return os.path.join(
get_index_dir(config),
"added_%.3f_%.3f.faissindex"
% (config.retro_index_train_load_fraction, config.retro_index_add_load_fraction,),
)
def get_added_index(self, config: RetroPreprocessingConfig) -> faiss.Index:
"""Get index that has been populated with vectors.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
Returns:
'Added' (i.e., populated) Faiss index, loaded from storage.
"""
return faiss.read_index(self.get_added_index_path(config))
@abc.abstractmethod
def train(self, config: RetroPreprocessingConfig) -> None:
"""Train index on a representative set of vectors.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
"""
@abc.abstractmethod
def add(self, config: RetroPreprocessingConfig, text_dataset: GPTToTextDataset) -> None:
"""Add vectors to index.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
text_dataset (GPTToTextDataset): Text dataset that will be embedded and added to the index.
"""
def embed_text_dataset_block(
self, embedder: Embedder, text_dataset: GPTToTextDataset, _range: Tuple[int, int]
) -> np.ndarray:
"""Embed a range of a text dataset.
Args:
embedder (Embedder): Embedder used for embedding a text dataset.
text_dataset (GPTToTextDataset): Text dataset that will be embedded.
_range (Tuple[int, int]): Start/end sample indices within text dataset used for embedding.
Returns:
An array of embeddings, with shape (len(text_dataset), dimension(embedder)).
"""
sub_dataset = torch.utils.data.Subset(text_dataset, range(*_range))
return embedder.embed_text_dataset(sub_dataset)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""
Exports:
- FaissBaseIndex: Unoptimized Faiss index wrapper
- FaissParallelAddIndex: Optimized index.add() for Faiss index.
"""
from .faiss_base import FaissBaseIndex
from .faiss_par_add import FaissParallelAddIndex
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""
This class implements a simple, un-optimized wrapper around a Faiss index, that
implements the Index interface (see ..index.py). While this class is
instantiable, it is meant to be extended with optimizations in classes that
inherit from this class (see FaissParAddIndex, for an example).
"""
import os
import numpy as np
import torch
from tqdm import tqdm
from megatron.core.datasets.retro.config import RetroPreprocessingConfig
from megatron.core.datasets.retro.external_libs import faiss
from megatron.core.datasets.retro.index.index import Index
from megatron.core.datasets.retro.index.utils import (
get_training_data_merged_path,
num_samples_to_block_ranges,
)
from megatron.core.datasets.retro.utils import GPTToTextDataset, log_retro_rank_0
class FaissBaseIndex(Index):
"""Base class for Faiss-base indexes.
This class wraps a Faiss index, and adds additional functionality for training
and adding codes. This base class performs a naive sequential code adding,
while the optimized FaissParallelAddIndex class performs a parallel
index.add().
"""
def _train(self, config: RetroPreprocessingConfig) -> None:
"""Train index (rank 0's method).
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
"""
assert torch.distributed.get_rank() == 0
# Set num threads (torch.distributed reset it to 1).
faiss.omp_set_num_threads(64)
empty_index_path = self.get_empty_index_path(config)
# Index already exists? -> return.
if os.path.isfile(empty_index_path):
return
# Load data.
merged_path = get_training_data_merged_path(config)
inp = np.memmap(merged_path, dtype="f4", mode="r",).reshape((-1, config.hidden_size))
# Init index.
index = faiss.index_factory(config.hidden_size, config.retro_index_str)
# Move to GPU.
log_retro_rank_0("> move faiss index to gpu.")
index_ivf = faiss.extract_index_ivf(index)
clustering_index = faiss.index_cpu_to_all_gpus(faiss.IndexFlatL2(index_ivf.d))
index_ivf.clustering_index = clustering_index
log_retro_rank_0("> finished moving to gpu.")
self.make_object_verbose(index, True)
self.make_object_verbose(index_ivf, True)
self.make_object_verbose(index_ivf.quantizer, True)
self.make_object_verbose(index_ivf.clustering_index, True)
# Train index.
index.train(inp)
# Save index.
faiss.write_index(index, empty_index_path)
def train(self, config: RetroPreprocessingConfig) -> None:
"""Train index.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
"""
# Single process only.
if torch.distributed.get_rank() == 0:
self._train(config)
torch.distributed.barrier()
def _add(self, config: RetroPreprocessingConfig, text_dataset: GPTToTextDataset) -> None:
"""Add to index (rank 0's method).
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
text_dataset (GPTToTextDataset): Text dataset that will be embedded and added to the index.
"""
assert torch.distributed.get_rank() == 0
dataset_sample_ranges = num_samples_to_block_ranges(len(text_dataset))
# Set num threads (torch.distributed reset it to 1).
faiss.omp_set_num_threads(64)
# Bert embedder.
embedder = config.bert_embedders.mem
# Empty/added index paths.
empty_index_path = self.get_empty_index_path()
added_index_path = self.get_added_index_path()
# Skip adding, if index exists.
if os.path.isfile(added_index_path):
return
# Read trained index.
index = faiss.read_index(empty_index_path)
# Iterate data blocks & add.
for sample_range in tqdm(dataset_sample_ranges, "faiss_base.add"):
# Embed text.
embeds = self.embed_text_dataset_block(embedder, text_dataset, sample_range)
# Add to index.
index.add(embeds)
# Write index.
faiss.write_index(index, added_index_path)
def add(self, config: RetroPreprocessingConfig, text_dataset: GPTToTextDataset) -> str:
"""Add to index.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
text_dataset (GPTToTextDataset): Text dataset that will be embedded and added to the index.
Returns:
File path to the populated index.
"""
# Single process only.
if torch.distributed.get_rank() == 0:
self._add(config, text_dataset)
# Wait for rank 0.
torch.distributed.barrier()
# Get output index path, for return.
return self.get_added_index_path(config)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Multi-process & multi-node version of Faiss's index.add().
This class inherits from FaissBaseIndex, and optimizes the 'add()' method by
making it multi-node and multi-process, with bit-wise equivalence to
FaissBaseIndex. This allows 'add()' to scale out to very large datasets, since
the vast majority of the computational effort is embarrassingly parallel.
"""
import os
import shutil
from typing import Tuple
import numpy as np
import psutil
import torch
from tqdm import tqdm
from megatron.core.datasets.retro.config import Embedder, RetroPreprocessingConfig
from megatron.core.datasets.retro.external_libs import faiss, h5py
from megatron.core.datasets.retro.index.utils import get_added_code_paths, get_added_codes_dir
from megatron.core.datasets.retro.utils import (
GPTToTextDataset,
get_blocks_by_rank,
log_retro_rank_0,
retro_makedir,
)
from .faiss_base import FaissBaseIndex
class FaissParallelAddIndex(FaissBaseIndex):
"""
This class parallelizes both 1) encoding vectors, and 2) adding codes to the
index. This class is more performant than naive use of Faiss, because most
of the computational work is in encoding the vectors, which is an
embarassingly parallel operation.
"""
def encode_block(
self, index: faiss.Index, embedder: Embedder, text_dataset: GPTToTextDataset, block: dict
) -> Tuple[np.ndarray, np.ndarray]:
"""Encode sub-dataset block, to be later added to index.
Encode the data subset, generally in blocks of 1M vectors each. For
each block, the empty/trained index is loaded, codes are computed
via index.sa_encode(), and the resulting codes are saved to disk.
Args:
index (faiss.Index): Faiss index object.
embedder (Embedder): Embedder used to embed text dataset.
text_dataset (GPTToTextDataset): Text dataset to be embedded and encoded.
block (dict): Range information specifying start/end indices within text dataset.
Returns:
A tuple of (embeddings, encodings) for the given block subset of the text dataset.
"""
# Embed block.
embeddings = self.embed_text_dataset_block(embedder, text_dataset, block["range"],)
# Encode block.
log_retro_rank_0("encode.")
codes = index.sa_encode(embeddings)
# Return embeddings for validation purposes.
return embeddings, codes
def save_block(self, config: RetroPreprocessingConfig, block: dict, codes: np.ndarray) -> None:
"""Save block of codes to disk.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
block (dict): Range information specifying the start/end indices within the encoded text dataset. Here, the 'path' item is used for writing the encodings to storage.
codes (np.ndarray): Block of encodings to be saved to storage.
"""
# Save neighbors.
log_retro_rank_0("save codes.")
retro_makedir(config, os.path.dirname(block["path"]))
with h5py.File(block["path"], "w") as f:
f.create_dataset("data", data=codes)
def encode(self, config: RetroPreprocessingConfig, text_dataset: GPTToTextDataset) -> None:
"""Encode text dataset, to be later added to index.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
text_dataset (GPTToTextDataset): Text dataset to be encoded by the index.
"""
codes_dir = get_added_codes_dir(config)
retro_makedir(config, codes_dir)
# Index.
index = self.get_empty_index(config)
# Bert embedder.
embedder = config.retro_bert_embedders.mem
# Missing code blocks.
def validate(f: h5py.File) -> None:
"""Validation method for validating loaded encodings.
Args:
f (h5py.File): File that contains encodings.
"""
assert len(f["data"].shape) == 2
blocks = get_blocks_by_rank(
codes_dir, len(text_dataset), config.retro_block_size, validate=validate,
)
# Encode each block.
for block_index, block in enumerate(blocks.missing):
if block is not None:
# Progress.
log_retro_rank_0(
"encode block %d / %d ... %s."
% (block_index, len(blocks.missing), block["path"],)
)
# Encode and save.
_, codes = self.encode_block(index, embedder, text_dataset, block)
self.save_block(config, block, codes)
# Synchronize progress across all ranks. (for easier observation)
log_retro_rank_0(" > waiting for other ranks to finish block.")
torch.distributed.barrier()
def add_codes(self, config: RetroPreprocessingConfig) -> None:
"""Read codes from disk, and add them to the index.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
"""
if torch.distributed.get_rank() != 0:
return
added_index_path = self.get_added_index_path(config)
if os.path.exists(added_index_path):
return
# Index.
log_retro_rank_0("read empty index.")
index = self.get_empty_index(config)
index_ivf = faiss.extract_index_ivf(index)
# Add codes.
log_retro_rank_0("add codes.")
code_paths = get_added_code_paths(config)
pbar = tqdm(code_paths)
for code_path in pbar:
pbar.set_description(
"add codes, mem %.3f gb, %.1f%%"
% (psutil.virtual_memory()[3] / 1024 ** 3, psutil.virtual_memory()[2],)
)
with h5py.File(code_path) as f:
nload = int(config.retro_index_add_load_fraction * f["data"].shape[0])
offset = int(os.path.basename(code_path).split("-")[0])
xids = np.arange(offset, offset + nload)
codes = np.copy(f["data"][:nload])
index_ivf.add_sa_codes(codes, xids)
# Update index's ntotal.
index.ntotal = index_ivf.ntotal
# Write index.
log_retro_rank_0("write added index.")
faiss.write_index(index, added_index_path)
def remove_codes(self, config: RetroPreprocessingConfig) -> None:
"""Remove added codes after adding to index.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
"""
if torch.distributed.get_rank() != 0:
return
assert os.path.isfile(self.get_added_index_path(config))
if config.retro_index_delete_added_codes:
raise Exception("remove?")
shutil.rmtree(get_added_codes_dir(config), ignore_errors=True)
def add(self, config: RetroPreprocessingConfig, text_dataset: GPTToTextDataset) -> None:
"""Add vectors to index.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
text_dataset (GPTToTextDataset): Text dataset that will be embedded and added to the index.
"""
# Encode chunks.
self.encode(config, text_dataset)
# Add codes to index.
self.add_codes(config)
# Wait for (single-process) adding to complete.
torch.distributed.barrier()
# Remove codes.
self.remove_codes(config)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Utilities for building an index."""
import glob
import os
from typing import List, Tuple
from megatron.core.datasets.retro.config import RetroPreprocessingConfig
from megatron.core.datasets.retro.utils import retro_makedir
def get_index_dir(config: RetroPreprocessingConfig) -> str:
"""Create sub-directory for this index.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
Returns:
Path to index sub-directory within Retro project.
"""
# Directory path.
index_dir_path = os.path.join(
config.retro_project_dir, "index", config.retro_index_type, config.retro_index_str,
)
# Make directory.
retro_makedir(config, index_dir_path)
return index_dir_path
def num_samples_to_block_ranges(
config: RetroPreprocessingConfig, num_samples: int
) -> List[Tuple[int, int]]:
"""Split a range (length num_samples) into sequence of block ranges
of size block_size.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
num_samples (int): Split `num_samples` into consecutive block ranges, where each block is size `config.retro_block_size`.
Returns:
A list of tuples where each item is the (start, end) index for a given block.
"""
block_size = config.retro_block_size
start_idxs = list(range(0, num_samples, block_size))
end_idxs = [min(num_samples, s + block_size) for s in start_idxs]
ranges = list(zip(start_idxs, end_idxs))
return ranges
def get_training_data_root_dir(config: RetroPreprocessingConfig) -> str:
"""Get root directory for embeddings (blocks and merged data).
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
Returns:
Path to the training data directory, which contains both training embedding blocks and the final merged training embeddings.
"""
return os.path.join(config.retro_project_dir, "index", "train_emb")
def get_training_data_block_dir(config: RetroPreprocessingConfig) -> str:
"""Get directory for of saved embedding blocks.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
Returns:
Path to the directory containing the training embedding blocks, which will be later merged into a single embedding array.
"""
return os.path.join(get_training_data_root_dir(config), "blocks")
def get_training_data_block_paths(config: RetroPreprocessingConfig) -> List[str]:
"""Get paths to saved embedding blocks.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
Returns:
Paths of all training embedding blocks.
"""
return sorted(glob.glob(get_training_data_block_dir(config) + "/*.hdf5"))
def get_training_data_merged_path(config: RetroPreprocessingConfig) -> str:
"""Get path to merged training embeddings.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
Returns:
Path to the merged training embedding binary file.
"""
return os.path.join(
get_training_data_root_dir(config),
"train_%.3f.bin" % config.retro_index_train_load_fraction,
)
def get_added_codes_dir(config: RetroPreprocessingConfig) -> str:
"""Get directory of saved encodings.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
Returns:
Path to the directory containing the vector encodings for adding to the index.
"""
return os.path.join(get_index_dir(config), "add_codes")
def get_added_code_paths(config: RetroPreprocessingConfig) -> List[str]:
"""Get paths to all saved encodings.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
Returns:
Paths of all vector encoding blocks, for adding to the index.
"""
return sorted(glob.glob(get_added_codes_dir(config) + "/*.hdf5"))
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Validate an index's data.
This module contains functionality for checking for bitwise equality across code
changes. The training and adding steps of index construction can be validated
separately. The following high-level checks are supported:
- Training: Validate that saved training embeddings are bitwise equal with a
sample set of freshly computed embeddings. (*Note*:
`--no-retro-index-delete-training-embeddings` must be used.)
- Adding: Validate that the saved encodings are bitwise equal with a sample of
sample set of freshly computed encodings. (*Note*:
`--no-retro-index-delete-added-codes` must be used.)
"""
import typing
import numpy as np
import torch
from torch.utils.data import Subset
from megatron.core.datasets.retro.config import RetroPreprocessingConfig
from megatron.core.datasets.retro.external_libs import h5py
from megatron.core.datasets.retro.utils import (
GPTToTextDataset,
get_blocks_by_rank,
log_retro_rank_0,
)
from .build import get_text_dataset_for_adding, get_text_dataset_for_training
from .factory import IndexFactory
from .utils import get_added_codes_dir, get_training_data_block_dir
##################################################
# Validate trained index.
##################################################
def validate_training_embeddings(config: RetroPreprocessingConfig) -> None:
"""Validate training embeddings.
Steps:
- Randomly sample subset of text dataset blocks.
- Embed each block.
- Compare against saved embeddings.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
"""
# Training text dataset.
text_dataset = get_text_dataset_for_training(config)
# Sample existing blocks.
blocks = get_blocks_by_rank(
dirname=get_training_data_block_dir(config),
n_samples=len(text_dataset),
block_size=config.retro_block_size,
validate=None,
sample=config.retro_task_validate,
)
assert blocks.n_missing_world == 0
# Embed & validate blocks.
embedder = config.retro_bert_embedders.mem
for block_idx, block in enumerate(blocks.existing):
# Missing block lists are extended with None to have equal-length
# lists. Skip the Nones.
if block is not None:
# Progress. (*note*: move world progress to here.)
log_retro_rank_0(
"embed training block %d / %d ... %s."
% (block_idx, len(blocks.existing), block["path"],)
)
# Load existing block embeddings.
with h5py.File(block["path"]) as f:
existing_embeddings = np.copy(f["data"])
# Embed block.
sub_dataset = Subset(text_dataset, range(*block["range"]))
embeddings = embedder.embed_text_dataset(sub_dataset, "train")
# Check equality.
log_retro_rank_0(" > validate.")
assert np.array_equal(existing_embeddings, embeddings)
# Synchronize progress across all ranks. (for easier observation)
log_retro_rank_0(" > waiting for other ranks to finish block.")
torch.distributed.barrier()
log_retro_rank_0(" > finished validating training embeddings.")
##################################################
# Validate filled index.
##################################################
def validate_added_encodings(config: RetroPreprocessingConfig) -> None:
"""Validate added encodings.
Steps:
- Randomly sample subset of text dataset blocks.
- Encode each block.
- Compare against saved encodings.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
"""
# Index.
index = IndexFactory.get_index(config.retro_index_type)
inner_index = index.get_empty_index(config)
# Text dataset.
text_dataset = get_text_dataset_for_adding(config)
# Sample existing blocks.
def validate(f: h5py.File) -> None:
"""Validation method for validating encoding blocks.
Args:
f (h5py.File): File with block of encodings.
"""
assert len(f["data"].shape) == 2
blocks = get_blocks_by_rank(
dirname=get_added_codes_dir(config),
n_samples=len(text_dataset),
block_size=config.retro_block_size,
validate=validate,
sample=config.retro_task_validate,
)
assert blocks.n_missing_world == 0
# Encode and validate blocks.
embedder = config.retro_bert_embedders.mem
for block_idx, block in enumerate(blocks.existing):
if block is not None:
# Progress.
log_retro_rank_0(
"encode block %d / %d ... %s." % (block_idx, len(blocks.existing), block["path"],)
)
# Load existing codes.
with h5py.File(block["path"]) as f:
existing_codes = np.copy(f["data"])
# Encode block.
embeddings, codes = index.encode_block(inner_index, embedder, text_dataset, block)
# Check equality.
log_retro_rank_0(" > validate.")
assert np.array_equal(existing_codes, codes)
# Synchronize progress across all ranks. (for easier observation)
log_retro_rank_0(" > waiting for other ranks to finish block.")
torch.distributed.barrier()
log_retro_rank_0(" > finished validating added encodings.")
##################################################
# Validate index (trained + filled).
##################################################
def validate_index(config: RetroPreprocessingConfig) -> None:
"""Validate index.
Validating index involves sequentially running stages above:
- Validate trained index.
- Validate filled index.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
"""
# Validate training embeddings.
validate_training_embeddings(config)
# Validate added codes.
validate_added_encodings(config)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""
A GPTChunkDataset is a wrapper around a regular GPTDataset, that sequentially
chunks the sample tokens into `retro_chunk_length` sized smaller samples.
For example, if the GPTDataset has 100 samples and a sequence length of 2048, and
retro_chunk_length is 64, then the GPTChunkDataset will contain 100*(2048/64) =
3200 samples, each with length 64.
"""
import torch
from megatron.core.datasets.gpt_dataset import GPTDataset
from megatron.core.datasets.retro.utils import get_num_chunks_per_sample
from .utils import get_neighbor_dir
class GPTChunkDataset(torch.utils.data.Dataset):
"""Pretraining chunk dataset wraps a standard GPT dataset.
This dataset conceptually divides each sample (e.g., length 2048)
into chunks (e.g., length 64) and restructures them into a list of
chunks (e.g., length num_samples * num_chunks_per_sample).
Args:
sample_dataset (GPTDataset): Original GPT dataset, with `sequence_length` size samples.
sample_length (int): Alias for `sequence_length`.
chunk_length (int): Retro chunk length (e.g., 64).
"""
def __init__(self, sample_dataset: GPTDataset, sample_length: int, chunk_length: int):
super().__init__()
self.sample_dataset = sample_dataset
self.chunk_length = chunk_length
self.n_chunks_per_sample = get_num_chunks_per_sample(sample_length, chunk_length)
self.n_samples = len(sample_dataset)
self.n_chunks = self.n_samples * self.n_chunks_per_sample
def __len__(self) -> int:
"""Get dataset length.
Returns:
Dataset length.
"""
return self.n_chunks
def __getitem__(self, idx: int) -> dict:
"""Get sample, including represented document IDs.
Args:
idx (int): Sample index.
Returns:
A sample, which contains both the chunk-length token sample ('text') along with all document_ids ('doc_ids') contained withing the full `sequence_length` sample.
"""
# Convert global chunk index to global sample index & local chunk index.
sample_idx = idx // self.n_chunks_per_sample
chunk_idx = idx % self.n_chunks_per_sample
# Extract sample data.
sample = self.sample_dataset[sample_idx]
sample_token_ids = sample["text"]
sample_doc_ids = sample["document_ids"]
# Chunk start/end token idxs.
token_start_idx = chunk_idx * self.chunk_length
token_end_idx = token_start_idx + self.chunk_length
chunk_token_ids = sample_token_ids[token_start_idx:token_end_idx]
# Sample.
return {
"doc_ids": sample_doc_ids,
"text": chunk_token_ids,
}
def build_gpt_chunk_datasets_from_gpt_datasets(
project_dir: str, gpt_datasets: dict, sample_length: int, chunk_length: int,
) -> dict:
"""Get train, valid, test GPT chunk datasets.
Args:
project_dir (str): Retro project dir.
gpt_datasets (dict): Mapping of 'train', 'valid', and 'test' GPT datasets (original, unchunked datasets).
sample_length (int): Alias of `sequence_length`.
chunk_length (int): Retro chunk length (e.g., 64).
Returns:
A <dict> ?
"""
# GPT chunk datasets.
chunk_datasets = {
key: {
"dataset": GPTChunkDataset(sample_ds, sample_length, chunk_length),
"neighbor_dir": get_neighbor_dir(project_dir, key, sample_ds),
"num_active_chunks": num_active_samples
* get_num_chunks_per_sample(sample_length, chunk_length),
}
if sample_ds
else None
for key, (sample_ds, num_active_samples) in gpt_datasets.items()
}
return chunk_datasets
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""A MultiSplitGPTDataset can handle multiple intersecting split strings, as well
as returning all of the document IDs of a sample."""
import logging
from dataclasses import dataclass
from typing import Dict, List
import numpy
from megatron.core.datasets.blended_megatron_dataset_config import (
convert_split_vector_to_split_matrix,
parse_and_normalize_split,
)
from megatron.core.datasets.gpt_dataset import GPTDataset, GPTDatasetConfig
from megatron.core.datasets.indexed_dataset import IndexedDataset
from megatron.core.datasets.utils import Split
from megatron.core.utils import log_single_rank
logger = logging.getLogger(__name__)
@dataclass
class MultiSplitGPTDatasetConfig(GPTDatasetConfig):
"""Configuration object for Megatron Core blended and Retro datasets.
Args:
return_document_ids (bool): Whether to return the document ids when querying the dataset. Turn this option on during preprocessing.
split_preprocessing (str): The Retro preprocessing split string. It follows the same pattern convention as 'split'. Not to be used with 'blend_per_split'.
"""
return_document_ids: bool = None
split_preprocessing: str = None
def __post_init__(self) -> None:
"""Validate config attributes."""
super().__post_init__()
assert self.split is not None, "the Retro data pipeline does not support 'blend_per_split'"
assert self.return_document_ids is not None, "this attribute must be user defined"
assert self.split_preprocessing is not None, "this attribute must be user defined"
split_vector = parse_and_normalize_split(self.split)
split_preprocessing_vector = parse_and_normalize_split(self.split_preprocessing)
if not numpy.allclose(split_vector, split_preprocessing_vector):
self.split_matrix = convert_split_vector_to_split_matrix(
split_vector, split_preprocessing_vector
)
log_single_rank(
logger,
logging.WARNING,
f"split =/= split_preprocessing. Let split_matrix = {self.split_matrix}",
)
class MultiSplitGPTDataset(GPTDataset):
"""Retro's customized GPT dataset.
Args:
indexed_dataset (IndexedDataset): The IndexedDataset around which to build the MegatronDataset.
dataset_path (str): The real path on disk to the dataset, for bookkeeping.
indexed_indices (numpy.ndarray): The set of the documents indices to expose.
num_samples (int): The number of samples to draw from the indexed dataset.
index_split (Split): The indexed_indices Split.
config (MultiSplitGPTDatasetConfig): The Retro-specific container for all config sourced parameters.
"""
def __init__(
self,
indexed_dataset: IndexedDataset,
dataset_path: str,
indexed_indices: numpy.ndarray,
num_samples: int,
index_split: Split,
config: MultiSplitGPTDatasetConfig,
) -> None:
super().__init__(
indexed_dataset, dataset_path, indexed_indices, num_samples, index_split, config
)
def __getitem__(self, idx: int) -> Dict[str, numpy.ndarray]:
"""Get dataset sample.
Args:
idx (int): The index into the dataset.
Returns:
Dict[str, numpy.ndarray]: The text ids and (optionally) the document ids wrapped in a dictionary.
"""
text, document_ids = self._query_document_sample_shuffle_indices(idx)
if self.config.return_document_ids:
return {"text": text, "document_ids": document_ids}
else:
return {"text": text}
@staticmethod
def _key_config_attributes() -> List[str]:
"""Add custom attributes for building unique dataset hash.
The preprocessing split used for preprocessing will constrain the samples available for pretraining.
Returns:
List[str]: The key config attributes.
"""
return super(MultiSplitGPTDataset, MultiSplitGPTDataset)._key_config_attributes() + [
"split_preprocessing"
]
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Entry point for querying an index using a GPTChunkDataset.
Querying involves:
- Iterate all chunks in the GPTChunkDataset.
- Query index for neighbor chunk IDs (i.e., chunks from the chunk database).
- Save neighbor chunk IDs to disk, for use in building a RetroDataset sample
during pretraining.
"""
import os
import time
import typing
import numpy as np
import psutil
import torch
from tqdm import tqdm
from megatron.core.datasets.retro.config import RetroPreprocessingConfig
from megatron.core.datasets.retro.db.dataset import DBDataset
from megatron.core.datasets.retro.db.utils import (
get_merged_train_dataset as get_db_merged_train_dataset,
)
from megatron.core.datasets.retro.external_libs import faiss, h5py
from megatron.core.datasets.retro.index.factory import IndexFactory
from megatron.core.datasets.retro.index.index import Index
from megatron.core.datasets.retro.index.utils import get_index_dir
from megatron.core.datasets.retro.query.gpt_chunk_dataset import GPTChunkDataset
from megatron.core.datasets.retro.utils import (
GPTToTextDataset,
get_blocks_by_rank,
log_retro_rank_0,
retro_makedir,
)
from .gpt_chunk_dataset import build_gpt_chunk_datasets_from_gpt_datasets
def get_index(config: RetroPreprocessingConfig, ondisk: bool = False,) -> faiss.Index:
"""Read index from disk.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
ondisk (bool): If `ondisk = True`, memory map the index. (For debugging purposes only; very non-performant.)
Returns:
A Faiss index, loaded from storage.
"""
# Load index.
index_wrapper = IndexFactory.get_index(config.retro_index_type)
index_dir = get_index_dir(config)
added_index_path = index_wrapper.get_added_index_path(config)
if ondisk:
index = faiss.read_index(added_index_path, faiss.IO_FLAG_MMAP)
else:
index = faiss.read_index(added_index_path)
# Search parameters.
faiss.ParameterSpace().set_index_parameter(index, "efSearch", config.retro_query_ef_search)
faiss.ParameterSpace().set_index_parameter(index, "nprobe", config.retro_query_nprobe)
return index
def embed_block(
config: RetroPreprocessingConfig, gpt_dataset: GPTChunkDataset, block: dict,
) -> np.ndarray:
"""Embed block of chunks.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
gpt_dataset (GPTChunkDataset): Chunk dataset to be embedded.
block (dict): Range information containing start/end indices of subset of chunk dataset.
Returns:
Embeddings array, with shape (len(block["range"]), dimension(embedder)).
"""
text_block_dataset = torch.utils.data.Subset(
GPTToTextDataset(gpt_dataset, config.retro_tokenizers.gpt), range(*block["range"]),
)
return config.retro_bert_embedders.mem.embed_text_dataset(text_block_dataset)
def query_embeddings(
config: RetroPreprocessingConfig,
db_dataset: DBDataset,
index: Index,
embeddings: np.ndarray,
chunk_id_range: range,
sample_map: dict,
n_chunks_per_sample: int,
verbose: bool = True,
) -> typing.Tuple[np.ndarray, np.ndarray]:
"""Query neighbors of a block of embeddings.
Querying includes:
- Query index for neighbor chunk IDs.
- Filter chunk IDs that have the same document ID as the queried embedding.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
db_dataset (DBDataset): Dataset containing chunk database entries.
index (Index): Vector index populated with chunk database indices.
embeddings (np.ndarray): Embeddings from GPT chunk dataset.
chunk_id_range (range): Chunk ID range from GPT chunk dataset.
sample_map (dict): Mapping of sample_idx to dataset_idx and document_ids. Used for document filtering.
n_chunks_per_sample (int): Number of chunks per sample (e.g., sequence_length / chunk_length).
verbose (bool): Log querying progress.
Returns:
A tuple of original (unfiltered) neighbor IDs, and filtered (by document ID) neighbor IDs.
"""
# Query neighbor ids.
if verbose:
log_retro_rank_0("search.")
t = time.time()
assert index.ntotal > 0, "check we don't accidentally have an empty index."
_, query_neighbor_ids = index.search(embeddings, config.retro_query_num_neighbors_query)
if verbose:
log_retro_rank_0(" time : %.3f sec." % (time.time() - t))
# Filter banned neighbor ids.
if verbose:
log_retro_rank_0("filter banned neighbor ids.")
filtered_neighbor_ids = np.full(
shape=(len(query_neighbor_ids), config.retro_query_num_neighbors_save),
fill_value=-1,
dtype="int64",
)
min_chunk_id, max_chunk_id = chunk_id_range
for chunk_id in range(min_chunk_id, max_chunk_id):
sample_id = chunk_id // n_chunks_per_sample
sample = sample_map[sample_id]
sample_dataset_idx = sample["dataset_idx"].item()
sample_doc_ids = sample["doc_ids"].tolist()
sample_doc_tuples = [(sample_dataset_idx, d) for d in sample_doc_ids]
# Get valid neighbors (!= -1).
query_row = [i for i in query_neighbor_ids[chunk_id - min_chunk_id] if i >= 0]
# Filter row.
filtered_row = [
i
for i in query_row
if tuple(db_dataset.doc_tuples[i].tolist()) not in sample_doc_tuples
]
filtered_row = filtered_row[: config.retro_query_num_neighbors_save]
filtered_row += [-1] * (config.retro_query_num_neighbors_save - len(filtered_row))
filtered_neighbor_ids[chunk_id - min_chunk_id] = filtered_row
return query_neighbor_ids, filtered_neighbor_ids
def query_embedding_block(
config: RetroPreprocessingConfig,
db_dataset: DBDataset,
index: Index,
embeddings: np.ndarray,
chunk_id_range: range,
sample_map: dict,
n_chunks_per_sample: int,
) -> typing.Tuple[np.ndarray, np.ndarray]:
"""Query a block of embeddings.
The block is broken into smaller sub-blocks, for easier tracking of progress.
Both the raw neighbor IDs and the filtered neighbor IDs (i.e., chunks with the
same document ID are removed) are collected.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
db_dataset (DBDataset): Dataset containing chunk database entries.
index (Index): Vector index populated with chunk database indices.
embeddings (np.ndarray): Embeddings from GPT chunk dataset.
chunk_id_range (range): Chunk ID range from GPT chunk dataset.
sample_map (dict): Mapping of sample_idx to dataset_idx and document_ids. Used for document filtering.
n_chunks_per_sample (int): Number of chunks per sample (e.g., sequence_length / chunk_length).
Returns:
A tuple of original (unfiltered) neighbor IDs, and filtered (by document ID) neighbor IDs.
"""
query_neighbor_ids = []
filtered_neighbor_ids = []
# Query in sub-blocks.
partial_block_size = 1000
for partial_start_idx in tqdm(
range(0, len(embeddings), partial_block_size),
" search",
miniters=(len(embeddings) // partial_block_size) // 10,
disable=torch.distributed.get_rank() != 0,
):
partial_end_idx = min(len(embeddings), partial_start_idx + partial_block_size)
partial_embeddings = embeddings[partial_start_idx:partial_end_idx]
partial_chunk_id_range = (
chunk_id_range[0] + partial_start_idx,
chunk_id_range[0] + partial_end_idx,
)
partial_query_neighbor_ids, partial_filtered_neighbor_ids = query_embeddings(
config,
db_dataset,
index,
partial_embeddings,
partial_chunk_id_range,
sample_map,
n_chunks_per_sample,
verbose=False,
)
query_neighbor_ids.append(partial_query_neighbor_ids)
filtered_neighbor_ids.append(partial_filtered_neighbor_ids)
# Concatenate.
query_neighbor_ids = np.concatenate(query_neighbor_ids, axis=0)
filtered_neighbor_ids = np.concatenate(filtered_neighbor_ids, axis=0)
return query_neighbor_ids, filtered_neighbor_ids
def query_block_neighbors(
config: RetroPreprocessingConfig,
db_dataset: DBDataset,
query_dataset: GPTChunkDataset,
index: Index,
block: dict,
) -> None:
"""Query neighbors of a dataset block (i.e., range).
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
db_dataset (DBDataset): Dataset containing chunk database entries.
query_dataset (GPTChunkDataset): GPT chunk dataset to be queried.
index (Index): Vector index populated with chunk database indices.
block (dict): Range information containing start/end indices for querying GPT chunk dataset.
"""
n_chunks_per_sample = query_dataset.n_chunks_per_sample
# Sample map.
sample_ids = sorted(
list(set(chunk_id // n_chunks_per_sample for chunk_id in range(*block["range"])))
)
sample_map = {}
for i in sample_ids:
sample = query_dataset.sample_dataset[i]
sample_map[i] = {
"dataset_idx": sample["dataset_id"],
"doc_ids": sample["document_ids"],
}
# Embed block.
embeddings = embed_block(config, query_dataset, block)
# Query embeddings.
_, filtered_neighbor_ids = query_embedding_block(
config, db_dataset, index, embeddings, block["range"], sample_map, n_chunks_per_sample,
)
if config.retro_task_validate is None:
# Save neighbors.
log_retro_rank_0("save neighbors.")
retro_makedir(config, os.path.dirname(block["path"]))
f = h5py.File(block["path"], "w")
f.create_dataset("neighbors", data=filtered_neighbor_ids)
f.close()
else:
# Validate neighbors.
with h5py.File(block["path"]) as f:
existing_neighbor_ids = np.copy(f["neighbors"])
assert np.array_equal(existing_neighbor_ids, filtered_neighbor_ids)
def query_dataset_neighbors(
config: RetroPreprocessingConfig,
db_dataset: DBDataset,
query_dataset: GPTChunkDataset,
num_active_chunks: int,
prefix: str,
neighbor_dir: str,
index: Index,
) -> None:
"""Query neighbors of each chunk within a dataset.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
db_dataset (DBDataset): Dataset containing chunk database entries.
query_dataset (GPTChunkDataset): GPT chunk dataset to be queried.
num_active_chunks (int): The 'active' chunks are the subset of the GPT chunk dataset that aren't being queried. This argument is used when validating the correctness of a subset of the GPT chunk dataset.
prefix (str): Extra string for logging progress.
neighbor_dir (str): File path to directory for saving neighbor IDs.
index (Index): Vector index populated with chunk database indices.
"""
def validate(f: h5py.File) -> None:
"""Validation method for validating saved neighbor IDs.
Args:
f (h5py.File): File containing save neighbor IDs.
"""
assert f["neighbors"].shape[1] == config.retro_query_num_neighbors_save, (
"neighbors.shape == %s; num_neighbors_target == %d."
% (str(f["neighbors"].shape), config.retro_num_neighbors_target,)
)
if config.retro_task_validate is None:
retro_makedir(config, neighbor_dir)
blocks = get_blocks_by_rank(
neighbor_dir, num_active_chunks, config.retro_block_size, validate=validate,
)
active_blocks = blocks.missing
else:
blocks = get_blocks_by_rank(
neighbor_dir,
num_active_chunks,
config.retro_block_size,
validate=validate,
sample=config.retro_task_validate,
)
assert blocks.n_missing_world == 0
active_blocks = blocks.existing
# Query each block.
for block_index, block in enumerate(active_blocks):
if block is not None:
# Progress.
log_retro_rank_0(
"%squery '%s' block %d / %d ... %s ... mem %.3f gb, %.1f%%."
% (
"" if config.retro_task_validate is None else "[validate] ",
prefix,
block_index,
len(active_blocks),
os.path.basename(block["path"]),
psutil.virtual_memory()[3] / 1024 ** 3,
psutil.virtual_memory()[2],
)
)
# Query block neighbors.
query_block_neighbors(config, db_dataset, query_dataset, index, block)
# Synchronize progress across all ranks. (for easier observation)
log_retro_rank_0(" > waiting for other ranks to finish block.")
torch.distributed.barrier()
def query_neighbors(config: RetroPreprocessingConfig) -> None:
"""Query pretraining datasets (train & valid).
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
"""
# Num threads.
faiss.omp_set_num_threads(64)
# Load chunk db dataset.
log_retro_rank_0("load chunk db dataset.")
db_dataset = get_db_merged_train_dataset(
project_dir=config.retro_project_dir,
chunk_length=config.retro_gpt_chunk_length,
eod_token_id=config.retro_tokenizers.gpt.eod,
)
db_dataset.load_doc_tuples()
# Load index.
log_retro_rank_0(" > get index.")
index = get_index(config)
# Query each (i.e., train, valid, test) dataset.
log_retro_rank_0(" > query.")
for prefix, info in vars(config.retro_gpt_chunk_datasets).items():
if info is None:
continue
log_retro_rank_0(
" > query '%s' dataset ... %d samples." % (prefix, info["num_active_chunks"])
)
query_dataset_neighbors(
config,
db_dataset,
info["dataset"],
info["num_active_chunks"],
prefix,
info["neighbor_dir"],
index,
)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""
A RetroDataset wraps both:
- A GPTDataset (which is nested as GPTChunkDataset -> MultiSplitGPTDataset ->
GPTDataset).
- Neighbor IDs of chunks in the chunk database, that were saved during
preprocessing.
Both the GPT sample data and the neighbor IDs are returned within a sample from
this dataset.
"""
import os
from typing import Any, Dict, Optional, Tuple
import numpy as np
import torch
from megatron.core.datasets.retro.db.dataset import DBDataset
from megatron.core.datasets.retro.db.utils import get_merged_train_dataset as get_db_dataset
from megatron.core.datasets.retro.external_libs import h5py
from megatron.core.datasets.retro.utils import BlockPathMap, log_retro_rank_0
from megatron.core.models.retro import RetroConfig
from .gpt_chunk_dataset import GPTChunkDataset, build_gpt_chunk_datasets_from_gpt_datasets
from .utils import get_query_dir
class RetroDataset(torch.utils.data.Dataset):
"""Dataset of retro samples.
Each sample contains the original GPT sample, along with the token IDs
of each neighbor of each chunk within the sequence. Neighbor array has
shape (num_chunks_per_sample, num_neighbors, num_retrieved_tokens).
** Note: chunk dataset wraps original GPT dataset (see gpt_chunk_dataset.py).
Args:
num_queried_samples (int): Total number of queried samples.
num_neighbors (int): Total number of saved neighbors.
num_retrieved_chunks (int): Number of retrieved chunks (e.g., 2 for neighbor + continuation).
block_size (int): Number of neighbor entries per file.
db_dataset (DBDataset): Chunk database used for retrieval.
chunk_dataset (GPTChunkDataset): GPT chunk dataset, which is a wrapper around a standard GPT dataset that breaks each sample into chunks.
neighbor_path_map (BlockPathMap): Mapping of neighbor ID to file path.
"""
def __init__(
self,
num_queried_samples: int,
num_neighbors: int,
num_retrieved_chunks: int,
block_size: int,
db_dataset: DBDataset,
chunk_dataset: GPTChunkDataset,
neighbor_path_map: BlockPathMap,
):
super().__init__()
self.num_queried_samples = num_queried_samples
self.num_neighbors = num_neighbors
self.num_retrieved_chunks = num_retrieved_chunks
self.block_size = block_size
self.db_dataset = db_dataset
self.chunk_dataset = chunk_dataset
self.neighbor_path_map = neighbor_path_map
def __len__(self) -> int:
"""Dataset length.
Returns:
Number of samples in dataset.
"""
return len(self.chunk_dataset.sample_dataset)
def __getitem__(self, sample_idx: int) -> dict:
"""Get dataset sample.
Args:
sample_idx (int): Index of sample in dataset.
Returns:
A dict consisting of GPT sample (attribute 'text') and corresponding neighbor chunk IDs ('neighbor_chunks', for indexing chunk database) and neighbor token IDs (corresponding chunk database GPT tokens).
"""
n_chunks_per_sample = self.chunk_dataset.n_chunks_per_sample
# Wrap sample idx around number of queried samples.
sample_idx = sample_idx % self.num_queried_samples
# Get standard sample.
sample = self.chunk_dataset.sample_dataset[sample_idx]
# Sample idx to chunk idxs.
chunk_idxs = list(
range(sample_idx * n_chunks_per_sample, (sample_idx + 1) * n_chunks_per_sample,)
)
# Collect retrieved tokens.
all_retrieved_chunk_ids = []
all_retrieved_token_ids = []
for chunk_idx in chunk_idxs:
# Neighbor chunk ids.
neighbor_path = self.neighbor_path_map[chunk_idx]
with h5py.File(neighbor_path, "r") as f:
neighbor_chunk_ids = f["neighbors"][
chunk_idx % self.block_size, : self.num_neighbors
].tolist()
# Retrieved (neighbor + continuation) token ids.
retrieved_chunk_ids = []
retrieved_token_ids = []
for neighbor_chunk_id in neighbor_chunk_ids:
current_chunk_ids = [
i % len(self.db_dataset)
for i in range(neighbor_chunk_id, neighbor_chunk_id + self.num_retrieved_chunks)
]
current_token_ids = [self.db_dataset[ci]["text"] for ci in current_chunk_ids]
retrieved_chunk_ids.append(current_chunk_ids)
retrieved_token_ids.append(current_token_ids)
# Collect retrieved tokens.
all_retrieved_chunk_ids.append(retrieved_chunk_ids)
all_retrieved_token_ids.append(retrieved_token_ids)
# Reshape retrieved tokens.
all_retrieved_chunk_ids = np.array(all_retrieved_chunk_ids).reshape(
(n_chunks_per_sample, self.num_neighbors, -1)
)
all_retrieved_token_ids = np.array(all_retrieved_token_ids).reshape(
(n_chunks_per_sample, self.num_neighbors, -1)
)
# Sample.
sample: Dict[str, np.ndarray] = {
**sample,
"neighbor_chunks": all_retrieved_chunk_ids,
"neighbor_tokens": all_retrieved_token_ids,
}
return sample
def get_retro_datasets(
config: RetroConfig, gpt_datasets: dict, sample_length: int, eod_token_id: int,
) -> Tuple[Optional[RetroDataset], Optional[RetroDataset], Optional[RetroDataset]]:
"""Get train, valid, test retro datasets.
Args:
config (RetroConfig): Retro preprocessing config.
gpt_datasets (dict): Mapping of data split key ('train', 'valid', or 'test') to the original sequence-length GPT dataset (i.e., not the chunk dataset).
sample_length (int): Alias to `sequence_length`.
eod_token_id (int): GPT EOD token ID.
Returns:
A tuple of 'train', 'valid', and 'test' `RetroDataset`s.
"""
# DB dataset.
db_dataset = get_db_dataset(
project_dir=config.retro_project_dir,
chunk_length=config.retro_chunk_length,
eod_token_id=eod_token_id,
)
# GPT chunk datasets.
chunk_ds_info_map = build_gpt_chunk_datasets_from_gpt_datasets(
project_dir=config.retro_project_dir,
gpt_datasets=gpt_datasets,
sample_length=sample_length,
chunk_length=config.retro_chunk_length,
)
# Retro datasets.
retro_dataset_map: Dict[str, Optional[RetroDataset]] = {}
query_dir = get_query_dir(config.retro_project_dir)
for data_key, chunk_ds_info in chunk_ds_info_map.items():
# Skip unused datasets.
if chunk_ds_info is None:
retro_dataset_map[data_key] = None
continue
# For consistency with preprocessing, the neighbor_dir is overwritten
# (from its setting in `build_gpt_chunk_datasets_from_gpt_datasets()`
# above). This is one piece -- along with setting data_path and
# train_samples from config.json -- of ensuring consistency between
# preprocessing and pretraining.
chunk_dataset = chunk_ds_info["dataset"]
chunk_ds_info["neighbor_dir"] = os.path.join(
query_dir, config.retro_neighbor_dirs[data_key],
)
neighbor_dir = chunk_ds_info["neighbor_dir"]
neighbor_path_map = BlockPathMap.from_dir(
dir=neighbor_dir, block_size=config.retro_block_size
)
# Verify num chunks.
n_active_chunks = chunk_ds_info["num_active_chunks"]
n_neighbor_chunks = neighbor_path_map.max_idx
if not os.path.isdir(neighbor_dir):
if torch.distributed.get_rank() == 0:
raise Exception(
"neighbor directory '%s' not found; please "
"compare --train-samples, --seq-length, --seed, "
"--eval-iters, and --eval-interval, with "
"retro preprocessing args." % neighbor_dir
)
torch.distributed.barrier()
exit()
if config.retro_verify_neighbor_count and n_active_chunks != n_neighbor_chunks:
if torch.distributed.get_rank() == 0:
log_retro_rank_0("neighbor_dir : %s" % neighbor_dir)
log_retro_rank_0("neighbor_path_map : %s" % neighbor_path_map)
raise Exception(
"num sampled chunks (%d) != num neighbor chunks "
"(%d); did you complete querying the entire "
"pretraining dataset?" % (n_active_chunks, n_neighbor_chunks)
)
torch.distributed.barrier()
exit()
# Retro dataset.
retro_dataset_map[data_key] = RetroDataset(
num_queried_samples=gpt_datasets[data_key][1],
num_neighbors=config.retro_num_neighbors,
num_retrieved_chunks=config.retro_num_retrieved_chunks,
block_size=config.retro_block_size,
db_dataset=db_dataset,
chunk_dataset=chunk_dataset,
neighbor_path_map=neighbor_path_map,
)
return (
retro_dataset_map["train"],
retro_dataset_map["valid"],
retro_dataset_map["test"],
)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Utilities for querying the pretraining dataset."""
import os
from megatron.core.datasets.megatron_dataset import MegatronDataset
def get_query_dir(project_dir: str) -> str:
"""Get root directory of all saved query data.
Args:
project_dir (str): Retro project dir.
Returns:
Path to query sub-directory in Retro project.
"""
return os.path.join(project_dir, "query")
def get_neighbor_dir(project_dir: str, key: str, dataset: MegatronDataset) -> str:
"""Get directory containing neighbor IDs for a dataset (i.e., train, valid, or test).
Args:
project_dir (str): Retro project dir.
key (str): Dataset split key; 'train', 'valid', or 'test'.
dataset (MegatronDataset): Dataset containing unique hash for finding corresponding neighbors.
Returns:
Path to directory containing this dataset's neighbors within Retro project.
"""
return os.path.join(
get_query_dir(project_dir), os.path.basename(f"{key}_{dataset.unique_description_hash}"),
)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Utilities for Retro preprocessing."""
import glob
import logging
import os
from collections import defaultdict
from types import SimpleNamespace
from typing import Any, Callable, Dict, List, Optional
import numpy as np
import torch
from tqdm import tqdm
from megatron.core import parallel_state
from megatron.core.datasets.retro.config import RetroPreprocessingConfig
from megatron.core.datasets.retro.query.multi_split_gpt_dataset import (
MultiSplitGPTDataset,
MultiSplitGPTDatasetConfig,
)
from megatron.core.utils import log_single_rank
from .external_libs import h5py
logger = logging.getLogger(__name__)
def log_retro_rank_0(message: str) -> None:
"""Log on rank 0.
Args:
message (str): Message to log.
"""
log_single_rank(logger, logging.INFO, "[RETRO] " + message)
def retro_makedir(config: RetroPreprocessingConfig, path: str) -> None:
"""Make a directory, conditional on not being in validation mode.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
path (str): Path to directory.
"""
if config.retro_task_validate is None:
os.makedirs(path, exist_ok=True)
def extract_data_config(config: RetroPreprocessingConfig) -> MultiSplitGPTDatasetConfig:
"""Extract data config from dataset.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
Returns:
The config object used to build the dataset.
"""
return config.retro_gpt_chunk_datasets.train["dataset"].sample_dataset.config
def get_num_chunks_per_sample(sample_length: int, chunk_length: int) -> int:
"""Compute seq_length // chunk_length.
Args:
sample_length (int): Alias of `sequence_length`.
chunk_length (int): Retro chunk length (e.g., 64).
Returns:
Number of chunks per sample (i.e., `sequence_length` / `chunk_length`).
"""
assert sample_length % chunk_length == 0
return sample_length // chunk_length
class GPTToTextDataset(torch.utils.data.Dataset):
"""Dataset to convert GPT tokens to text.
Args:
gpt_dataset (MultiSplitGPTDataset): GPT dataset, which outputs GPT token samples.
gpt_tokenizer (Any): GPT tokenizer.
"""
def __init__(self, gpt_dataset: MultiSplitGPTDataset, gpt_tokenizer: Any):
super().__init__()
self.gpt_dataset = gpt_dataset
self.gpt_tokenizer = gpt_tokenizer
def __len__(self) -> int:
"""Dataset length.
Returns:
Number of samples in the dataset.
"""
return len(self.gpt_dataset)
def __getitem__(self, idx: int) -> dict:
"""Get dataset sample.
Args:
idx (int): Index of sample.
Returns:
A dict containing attribute 'text' of type string.
"""
gpt_token_ids = self.gpt_dataset[idx]["text"].tolist()
text = self.gpt_tokenizer.detokenize(gpt_token_ids)
return {"text": text}
def get_blocks(
dirname: str, n_samples: int, block_size: int, validate: Callable = None,
) -> SimpleNamespace:
"""Divide range [0, num_samples) to sequence of block ranges.
This is a core method within the concept of block processing. The idea
is to divide a range (size n_samples) into a sequence of blocks. Each
block corresponds to a file within 'dirname' with name
'{start_idx}-{end_idx}.hdf5'. This method checks for the existence of
these files, and returns two lists, one for existing blocks and one for
missing blocks.
Args:
dirname (str): Path to directory containing block files.
n_samples (int): Ideal number of samples. The total number of saved block data is <=n_samples.
block_size (int): Max number of samples per block file (e.g., 100000).
validate (Callable): Method for validating each block file during load.
Returns:
A namespace consisting of 2 lists: existing blocks, and missing blocks. The total number of samples between the existing and missing blocks should equal n_samples above.
"""
assert os.path.isdir(dirname), "missing directory '%s.'" % dirname
# Block ranges.
block_start_idxs = list(range(0, n_samples, block_size))
block_end_idxs = [min(n_samples, i + block_size) for i in block_start_idxs]
block_ranges = list(zip(block_start_idxs, block_end_idxs))
# All block files (existing + missing).
n_digits = int(np.ceil(np.log(n_samples) / np.log(10)) + 1)
all_blocks = [
{
"range": r,
"path": os.path.join(
dirname, "%s-%s.hdf5" % tuple([str(i).zfill(n_digits) for i in r]),
),
}
for r in block_ranges
]
all_block_path_set = set(block["path"] for block in all_blocks)
# Validate function.
validate = (lambda f: None) if validate is None else validate
# Delete corrupt files.
if torch.distributed.get_rank() == 0:
existing_block_paths = [
block["path"] for block in all_blocks if os.path.exists(block["path"])
]
for index, path in enumerate(tqdm(existing_block_paths, "validating block.")):
assert path in all_block_path_set, "unexpected filename, '%s'." % path
try:
f = h5py.File(path, "r")
except:
os.remove(path)
continue
try:
validate(f)
except:
os.remove(path)
finally:
f.close()
# Wait for files to be deleted.
torch.distributed.barrier()
# Collect blocks.
blocks = SimpleNamespace(
existing=[b for b in all_blocks if os.path.exists(b["path"])],
missing=[b for b in all_blocks if not os.path.exists(b["path"])],
)
return blocks
def get_blocks_by_rank(
dirname: str,
n_samples: int,
block_size: int,
validate: Callable = None,
sample: Optional[float] = None,
) -> SimpleNamespace:
"""Divide existing and missing blocks evenly across all ranks.
See 'get_blocks()' above for description. The returned lists of existing and
missing blocks are split evenly across ranks via interleaving. This way,
each rank has a roughly equal number of blocks to process for a
downstream operation.
Args:
dirname (str): Path to directory containing block files.
n_samples (int): Ideal number of samples. The total number of saved block data is <=n_samples.
block_size (int): Max number of samples per block file (e.g., 100000).
validate (Callable): Method for validating each block file during load.
sample (Optional[float]): If provided, sample a random subset of the blocks. Used for validating preprocessing correctness.
Returns:
A namespace consisting of 2 lists: existing blocks, and missing blocks. Each of these two lists is potentially a sub-sample of the total set of existing and missing blocks, depending on whether sampling is used. Additionally, the attributes n_existing_world and n_missing_world are the total number of existing and missing blocks, independent of samples. Therefore, (n_existing_world + n_missing_world) * block_size == n_samples.
"""
# Get world blocks.
blocks = get_blocks(dirname, n_samples, block_size, validate)
# This rank's existing and missing files.
data_parallel_rank = parallel_state.get_data_parallel_rank()
data_parallel_world_size = parallel_state.get_data_parallel_world_size()
rank_existing_blocks = blocks.existing[
data_parallel_rank : len(blocks.existing) : data_parallel_world_size
]
rank_missing_blocks = blocks.missing[
data_parallel_rank : len(blocks.missing) : data_parallel_world_size
]
# Extend rank's existing and missing blocks (with None) such that all ranks
# have equal length lists. This allows for easier tracking of global progress.
def get_world_max(n: int) -> int:
"""Get max value across ranks.
Args:
n (int): Value on this rank.
Returns:
Max value across all ranks.
"""
n_tensor = torch.cuda.LongTensor([n])
torch.distributed.all_reduce(n_tensor, op=torch.distributed.ReduceOp.MAX)
return n_tensor.item()
max_n_existing = get_world_max(len(rank_existing_blocks))
max_n_missing = get_world_max(len(rank_missing_blocks))
rank_existing_blocks += [None] * (max_n_existing - len(rank_existing_blocks))
rank_missing_blocks += [None] * (max_n_missing - len(rank_missing_blocks))
# Collect blocks.
blocks = SimpleNamespace(
n_existing_world=len(blocks.existing),
n_missing_world=len(blocks.missing),
existing=rank_existing_blocks,
missing=rank_missing_blocks,
)
if sample is not None:
# Sample existing and missing blocks evenly across all ranks. The
# returned lists of blocks are randomly sampled (without replacement)
# to yield `sample * len(blocks)` number of blocks.
# Randomly sample blocks.
def sample_blocks(_blocks: List[Optional[Dict]]) -> List[Optional[Dict]]:
"""Sample a random subset of all blocks.
Args:
_blocks (List[Optional[Dict]]): List of all blocks.
Returns:
A random subset of the blocks.
"""
n_blocks_sample = int(np.ceil(sample * len(_blocks)))
sampled_blocks: List[Optional[Dict]] = [b for b in _blocks if b is not None]
np.random.seed(None)
np.random.shuffle(sampled_blocks)
sampled_blocks = sampled_blocks[:n_blocks_sample]
sampled_blocks += [None] * (n_blocks_sample - len(sampled_blocks))
return sampled_blocks
blocks.existing = sample_blocks(blocks.existing)
blocks.missing = sample_blocks(blocks.missing)
return blocks
class BlockPathMap:
"""Map an index to its containing block path.
The common use for this class is to have a directory of files containing
blocks of processed data, of uniform block size (e.g., 100k samples per
file). Each file must follow a naming convention of 'startIdx-endIdx.[ext]',
where 'endIdx' minus 'startIdx' must equal the block size, with the possible
exception of the final block. Given an input index, this class maps the
index to the containing block file.
Args:
block_paths (List[str]): List of paths to saved block files.
block_size (int): Max number of samples per block file (e.g., 100000).
"""
@classmethod
def from_dir(cls, dir: str, block_size: int, ext: str = "hdf5") -> Any:
"""Get list of block files, and create map.
Args:
dir (str): Path to directory containing saved block files.
block_size (int): Max number of samples per block file (e.g., 100000).
ext (str): Block file extension (e.g., 'hdf5').
Returns:
A mapping of sample index to block file path.
"""
assert os.path.isdir(dir), f"directory not found, '{dir}'."
return cls(sorted(glob.glob(dir + f"/*.{ext}")), block_size)
def __init__(self, block_paths: List[str], block_size: int):
self.max_idx = 0
self.block_path_map = {}
for block_path in block_paths:
name = os.path.splitext(os.path.basename(block_path))[0]
start_idx, end_idx = [int(i) for i in name.split("-")]
self.block_path_map[start_idx] = block_path
self.max_idx = max(self.max_idx, end_idx)
self.block_size = block_size
def __str__(self) -> str:
"""Stringify the mapping.
Returns:
A string representation of this block path map.
"""
return "%d paths" % len(self.block_path_map)
def __getitem__(self, idx: int) -> str:
"""Get block path from index.
Args:
idx (int): Index of sample.
Returns:
The path to the block file containing the sample index.
"""
block_start_idx = self.block_size * (idx // self.block_size)
block_path = self.block_path_map[block_start_idx]
return block_path
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from collections import deque
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Union
import numpy
from megatron.core.datasets.indexed_dataset import IndexedDataset
from megatron.core.datasets.masked_dataset import (
MaskedWordPieceDataset,
MaskedWordPieceDatasetConfig,
)
from megatron.core.datasets.utils import Split
@dataclass
class T5MaskedWordPieceDatasetConfig(MaskedWordPieceDatasetConfig):
"""Configuration object for Megatron Core T5 WordPiece datasets
NB: As a temporary holdover from Megatron-LM. The T5 tokenizer has an attribute which defines
a number of special sentinel tokens used during sampling. The assert in __post_init__ serves to
preserve compatibility with Megatron-LM until the T5 tokenizer is in Megatron Core.
"""
sequence_length_encoder: Optional[int] = field(init=False, default=None)
"""A sequence_length alias and the sequence length for the encoder"""
sequence_length_decoder: int = None
"""The sequence length for the decoder"""
def __post_init__(self) -> None:
"""Do asserts and set fields post init
"""
super().__post_init__()
self.sequence_length_encoder = self.sequence_length
assert self.sequence_length_encoder is not None
assert self.sequence_length_decoder is not None
assert len(self.tokenizer.additional_special_tokens_ids) > 0
class T5MaskedWordPieceDataset(MaskedWordPieceDataset):
"""The T5 dataset that assumes WordPiece tokenization
Args:
indexed_dataset (IndexedDataset): The IndexedDataset around which to build the MegatronDataset
dataset_path (str): The real path on disk to the dataset, for bookkeeping
indexed_indices (numpy.ndarray): The set of the documents indices to expose
num_samples (Optional[int]): The number of samples to draw from the indexed dataset. When None, build as many samples as correspond to one epoch.
index_split (Split): The indexed_indices Split
config (T5MaskedWordPieceDatasetConfig): The config
"""
def __init__(
self,
indexed_dataset: IndexedDataset,
dataset_path: str,
indexed_indices: numpy.ndarray,
num_samples: Optional[int],
index_split: Split,
config: T5MaskedWordPieceDatasetConfig,
) -> None:
super().__init__(
indexed_dataset, dataset_path, indexed_indices, num_samples, index_split, config
)
self.token_lookup = list(self.config.tokenizer.inv_vocab.keys())
# Account for the single <bos> and single <eos> token ids
self.sample_index = self._build_sample_index(self.config.sequence_length - 2, 1)
@staticmethod
def _key_config_attributes() -> List[str]:
"""Inherited method implementation
Returns:
List[str]: The key config attributes
"""
return super(
T5MaskedWordPieceDataset, T5MaskedWordPieceDataset
)._key_config_attributes() + ["sequence_length_decoder",]
def __getitem__(self, idx: int) -> Dict[str, Union[int, numpy.ndarray]]:
"""Abstract method implementation
Args:
idx (int): The index into the dataset
Returns:
Dict[str, Union[int, numpy.ndarray]]: The
"""
idx_beg, idx_end, target_sequence_length = self.sample_index[idx]
sample = [self.dataset[i] for i in range(idx_beg, idx_end)]
numpy_random_state = numpy.random.RandomState(
seed=(self.config.random_seed + idx) % 2 ** 32
)
assert target_sequence_length <= self.config.sequence_length
# Flatten the sample into a list of tokens
tokens = [token for sentence in sample for token in sentence]
# Truncate the list of tokens to a desired length
truncated = len(tokens) > target_sequence_length
tokens = tokens[:target_sequence_length]
# Masking
(tokens, _, _, _, masked_spans,) = self._create_masked_lm_predictions(
tokens, target_sequence_length, numpy_random_state
)
# Prepare the encoder input and decoder input and output
sentinels = deque(self.config.tokenizer.additional_special_tokens_ids)
encoder_input = []
decoder_input = [self.config.tokenizer.bos]
decoder_output = []
idx_beg = 0
for indices, labels in masked_spans:
sentinel = sentinels.popleft()
# set the end index
idx_end = indices[0]
encoder_input.extend(tokens[idx_beg:idx_end])
encoder_input.append(sentinel)
decoder_input.append(sentinel)
decoder_input.extend(labels)
decoder_output.append(sentinel)
decoder_output.extend(labels)
# set the start index
idx_beg = indices[-1] + 1
encoder_input.extend(tokens[idx_beg:])
decoder_output.append(self.config.tokenizer.eos)
# Pad the sequences and convert to NumPy
length_toks_encoder = len(encoder_input)
length_toks_decoder = len(decoder_input)
length_pads_encoder = self.config.sequence_length_encoder - length_toks_encoder
length_pads_decoder = self.config.sequence_length_decoder - length_toks_decoder
assert length_pads_encoder >= 0
assert length_pads_decoder >= 0
encoder_input = numpy.array(encoder_input, dtype=numpy.int64)
encoder_input = numpy.pad(
encoder_input, (0, length_pads_encoder), constant_values=self.config.tokenizer.pad
)
decoder_input = numpy.array(decoder_input, dtype=numpy.int64)
decoder_input = numpy.pad(
decoder_input, (0, length_pads_decoder), constant_values=self.config.tokenizer.pad
)
# Create attention and history masks
mask_encoder = self._make_attention_mask(encoder_input, encoder_input)
mask_encoder_decoder = self._make_attention_mask(decoder_input, encoder_input)
mask_decoder = self._make_attention_mask(decoder_input, decoder_input)
mask_decoder = mask_decoder * self._make_history_mask(decoder_input)
# Mask the labels
decoder_output = numpy.array(decoder_output, dtype=numpy.int64)
decoder_output = numpy.pad(decoder_output, (0, length_pads_decoder), constant_values=-1)
# Get the loss mask
loss_mask = numpy.zeros(self.config.sequence_length_decoder, dtype=numpy.int64)
loss_mask[:length_toks_decoder] = 1
return {
"text_enc": encoder_input,
"text_dec": decoder_input,
"labels": decoder_output,
"loss_mask": loss_mask,
"truncated": int(truncated),
"enc_mask": mask_encoder,
"dec_mask": mask_decoder,
"enc_dec_mask": mask_encoder_decoder,
}
@staticmethod
def _make_attention_mask(
source_block: numpy.ndarray, target_block: numpy.ndarray
) -> numpy.ndarray:
"""Return a 2-D attention mask
Args:
source_block (numpy.ndarray): A 1-D array
target_block (numpy.ndarray): A 1-D array
Returns:
numpy.ndarray: The 2-D attention mask
"""
mask = (target_block[None, :] >= 1) * (source_block[:, None] >= 1)
return mask.astype(numpy.int64)
@staticmethod
def _make_history_mask(block: numpy.ndarray) -> numpy.ndarray:
"""Return a 2-D history (lower-left-triangular) mask
Args:
block (numpy.ndarray): A 1-D array
Returns:
numpy.ndarray: The 2-D history (lower-left-triangular) mask
"""
arange = numpy.arange(block.shape[0])
mask = arange[None,] <= arange[:, None]
return mask.astype(numpy.int64)
def _get_token_mask(self, numpy_random_state: numpy.random.RandomState) -> int:
"""Abstract method implementation
100% of the time, replace the token id with mask token id.
Args:
numpy_random_state (RandomState): The NumPy random state
Returns:
int: The mask token id
"""
return self.config.tokenizer.mask
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import logging
from enum import Enum
from typing import List, Optional, Tuple
import numpy
import torch
from ..utils import log_single_rank
logger = logging.getLogger(__name__)
class Split(Enum):
train = 0
valid = 1
test = 2
def compile_helpers():
"""Compile C++ helper functions at runtime. Make sure this is invoked on a single process.
"""
import os
import subprocess
command = ["make", "-C", os.path.abspath(os.path.dirname(__file__))]
if subprocess.run(command).returncode != 0:
import sys
log_single_rank(logger, logging.ERROR, "Failed to compile the C++ dataset helper functions")
sys.exit(1)
def normalize(weights: List[float]) -> List[float]:
"""Do non-exponentiated normalization
Args:
weights (List[float]): The weights
Returns:
List[float]: The normalized weights
"""
w = numpy.array(weights, dtype=numpy.float64)
w_sum = numpy.sum(w)
w = (w / w_sum).tolist()
return w
def get_blend_from_list(
blend: Optional[List[str]],
) -> Optional[Tuple[List[str], Optional[List[float]]]]:
"""Get the megatron.core.datasets.blended_megatron_dataset_config.BlendedMegatronDatasetConfig blend from the blend list
Args:
blend (Optional[List[str]]): The blend list, which can be either (1) a list of prefixes, e.g. ["path/to/dataset_1_prefix", "path/to/dataset_2_prefix"], or (2) a flattened, zipped list of weights and prefixes, e.g. ["30", "path/to/dataset_1_prefix", "70", "path/to/dataset_2_prefix"]
Returns:
Optional[Tuple[List[str], Optional[List[float]]]]: The blend, consisting of a list of dataset prefixes and optionally a list of dataset weights, e.g. [["path/to/dataset_1_prefix", "path/to/dataset_2_prefix"], [30.0, 70.0]].
"""
if blend is None:
return None
if len(blend) % 2 == 1:
weight_per_dataset = None
raw_prefix_per_dataset = blend
else:
raw_weight_per_dataset, raw_prefix_per_dataset = zip(
*[(blend[i], blend[i + 1]) for i in range(0, len(blend), 2)]
)
weight_per_dataset = []
for rwpd in raw_weight_per_dataset:
try:
weight = float(rwpd)
except ValueError:
weight = None
weight_per_dataset.append(weight)
is_none = map(lambda _: _ is None, weight_per_dataset)
if any(is_none):
assert all(is_none)
weight_per_dataset = None
raw_prefix_per_dataset = blend
prefix_per_dataset = [rppd.strip() for rppd in raw_prefix_per_dataset]
return prefix_per_dataset, weight_per_dataset
# Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
from .core import check_is_distributed_checkpoint
from .mapping import LocalNonpersitentObject, ShardedTensor
from .serialization import (
load,
load_common_state_dict,
load_plain_tensors,
load_tensors_metadata,
save,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment