Commit 0816dd4a authored by libo11's avatar libo11
Browse files

Initial commit

parents
Pipeline #1728 canceled with stages
# Data Pipeline
## Data pre-processing
Data preprocessing is built around the following classes:
1. `IndexedDatasetBuilder`
2. `IndexedDataset`
At the moment, an end-to-end data preprocessing implementation is left to the user. See the class docstring(s) for more details.
#### IndexedDatasetBuilder
The `IndexedDatasetBuilder` is capable of building and merging `IndexedDataset` instances.
#### IndexedDataset
The `IndexedDataset` class is the lowest-level data interface in Megatron Core. Internally, an `IndexedDataset` instance references two binaries: the data file (`.bin`) contains document/sequence data and the index file (`.idx`) contains document/sequence metadata.
The index file stores dataset-level metadata first:
- The index header, for backward compatibility
- The index version, for backward compatibility
- A numeric code corresponding to the data type used to write data to the data file
- The number of sequences in the dataset
- The number of documents in the dataset
The index file stores document-level and sequence-level metadata second:
- In order, the number of elements per sequence
- In order, the byte offset (pointer) per sequence
- In order, the consecutive sequence index range `[...)` per document
- In order, the mode per sequence (in the multimodal case)
## Data loading: construction
Building the data loaders is a distributed-aware process built around the following classes:
1. `BlendedMegatronDatasetConfig`
2. `BlendedMegatronDatasetBuilder`
3. `IndexedDataset`
3. `MegatronDataset`
4. `BlendedDataset`
See the class docstrings for more details.
#### BlendedMegatronDatasetConfig (extendable)
The `BlendedMegatronDatasetConfig` class parameterizes the `BlendedMegatronDatasetBuilder` and in turn the `MegatronDataset` and `BlendedDataset`.
Different training/inference regimes will require different extensions e.g. the `GPTDatasetConfig`
#### BlendedMegatronDatasetBuilder
The `BlendedMegatronDatasetBuilder` class builds the highest-level data interfaces in Megatron Core.
**NB:** All ranks should attempt to build the dataset via the `BlendedMegatronDatasetBuilder` or the program will hang. Which ranks follow through on their attempts can be controlled via the `BlendedMegatronDatasetConfig`.
#### IndexedDataset
The `IndexedDataset` class is the lowest-level data interface in Megatron Core.
The `IndexedDataset` should already exist on disk before attempting to build any of the high-level data interfaces.
#### MegatronDataset (extendable)
The `MegatronDataset` abstract class is a high-level data interface in Megatron Core. It is an abstraction built upon the `IndexedDataset`.
Different training/inference regimes will require different extensions e.g. the `GPTDataset`
#### BlendedDataset
The `BlendedDataset` class is a high-level data interface in Megatron Core. It is an abstraction built upon the `MegatronDataset`.
The `BlendedDataset` is only necessary when a blend multiple data distributions, i.e. multiple `MegatronDataset` instances, should contribute to a certain dataset split. The blend can be controlled via the `BlendedMegatronDatasetConfig`.
## Data loading: implementation
### GPTDataset
The `GPTDataset` is parameterized by the following variables: the underlying `IndexedDataset` instance `indexed_dataset`, the split indices `indexed_indices` (the congituous subset of document or sequence indices used for training, validation, and testing), the number of samples `N`, the sequence length `S`, and the random seed `R`.
The `GPTDataset` creates three index mappings to facilitate lookup: (1) the document index, (2) the sample index, and (3) the shuffle index.
1. The document index _Do_idx_ is a 1-D array mapping from _i_ to document index of length `E * |indexed_indices|` where `E` corresponds to the minimum number of epochs such that `E * |indexed_indices| >= N`. The document index is shuffled according to `R`.
```
Given:
N = 15
indexed_indices = [5, 6, 7, 8, 9]
E = 3
Then, for example:
Do_idx = [8, 8, 9, 6, 7, 5, 8, 5, 6, 6, 5, 9, 7, 7, 9]
```
2. The sample index _Sa_idx_ is a 2-D array mapping from _j_ to pairs of (_i_, _Do_idx_[ _i_ ] offset) of shape `[N + 1, 2]`. The rows _j_ and _j_ + 1 serve as the left and right bounds for the _j_-th sample.
```
Given:
S = 1024
Then, for example:
Sa_idx[0] = (0, 0)
Sa_idx[1] = (0, 1024) => Do_idx[0] has length greater than S
Sa_idx[2] = (1, 512) => Do_idx[0] has length 1536
Sa_idx[3] = (2, 0) => Do_idx[1] has length 1536
Sa_idx[4] = (5, 300) => Do_idx[2:5] are shorter documents relative to Do_idx[0:2]
Sa_idx[5] = (6, 24) => Do_idx[5] has length 1300
```
3. The shuffle index _Sh_idx_ is a 1-D array mapping from _k_ to _j_ of length `N`. The shuffle index is shuffled according to `R`.
```
Given
N = 10
Then, for example:
Sh_idx = [4, 0, 2, 6, 1, 9, 5, 8, 7, 3]
```
To query the `GPTDataset` for the _k_-th sample we do the following
- Use the shuffle index to get the index _j_ into the sample index.
```
j = Sh_idx[k]
```
- Use the sample index to get the left and right sample-bounding indices into the document index and the starting token offset for each document.
```
i, offset = Sa_idx[j]
i_next, offset_next = Sa_idx[j + 1]
```
- Use the document index to retrieve `S` tokens from consecutive (in the document index) documents.
```
sample = []
sample += indexed_dataset[Do_idx[i]][offset:]
if i != i_next:
sample += indexed_dataset[Do_idx[i + 1:i_next]]
sample += indexed_dataset[Do_idx[i_next]][:offset_next]
```
To save time during initialization, each index is built/cached sequentially on one process rank and subsequently loaded in parallel on other process ranks. The cached indices are unique to a hash generated in the `MegatronDataset.__init__` function.
### BlendedDataset
The `BlendedDataset` is parameterized by the following variables: the underlying `MegatronDataset` instances `D`, the weights `W` (one per dataset), and the size `S`. The `BlendedDataset` will draw samples from contributing datasets in proportion to the weights until achieving a composite dataset of the desired size. During each sampling step, we draw a single sample from the dataset which has the greatest sampling error.
The `BlendedDataset` creates two "blending" indices to facilitate lookup: (1) the dataset index and (2) the dataset sample index.
1. The dataset index _Da_idx_ is a 1-D array mapping from _i_ to dataset index of length `S`.
```
Given
D = [d0, d1, d2]
W = [1/2, 1/4, 1/4]
S = 4
Then, for example:
Da_idx = [0, 1, 2, 0]
```
2. The dataset sample index _Sa_idx_ is a 1-D mapping from _i_ to the sample index for dataset _Da_idx[i]_ of length `S`.
```
Given
Da_idx = [0, 1, 2, 0]
Then, for example:
Sa_idx = [0, 0, 0, 1]
```
To query the `BlendedDataset` for the _k_-th sample we do the following
- Use the dataset index to retrieve the corresponding dataset from `D` and the dataset sample index to retrieve the corresponding sample from that dataset.
```
sample = D[Da_idx[k]][Sa_idx[k]]
```
To save time during initialization, each index is built/cached sequentially on one process rank and subsequently loaded in parallel on other process ranks. The cached indices are unique to a hash generated in the `BlendedDataset.__init__` function.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from .config import RetroGPTChunkDatasets
from .query.multi_split_gpt_dataset import MultiSplitGPTDataset, MultiSplitGPTDatasetConfig
from .query.retro_dataset import get_retro_datasets
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""
Exports:
- Embedder: Base class for all Bert embedders.
- RetroBertEmbedders: Container class for in-memory and on-disk embedders.
- RetroPreprocessingConfig: Configuration class for all of Retro preprocessing.
- RetroGPTChunkDatasets: Container class for train, valid, and test datasets.
- RetroTokenizers: Container class for GPT and Bert tokenizers.
"""
from .bert_embedders import Embedder, RetroBertEmbedders
from .config import RetroPreprocessingConfig
from .gpt_chunk_datasets import RetroGPTChunkDatasets
from .tokenizers import RetroTokenizers
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Container dataclass for holding both in-memory and on-disk Bert embedders."""
import abc
from dataclasses import dataclass
from typing import Any
import numpy as np
import torch
class Embedder(abc.ABC):
"""Base class for all Bert embedders.
All embedders should be able to embed either an entire text dataset (to a 2D
numpy array), or a single text string (to a 1D numpy array).
"""
@abc.abstractmethod
def embed_text_dataset(self, text_dataset: torch.utils.data.Dataset) -> np.ndarray:
"""Embed a text dataset.
Args:
text_dataset (torch.utils.data.Dataset): Text dataset to embed. Each sample of the text dataset should output a dict with a key 'text' and a string value.
Returns:
A 2D ndarray with shape (len(text_dataset), dimension(embedder)).
"""
@abc.abstractmethod
def embed_text(self, text: str) -> np.ndarray:
"""Embed a simple string of text.
Args:
text (str): A single text sample.
Returns:
A 1D ndarray with shape (dimensions(embedder),).
"""
@dataclass
class RetroBertEmbedders:
"""Container dataclass for in-memory and on-disk Bert embedders."""
disk: Embedder
mem: Embedder
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Retro preprocessing config."""
from dataclasses import dataclass
from megatron.core.transformer import TransformerConfig
from .bert_embedders import RetroBertEmbedders
from .gpt_chunk_datasets import RetroGPTChunkDatasets
from .tokenizers import RetroTokenizers
@dataclass
class RetroPreprocessingConfig(TransformerConfig):
"""Configuration object for Retro preprocessing.
*Note* : Arguments prefixed with '--retro-gpt-*' or '--retro-bert-*' are
included and named as such to more easily handle managing both models
running at the same time. Megatron is not optimized to run two models at
once, so this naming convention makes it clearer.
Args:
retro_project_dir (str): Retro project directory, which contains the preprocessed data for for pretraining. This directory is built during preprocessing (see tools/retro/README.md), and contains subdirectories for the chunk database and pretraining neighbors.
retro_tasks (str): Comma-separated list of tasks to run. Run entire preprocesing pipeline by using '--retro-tasks build'. Alternatively, run individual stages with tasks (in this order) 'db-build', 'index-build', or 'query-pretraining-neighbors'. For example, '--retro-tasks db-build,index-build,query-pretraining-neighbors' is equivalent to '--retro-tasks build'; or the argument can contain a subset of these tasks. Stages must always be run in the correct order (listed above).
retro_task_validate (float): If defined, validate a randomly sampled subset of the existing results of the given task. Each task implements a 'validate' method that is responsible for sampling a `retro_task_validate` fraction of the existing results, and then checking for bitwise equality with the current code base. (E.g., `--retro-task-validate 0.01`.)
retro_block_size (int): Number of chunks to process at a time when generating Bert embeddings and querying the search index. Partial results for each block are generally saved to disk in separate files.
retro_doc_block_size (int): Number of documents to processe at time when processing token datasets into chunk databases. The partial chunk database for each block is saved into a separate file.
retro_gpt_seed (int): Random seed used for python, numpy, pytorch, and cuda.
retro_gpt_data_path (str): Path to the training dataset. Accepted format: 1) a single data path, 2) multiple datasets in the form: dataset1-weight dataset1-path dataset2-weight dataset2-path ... It is used with --split when a single dataset used for all three: train, valid and test. It is exclusive to the other --*-data-path args.
retro_gpt_data_cache_path (str): Path to a directory to hold cached index files.
retro_gpt_split (str): Comma-separated list of proportions for training, validation, and test split. For example the split `90,5,5` will use 90%% of data for training, 5%% for validation and 5%% for test.
retro_gpt_train_samples (int): Total number of samples to train over all training runs.
retro_gpt_eval_interval (int): GPT evaluation interval.
retro_gpt_eval_iters (int): GPT evaluation iterations.
retro_gpt_tokenizer_type (str): GPT tokenizer type.
retro_gpt_tokenizer_model (str): GPT tokenizer model file.
retro_gpt_vocab_file (str): GPT vocab file.
retro_gpt_merge_file (str): GPT merge file.
retro_gpt_seq_length (int): GPT sequence length.
retro_gpt_global_batch_size (int): GPT global batch size.
retro_gpt_chunk_length (int): GPT chunk length.
retro_bert_tokenizer_type (str): Bert tokenizer type (for when using '--bert-embedder-type megatron').
retro_bert_vocab_file (str): Bert vocab file.
retro_bert_batch_size (int): Micro-batch size for processing Bert embeddings.
retro_bert_max_chunk_length (int): Maximum sequence length for Bert embeddings. (Named 'chunk' here in reference to these Bert sequences being converted from GPT chunks.)
retro_index_type (str): A 'faiss-base' index is a simple, un-optimized wrapper around a Faiss index. A 'faiss-par-add' index optimizes the 'add()' method by making it multi-node and multi-process, but with bit-wise equivalent results.
retro_index_str (str): Index string used for calling faiss.index_factory(). For example, 'IVF262144_HNSW32,Flat' or 'OPQ32_256,IVF4194304_HNSW32,PQ32'.
retro_index_ntrain (int): Number of database chunks to use for training the index. This value must be less or equal to the total number of chunks in the database.
retro_index_train_load_fraction (float): Fraction of sampled chunks to use for training the index. Useful when our total sampled embeddings use too much memory; lowering the load fraction is less costly than re-embedding a new sampled dataset from scratch.
retro_index_add_load_fraction (float): Fraction of database chunks to use for adding to the index. Useful when our total index size would use too much memory; lowering the load fraction is less costly than re-designing our token datasets.
retro_index_delete_training_embeddings (bool): Delete training embeddings for the search index. Useful for debugging.
retro_index_delete_added_codes (bool): Delete added codes for the search index. Useful for debugging.
retro_query_ef_search (int): Index ef-search parameter for Hierarchical Navigable Small Worlds (HNSW) during querying.
retro_query_nprobe (int): Index nprobe parameter for Inverted File (IVF) during querying.
retro_query_num_neighbors_query (int): Number of neighbors to retrieve when calling index.search().
retro_query_num_neighbors_save (int): Number of neighbors to save to disk after the index's returned neighbors. If longer than target value, neighbors truncated; and if shorter than target value, neighbors are padded with -1's.
retro_bert_embedders (RetroBertEmbedders): Set of Bert embedders used for embedding chunks. Contains entries: 1) 'mem' for an in-memory embedder, and 2) 'disk' for an embedder that saves results in blocks to disk.
retro_gpt_chunk_datasets (RetroGPTChunkDatasets): GPT datasets for 'train', 'valid', and 'test'.
retro_tokenizers (RetroTokenizers): GPT ('gpt') and Bert ('bert') tokenizers.
"""
# Basic.
retro_project_dir: str = None
retro_tasks: str = 'build'
retro_task_validate: float = None
retro_block_size: int = 100000
retro_doc_block_size: int = 100000
# GPT.
retro_gpt_seed: int = 1234
retro_gpt_data_path: list = None # basic list here, for parsing purposes
retro_gpt_data_cache_path: str = None
retro_gpt_split: str = '969,30,1'
retro_gpt_train_samples: int = None
retro_gpt_eval_interval: int = None
retro_gpt_eval_iters: int = None
retro_gpt_tokenizer_type: str = None
retro_gpt_tokenizer_model: str = None
retro_gpt_vocab_file: str = None
retro_gpt_merge_file: str = None
retro_gpt_seq_length: int = None
retro_gpt_global_batch_size: int = None
retro_gpt_chunk_length: int = 64
# Bert.
retro_bert_tokenizer_type: str = None
retro_bert_vocab_file: str = None
retro_bert_batch_size: int = 128
retro_bert_max_chunk_length: int = 256
# Index.
retro_index_type: str = 'faiss-par-add'
retro_index_str: str = None
retro_index_ntrain: int = None
retro_index_train_load_fraction: float = 1.0
retro_index_add_load_fraction: float = 1.0
retro_index_delete_training_embeddings: bool = True
retro_index_delete_added_codes: bool = True
# Query.
retro_query_ef_search: int = 256
retro_query_nprobe: int = 65536
retro_query_num_neighbors_query: int = 200
retro_query_num_neighbors_save: int = 20
# Tools.
retro_bert_embedders: RetroBertEmbedders = None
retro_gpt_chunk_datasets: RetroGPTChunkDatasets = None
retro_tokenizers: RetroTokenizers = None
def __post_init__(self) -> None:
"""Validate Retro config."""
# Validate required attributes.
assert self.retro_project_dir is not None
assert self.retro_tasks is not None
assert self.retro_gpt_data_path is not None or self.retro_gpt_data_cache_path is not None
assert self.retro_gpt_train_samples is not None
assert self.retro_gpt_eval_interval is not None
assert self.retro_gpt_eval_iters is not None
assert self.retro_gpt_tokenizer_type is not None
assert self.retro_gpt_tokenizer_model is not None or (
self.retro_gpt_vocab_file is not None and self.retro_gpt_merge_file is not None
)
assert self.retro_gpt_seq_length is not None
assert self.retro_gpt_global_batch_size is not None
assert self.retro_bert_tokenizer_type is not None
assert self.retro_bert_vocab_file is not None
assert self.retro_index_str is not None
assert self.retro_index_ntrain is not None
# Split retro tasks.
self.retro_tasks = self.retro_tasks.split(",")
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Container dataclass for GPT chunk datasets (train, valid, and test)."""
from dataclasses import dataclass
@dataclass
class RetroGPTChunkDatasets:
"""Container dataclass for GPT chunk datasets."""
# Each dict contains 'dataset', 'neighbor_dir', and 'num_active_chunks'.
train: dict = None
valid: dict = None
test: dict = None
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Container class for GPT and Bert tokenizers."""
from dataclasses import dataclass
from megatron.core.datasets.megatron_tokenizer import MegatronTokenizer
@dataclass
class RetroTokenizers:
"""Container class for GPT and Bert tokenizers."""
gpt: MegatronTokenizer = None
bert: MegatronTokenizer = None
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""
Exports:
- build_db: Build a chunk database from a list of indexed datasets.
"""
from .build import build_db
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Build a chunk database from a list of indexed datasets.
Building a chunk database consists of.
- Breaking each document of each indexed dataset into consecutive
retro_gpt_chunk_length chunks.
- Re-tokenize each chunk into Bert, and discard any chunks with empty Bert
tokens.
- Save chunk offsets to disk for each indexed dataset.
"""
import glob
import os
import types
from concurrent.futures import ProcessPoolExecutor, as_completed
from typing import Dict, List, Tuple
import numpy as np
import torch
from tqdm import tqdm
from megatron.core.datasets.indexed_dataset import IndexedDataset
from megatron.core.datasets.retro.config import RetroPreprocessingConfig
from megatron.core.datasets.retro.external_libs import h5py
from megatron.core.datasets.retro.utils import (
extract_data_config,
get_blocks_by_rank,
log_retro_rank_0,
retro_makedir,
)
from .utils import (
get_indexed_dataset_infos,
get_indexed_dataset_infos_path,
get_individual_chunk_db,
get_individual_db_dir,
get_individual_db_paths,
get_individual_doc_offsets,
get_merged_db_path_map,
init_indexed_dataset_infos,
load_indexed_datasets,
save_indexed_dataset_infos,
)
def build_partial_db(
config: types.SimpleNamespace,
dataset_idx: int,
n_datasets: int,
indexed_dataset: IndexedDataset,
block_id: int,
n_blocks: int,
block: dict,
proc_id: int,
n_procs: int,
) -> Tuple[int, list, list, dict]:
"""Process a document index range of the indexed dataset.
The chunk database is built in parallel blocks, since de-tokenizing &
re-tokenizing for Bert-length computation is expensive. This method
iterates each document and extracts sequential 'chunk-length' sequences
from each document.
Args:
config (types.SimpleNamespace): Subset of Retro config, containing 'chunk_length', 'gpt_eod', 'gpt_detokenize', 'bert_tokenize', and 'task_validate'.
dataset_idx (int): Index of this dataset out of all blended datasets.
n_datasets (int): Total number of blended datasets.
indexed_dataset (IndexedDataset): Indexed dataset to be chunked.
block_id (int): Block index out of all blocks to be processed.
n_blocks (int): Total number of blocks to be processed.
block (dict): Range information such as start/end points for chunking idnexed dataset.
proc_id (int): Process ID for tracking parallel process order.
n_procs (int): Total number of parallel processes.
Returns:
A tuple containing:
- Process ID.
- List of valid chunks.
- List of invalid chunks (i.e., chunks that converted to empty Bert embeddings.).
- Dict mapping document ID to number of valid chunks.
"""
# Document start/end indexes.
doc_range = block["range"]
n_docs = doc_range[1] - doc_range[0]
n_docs_per_proc = int(np.ceil(n_docs / n_procs))
doc_start_id = doc_range[0] + proc_id * n_docs_per_proc
doc_end_id = min(doc_range[1], doc_start_id + n_docs_per_proc)
# Print progress.
progress_proc_ids = set(range(n_procs)) if torch.distributed.get_rank() == 0 else set()
if proc_id in progress_proc_ids:
log_retro_rank_0(
" > building partial chunk db, proc %d / %d, docs %d:%d / %d."
% (proc_id, n_procs, doc_start_id, doc_end_id, n_docs,)
)
# Progress bars (snapshot of overall progress).
doc_id_iter = range(doc_start_id, doc_end_id)
pbar = (
tqdm(doc_id_iter, "parse doc chunks", miniters=len(doc_id_iter) // 20,)
if proc_id in progress_proc_ids
else doc_id_iter
)
# Iterate documents & parse chunks.
chunk_db_valid: List[Tuple] = []
chunk_db_invalid: List[Tuple] = []
doc_size_map = {}
for doc_id in pbar:
# Progress description.
try:
pbar.set_description(
"%sds %d / %d, block %d / %d, proc %d / %d."
% (
"" if config.task_validate is None else "[validate] ",
dataset_idx,
n_datasets,
block_id,
n_blocks,
proc_id,
n_procs,
)
)
except:
pass
# Remove EOD token.
doc = indexed_dataset.get(doc_id)
if doc[-1].item() == config.gpt_eod:
doc = doc[:-1]
doc_len = len(doc)
# Chunk start/end indexes.
chunk_start_idxs = list(range(0, doc_len, config.chunk_length))
chunk_end_idxs = [min(doc_len, s + config.chunk_length) for s in chunk_start_idxs]
# Re-tokenize each chunk to Bert/Wordpiece (empty bert -> 'invalid').
doc_size_map[doc_id] = 0
for i, chunk_start_idx in enumerate(chunk_start_idxs):
# Re-tokenize.
chunk_end_idx = chunk_end_idxs[i]
gpt_token_ids = indexed_dataset.get(
idx=doc_id, offset=chunk_start_idx, length=chunk_end_idx - chunk_start_idx,
)
text = config.gpt_detokenize(gpt_token_ids.tolist())
bert_token_ids = config.bert_tokenize(text)
# 'Valid' for non-empty Bert chunks; 'invalid' otherwise.
if len(bert_token_ids) == 0:
_chunk_db = chunk_db_invalid
else:
_chunk_db = chunk_db_valid
doc_size_map[doc_id] += 1
_chunk_db.append((doc_id, chunk_start_idx, chunk_end_idx, len(bert_token_ids),))
return proc_id, chunk_db_valid, chunk_db_invalid, doc_size_map
def build_block_db(
config: RetroPreprocessingConfig,
dataset_idx: int,
n_datasets: int,
indexed_dataset: IndexedDataset,
n_procs: int,
executor: ProcessPoolExecutor,
n_missing_blocks: int,
block_idx: int,
block: dict,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Split each document within block into consecutive retro_gpt_chunk_length size chunks.
Args:
config (RetroPreprocessingConfig): For DB building, we make use of attributes 'chunk_length', 'gpt_eod', 'gpt_detokenize', 'bert_tokenize', and 'task_validate'.
dataset_idx (int): Index of this dataset out of all blended datasets.
n_datasets (int): Total number of blended datasets.
indexed_dataset (IndexedDataset): Indexed dataset to be chunked.
n_procs (int): Total number of parallel processes.
executor (ProcessPoolExecutor): Executor for launching parallel processes.
n_missing_blocks (int): Total number of blocks to be processed.
block_idx (int): Block index out of all blocks to be processed.
block (dict): Range information such as start/end points for chunking idnexed dataset.
Returns:
A tuple containing:
- List of valid chunks.
- List of invalid chunks (i.e., chunks that converted to empty Bert embeddings.).
- Dict mapping document ID to number of valid chunks.
"""
# Build partial dbs.
log_retro_rank_0(' > build partial dbs.')
futures = []
for proc_id in range(n_procs): # not true process id
futures.append(
executor.submit(
build_partial_db,
types.SimpleNamespace(
chunk_length=config.retro_gpt_chunk_length,
gpt_eod=config.retro_tokenizers.gpt.eod,
gpt_detokenize=config.retro_tokenizers.gpt.detokenize,
bert_tokenize=config.retro_tokenizers.bert.tokenize,
task_validate=config.retro_task_validate,
),
dataset_idx,
n_datasets,
indexed_dataset,
block_idx,
n_missing_blocks,
block,
proc_id,
n_procs,
)
)
partial_chunk_dbs = []
for future in as_completed(futures):
partial_chunk_dbs.append(future.result())
# Concatenate chunks.
partial_chunk_dbs.sort(key=lambda item: item[0]) # sort by proc_id
chunk_db_valid = [
item for partial_chunk_db in partial_chunk_dbs for item in partial_chunk_db[1]
]
chunk_db_invalid = [
item for partial_chunk_db in partial_chunk_dbs for item in partial_chunk_db[2]
]
# Convert to numpy.
log_retro_rank_0(' > converting chunk db to numpy.')
chunk_db_valid = np.array(chunk_db_valid, dtype="uint32")
chunk_db_invalid = np.array(chunk_db_invalid, dtype="uint32")
# Document offsets.
doc_sizes = [
(d, s) for partial_chunk_db in partial_chunk_dbs for d, s in partial_chunk_db[3].items()
]
doc_sizes.sort(key=lambda item: item[0])
doc_offsets = np.cumsum([item[1] for item in doc_sizes]).astype("uint64")
doc_offsets = np.stack(
(np.array([item[0] for item in doc_sizes], dtype="uint64"), doc_offsets), axis=1
)
return chunk_db_valid, chunk_db_invalid, doc_offsets
def save_block_db(
block: dict, chunk_db_valid: np.ndarray, chunk_db_invalid: np.ndarray, doc_offsets: np.ndarray,
) -> None:
"""Save block of chunked tokens to disk. These blocks are later used for
training and adding to the vector index.
Args:
block (dict): Range information such as start/end points for chunking idnexed dataset.
chunk_db_valid (np.ndarray): Array of valid chunk indexes.
chunk_db_invalid (np.ndarray): Array of invalid chunk indexes.
doc_offsets (np.ndarray): Array of document offsets by chunks.
"""
log_retro_rank_0(" > saving individual db.")
with h5py.File(block["path"], "w") as f:
dset = f.create_dataset("chunks_valid", data=chunk_db_valid)
dset = f.create_dataset("chunks_invalid", data=chunk_db_invalid)
dset = f.create_dataset("doc_offsets", data=doc_offsets)
def build_individual_db(
config: RetroPreprocessingConfig, dataset_idx: int, n_datasets: int, dataset_info: dict,
) -> None:
"""Process a single indexed dataset & extract chunks.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
dataset_idx (int): Dataset index within blended dataset.
n_datasets (int): Total number of datasets within blended dataset.
dataset_info (dict): Metadata for dataset (see `save_indexed_dataset_infos()` in `utils.py` for more detail).
"""
# Make directory.
db_dir = get_individual_db_dir(config.retro_project_dir, dataset_info["prefix"])
retro_makedir(config, db_dir)
# Indexed dataset.
indexed_dataset = dataset_info["dataset"]
# Missing DB blocks (split by documents).
blocks = get_blocks_by_rank(
db_dir,
len(indexed_dataset),
config.retro_doc_block_size,
validate=lambda f: f["chunks_valid"].shape == (0,) or f["chunks_valid"].shape[1] == 4,
sample=config.retro_task_validate,
)
if config.retro_task_validate is None:
active_blocks = blocks.missing
else:
assert blocks.n_missing_world == 0
active_blocks = blocks.existing
# Prevent missing-path-write race condition.
torch.distributed.barrier()
# Nothing to do?
if config.retro_task_validate is None and not active_blocks:
return
# Num processes.
if blocks.n_missing_world == 1:
n_procs = 128
elif blocks.n_missing_world <= 2:
n_procs = 64
elif blocks.n_missing_world <= 4:
n_procs = 32
elif blocks.n_missing_world <= 8:
n_procs = 16
else:
n_procs = 8
# Process documents in parallel.
with ProcessPoolExecutor(max_workers=n_procs) as executor:
for block_idx, block in enumerate(active_blocks):
if block is not None:
# Build block DB.
chunk_db_valid, chunk_db_invalid, doc_offsets = build_block_db(
config=config,
dataset_idx=dataset_idx,
n_datasets=n_datasets,
indexed_dataset=indexed_dataset,
n_procs=n_procs,
executor=executor,
n_missing_blocks=len(active_blocks),
block_idx=block_idx,
block=block,
)
if config.retro_task_validate is None:
# Save block DB.
save_block_db(
block=block,
chunk_db_valid=chunk_db_valid,
chunk_db_invalid=chunk_db_invalid,
doc_offsets=doc_offsets,
)
else:
# Load existing block DB.
with h5py.File(block["path"]) as f:
existing_chunks_valid = np.copy(f["chunks_valid"])
existing_chunks_invalid = np.copy(f["chunks_invalid"])
existing_doc_offsets = np.copy(f["doc_offsets"])
# Check equality.
log_retro_rank_0(" > validate.")
assert np.array_equal(existing_chunks_valid, chunk_db_valid)
assert np.array_equal(existing_chunks_invalid, chunk_db_invalid)
assert np.array_equal(existing_doc_offsets, doc_offsets)
# Wait for all ranks to finish block.
log_retro_rank_0(" > waiting for all ranks to finish block.")
torch.distributed.barrier()
log_retro_rank_0(" > finished saving individual db.")
def build_individual_dbs(
config: RetroPreprocessingConfig, indexed_dataset_infos: List[Dict],
) -> None:
"""Iterate each indexed dataset & process its chunks.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
indexed_dataset_infos (List[Dict]): Preprocessing metadata for each dataset.
"""
# Build individual DBs.
log_retro_rank_0(" > build individual chunk dbs.")
for ds_idx, ds_info in enumerate(indexed_dataset_infos):
# Progress.
log_retro_rank_0(
" > building individual db, dataset %d / %d ... '%s'."
% (ds_idx, len(indexed_dataset_infos), ds_info["prefix"],)
)
# Process single dataset.
build_individual_db(config, ds_idx, len(indexed_dataset_infos), ds_info)
def update_chunk_counts(
config: RetroPreprocessingConfig, indexed_dataset_infos: List[Dict]
) -> None:
"""Set n_chunks_train & n_chunks sampled for each individual DB.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
indexed_dataset_infos (List[Dict]): Preprocessing metadata for each dataset (i.e., 'prefix', 'ratio', 'n_chunks', etc.).
"""
if torch.distributed.get_rank() != 0:
return
# Data ratio sum (for setting index training chunks).
data_ratio_sum = sum([d["ratio"] for d in indexed_dataset_infos])
# Training split size (split at document level).
train_fraction = float(extract_data_config(config).split.split(",")[0]) / 100
assert train_fraction > 0 and train_fraction <= 1
# Set n_chunks (including n_chunks_sampled for unambiguity).
log_retro_rank_0(" > compute n_chunks.")
for ds_index, ds_info in enumerate(indexed_dataset_infos):
db_paths = get_individual_db_paths(config.retro_project_dir, ds_info["prefix"])
# Update counts.
ds_info["n_docs"] = len(ds_info["dataset"].document_indices) - 1
ds_info["n_docs_train"] = int(train_fraction * ds_info["n_docs"])
ds_info["n_chunks"] = 0 # previously, 'n_chunks_valid'
ds_info["n_chunks_train"] = 0
ds_info["n_chunks_invalid"] = 0
for db_path in tqdm(
db_paths, "%d/%d, %s" % (ds_index, len(indexed_dataset_infos), ds_info["prefix"])
):
with h5py.File(db_path, "r") as f:
ds_info["n_chunks"] += len(f["chunks_valid"])
ds_info["n_chunks_invalid"] += len(f["chunks_invalid"])
ds_info["n_chunks_train"] += (
(np.copy(f["chunks_valid"][:, 0]) < ds_info["n_docs_train"]).sum().item()
)
ds_info["n_chunks_sampled"] = int(
config.retro_index_ntrain * ds_info["ratio"] / data_ratio_sum
)
# Verify counts.
assert ds_info["n_chunks_train"] <= ds_info["n_chunks"], "n_train (%d) > n_total (%d)." % (
ds_info["n_chunks_train"],
ds_info["n_chunks"],
)
assert ds_info["n_chunks_sampled"] <= ds_info["n_chunks_train"], (
"n_sampled (%d) > n_train (%d)."
% (ds_info["n_chunks_sampled"], ds_info["n_chunks_train"])
)
def merge_dbs(project_dir: str, indexed_dataset_infos: List[Dict], db_type: str) -> None:
"""Merge individual DBs into single DB.
Args:
project_dir (str): Retro project dir.
indexed_dataset_infos (List[Dict]): Preprocessing metadata for each dataset (i.e., 'prefix', 'ratio', 'n_chunks', etc.).
db_type (str): DB type (e.g., 'sampled', 'train', or 'valid').
"""
if torch.distributed.get_rank() != 0:
return
log_retro_rank_0(" > build %s chunk db." % db_type)
# Count chunks.
if db_type == "sampled":
n_chunks_key = "n_chunks_sampled"
n_docs_key = None
elif db_type == "train":
n_chunks_key = "n_chunks_train"
n_docs_key = "n_docs_train"
elif db_type == "valid":
n_docs_key = None
else:
raise Exception("handle db_type '%s'." % db_type)
if db_type == "valid":
n_chunks = sum(m["n_chunks"] - m["n_chunks_train"] for m in indexed_dataset_infos)
else:
n_chunks = sum(m[n_chunks_key] for m in indexed_dataset_infos)
n_docs = None if n_docs_key is None else sum(m[n_docs_key] for m in indexed_dataset_infos)
# DB path.
db_path = get_merged_db_path_map(project_dir)[db_type]
# Delete existing chunk db if incorrect size.
if os.path.exists(db_path):
try:
f = h5py.File(db_path)
n_alloc = len(f["chunks"]) # total allocated
n_written = f["n_written"][0].item() # total written
f.close()
if n_chunks != n_alloc or n_chunks != n_written:
os.remove(db_path)
except Exception as e:
if isinstance(e, OSError):
os.remove(db_path)
elif isinstance(e, KeyError):
f.close()
os.remove(db_path)
else:
raise e
# Build merged chunk db.
if not os.path.exists(db_path):
os.makedirs(os.path.dirname(db_path), exist_ok=True)
f = h5py.File(db_path, "w")
# Initialize output arrays.
merged_chunk_db: np.ndarray = f.create_dataset("chunks", (n_chunks, 5), dtype="uint32")
merged_doc_offsets: np.ndarray = (
None
if n_docs_key is None
else f.create_dataset("doc_offsets", (n_docs, 3), dtype="uint64")
)
n_written = f.create_dataset("n_written", (1,), dtype="uint64")
n_written[0] = 0
# Iterate indexed datasets & collect chunks.
chunk_start_index = 0
doc_start_index = 0
doc_start_offset = 0
for ds_idx, ds_info in enumerate(indexed_dataset_infos):
log_retro_rank_0(
" > merging dbs; '%s', dataset %d / %d ... '%s'."
% (db_type, ds_idx, len(indexed_dataset_infos), ds_info["prefix"]),
)
individual_chunk_db: np.ndarray = get_individual_chunk_db(project_dir, ds_idx, ds_info)
individual_doc_offsets: np.ndarray = (
None
if n_docs_key is None
else get_individual_doc_offsets(project_dir, ds_idx, ds_info)
)
if db_type == "valid":
individual_chunk_db = individual_chunk_db[ds_info["n_chunks_train"] :]
if n_docs_key is None:
individual_doc_offsets = None
else:
train_doc_offset = individual_doc_offsets[ds_info["n_docs_train"] - 1, 2]
individual_doc_offsets = np.copy(
individual_doc_offsets[ds_info["n_docs_train"] :]
)
individual_doc_offsets[:, 2] -= train_doc_offset
log_retro_rank_0("~~~")
log_retro_rank_0(individual_doc_offsets)
log_retro_rank_0(train_doc_offset)
raise Exception("test me.")
else:
individual_chunk_db = individual_chunk_db[: ds_info[n_chunks_key]]
individual_doc_offsets = (
None
if n_docs_key is None
else np.copy(individual_doc_offsets[: ds_info[n_docs_key]])
)
merged_chunk_db[
chunk_start_index : chunk_start_index + len(individual_chunk_db)
] = individual_chunk_db
chunk_start_index += len(individual_chunk_db)
n_written[0] = chunk_start_index
if n_docs_key is not None:
individual_doc_offsets[:, 2] += doc_start_offset
doc_end_index = doc_start_index + individual_doc_offsets.shape[0]
merged_doc_offsets[doc_start_index:doc_end_index] = individual_doc_offsets
doc_start_index = doc_end_index
doc_start_offset = individual_doc_offsets[-1, 2].item()
f.close()
def build_merged_dbs(project_dir: str, indexed_dataset_infos: List[Dict]) -> None:
"""Merge individual dataset components into single database.
This method merges databases for DB types:
- 'sampled': used for training the vector index.
- 'train': used for adding to the trained vector index.
- 'valid': can be used for validating/testing the vector index.
Args:
project_dir (str): Retro project dir.
indexed_dataset_infos (List[Dict]): Preprocessing metadata for each dataset (i.e., 'prefix', 'ratio', 'n_chunks', etc.).
"""
merge_dbs(project_dir, indexed_dataset_infos, "sampled")
merge_dbs(project_dir, indexed_dataset_infos, "train")
merge_dbs(project_dir, indexed_dataset_infos, "valid")
def build_db(config: RetroPreprocessingConfig) -> None:
"""Extract token chunks from each indexed dataset.
Iterate each document of each indexed dataset, extract that document's chunks, and save to a 'DB' (hdf5 file).
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
"""
project_dir = config.retro_project_dir
# Indexed dataset info.
if config.retro_task_validate is None:
indexed_dataset_infos = init_indexed_dataset_infos(config)
else:
indexed_dataset_infos = get_indexed_dataset_infos(config.retro_project_dir)
# Build individual dbs.
build_individual_dbs(config, indexed_dataset_infos)
# If validating, return here.
if config.retro_task_validate is not None:
return
# Single-process going forward.
if torch.distributed.get_rank() != 0:
return
# Update n_chunks & save indexed dataset infos.
if not os.path.exists(get_indexed_dataset_infos_path(project_dir)):
update_chunk_counts(config, indexed_dataset_infos)
save_indexed_dataset_infos(project_dir, indexed_dataset_infos)
indexed_dataset_infos = get_indexed_dataset_infos(project_dir)
# Builded merged dbs.
build_merged_dbs(project_dir, indexed_dataset_infos)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""A DBDataset is for iterating the chunks of the chunk database.
This dataset is used for both training a vector index, and adding vectors to a
trained index.
"""
from typing import List
import numpy as np
import torch
from tqdm import tqdm
from megatron.core.datasets.indexed_dataset import IndexedDataset
class DBDataset(torch.utils.data.Dataset):
"""Dataset for iterating chunks.
Args:
db_path (str): Path of HDF5-format chunk database.
indexed_datasets (List[IndexedDataset]): Indexed datasets used to build database.
chunks (np.ndarray): Array of chunk indexes, for indexing into indexed datasets. Format [dataset_idx, doc_id, start_idx, end_idx, bert_length].
chunk_length (int): Max GPT chunk length (e.g., 64).
eod_token_id (int): EOD token ID.
"""
def __init__(
self,
db_path: str,
indexed_datasets: List[IndexedDataset],
chunks: np.ndarray,
chunk_length: int,
eod_token_id: int,
):
assert chunks.shape[1] == 5, (
"expected 5 columns (dataset_idx, "
"doc_idx, token_start_idx, token_end_idx, bert_chunk_length); "
"found %d columns." % chunks.shape[1]
)
self.db_path = db_path
self.indexed_datasets = indexed_datasets
self.chunks = chunks
self.doc_chunk_map = None
self.max_chunk_length = chunk_length
self.eod_token_id = eod_token_id
def __len__(self) -> int:
"""Length of DB dataset.
Returns:
Number of chunks contained in the dataset.
"""
return self.chunks.shape[0]
def __getitem__(self, chunk_id: int) -> dict:
"""DB dataset sample.
Args:
chunk_id (int): Index of chunk within dataset.
Returns:
A dict containing:
- 'doc_id': Document index within indexed dataset.
- 'text': GPT token IDs.
"""
# Chunk start/end indexes.
indexed_dataset_id, doc_id, token_start_idx, token_end_idx, _ = [
value.item() for value in self.chunks[chunk_id]
]
chunk_length = token_end_idx - token_start_idx
indexed_dataset = self.indexed_datasets[indexed_dataset_id]
# Chunk token ids.
token_ids = indexed_dataset.get(doc_id, offset=token_start_idx, length=chunk_length)
# Extend chunks to max_chunk_length by padding with EOD tokens.
if chunk_length != self.max_chunk_length:
assert chunk_length < self.max_chunk_length, "invalid chunk len."
token_ids = token_ids.tolist()
token_ids += [self.eod_token_id] * (self.max_chunk_length - chunk_length)
return {
"doc_id": doc_id,
"text": np.array(token_ids, dtype=np.int64),
}
def load_doc_tuples(self) -> None:
"""Load the dataset & document ids.
Load the dataset id & document id of each chunk in the database, to
be used for causality filtering during querying.
"""
self.doc_tuples = np.zeros(shape=(len(self), 2), dtype="uint32")
block_size = int(1e6)
for start_idx in tqdm(
range(0, len(self), block_size),
"load doc tuples",
miniters=(len(self) // block_size) // 10,
disable=torch.distributed.get_rank() != 0,
):
end_idx = min(len(self), start_idx + block_size)
self.doc_tuples[start_idx:end_idx] = self.chunks[start_idx:end_idx, :2]
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Utilities for building a chunk database."""
import glob
import json
import os
from typing import Dict, List, Optional
import numpy as np
from megatron.core.datasets.indexed_dataset import IndexedDataset
from megatron.core.datasets.retro.config import RetroPreprocessingConfig
from megatron.core.datasets.retro.external_libs import h5py
from megatron.core.models.retro.utils import get_gpt_data_dir
from .dataset import DBDataset
def get_db_dir(project_dir: str) -> str:
"""Sub-directory for DB data.
Args:
project_dir (str): Path to Retro project dir.
Returns:
Path of the DB sub-directory within the project.
"""
return os.path.join(project_dir, "db")
def init_indexed_dataset_infos(config: RetroPreprocessingConfig) -> List[Dict]:
"""Gather meta-info about each indexed dataset.
The returned info array allows for easy access to the configuration, and
helps remove ambiguity.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
Returns:
List of processing metadata for each dataset, including:
- ratio: Data split weight.
- prefix: Relative path to dataset under DB sub-directory.
"""
data_dir = get_gpt_data_dir(config.retro_project_dir)
data_blend: List[str] = config.retro_gpt_data_path
assert len(data_blend) % 2 == 0, "currently, only blended dataset is supported."
# Dataset infos.
infos = []
for i in range(0, len(data_blend), 2):
ratio = float(data_blend[i])
prefix = data_blend[i + 1]
path = os.path.join(data_dir, prefix + ".bin")
assert os.path.exists(path), "couldn't find '%s'." % path
infos.append(
{"ratio": ratio, "prefix": prefix,}
)
# Load indexed datasets.
load_indexed_datasets(config.retro_project_dir, infos)
return infos
def get_indexed_dataset_infos_path(project_dir: str) -> str:
"""Path to indexed dataset meta-infos.
Args:
project_dir (str): Path to Retro project dir.
Returns:
Path to the `indexed_dataset_infos.json` file.
"""
return os.path.join(get_db_dir(project_dir), "indexed_dataset_infos.json")
def save_indexed_dataset_infos(project_dir: str, indexed_dataset_infos: List[Dict]) -> None:
"""Save dataset order & meta-info.
Args:
project_dir (str): Path to Retro project dir.
indexed_dataset_infos (List[Dict]): List of metadata for each dataset, with each entry containing:
- ratio: Data split weight.
- prefix: Relative path to dataset under DB sub-directory.
- n_docs: Number of documents.
- n_docs_train: Number of documents used for pretraining.
- n_chunks: Number of valid chunks.
- n_chunks_train: Number of valid chunks used for pretraining.
- n_chunks_invalid: Number of invalid chunks.
- n_chunks_sampled: Number of valid chunks used for vector index training.
"""
# Remove 'dataset' field.
clean_infos = []
for info in indexed_dataset_infos:
info = dict(info)
del info["dataset"]
clean_infos.append(info)
# Save.
with open(get_indexed_dataset_infos_path(project_dir), "w") as f:
json.dump(clean_infos, f, indent=4)
def load_indexed_datasets(project_dir: str, indexed_dataset_infos: List[Dict]) -> None:
"""Loaded indexed datasets into memory-mapped datasets.
Args:
project_dir (str): Path to Retro project dir.
indexed_dataset_infos (List[Dict]): List of metadata for each dataset (see `save_indexed_dataset_infos()` for more details.
"""
data_dir = get_gpt_data_dir(project_dir)
for info in indexed_dataset_infos:
info["dataset"] = IndexedDataset(os.path.join(data_dir, info["prefix"]), mmap=True)
def get_indexed_dataset_infos(project_dir: str) -> List[Dict]:
"""Load indexed dataset meta-infos.
Args:
project_dir (str): Path to Retro project dir.
Returns:
List of metadata for each dataset (see `save_indexed_dataset_infos()` for more details.
"""
# Load json.
path = get_indexed_dataset_infos_path(project_dir)
with open(path) as f:
infos = json.load(f)
# Load indexed datasets.
load_indexed_datasets(project_dir, infos)
return infos
def get_individual_db_dir(project_dir: str, prefix: str) -> str:
"""Individual DB's directory.
Args:
project_dir (str): Path to Retro project dir.
prefix (str): Unique relative path to dataset within project dir.
Returns:
Path to the given datasets's chunk database.
"""
return os.path.join(get_db_dir(project_dir), "individual", prefix)
def get_individual_db_paths(project_dir: str, prefix: str) -> List[str]:
"""Get paths of all database blocks of an individual dataset.
Args:
project_dir (str): Path to Retro project dir.
prefix (str): Unique relative path to dataset within project dir.
Returns:
Paths to each HDF5 chunk database files that comprises this datasets full chunk database.
"""
return sorted(glob.glob(get_individual_db_dir(project_dir, prefix) + "/*hdf5"))
def get_individual_chunk_db(project_dir: str, ds_id: int, ds_info: dict) -> np.ndarray:
"""Load individual dataset's chunk DB.
Args:
project_dir (str): Path to Retro project dir.
ds_id (int): Index of dataset within blended dataset.
ds_info (dict): Preprocessing metadata for dataset (see `save_indexed_dataset_infos()` for more detail).
Returns:
Array of chunk start/end indexes for this dataset, where the chunk indexes can be used for indexing into the corresponding indexed dataset.
"""
paths = get_individual_db_paths(project_dir, ds_info["prefix"])
# *Note*: convert to dataset, rather than copying to memory.
db = np.zeros((ds_info["n_chunks"], 5), dtype="uint32")
db[:, 0] = ds_id
start_idx = 0
for path in paths:
f = h5py.File(path, "r")
n_chunks_current = f["chunks_valid"].shape[0]
db[start_idx : (start_idx + n_chunks_current), 1:] = f["chunks_valid"]
start_idx += n_chunks_current
f.close()
assert start_idx == ds_info["n_chunks"]
return db
def get_individual_doc_offsets(project_dir: str, ds_id: int, ds_info: dict) -> np.ndarray:
"""Load individual dataset's document offsets.
Args:
project_dir (str): Path to Retro project dir.
ds_id (int): Index of dataset within blended dataset.
ds_info (dict): Preprocessing metadata for dataset (see `save_indexed_dataset_infos()` for more detail).
Returns:
Array of document offsets by chunk index for this dataset.
"""
paths = get_individual_db_paths(project_dir, ds_info["prefix"])
# *Note*: convert to dataset, rather than copying to memory.
doc_offsets = np.zeros((ds_info["n_docs"], 3), dtype="uint64")
doc_offsets[:, 0] = ds_id
start_idx = 0
start_offset = 0
for path in paths:
with h5py.File(path) as f:
current_doc_offsets = np.copy(f["doc_offsets"])
current_doc_offsets[:, 1] += start_offset
current_ndocs = current_doc_offsets.shape[0]
doc_offsets[start_idx : (start_idx + current_ndocs), 1:] = current_doc_offsets
start_idx += current_ndocs
start_offset = current_doc_offsets[-1, 1].item()
return doc_offsets
def get_merged_db_path_map(project_dir: str) -> dict:
"""Paths to merged datasets.
Args:
project_dir (str): Path to Retro project dir.
Returns:
A dict of chunk databases, one for each of:
- sampled: Chunks used for training the vector index.
- train: Chunks used for pretraining 'train' dataset.
- valid: Chunks used for pretraining 'valid' dataset.
"""
base_dir = get_db_dir(project_dir)
return {
"sampled": os.path.join(base_dir, "merged", "sampled.hdf5"),
"train": os.path.join(base_dir, "merged", "train.hdf5"),
"valid": os.path.join(base_dir, "merged", "valid.hdf5"),
}
def get_merged_dataset(
project_dir: str,
chunk_length: int,
eod_token_id: int,
db_type: str,
indexed_dataset_infos: Optional[List[Dict]] = None,
) -> DBDataset:
"""Get merged dataset.
Args:
project_dir (str): Path to Retro project dir.
chunk_length (int): GPT chunk length (e.g., 64).
eod_token_id (int): EOD token ID.
db_type (str): DB type (e.g., 'sampled', 'train', or 'valid').
indexed_dataset_infos (Optional[List[Dict]]): Optionally, pre-loaded list of dataset metadata (see `save_indexed_dataset_infos()` for more detail). If not provided, the indexed dataset infos will be loaded from disk.
Returns:
A DBDataset, which is a dataset that wraps the HDF5 chunk index array.
"""
if not indexed_dataset_infos:
indexed_dataset_infos = get_indexed_dataset_infos(project_dir)
# Load chunks.
db_path = get_merged_db_path_map(project_dir)[db_type]
f = h5py.File(db_path, "r")
chunks = f["chunks"]
# DB dataset.
indexed_datasets = [info["dataset"] for info in indexed_dataset_infos]
dataset = DBDataset(
db_path=db_path,
indexed_datasets=indexed_datasets,
chunks=chunks,
chunk_length=chunk_length,
eod_token_id=eod_token_id,
)
return dataset
def get_merged_sampled_dataset(
project_dir: str,
chunk_length: int,
eod_token_id: int,
indexed_dataset_infos: Optional[List[Dict]] = None,
) -> DBDataset:
"""Get sampled dataset (for training the vector index).
Args:
project_dir (str): Path to Retro project dir.
chunk_length (int): GPT chunk length (e.g., 64).
eod_token_id (int): EOD token ID.
indexed_dataset_infos (Optional[List[Dict]]): Optionally, pre-loaded list of dataset metadata (see `save_indexed_dataset_infos()` for more detail). If not provided, the indexed dataset infos will be loaded from disk.
Returns:
A DBDataset, which is a dataset that wraps the HDF5 chunk index array.
"""
return get_merged_dataset(
project_dir, chunk_length, eod_token_id, "sampled", indexed_dataset_infos
)
def get_merged_train_dataset(
project_dir: str,
chunk_length: int,
eod_token_id: int,
indexed_dataset_infos: Optional[List[Dict]] = None,
) -> DBDataset:
"""Get training dataset (for adding to the vector index).
Args:
project_dir (str): Path to Retro project dir.
chunk_length (int): GPT chunk length (e.g., 64).
eod_token_id (int): EOD token ID.
indexed_dataset_infos (Optional[List[Dict]]): Optionally, pre-loaded list of dataset metadata (see `save_indexed_dataset_infos()` for more detail). If not provided, the indexed dataset infos will be loaded from disk.
Returns:
A DBDataset, which is a dataset that wraps the HDF5 chunk index array.
"""
return get_merged_dataset(
project_dir, chunk_length, eod_token_id, "train", indexed_dataset_infos
)
def get_merged_valid_dataset(
project_dir: str,
chunk_length: int,
eod_token_id: int,
indexed_dataset_infos: Optional[List[Dict]] = None,
) -> DBDataset:
"""Get validation dataset (for testing the vector index).
Args:
project_dir (str): Path to Retro project dir.
chunk_length (int): GPT chunk length (e.g., 64).
eod_token_id (int): EOD token ID.
indexed_dataset_infos (Optional[List[Dict]]): Optionally, pre-loaded list of dataset metadata (see `save_indexed_dataset_infos()` for more detail). If not provided, the indexed dataset infos will be loaded from disk.
Returns:
A DBDataset, which is a dataset that wraps the HDF5 chunk index array.
"""
return get_merged_dataset(
project_dir, chunk_length, eod_token_id, "valid", indexed_dataset_infos
)
def get_merged_datasets(project_dir: str, chunk_length: int, eod_token_id: int) -> dict:
"""Get all merged datasets.
Args:
project_dir (str): Path to Retro project dir.
chunk_length (int): GPT chunk length (e.g., 64).
eod_token_id (int): EOD token ID.
Returns:
A dict mapping DB type ('sampled', 'train', or 'valid') to the corresponding DBDataset, which is a dataset that wraps the HDF5 chunk index array.
"""
fns = {
"sampled": get_merged_sampled_dataset,
"train": get_merged_train_dataset,
"valid": get_merged_valid_dataset,
}
datasets = {key: fn(project_dir, chunk_length, eod_token_id) for key, fn in fns.items()}
return datasets
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Required external libraries for Retro preprocessing."""
import importlib
required_libs = [
"faiss",
"h5py",
"transformers", # for huggingface bert
]
for lib in required_libs:
try:
globals()[lib] = importlib.import_module(lib)
except ImportError as e:
raise Exception(
f"Missing one or more packages required for Retro preprocessing: {required_libs}. Tried importing '{lib}'."
)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""
Exports:
- train_index: Train an index on representative vectors.
- add_to_index: Add vectors to a trained index.
- build_index: Wrapper function that calls above two functions.
"""
from .build import add_to_index, build_index, train_index
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Construct an index.
Constructing an index generally happens in two phases:
- index.train(): Train an index on a representative set of vectors.
- index.add(): Add vectors to an index, to be available for retrieval.
"""
import os
import shutil
import numpy as np
import torch
from tqdm import tqdm
from megatron.core.datasets.retro.config import RetroPreprocessingConfig
from megatron.core.datasets.retro.db.utils import (
get_merged_sampled_dataset,
get_merged_train_dataset,
)
from megatron.core.datasets.retro.external_libs import h5py
from megatron.core.datasets.retro.utils import GPTToTextDataset
from .factory import IndexFactory
from .utils import (
get_training_data_block_dir,
get_training_data_block_paths,
get_training_data_merged_path,
get_training_data_root_dir,
)
##################################################
# Train index.
##################################################
def get_empty_index_path(config: RetroPreprocessingConfig) -> str:
"""Path of empty index.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
Returns:
Path to the empty (trained, but without added samples) vector index.
"""
index = IndexFactory.get_index(config.retro_index_type)
empty_index_path = index.get_empty_index_path(config)
return empty_index_path
def get_block_nload(block_path: str, load_fraction: float) -> int:
"""Compute number of blocks to load.
This is computed by multiplying the total number of available blocks with the
fraction of blocks to load.
Args:
block_path (str): Path to HDF5 file containing block of data. File must contain key 'data'.
load_fraction (float): Fraction (0 < load_fraction <= 1) of block samples to load.
Returns:
Number of block samples to load.
"""
with h5py.File(block_path) as fi:
return int(load_fraction * fi["data"].shape[0])
def merge_embedding_blocks(config: RetroPreprocessingConfig) -> None:
"""Merge individual embedding blocks into a single binary mmap file.
The embeddings are initially stored in block-sized (e.g., ~100k embeddings per
block) HDF5 files. These individual block files must be merged into a single
file before training, to be based as a numpy mmap array to the index.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
"""
if torch.distributed.get_rank() != 0:
return
# Get block, merged paths.
load_fraction = config.retro_index_train_load_fraction
block_paths = get_training_data_block_paths(config)
bin_path = get_training_data_merged_path(config)
# Skip, if already built.
if os.path.exists(bin_path):
return
# Merge blocks.
with open(bin_path, "wb") as fo:
byte_offset = 0
for block_idx, block_path in enumerate(
tqdm(
block_paths,
"merge train embeddings",
miniters=len(block_paths) // 10,
disable=torch.distributed.get_rank() != 0,
)
):
with h5py.File(block_path) as fi:
nload = get_block_nload(block_path, load_fraction)
block = np.array(fi["data"][:nload], copy=False)
fo.write(block.tobytes())
byte_offset += block.size * block.itemsize
fo.seek(byte_offset)
def get_text_dataset_for_training(config: RetroPreprocessingConfig) -> GPTToTextDataset:
"""Convert GPT token chunk dataset to a text dataset for passing to the
embedder.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
Returns:
The text dataset consisting of tokens converted from sampled chunk database.
"""
gpt_dataset = get_merged_sampled_dataset(
project_dir=config.retro_project_dir,
chunk_length=config.retro_gpt_chunk_length,
eod_token_id=config.retro_tokenizers.gpt.eod,
)
text_dataset = GPTToTextDataset(gpt_dataset, config.retro_tokenizers.gpt)
return text_dataset
def embed_training_chunks(config: RetroPreprocessingConfig) -> None:
"""Embed DB chunks.
Store chunks in blocks on disk. These blocks will later be merged into
a single dataset for training the index.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
"""
merged_train_data_path = get_training_data_merged_path(config)
if os.path.exists(merged_train_data_path):
return
# Get training text dataset.
text_dataset = get_text_dataset_for_training(config)
# Embed dataset.
embedder = config.retro_bert_embedders.disk
embedder.embed_text_dataset("index", get_training_data_block_dir(config), text_dataset)
# Merge embeddings.
merge_embedding_blocks(config)
def train_on_embeddings(config: RetroPreprocessingConfig) -> None:
"""Train index on embedded DB chunks.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
"""
index = IndexFactory.get_index(config.retro_index_type)
index.train(config)
def remove_embeddings(config: RetroPreprocessingConfig) -> None:
"""Remove embeddings after training.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
"""
torch.distributed.barrier()
if torch.distributed.get_rank() != 0:
return
empty_index_path = get_empty_index_path(config)
assert os.path.isfile(empty_index_path)
shutil.rmtree(get_training_data_root_dir(config), ignore_errors=True)
def _train_index(config: RetroPreprocessingConfig) -> None:
"""Train index on DB chunks.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
"""
# Check if trained index already exists.
if not os.path.isfile(get_empty_index_path(config)):
# Embed training chunks.
embed_training_chunks(config)
# Train index on embeddings.
train_on_embeddings(config)
# Wait for (single-process) training to complete.
torch.distributed.barrier()
# Remove embeddings.
if config.retro_index_delete_training_embeddings:
remove_embeddings(config)
def train_index(config: RetroPreprocessingConfig) -> None:
"""Entry point for training the index.
We select whether to train a new index, or validate an existing index.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
"""
# Train new index.
if config.retro_task_validate is None:
_train_index(config)
# Validate existing trained index.
else:
from .validate import validate_training_embeddings
validate_training_embeddings(config)
##################################################
# Add to index.
##################################################
def get_text_dataset_for_adding(config: RetroPreprocessingConfig) -> GPTToTextDataset:
"""Convert GPT token chunk dataset to a text dataset for passing to the
embedder.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
Returns:
The text dataset that consists of tokens converted from the 'train' chunk database. These are the chunks used for retrieval by the pretraining 'train' dataset.
"""
gpt_dataset = get_merged_train_dataset(
project_dir=config.retro_project_dir,
chunk_length=config.retro_gpt_chunk_length,
eod_token_id=config.retro_tokenizers.gpt.eod,
)
text_dataset = GPTToTextDataset(gpt_dataset, config.retro_tokenizers.gpt)
return text_dataset
def _add_to_index(config: RetroPreprocessingConfig) -> str:
"""Add DB chunks to index.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
Returns:
Path to the populated index.
"""
# Get index.
index = IndexFactory.get_index(config.retro_index_type)
# Get text dataset.
text_dataset = get_text_dataset_for_adding(config)
# Add to index.
output_index_path = index.add(config, text_dataset)
return output_index_path
def add_to_index(config: RetroPreprocessingConfig) -> None:
"""Entry point for adding to the index.
We select whether to add to a new index, or validate an existing index.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
"""
# Add to new index.
if config.retro_task_validate is None:
_add_to_index(config)
# Validate existing encodings.
else:
from .validate import validate_added_encodings
validate_added_encodings(config)
##################################################
# Build index (train + add).
##################################################
def build_index(config: RetroPreprocessingConfig) -> None:
"""Build index.
Building index involves sequentially running stages above:
- Train index (on sampled training chunks).
- Add to index (on all training chunks).
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
"""
# Train index.
train_index(config)
# Add to index.
add_to_index(config)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""The IndexFactory constructs an index from an index type string."""
from megatron.core.datasets.retro.index.index import Index
from .indexes import FaissBaseIndex, FaissParallelAddIndex
class IndexFactory:
"""Get index.
Index type generally read from argument '--retro-index-ty'.
"""
@classmethod
def get_index_class(cls, index_type: str) -> type:
"""Get an index class, given a type string.
Args:
index_type (str): One of 'faiss-base' (naive Faiss index wrapper) or 'faiss-par-add' (Faiss index wrapper with near embarrassingly parallel index.add().
Returns:
An `Index` sub-type corresponding to the `index_type`.
"""
return {"faiss-base": FaissBaseIndex, "faiss-par-add": FaissParallelAddIndex,}[index_type]
@classmethod
def get_index(cls, index_type: str) -> Index:
"""Construct an index from an index type string.
Args:
index_type (str): One of 'faiss-base' (naive Faiss index wrapper) or 'faiss-par-add' (Faiss index wrapper with near embarrassingly parallel index.add().
Returns:
An `Index` instance corresponding to the `index_type`.
"""
index_class = cls.get_index_class(index_type)
index = index_class()
return index
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Base class for all vector indexes.
A vector index is a type of retrieval database that is queried using vectors,
and returns vectors that are 'similar' (e.g., by cosine distance) to the query
vector. The construction and usage of an index generally has the following
pattern:
- Train the index on representative vectors.
- Add vectors to the index (i.e., vectors available for retrieval)
- Query index with new vector, to retrieve similar vector indexes.
"""
import abc
import os
from typing import List, Tuple
import numpy as np
import torch
from megatron.core.datasets.retro.config import Embedder, RetroPreprocessingConfig
from megatron.core.datasets.retro.external_libs import faiss
from megatron.core.datasets.retro.utils import GPTToTextDataset
from .utils import get_index_dir
class Index(abc.ABC):
"""Abstract base class for indexes.
*Note* : While currently only Faiss-based classes are implemented, in the
future, this class will be extended with other types of indexes that have
different performance-accuracy trade-offs.
The primary methods to override are:
- train() : Train index on the sampled training chunks.
- add() : Add all training chunks to index.
"""
@classmethod
def make_object_verbose(cls, index: faiss.Index, verbose: bool) -> None:
"""Make index object verbose.
Args:
index (faiss.Index): Faiss object to set verbose.
verbose (bool): Sets whether index should log status updates during training and adding.
"""
assert isinstance(verbose, bool)
faiss.ParameterSpace().set_index_parameter(index, "verbose", verbose)
def get_empty_index_path(self, config: RetroPreprocessingConfig) -> str:
"""Get file path to empty index (i.e., trained, but unpopulated).
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
Returns:
File path to empty index (i.e., this index has had index.train() called, but not yet index.add()).
"""
return os.path.join(
get_index_dir(config), "empty_%.3f.faissindex" % config.retro_index_train_load_fraction,
)
def get_empty_index(self, config: RetroPreprocessingConfig) -> faiss.Index:
"""Get empty index (i.e., trained, but unpopulated).
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
Returns:
Empty Faiss index, loaded from storage.
"""
return faiss.read_index(self.get_empty_index_path(config))
def get_added_index_path(self, config: RetroPreprocessingConfig) -> str:
"""Get file path to index that has been populated with vectors.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
Returns:
File path to added index (i.e., this index has had both index.train() and index.add() called).
"""
return os.path.join(
get_index_dir(config),
"added_%.3f_%.3f.faissindex"
% (config.retro_index_train_load_fraction, config.retro_index_add_load_fraction,),
)
def get_added_index(self, config: RetroPreprocessingConfig) -> faiss.Index:
"""Get index that has been populated with vectors.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
Returns:
'Added' (i.e., populated) Faiss index, loaded from storage.
"""
return faiss.read_index(self.get_added_index_path(config))
@abc.abstractmethod
def train(self, config: RetroPreprocessingConfig) -> None:
"""Train index on a representative set of vectors.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
"""
@abc.abstractmethod
def add(self, config: RetroPreprocessingConfig, text_dataset: GPTToTextDataset) -> None:
"""Add vectors to index.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
text_dataset (GPTToTextDataset): Text dataset that will be embedded and added to the index.
"""
def embed_text_dataset_block(
self, embedder: Embedder, text_dataset: GPTToTextDataset, _range: Tuple[int, int]
) -> np.ndarray:
"""Embed a range of a text dataset.
Args:
embedder (Embedder): Embedder used for embedding a text dataset.
text_dataset (GPTToTextDataset): Text dataset that will be embedded.
_range (Tuple[int, int]): Start/end sample indices within text dataset used for embedding.
Returns:
An array of embeddings, with shape (len(text_dataset), dimension(embedder)).
"""
sub_dataset = torch.utils.data.Subset(text_dataset, range(*_range))
return embedder.embed_text_dataset(sub_dataset)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""
Exports:
- FaissBaseIndex: Unoptimized Faiss index wrapper
- FaissParallelAddIndex: Optimized index.add() for Faiss index.
"""
from .faiss_base import FaissBaseIndex
from .faiss_par_add import FaissParallelAddIndex
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""
This class implements a simple, un-optimized wrapper around a Faiss index, that
implements the Index interface (see ..index.py). While this class is
instantiable, it is meant to be extended with optimizations in classes that
inherit from this class (see FaissParAddIndex, for an example).
"""
import os
import numpy as np
import torch
from tqdm import tqdm
from megatron.core.datasets.retro.config import RetroPreprocessingConfig
from megatron.core.datasets.retro.external_libs import faiss
from megatron.core.datasets.retro.index.index import Index
from megatron.core.datasets.retro.index.utils import (
get_training_data_merged_path,
num_samples_to_block_ranges,
)
from megatron.core.datasets.retro.utils import GPTToTextDataset, log_retro_rank_0
class FaissBaseIndex(Index):
"""Base class for Faiss-base indexes.
This class wraps a Faiss index, and adds additional functionality for training
and adding codes. This base class performs a naive sequential code adding,
while the optimized FaissParallelAddIndex class performs a parallel
index.add().
"""
def _train(self, config: RetroPreprocessingConfig) -> None:
"""Train index (rank 0's method).
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
"""
assert torch.distributed.get_rank() == 0
# Set num threads (torch.distributed reset it to 1).
faiss.omp_set_num_threads(64)
empty_index_path = self.get_empty_index_path(config)
# Index already exists? -> return.
if os.path.isfile(empty_index_path):
return
# Load data.
merged_path = get_training_data_merged_path(config)
inp = np.memmap(merged_path, dtype="f4", mode="r",).reshape((-1, config.hidden_size))
# Init index.
index = faiss.index_factory(config.hidden_size, config.retro_index_str)
# Move to GPU.
log_retro_rank_0("> move faiss index to gpu.")
index_ivf = faiss.extract_index_ivf(index)
clustering_index = faiss.index_cpu_to_all_gpus(faiss.IndexFlatL2(index_ivf.d))
index_ivf.clustering_index = clustering_index
log_retro_rank_0("> finished moving to gpu.")
self.make_object_verbose(index, True)
self.make_object_verbose(index_ivf, True)
self.make_object_verbose(index_ivf.quantizer, True)
self.make_object_verbose(index_ivf.clustering_index, True)
# Train index.
index.train(inp)
# Save index.
faiss.write_index(index, empty_index_path)
def train(self, config: RetroPreprocessingConfig) -> None:
"""Train index.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
"""
# Single process only.
if torch.distributed.get_rank() == 0:
self._train(config)
torch.distributed.barrier()
def _add(self, config: RetroPreprocessingConfig, text_dataset: GPTToTextDataset) -> None:
"""Add to index (rank 0's method).
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
text_dataset (GPTToTextDataset): Text dataset that will be embedded and added to the index.
"""
assert torch.distributed.get_rank() == 0
dataset_sample_ranges = num_samples_to_block_ranges(len(text_dataset))
# Set num threads (torch.distributed reset it to 1).
faiss.omp_set_num_threads(64)
# Bert embedder.
embedder = config.bert_embedders.mem
# Empty/added index paths.
empty_index_path = self.get_empty_index_path()
added_index_path = self.get_added_index_path()
# Skip adding, if index exists.
if os.path.isfile(added_index_path):
return
# Read trained index.
index = faiss.read_index(empty_index_path)
# Iterate data blocks & add.
for sample_range in tqdm(dataset_sample_ranges, "faiss_base.add"):
# Embed text.
embeds = self.embed_text_dataset_block(embedder, text_dataset, sample_range)
# Add to index.
index.add(embeds)
# Write index.
faiss.write_index(index, added_index_path)
def add(self, config: RetroPreprocessingConfig, text_dataset: GPTToTextDataset) -> str:
"""Add to index.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
text_dataset (GPTToTextDataset): Text dataset that will be embedded and added to the index.
Returns:
File path to the populated index.
"""
# Single process only.
if torch.distributed.get_rank() == 0:
self._add(config, text_dataset)
# Wait for rank 0.
torch.distributed.barrier()
# Get output index path, for return.
return self.get_added_index_path(config)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Multi-process & multi-node version of Faiss's index.add().
This class inherits from FaissBaseIndex, and optimizes the 'add()' method by
making it multi-node and multi-process, with bit-wise equivalence to
FaissBaseIndex. This allows 'add()' to scale out to very large datasets, since
the vast majority of the computational effort is embarrassingly parallel.
"""
import os
import shutil
from typing import Tuple
import numpy as np
import psutil
import torch
from tqdm import tqdm
from megatron.core.datasets.retro.config import Embedder, RetroPreprocessingConfig
from megatron.core.datasets.retro.external_libs import faiss, h5py
from megatron.core.datasets.retro.index.utils import get_added_code_paths, get_added_codes_dir
from megatron.core.datasets.retro.utils import (
GPTToTextDataset,
get_blocks_by_rank,
log_retro_rank_0,
retro_makedir,
)
from .faiss_base import FaissBaseIndex
class FaissParallelAddIndex(FaissBaseIndex):
"""
This class parallelizes both 1) encoding vectors, and 2) adding codes to the
index. This class is more performant than naive use of Faiss, because most
of the computational work is in encoding the vectors, which is an
embarassingly parallel operation.
"""
def encode_block(
self, index: faiss.Index, embedder: Embedder, text_dataset: GPTToTextDataset, block: dict
) -> Tuple[np.ndarray, np.ndarray]:
"""Encode sub-dataset block, to be later added to index.
Encode the data subset, generally in blocks of 1M vectors each. For
each block, the empty/trained index is loaded, codes are computed
via index.sa_encode(), and the resulting codes are saved to disk.
Args:
index (faiss.Index): Faiss index object.
embedder (Embedder): Embedder used to embed text dataset.
text_dataset (GPTToTextDataset): Text dataset to be embedded and encoded.
block (dict): Range information specifying start/end indices within text dataset.
Returns:
A tuple of (embeddings, encodings) for the given block subset of the text dataset.
"""
# Embed block.
embeddings = self.embed_text_dataset_block(embedder, text_dataset, block["range"],)
# Encode block.
log_retro_rank_0("encode.")
codes = index.sa_encode(embeddings)
# Return embeddings for validation purposes.
return embeddings, codes
def save_block(self, config: RetroPreprocessingConfig, block: dict, codes: np.ndarray) -> None:
"""Save block of codes to disk.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
block (dict): Range information specifying the start/end indices within the encoded text dataset. Here, the 'path' item is used for writing the encodings to storage.
codes (np.ndarray): Block of encodings to be saved to storage.
"""
# Save neighbors.
log_retro_rank_0("save codes.")
retro_makedir(config, os.path.dirname(block["path"]))
with h5py.File(block["path"], "w") as f:
f.create_dataset("data", data=codes)
def encode(self, config: RetroPreprocessingConfig, text_dataset: GPTToTextDataset) -> None:
"""Encode text dataset, to be later added to index.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
text_dataset (GPTToTextDataset): Text dataset to be encoded by the index.
"""
codes_dir = get_added_codes_dir(config)
retro_makedir(config, codes_dir)
# Index.
index = self.get_empty_index(config)
# Bert embedder.
embedder = config.retro_bert_embedders.mem
# Missing code blocks.
def validate(f: h5py.File) -> None:
"""Validation method for validating loaded encodings.
Args:
f (h5py.File): File that contains encodings.
"""
assert len(f["data"].shape) == 2
blocks = get_blocks_by_rank(
codes_dir, len(text_dataset), config.retro_block_size, validate=validate,
)
# Encode each block.
for block_index, block in enumerate(blocks.missing):
if block is not None:
# Progress.
log_retro_rank_0(
"encode block %d / %d ... %s."
% (block_index, len(blocks.missing), block["path"],)
)
# Encode and save.
_, codes = self.encode_block(index, embedder, text_dataset, block)
self.save_block(config, block, codes)
# Synchronize progress across all ranks. (for easier observation)
log_retro_rank_0(" > waiting for other ranks to finish block.")
torch.distributed.barrier()
def add_codes(self, config: RetroPreprocessingConfig) -> None:
"""Read codes from disk, and add them to the index.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
"""
if torch.distributed.get_rank() != 0:
return
added_index_path = self.get_added_index_path(config)
if os.path.exists(added_index_path):
return
# Index.
log_retro_rank_0("read empty index.")
index = self.get_empty_index(config)
index_ivf = faiss.extract_index_ivf(index)
# Add codes.
log_retro_rank_0("add codes.")
code_paths = get_added_code_paths(config)
pbar = tqdm(code_paths)
for code_path in pbar:
pbar.set_description(
"add codes, mem %.3f gb, %.1f%%"
% (psutil.virtual_memory()[3] / 1024 ** 3, psutil.virtual_memory()[2],)
)
with h5py.File(code_path) as f:
nload = int(config.retro_index_add_load_fraction * f["data"].shape[0])
offset = int(os.path.basename(code_path).split("-")[0])
xids = np.arange(offset, offset + nload)
codes = np.copy(f["data"][:nload])
index_ivf.add_sa_codes(codes, xids)
# Update index's ntotal.
index.ntotal = index_ivf.ntotal
# Write index.
log_retro_rank_0("write added index.")
faiss.write_index(index, added_index_path)
def remove_codes(self, config: RetroPreprocessingConfig) -> None:
"""Remove added codes after adding to index.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
"""
if torch.distributed.get_rank() != 0:
return
assert os.path.isfile(self.get_added_index_path(config))
if config.retro_index_delete_added_codes:
raise Exception("remove?")
shutil.rmtree(get_added_codes_dir(config), ignore_errors=True)
def add(self, config: RetroPreprocessingConfig, text_dataset: GPTToTextDataset) -> None:
"""Add vectors to index.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
text_dataset (GPTToTextDataset): Text dataset that will be embedded and added to the index.
"""
# Encode chunks.
self.encode(config, text_dataset)
# Add codes to index.
self.add_codes(config)
# Wait for (single-process) adding to complete.
torch.distributed.barrier()
# Remove codes.
self.remove_codes(config)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Utilities for building an index."""
import glob
import os
from typing import List, Tuple
from megatron.core.datasets.retro.config import RetroPreprocessingConfig
from megatron.core.datasets.retro.utils import retro_makedir
def get_index_dir(config: RetroPreprocessingConfig) -> str:
"""Create sub-directory for this index.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
Returns:
Path to index sub-directory within Retro project.
"""
# Directory path.
index_dir_path = os.path.join(
config.retro_project_dir, "index", config.retro_index_type, config.retro_index_str,
)
# Make directory.
retro_makedir(config, index_dir_path)
return index_dir_path
def num_samples_to_block_ranges(
config: RetroPreprocessingConfig, num_samples: int
) -> List[Tuple[int, int]]:
"""Split a range (length num_samples) into sequence of block ranges
of size block_size.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
num_samples (int): Split `num_samples` into consecutive block ranges, where each block is size `config.retro_block_size`.
Returns:
A list of tuples where each item is the (start, end) index for a given block.
"""
block_size = config.retro_block_size
start_idxs = list(range(0, num_samples, block_size))
end_idxs = [min(num_samples, s + block_size) for s in start_idxs]
ranges = list(zip(start_idxs, end_idxs))
return ranges
def get_training_data_root_dir(config: RetroPreprocessingConfig) -> str:
"""Get root directory for embeddings (blocks and merged data).
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
Returns:
Path to the training data directory, which contains both training embedding blocks and the final merged training embeddings.
"""
return os.path.join(config.retro_project_dir, "index", "train_emb")
def get_training_data_block_dir(config: RetroPreprocessingConfig) -> str:
"""Get directory for of saved embedding blocks.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
Returns:
Path to the directory containing the training embedding blocks, which will be later merged into a single embedding array.
"""
return os.path.join(get_training_data_root_dir(config), "blocks")
def get_training_data_block_paths(config: RetroPreprocessingConfig) -> List[str]:
"""Get paths to saved embedding blocks.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
Returns:
Paths of all training embedding blocks.
"""
return sorted(glob.glob(get_training_data_block_dir(config) + "/*.hdf5"))
def get_training_data_merged_path(config: RetroPreprocessingConfig) -> str:
"""Get path to merged training embeddings.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
Returns:
Path to the merged training embedding binary file.
"""
return os.path.join(
get_training_data_root_dir(config),
"train_%.3f.bin" % config.retro_index_train_load_fraction,
)
def get_added_codes_dir(config: RetroPreprocessingConfig) -> str:
"""Get directory of saved encodings.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
Returns:
Path to the directory containing the vector encodings for adding to the index.
"""
return os.path.join(get_index_dir(config), "add_codes")
def get_added_code_paths(config: RetroPreprocessingConfig) -> List[str]:
"""Get paths to all saved encodings.
Args:
config (RetroPreprocessingConfig): Retro preprocessing config.
Returns:
Paths of all vector encoding blocks, for adding to the index.
"""
return sorted(glob.glob(get_added_codes_dir(config) + "/*.hdf5"))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment