"git@developer.sourcefind.cn:orangecat/ollama.git" did not exist on "14e71350c8761972aaf1c17b1dbcb13c6cc4f7a5"
Commit 762072e0 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'lmcafee/retro' into 'main'

Retro

See merge request ADLR/megatron-lm!489
parents ef59b687 17a6044d
#!/bin/bash
##################################################
# Example script for pretraining Retro.
##################################################
set -u
unset NCCL_DEBUG
export CUDA_DEVICE_MAX_CONNECTIONS=1
NPROCS=8 # NPROCS must be <= number of GPUs.
################ Dataset configs. ################
# This script contains methods to customize arguments to specific dataset
# types. Customize this script as needed for your datasets.
DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
. $DIR/get_dataset_configs.sh
################ Environment variables. ################
# *Note*: See 'Required environment variables' in 'get_preprocess_cmd.sh' for
# a description of the required environment variables. These variables can be
# set however a user would like. In our setup, we use another bash script
# (location defined by $RETRO_ENV_VARS) that sets all the environment variables
# at once.
. $RETRO_ENV_VARS
################ Data blend. ################
. ${DATA_BLEND_SCRIPT}
DATA_PATH=${DATA_BLEND}
######## Retro setup. ########
RETRO_ADD_RETRIEVER=1
RETRO_CYCLIC_TRAIN_ITERS=750000
RETRO_NUM_NEIGHBORS=2
######## Arguments. ########
CHECKPOINT_DIR=${RETRO_WORKDIR}/checkpoints/${RETRO_ADD_RETRIEVER}
TENSORBOARD_DIR="${CHECKPOINT_DIR}/tensorboard"
mkdir -p ${TENSORBOARD_DIR}
ARGS=" \
--save-interval 1000 \
--save ${CHECKPOINT_DIR} \
--load ${CHECKPOINT_DIR} \
--tensorboard-dir ${TENSORBOARD_DIR} \
--log-interval 5 \
--tensor-model-parallel-size 1 \
--pipeline-model-parallel-size 1 \
--num-layers 12 \
--hidden-size 768 \
--num-attention-heads 12 \
--seq-length 2048 \
--max-position-embeddings 2048 \
--micro-batch-size 4 \
--global-batch-size 256 \
--train-samples ${RETRO_GPT_TRAIN_SAMPLES} \
--lr-decay-samples ${LR_DECAY_SAMPLES} \
--lr-warmup-samples ${LR_WARMUP_SAMPLES} \
--lr 6.0e-4 \
--min-lr 6.0e-5 \
--lr-decay-style cosine \
--eval-interval ${RETRO_GPT_EVAL_INTERVAL} \
--eval-iters ${RETRO_GPT_EVAL_ITERS} \
--data-path ${DATA_PATH} \
--vocab-file ${GPT_VOCAB_FILE} \
--merge-file ${GPT_MERGE_FILE} \
--split 98,2,0 \
--clip-grad 1.0 \
--weight-decay 0.1 \
--adam-beta1 0.9 \
--adam-beta2 0.95 \
--init-method-std 0.023 \
--log-params-norm \
--log-num-zeros-in-grad \
--fp16 \
--DDP-impl local \
--dataloader-type ${DATALOADER_TYPE} \
--no-data-sharding \
--no-gradient-accumulation-fusion \
"
if [ "$RETRO_ADD_RETRIEVER" = "0" ]; then
SCRIPT=pretrain_gpt.py
else
ARGS="${ARGS} \
--retro-add-retriever \
--retro-workdir ${RETRO_WORKDIR} \
--retro-cyclic-train-iters ${RETRO_CYCLIC_TRAIN_ITERS} \
--retro-num-neighbors ${RETRO_NUM_NEIGHBORS} \
"
SCRIPT=pretrain_retro.py
fi
echo "~~~~~~~~~~~~~~~~~~~~~~~~~~"
echo "ARGS = '$ARGS'."
echo "~~~~~~~~~~~~~~~~~~~~~~~~~~"
python -m torch.distributed.launch \
--nproc_per_node ${NPROCS} \
--nnodes 1 \
--node_rank 0 \
--master_addr localhost \
--master_port 6000 \
${SCRIPT} \
${ARGS} \
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import importlib
required_libs = [
"faiss",
"h5py",
"transformers", # for huggingface bert
]
for lib in required_libs:
try:
globals()[lib] = importlib.import_module(lib)
except ImportError as e:
raise Exception(f"Missing one or more packages required for Retro preprocessing: {required_libs}. Tried importing '{lib}'.")
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from .index import Index
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import numpy as np
import os
import shutil
import torch
from tqdm import tqdm
from megatron import get_retro_args, print_rank_0
from tools.bert_embedding import DiskDataParallelBertEmbedder
from tools.retro.db.utils import (
get_indexed_dataset_infos,
get_merged_sampled_dataset,
get_merged_train_dataset,
)
from tools.retro.external_libs import h5py
from tools.retro.index.factory import IndexFactory
from tools.retro.utils import GPTToTextDataset
from .utils import (
get_training_data_dir,
get_training_data_merged,
)
##################################################
# Train index.
##################################################
def get_empty_index_path():
'''Path of empty index.'''
args = get_retro_args()
index = IndexFactory.get_index(args.retro_index_type)
empty_index_path = index.get_empty_index_path()
return empty_index_path
def embed_db():
'''Embed DB chunks.
Store chunks in blocks on disk. These blocks will later be merged into
a single dataset for training the index.
'''
args = get_retro_args()
# Get db dataset.
gpt_dataset = get_merged_sampled_dataset()
text_dataset = GPTToTextDataset(gpt_dataset)
# Embed dataset.
embedder = DiskDataParallelBertEmbedder(args.retro_bert_batch_size,
args.retro_bert_max_chunk_length,
args.retro_block_size,
args.bert_embedder_type)
embedder.embed_text_dataset("index", get_training_data_dir(), text_dataset)
def train_on_embeddings():
'''Train index on embedded DB chunks.'''
args = get_retro_args()
index = IndexFactory.get_index(args.retro_index_type)
index.train(get_training_data_merged)
def remove_embeddings():
'''Remove embeddings after training.'''
torch.distributed.barrier()
if torch.distributed.get_rank() != 0:
return
empty_index_path = get_empty_index_path()
assert os.path.isfile(empty_index_path)
shutil.rmtree(get_training_data_dir(), ignore_errors=True)
def train_index():
'''Train index on DB chunks.'''
args = get_retro_args()
# Check if trained index already exists.
if not os.path.isfile(get_empty_index_path()):
# Embed training chunks.
embed_db()
# Train index on embeddings.
train_on_embeddings()
# Wait for (single-process) training to complete.
torch.distributed.barrier()
# Remove embeddings.
if args.retro_delete_index_training_embeddings:
remove_embeddings()
##################################################
# Add to index.
##################################################
def add_to_index():
'''Add DB chunks to index.'''
args = get_retro_args()
# Get index.
index = IndexFactory.get_index(args.retro_index_type)
# Get text dataset.
gpt_dataset = get_merged_train_dataset()
text_dataset = GPTToTextDataset(gpt_dataset)
# Add to index.
output_index_path = index.add(text_dataset)
return output_index_path
##################################################
# Build index (train + add).
##################################################
def build_index():
'''Build index.
Building index involves sequentially running stages above:
- Train index (on sampled training chunks).
- Add to index (on all training chunks).
'''
# Train index.
train_index()
# Add to index.
add_to_index()
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from .indexes import FaissBaseIndex, FaissParallelAddIndex
class IndexFactory:
'''Get index.
Index type generally read from argument '--retro-index-ty'.
'''
@classmethod
def get_index_class(cls, index_type):
return {
"faiss-base" : FaissBaseIndex,
"faiss-par-add" : FaissParallelAddIndex,
}[index_type]
@classmethod
def get_index(cls, index_type):
index_class = cls.get_index_class(index_type)
index = index_class()
return index
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import abc
import numpy as np
import os
import torch
from tools.retro.external_libs import faiss
from .utils import get_index_dir
class Index(abc.ABC):
'''Abstract base class for indexes.
*Note* : While currently only Faiss-based classes are implemented, in the
future, this class will be extended with other types of indexes that have
different performance-accuracy trade-offs.
The primary methods to override are:
- train() : Train index on the sampled training chunks.
- add() : Add all training chunks to index.
'''
@classmethod
def c_verbose(cls, index, v):
'''Make index object verbose.'''
assert isinstance(v, bool)
faiss.ParameterSpace().set_index_parameter(index, "verbose", v)
def get_empty_index_path(self):
return os.path.join(get_index_dir(), "empty.faissindex")
def get_empty_index(self):
return faiss.read_index(self.get_empty_index_path())
def get_added_index_path(self):
return os.path.join(get_index_dir(), "added.faissindex")
def get_added_index(self):
return faiss.read_index(self.get_added_index_path())
@abc.abstractmethod
def train(self, *args):
pass
@abc.abstractmethod
def add(self, *args):
pass
def embed_text_dataset_block(self, embedder, text_dataset, _range):
'''Embed a range of a text dataset.'''
sub_dataset = torch.utils.data.Subset(text_dataset, range(*_range))
return embedder.embed_text_dataset(sub_dataset)
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from .faiss_base import FaissBaseIndex
from .faiss_par_add import FaissParallelAddIndex
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
"""
This class implements a simple, un-optimized wrapper around a Faiss index, that
implements the Index interface (see ..index.py). While this class is
instantiable, it is meant to be extended with optimizations in classes that
inherit from this class (see FaissParAddIndex, for an example).
"""
from datetime import timedelta
import os
import torch
from tqdm import tqdm
from megatron import get_retro_args, print_rank_0
from tools.bert_embedding import BertEmbedder
from tools.retro.external_libs import faiss
from tools.retro.index import Index
from tools.retro.index.utils import num_samples_to_block_ranges
class FaissBaseIndex(Index):
def _train(self, input_data_loader):
'''Train index (rank 0's method).'''
args = get_retro_args()
assert torch.distributed.get_rank() == 0
# Set num threads (torch.distributed reset it to 1).
# faiss.omp_set_num_threads(32)
faiss.omp_set_num_threads(64)
# faiss.omp_set_num_threads(128)
empty_index_path = self.get_empty_index_path()
# Index already exists? -> return.
if os.path.isfile(empty_index_path):
return
# Load data.
inp = input_data_loader()
# Init index.
index = faiss.index_factory(args.retro_index_nfeats,
args.retro_index_str)
# Move to GPU.
index_ivf = faiss.extract_index_ivf(index)
clustering_index = \
faiss.index_cpu_to_all_gpus(faiss.IndexFlatL2(index_ivf.d))
index_ivf.clustering_index = clustering_index
self.c_verbose(index, True)
self.c_verbose(index_ivf, True)
self.c_verbose(index_ivf.quantizer, True)
self.c_verbose(index_ivf.clustering_index, True)
# Train index.
index.train(inp)
# Save index.
faiss.write_index(index, empty_index_path)
def train(self, input_data_loader):
'''Train index.'''
# Single process only.
if torch.distributed.get_rank() == 0:
self._train(input_data_loader)
torch.distributed.barrier()
def _add(self, text_dataset):
'''Add to index (rank 0's method).'''
assert torch.distributed.get_rank() == 0
args = get_retro_args()
dataset_sample_ranges = num_samples_to_block_ranges(len(text_dataset))
# Set num threads (torch.distributed reset it to 1).
faiss.omp_set_num_threads(64)
# Bert embedder.
embedder = BertEmbedder(args.retro_bert_batch_size,
args.retro_bert_max_chunk_length,
args.bert_embedder_type)
# Empty/added index paths.
empty_index_path = self.get_empty_index_path()
added_index_path = self.get_added_index_path()
# Skip adding, if index exists.
if os.path.isfile(added_index_path):
return
# Read trained index.
index = faiss.read_index(empty_index_path)
# Iterate data blocks & add.
for sample_range in tqdm(dataset_sample_ranges, "faiss_base.add"):
# Embed text.
embeds = self.embed_text_dataset_block(
embedder, text_dataset, sample_range)
# Add to index.
index.add(embeds)
# Write index.
faiss.write_index(index, added_index_path)
def add(self, text_dataset):
'''Add to index.'''
# Single process only.
if torch.distributed.get_rank() == 0:
self._add(text_dataset)
# Wait for rank 0.
torch.distributed.barrier()
# Get output index path, for return.
return self.get_added_index_path()
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
"""Multi-process & multi-node version of Faiss's index.add().
This class inherits from FaissBaseIndex, and optimizes the 'add()' method by
making it multi-node and multi-process, with bit-wise equivalence to
FaissBaseIndex. This allows 'add()' to scale out to very large datasets, since
the vast majority of the computational effort is embarrassingly parallel.
"""
import numpy as np
import os
import shutil
import torch
from tqdm import tqdm
from megatron import get_retro_args, print_rank_0
from tools.bert_embedding import BertEmbedder
from tools.bert_embedding.utils import get_missing_blocks_by_rank
from tools.retro.external_libs import faiss, h5py
from tools.retro.index.utils import get_added_codes_dir, get_added_code_paths
from .faiss_base import FaissBaseIndex
class FaissParallelAddIndex(FaissBaseIndex):
def encode_block(self, index, embedder, text_dataset, block):
'''Encode sub-dataset block, to be later added to index.
Encode the data subset, generally in blocks of 1M vectors each. For
each block, the empty/trained index is loaded, codes are computed
via index.sa_encode(), and the resulting codes are saved to disk.
'''
args = get_retro_args()
# Embed block.
embeddings = self.embed_text_dataset_block(
embedder,
text_dataset,
block["range"],
)
# Encode block.
print_rank_0("encode.")
codes = index.sa_encode(embeddings)
# Save neighbors.
print_rank_0("save codes.")
os.makedirs(os.path.dirname(block["path"]), exist_ok=True)
with h5py.File(block["path"], "w") as f:
f.create_dataset("data", data=codes)
def encode(self, text_dataset):
'''Encode text dataset, to be later added to index.'''
args = get_retro_args()
codes_dir = get_added_codes_dir()
# Index.
index = self.get_empty_index()
# Bert embedder.
embedder = BertEmbedder(args.retro_bert_batch_size,
args.retro_bert_max_chunk_length,
args.bert_embedder_type)
# Missing code blocks.
def validate(f):
assert len(f["data"].shape) == 2
n_missing_blocks, missing_code_blocks = get_missing_blocks_by_rank(
codes_dir,
len(text_dataset),
args.retro_block_size,
validate=validate,
)
# Encode each block.
for block_index, block in enumerate(missing_code_blocks):
if block is not None:
# Progress.
print_rank_0("encode block %d / %d ... %s." % (
block_index,
len(missing_code_blocks),
block["path"],
))
# Query block neighbors.
self.encode_block(index, embedder, text_dataset, block)
# Synchronize progress across all ranks. (for easier observation)
print_rank_0(" > waiting for other ranks to finish block.")
torch.distributed.barrier()
def add_codes(self):
if torch.distributed.get_rank() != 0:
return
added_index_path = self.get_added_index_path()
if os.path.exists(added_index_path):
return
# Index.
print_rank_0("read empty index.")
index = self.get_empty_index()
index_ivf = faiss.extract_index_ivf(index)
# Add codes.
print_rank_0("add codes.")
code_paths = get_added_code_paths()
for code_path in tqdm(code_paths, "add codes"):
with h5py.File(code_path) as f:
codes = np.copy(f["data"])
index_ivf.add_sa_codes(codes)
# Update index's ntotal.
index.ntotal = index_ivf.ntotal
# Write index.
print_rank_0("write added index.")
faiss.write_index(index, added_index_path)
def remove_codes(self):
'''Remove added codes after adding to index.'''
if torch.distributed.get_rank() != 0:
return
assert os.path.isfile(self.get_added_index_path())
shutil.rmtree(get_added_codes_dir(), ignore_errors=True)
def add(self, text_dataset):
# Check if index already exists.
if not os.path.isfile(self.get_added_index_path()):
# Encode chunks.
self.encode(text_dataset)
# Add codes to index.
self.add_codes()
# Wait for (single-process) adding to complete.
torch.distributed.barrier()
# Remove codes.
self.remove_codes()
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import concurrent
import gc
import glob
import numpy as np
import os
import psutil
import time
import torch
from tqdm import tqdm
from megatron import get_retro_args, print_rank_0
from tools.retro.db.utils import get_indexed_dataset_infos
from tools.retro.external_libs import h5py
def get_index_dir():
"""Create sub-directory for this index."""
args = get_retro_args()
# Directory path.
index_dir_path = os.path.join(
args.retro_workdir,
"index",
args.retro_index_type,
args.retro_index_str,
)
# Make directory.
os.makedirs(index_dir_path, exist_ok=True)
return index_dir_path
def num_samples_to_block_ranges(num_samples):
'''Split a range (length num_samples) into sequence of block ranges
of size block_size.'''
args = get_retro_args()
block_size = args.retro_block_size
start_idxs = list(range(0, num_samples, block_size))
end_idxs = [min(num_samples, s + block_size) for s in start_idxs]
ranges = list(zip(start_idxs, end_idxs))
return ranges
def get_training_data_dir():
return os.path.join(get_index_dir(), "train_tmp")
def get_training_data_paths():
return sorted(glob.glob(get_training_data_dir() + "/*.hdf5"))
def get_added_codes_dir():
return os.path.join(get_index_dir(), "add_tmp")
def get_added_code_paths():
return sorted(glob.glob(get_added_codes_dir() + "/*.hdf5"))
def get_training_data_group_infos():
args = get_retro_args()
block_paths = get_training_data_paths()
max_group_size = args.retro_index_train_block_size
groups = []
group = []
group_size = 0
for block_path in block_paths:
with h5py.File(block_path) as f:
block_size = f["data"].shape[0]
group.append(block_path)
group_size += block_size
if group_size >= max_group_size:
groups.append({
"paths" : group,
"size" : group_size,
})
group = []
group_size = 0
if group:
groups.append({
"paths" : group,
"size" : group_size,
})
return groups
def load_training_block(path, load_fraction):
with h5py.File(path) as f:
n_load = int(load_fraction * f["data"].shape[0])
return np.copy(f["data"][:n_load])
def load_training_group(executor, group_info, load_fraction):
# Launch threads to load block data.
futures = []
for path in group_info["paths"]:
futures.append(executor.submit(load_training_block, path, load_fraction))
# Collect block data.
block_datas = []
for future in futures:
block_datas.append(future.result())
# Concatenate blocks.
group_data = np.concatenate(block_datas, axis=0)
# Garbage collect.
for d in block_datas:
del d
gc.collect()
return group_data
def get_training_data_merged():
'''Merge embeddings into single dataset.'''
args = get_retro_args()
# Setup.
ds_infos = get_indexed_dataset_infos()
n_chunks_sampled = sum(d["n_chunks_sampled"] for d in ds_infos)
load_fraction = args.retro_index_train_load_fraction
# Initialize merged data.
print("allocate training data array.")
t = time.time()
data = np.empty((n_chunks_sampled, args.retro_index_nfeats), dtype="f4")
print(" time : %.3f sec." % (time.time() - t))
# Data groups (minimizing fragmentation).
group_infos = get_training_data_group_infos()
# Load data blocks.
n_threads = max(len(group["paths"]) for group in group_infos)
with concurrent.futures.ThreadPoolExecutor(max_workers=n_threads) as executor:
# Load data blocks.
print("load training data blocks.")
start_idx = 0
pbar = tqdm(group_infos)
for group_info in pbar:
pbar.set_description("mem %.0f gb, %.1f%%" % (
psutil.virtual_memory()[3] / 1024**3,
psutil.virtual_memory()[2],
))
# Load group data.
group_data = load_training_group(executor, group_info, load_fraction)
data[start_idx:(start_idx+len(group_data))] = group_data
start_idx += len(group_data)
# Garbage collect.
del group_data
gc.collect()
# Handle load ratio <1.
data = data[:start_idx]
print("> training block data.shape = %s." % str(data.shape))
return data
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
"""Preprocess data for Retro.
Stages (see argument '--retro-tasks'):
- Build chunk database (DB).
- Build index (train, add).
- Query pretraining neighbors.
"""
import json
import os
import torch
from megatron import get_args, initialize_megatron, print_rank_0
from megatron.global_vars import set_retro_args
from tools.retro.db import build_db
from tools.retro.index.build import add_to_index, build_index, train_index
from tools.retro.pretraining.query import query_pretraining_neighbors
from tools.retro.utils import get_args_path
def add_retro_args(parser):
"""Retro preprocesing arguments.
*Note* : Arguments prefixed with '--retro-gpt-*' or '--retro-bert-*' are
included and named as such to more easily handle managing both models
running at the same time. Megatron is not optimized to run two models at
once, so this naming convention makes it clearer.
"""
group = parser.add_argument_group(title="Retro preprocessing.")
group.add_argument("--retro-gpt-vocab-file", required=True,
help="GPT vocab file.")
group.add_argument("--retro-gpt-merge-file", required=True,
help="GPT merge file.")
group.add_argument("--retro-gpt-tokenizer-type", required=True,
help="GPT tokenizer type.")
group.add_argument("--retro-gpt-seq-length", type=int, default=2048,
help="GPT sequence length.")
group.add_argument("--retro-gpt-chunk-length", type=int, default=64,
help="GPT chunk length.")
group.add_argument("--retro-bert-vocab-file", required=True,
help="Bert vocab file.")
group.add_argument("--retro-bert-tokenizer-type", required=True,
help="Bert tokenizer type (for when using "
"'--bert-embedder-type megatron').")
group.add_argument("--retro-bert-batch-size", type=int, default=128,
help="Micro-batch size for processing Bert embeddings.")
group.add_argument("--retro-bert-max-chunk-length", type=int, default=256,
help="Maximum sequence length for Bert embeddings. "
"(Named 'chunk' here in reference to these Bert "
"sequences being converted from GPT chunks.)")
group.add_argument("--retro-tasks", default="build",
help="Comma-separated list of tasks to run. Run entire "
"preprocesing pipeline by using '--retro-tasks build'. "
"Alternatively, run individual stages with tasks (in "
"this order) 'db-build', 'index-build', or "
"'pretraining-query-neighbors'. For example, "
"'--retro-tasks db-build,index-build,"
"pretraining-query-neighbors' is equivalent to "
"'--retro-tasks build'; or the argument can contain "
"a subset of these tasks. Stages must always be run "
"in the correct order (listed above).")
group.add_argument("--retro-index-nfeats", "-f", type=int, default=1024,
help="Dimension of Bert embeddings. Bert-large is "
"commonly used, so this value defaults to 1024.")
group.add_argument("--retro-index-type", default="faiss-par-add",
choices=["faiss-base", "faiss-par-add"],
help="A 'faiss-base' index is a simple, un-optimized "
"wrapper around a Faiss index. A 'faiss-par-add' index "
"optimizes the 'add()' method by making it multi-node "
"and multi-process, but with bit-wise equivalent "
"results.")
group.add_argument("--retro-index-str", required=True,
help="Index string used for calling "
"faiss.index_factory(). For example, "
"'IVF262144_HNSW32,Flat' or "
"'OPQ32_256,IVF4194304_HNSW32,PQ32'.")
group.add_argument("--retro-ef-search", type=int, default=256,
help="Index ef-search parameter for HNSW during "
"querying.")
group.add_argument("--retro-nprobe", type=int, default=65536,
help="Index nprobe parameter for IVF during "
"querying.")
group.add_argument("--retro-nchunks-sampled", type=int, required=True,
help="Number of database chunks to use for training "
"the index. This value must be less or equal to the "
"total number of chunks in the database.")
group.add_argument("--retro-doc-block-size", type=int, default=100000,
help="Number of documents to processe at time when "
"processing token datasets into chunk databases. The "
"partial chunk database for each block is saved into "
"a separate file.")
group.add_argument("--retro-block-size", type=int, default=100000,
help="Number of chunks to process at a time when "
"generating Bert embeddings and querying the search "
"index. Partial results for each block are generally "
"saved to disk in separate files.")
group.add_argument("--retro-index-train-block-size",
type=int, default=3750000,
help="As a memory fragmentation optimization, when "
"loading training data for training the search index, "
"enough data blocks loaded at a time until they reach "
"retro_index_train_block_size, and then this "
"data block is copied into the full training data "
"array.")
group.add_argument("--retro-index-train-load-fraction",
type=float, default=1.,
help="Fraction of sampled chunks to use for training "
"the index. Useful when our total sampled embeddings "
"use too much memory; lowering the load fraction is "
"less costly than re-embedding a new sampled dataset "
"from scratch.")
group.add_argument("--retro-num-neighbors-query", type=int, default=2000,
help="Number of neighbors to retrieve when calling "
"index.search().")
group.add_argument("--retro-num-neighbors-target", type=int, default=200,
help="Number of neighbors to save to disk after "
"the index's returned neighbors. If longer than target "
"value, neighbors truncated; and if shorter than target "
"value, neighbors are padded with -1's.")
group.add_argument("--retro-no-delete-index-training-embeddings",
action='store_false',
dest="retro_delete_index_training_embeddings",
help="Skip deleting training embeddings for the search "
"index. Useful for debugging.")
# Enforce argument naming convention.
for action in group._group_actions:
prefix = action.dest.split("_")[0]
assert prefix == "retro", \
"Retro args must be prefixed with '--retro-*', for consistent " \
"styling. Please fix '%s'." % ", ".join(action.option_strings)
return parser
def save_args(args):
'''Save copy of args within retro workdir.'''
if torch.distributed.get_rank() == 0:
args_path = get_args_path(args.retro_workdir)
with open(args_path, "w") as f:
json.dump(vars(args), f, indent=4, default=lambda o : "<skipped>")
torch.distributed.barrier()
if __name__ == "__main__":
# Initalize Megatron.
initialize_megatron(extra_args_provider=add_retro_args)
# Split retro tasks.
args = get_args()
args.retro_tasks = args.retro_tasks.split(",")
# Save/set retro args.
os.makedirs(args.retro_workdir, exist_ok=True)
save_args(args)
set_retro_args(args)
# Select task to run.
for task in args.retro_tasks:
print_rank_0("start '%s'." % task)
# Run all stages.
if task == "build":
build_db()
torch.distributed.barrier()
build_index()
torch.distributed.barrier()
query_pretraining_neighbors()
# DB (i.e., chunk db).
elif task == "db-build":
build_db()
# Index.
elif task == "index-build":
build_index() # calls both train + add.
elif task == "index-train":
train_index() # train only
elif task == "index-add":
add_to_index() # add only
# Pretraining.
elif task == "pretraining-query-neighbors":
query_pretraining_neighbors()
else:
raise Exception("specialize for task '%s'." % task)
torch.distributed.barrier()
print_rank_0("end '%s'." % task)
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import os
import torch
from megatron import get_retro_args, print_rank_0
from megatron.data.gpt_dataset import build_train_valid_test_datasets
from megatron.training import (
build_train_valid_test_data_loaders,
update_train_iters,
)
from tools.retro.db.utils import get_indexed_dataset_infos
from tools.retro.utils import get_num_chunks_per_sample
from .utils import get_pretraining_workdir
class ChunkDataset(torch.utils.data.Dataset):
'''Pretraining chunk dataset wraps a standard GPT dataset.
This dataset conceptually divides each sample (e.g., length 2048)
into chunks (e.g., length 64) and restructures them into a list of
chunks (e.g., length num_samples * num_chunks_per_sample).
'''
def __init__(self, sample_dataset, chunk_length):
super().__init__()
self.sample_dataset = sample_dataset
self.chunk_length = chunk_length
self.n_chunks_per_sample = get_num_chunks_per_sample()
self.n_samples = len(sample_dataset)
self.n_chunks = self.n_samples * self.n_chunks_per_sample
def __len__(self):
return self.n_chunks
def __getitem__(self, idx):
# Convert global chunk index to global sample index & local chunk index.
sample_idx = idx // self.n_chunks_per_sample
chunk_idx = idx % self.n_chunks_per_sample
# Extract sample data.
sample = self.sample_dataset[sample_idx]
sample_token_ids = sample["text"]
sample_doc_ids = sample["doc_ids"]
# Chunk start/end token idxs.
token_start_idx = chunk_idx * self.chunk_length
token_end_idx = token_start_idx + self.chunk_length
chunk_token_ids = sample_token_ids[token_start_idx:token_end_idx]
# Sample.
return {
"doc_ids" : sample_doc_ids,
"text" : chunk_token_ids,
}
def verify_indexed_dataset_order():
'''Verify pretraining order same as DB order.'''
args = get_retro_args()
# DB dataset prefixes.
db_indexed_dataset_infos = get_indexed_dataset_infos()
db_prefixes = [ info["prefix"] for info in db_indexed_dataset_infos ]
# Verify order & prefixes.
assert len(args.data_path) >= 2, "blendable dataset supported only."
pretraining_prefixes = args.data_path[1:None:2]
if len(db_prefixes) != len(pretraining_prefixes):
raise Exception("inconsistent dataset count between db & pretraining.")
if db_prefixes != pretraining_prefixes:
raise Exception("inconsistent dataset order between db & pretraining.")
def train_valid_test_datasets_provider(train_val_test_num_samples):
"""Build train, valid, and test datasets."""
args = get_retro_args()
print_rank_0('> building train, validation, and test datasets '
'for GPT ...')
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
data_prefix=args.data_path,
data_impl=args.data_impl,
splits_string=args.split,
train_valid_test_num_samples=train_val_test_num_samples,
seq_length=args.retro_gpt_seq_length,
seed=args.seed,
skip_warmup=(not args.mmap_warmup),
return_doc_ids=args.retro_return_doc_ids)
print_rank_0("> finished creating pretrained GPT datasets ...")
return train_ds, valid_ds, test_ds
def get_chunk_dataset_map():
'''Get train, valid, test chunk datasets.'''
args = get_retro_args()
# Update train iters.
update_train_iters(args)
args.iteration = 0
args.consumed_train_samples = 0
# Verify indexed dataset order.
verify_indexed_dataset_order()
# Datasets.
print_rank_0(" > data loader.")
train_data_loader, valid_data_loader, test_data_loader \
= build_train_valid_test_data_loaders(
train_valid_test_datasets_provider)
data_loader_map = {
"train" : train_data_loader,
"valid" : valid_data_loader,
"test" : test_data_loader,
}
# Info dict.
workdir = get_pretraining_workdir()
dataset_map = {
key : {
"neighbor_dir" : os.path.join(
workdir,
os.path.basename(loader.dataset.datasets[0].index_prefix),
),
"data" : ChunkDataset(loader.dataset, args.retro_gpt_chunk_length),
}
for key, loader in data_loader_map.items() if loader
}
return dataset_map
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from collections import defaultdict
import numpy as np
import os
import time
import torch
from tqdm import tqdm
from megatron import get_retro_args, mpu, print_rank_0
from tools.bert_embedding import BertEmbedder
from tools.bert_embedding.utils import get_missing_blocks_by_rank
from tools.retro.db.utils import (
get_merged_train_dataset as get_db_merged_train_dataset,
get_train_doc_chunk_map,
)
from tools.retro.external_libs import faiss, h5py
from tools.retro.index.factory import IndexFactory
from tools.retro.index.utils import get_index_dir, num_samples_to_block_ranges
from tools.retro.utils import GPTToTextDataset
from .chunk_dataset import get_chunk_dataset_map
def get_index(chunk_db_dataset, ondisk=False):
'''Read index from disk.'''
args = get_retro_args()
# Chunk db block ranges.
n_db_chunks = len(chunk_db_dataset)
dataset_block_ranges = num_samples_to_block_ranges(n_db_chunks)
# Load index.
index_wrapper = IndexFactory.get_index(args.retro_index_type)
index_dir = get_index_dir()
added_index_path = index_wrapper.get_added_index_path()
if ondisk:
index = faiss.read_index(added_index_path, faiss.IO_FLAG_MMAP)
else:
index = faiss.read_index(added_index_path)
# Search parameters.
faiss.ParameterSpace().set_index_parameter(index, "efSearch",
args.retro_ef_search)
faiss.ParameterSpace().set_index_parameter(index, "nprobe",
args.retro_nprobe)
return index
def embed_block(gpt_dataset, block, embedder):
'''Embed block of chunks.'''
text_block_dataset = torch.utils.data.Subset(
GPTToTextDataset(gpt_dataset),
range(*block["range"]),
)
return embedder.embed_text_dataset(text_block_dataset)
def query_embeddings(index, banned_chunk_map, chunk_id_range,
embeddings, sample_map, n_chunks_per_sample,
verbose=True):
'''Query neighbors of a block of embeddings.'''
args = get_retro_args()
# Query neighbor ids.
if verbose: print_rank_0("search.")
t = time.time()
assert index.ntotal > 0, "check we don't accidentally have an empty index."
_, query_neighbor_ids = \
index.search(embeddings, args.retro_num_neighbors_query)
if verbose: print_rank_0(" time : %.3f sec." % (time.time() - t))
# Banned neighbor ids.
if verbose: print_rank_0("get banned neighbor ids.")
sample_banned_chunk_id_map = {}
for sample_id, sample in sample_map.items():
dataset_idx = sample["dataset_idx"].item()
doc_ids = sample["doc_ids"].tolist()
banned_chunk_ids = set()
for doc_id in doc_ids:
banned_chunk_ids.update(banned_chunk_map[(dataset_idx, doc_id)])
sample_banned_chunk_id_map[sample_id] = banned_chunk_ids
# Filter banned neighbor ids.
if verbose: print_rank_0("filter banned neighbor ids.")
filtered_neighbor_ids = np.full(
shape=(len(query_neighbor_ids), args.retro_num_neighbors_target),
fill_value=-1,
dtype="int64",
)
min_chunk_id, max_chunk_id = chunk_id_range
for chunk_id in range(min_chunk_id, max_chunk_id):
sample_id = chunk_id // n_chunks_per_sample
# Get valid neighbors (!= -1).
query_row = [ i for i in query_neighbor_ids[chunk_id-min_chunk_id]
if i >= 0 ]
# Filter row.
filtered_row = [i for i in query_row
if i not in sample_banned_chunk_id_map[sample_id]]
filtered_row = filtered_row[:args.retro_num_neighbors_target]
filtered_row += \
[-1] * (args.retro_num_neighbors_target - len(filtered_row))
filtered_neighbor_ids[chunk_id-min_chunk_id] = filtered_row
return query_neighbor_ids, filtered_neighbor_ids
def query_embedding_block(index, banned_chunk_map, chunk_id_range,
embeddings, sample_map, n_chunks_per_sample):
query_neighbor_ids = []
filtered_neighbor_ids = []
# Query in sub-blocks.
partial_block_size = 1000
for partial_start_idx in tqdm(
range(0, len(embeddings), partial_block_size),
"search",
):
partial_end_idx = min(len(embeddings),
partial_start_idx + partial_block_size)
partial_embeddings = embeddings[partial_start_idx:partial_end_idx]
partial_chunk_id_range = (
chunk_id_range[0] + partial_start_idx,
chunk_id_range[0] + partial_end_idx,
)
partial_query_neighbor_ids, partial_filtered_neighbor_ids = \
query_embeddings(index, banned_chunk_map, partial_chunk_id_range,
partial_embeddings, sample_map, n_chunks_per_sample,
verbose=False)
query_neighbor_ids.append(partial_query_neighbor_ids)
filtered_neighbor_ids.append(partial_filtered_neighbor_ids)
# Concatenate.
query_neighbor_ids = np.concatenate(query_neighbor_ids, axis=0)
filtered_neighbor_ids = np.concatenate(filtered_neighbor_ids, axis=0)
return query_neighbor_ids, filtered_neighbor_ids
def query_block_neighbors(index, banned_chunk_map, chunk_dataset,
block, embedder):
'''Query neighbors of a dataset block (i.e., range).'''
args = get_retro_args()
n_chunks_per_sample = chunk_dataset.n_chunks_per_sample
# Sample map.
sample_ids = sorted(list(set(chunk_id // n_chunks_per_sample
for chunk_id in range(*block["range"]))))
sample_map = {i:chunk_dataset.sample_dataset[i] for i in sample_ids}
# Embed block.
embeddings = embed_block(chunk_dataset, block, embedder)
# Query embeddings.
_, filtered_neighbor_ids = query_embedding_block(
index, banned_chunk_map, block["range"],
embeddings, sample_map,
n_chunks_per_sample)
# Save neighbors.
print_rank_0("save neighbors.")
os.makedirs(os.path.dirname(block["path"]), exist_ok=True)
f = h5py.File(block["path"], "w")
f.create_dataset("neighbors", data=filtered_neighbor_ids)
f.close()
def query_dataset_neighbors(index, banned_chunk_map,
prefix, chunk_dataset, neighbor_dir,
embedder):
'''Query neighbors of each chunk within a dataset.'''
args = get_retro_args()
def validate(f):
assert f["neighbors"].shape[1] == args.retro_num_neighbors_target, \
"neighbors.shape == %s; num_neighbors_target == %d." % (
str(f["neighbors"].shape),
args.retro_num_neighbors_target,
)
n_missing_blocks, missing_neighbor_blocks = get_missing_blocks_by_rank(
neighbor_dir,
len(chunk_dataset),
args.retro_block_size,
validate=validate,
)
# Query each block.
for block_index, block in enumerate(missing_neighbor_blocks):
if block is not None:
# Progress.
print_rank_0("query '%s' block %d / %d ... %s." % (
prefix,
block_index,
len(missing_neighbor_blocks),
block["path"],
))
# Query block neighbors.
query_block_neighbors(index, banned_chunk_map,
chunk_dataset, block, embedder)
# Synchronize progress across all ranks. (for easier observation)
print_rank_0(" > waiting for other ranks to finish block.")
torch.distributed.barrier()
def query_pretraining_neighbors():
'''Query pretraining datasets (train & valid).'''
args = get_retro_args()
# Num threads.
faiss.omp_set_num_threads(64)
# Load chunk db dataset.
print_rank_0("load chunk db dataset.")
chunk_db_dataset = get_db_merged_train_dataset()
# Load index, banned chunk ids, datasets.
print_rank_0(" > get index.")
index = get_index(chunk_db_dataset)
print_rank_0(" > get banned doc-chunk id map.")
banned_chunk_map = get_train_doc_chunk_map()
print_rank_0(" > get dataset map.")
chunk_dataset_map = get_chunk_dataset_map()
# Bert embedder.
embedder = BertEmbedder(args.retro_bert_batch_size,
args.retro_bert_max_chunk_length,
args.bert_embedder_type)
# Query each (i.e., train, valid, test) dataset.
print_rank_0(" > query.")
for prefix, info in chunk_dataset_map.items():
print_rank_0(" > query '%s' dataset ... %d samples." %
(prefix, len(info["data"])))
query_dataset_neighbors(index, banned_chunk_map,
prefix, info["data"], info["neighbor_dir"],
embedder)
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import numpy as np
import os
import torch
from megatron import get_args, get_retro_args
from tools.bert_embedding.utils import get_index_path_map
from tools.retro.db.utils import get_merged_train_dataset as get_db_dataset
from tools.retro.external_libs import h5py
from .chunk_dataset import get_chunk_dataset_map
class RetroDataset(torch.utils.data.Dataset):
'''Dataset of retro samples.
Each sample contains the original GPT sample, along with the token IDs
of each neighbor of each chunk within the sequence. Neighbor array has
shape (num_chunks_per_sample, num_neighbors, num_retrieved_tokens).
'''
def __init__(self,
num_neighbors,
num_retrieved_chunks,
block_size,
db_dataset,
chunk_dataset,
neighbor_path_map):
'''Note: chunk dataset wraps original GPT dataset (see
chunk_dataset.py).'''
super().__init__()
self.num_neighbors = num_neighbors
self.num_retrieved_chunks = num_retrieved_chunks
self.block_size = block_size
self.db_dataset = db_dataset
self.chunk_dataset = chunk_dataset
self.neighbor_path_map = neighbor_path_map
def __len__(self):
return len(self.chunk_dataset.sample_dataset)
def __getitem__(self, sample_idx):
n_chunks_per_sample = self.chunk_dataset.n_chunks_per_sample
# Get standard sample.
sample = self.chunk_dataset.sample_dataset[sample_idx]
# Sample idx to chunk idxs.
chunk_idxs = list(range(
sample_idx * n_chunks_per_sample,
(sample_idx + 1) * n_chunks_per_sample,
))
# Collect retrieved tokens.
all_retrieved_chunk_ids = []
all_retrieved_token_ids = []
for chunk_idx in chunk_idxs:
# Neighbor chunk ids.
neighbor_path = self.neighbor_path_map[chunk_idx]
with h5py.File(neighbor_path, "r") as f:
neighbor_chunk_ids = f["neighbors"] \
[chunk_idx % self.block_size, :self.num_neighbors].tolist()
# Retrieved (neighbor + continuation) token ids.
retrieved_chunk_ids = []
retrieved_token_ids = []
for neighbor_chunk_id in neighbor_chunk_ids:
current_chunk_ids = [
i % len(self.db_dataset)
for i in range(
neighbor_chunk_id,
neighbor_chunk_id + self.num_retrieved_chunks)]
current_token_ids = [self.db_dataset[ci]["text"]
for ci in current_chunk_ids]
retrieved_chunk_ids.append(current_chunk_ids)
retrieved_token_ids.append(current_token_ids)
# Collect retrieved tokens.
all_retrieved_chunk_ids.append(retrieved_chunk_ids)
all_retrieved_token_ids.append(retrieved_token_ids)
# Reshape retrieved tokens.
all_retrieved_chunk_ids = np.array(all_retrieved_chunk_ids) \
.reshape((n_chunks_per_sample, self.num_neighbors, -1))
all_retrieved_token_ids = np.array(all_retrieved_token_ids) \
.reshape((n_chunks_per_sample, self.num_neighbors, -1))
# Sample.
sample = {
**sample,
"neighbor_chunks" : all_retrieved_chunk_ids,
"neighbor_tokens" : all_retrieved_token_ids,
}
return sample
def get_retro_datasets():
'''Get train, valid, test retro datasets.'''
args = get_args()
retro_args = get_retro_args()
# DB dataset.
db_dataset = get_db_dataset()
# Retro datasets.
chunk_ds_info_map = get_chunk_dataset_map()
retro_dataset_map = {}
for data_key, chunk_ds_info in chunk_ds_info_map.items():
chunk_dataset = chunk_ds_info["data"]
neighbor_dir = chunk_ds_info["neighbor_dir"]
neighbor_path_map = get_index_path_map(neighbor_dir)
# Verify dataset prefixes.
sample_prefix = chunk_dataset.sample_dataset.datasets[0].index_prefix
neighbor_prefix = os.path.basename(neighbor_dir)
assert sample_prefix == neighbor_prefix, \
"inconsistent dataset source; '%s' vs. '%s'." % \
(sample_prefix, neighbor_prefix)
# Verify num chunks.
n_sample_chunks = len(chunk_dataset)
n_neighbor_chunks = len(neighbor_path_map.id_index_map)
if n_sample_chunks != n_neighbor_chunks:
print("neighbor_dir : %s" % neighbor_dir)
print("neighbor_path_map : %s" % neighbor_path_map)
raise Exception("num sampled chunks (%d) != num neighbor chunks (%d)"
% (n_sample_chunks, n_neighbor_chunks))
# Retro dataset.
retro_dataset_map[data_key] = RetroDataset(
num_neighbors=args.retro_num_neighbors,
num_retrieved_chunks=args.retro_num_retrieved_chunks,
block_size=retro_args.retro_block_size,
db_dataset=db_dataset,
chunk_dataset=chunk_dataset,
neighbor_path_map=neighbor_path_map,
)
# Extract datasets.
train_ds = retro_dataset_map.get("train", None)
valid_ds = retro_dataset_map.get("valid", None)
test_ds = retro_dataset_map.get("test", None)
return train_ds, valid_ds, test_ds
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import os
from megatron import get_retro_args
def get_pretraining_workdir():
args = get_retro_args()
return os.path.join(args.retro_workdir, "pretraining")
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import os
import torch
import types
from megatron import get_retro_args
from megatron.tokenizer.tokenizer import (
_BertWordPieceTokenizer,
_GPT2BPETokenizer,
)
def get_args_path(workdir):
'''Argument copy stored within retro workdir.'''
return os.path.join(workdir, "args.json")
def get_num_chunks_per_sample():
'''Compute seq_length // chunk_length.'''
args = get_retro_args()
sample_length = args.retro_gpt_seq_length
chunk_length = args.retro_gpt_chunk_length
assert sample_length % chunk_length == 0
return sample_length // chunk_length
def get_gpt_tokenizer():
'''GPT (BPE) tokenizer.'''
args = get_retro_args()
return _GPT2BPETokenizer(
vocab_file=args.retro_gpt_vocab_file,
merge_file=args.retro_gpt_merge_file,
)
def get_bert_tokenizer():
'''Bert (Wordpiece) tokenizer.'''
args = get_retro_args()
lower_case = {
"BertWordPieceLowerCase" : True,
"BertWordPieceCase" : False,
}[args.retro_bert_tokenizer_type]
return _BertWordPieceTokenizer(
vocab_file=args.retro_bert_vocab_file,
lower_case=lower_case,
)
class GPTToTextDataset(torch.utils.data.Dataset):
'''Dataset to convert GPT tokens to text.'''
def __init__(self, gpt_dataset):
super().__init__()
self.gpt_dataset = gpt_dataset
self.gpt_tokenizer = get_gpt_tokenizer()
def __len__(self):
return len(self.gpt_dataset)
def __getitem__(self, idx):
gpt_token_ids = self.gpt_dataset[idx]["text"].tolist()
text = self.gpt_tokenizer.detokenize(gpt_token_ids)
return {"text": text}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment