Commit 762072e0 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'lmcafee/retro' into 'main'

Retro

See merge request ADLR/megatron-lm!489
parents ef59b687 17a6044d
......@@ -122,7 +122,7 @@ def load_vocab(vocab_file):
"""Loads a vocabulary file into a dictionary."""
vocab = collections.OrderedDict()
index = 0
with open(vocab_file, "r") as reader:
with open(vocab_file, "r", encoding = "utf-8") as reader:
while True:
token = convert_to_unicode(reader.readline())
if not token:
......
......@@ -143,6 +143,11 @@ def pretrain(train_valid_test_dataset_provider,
print_rank_0('training ...')
iteration = 0
if args.dataloader_type == 'cyclic' and args.retro_add_retriever:
args.train_iters = args.retro_cyclic_train_iters
print_rank_0("retro cyclic train iters : %d" % args.train_iters)
if args.do_train and args.train_iters > 0:
iteration = train(forward_step_func,
model, optimizer, opt_param_scheduler,
......@@ -751,7 +756,7 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
# Exiting based on iterations
if args.exit_interval and iteration % args.exit_interval == 0:
if not saved_checkpoint:
if args.save and not saved_checkpoint:
save_checkpoint_and_time(iteration, model, optimizer,
opt_param_scheduler)
torch.distributed.barrier()
......@@ -864,7 +869,8 @@ def cyclic_iter(iter):
for x in iter:
yield x
def build_train_valid_test_data_iterators(
def build_train_valid_test_data_loaders(
build_train_valid_test_datasets_provider):
"""XXX"""
args = get_args()
......@@ -931,6 +937,19 @@ def build_train_valid_test_data_iterators(
args.do_valid = flags[1].item()
args.do_test = flags[2].item()
return train_dataloader, valid_dataloader, test_dataloader
def build_train_valid_test_data_iterators(
build_train_valid_test_datasets_provider):
args = get_args()
# Build loaders.
train_dataloader, valid_dataloader, test_dataloader = \
build_train_valid_test_data_loaders(
build_train_valid_test_datasets_provider)
# Build iterators.
dl_type = args.dataloader_type
assert dl_type in ['single', 'cyclic']
......
......@@ -103,7 +103,7 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
skip_warmup=(not args.mmap_warmup),
train_data_prefix=args.train_data_path,
valid_data_prefix=args.valid_data_path,
test_data_prefix=args.test_data_path,)
test_data_prefix=args.test_data_path)
print_rank_0("> finished creating GPT datasets ...")
return train_ds, valid_ds, test_ds
......
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
"""Pretrain Retro."""
from functools import partial
import torch
from megatron import get_args, get_retro_args
from megatron import get_timers
from megatron import get_tokenizer
from megatron import print_rank_0
from megatron.core import mpu, tensor_parallel
from megatron.model import GPTModel, ModelType
from megatron.training import pretrain
from megatron.utils import get_ltor_masks_and_position_ids
from tools.retro.pretraining.retro_dataset import get_retro_datasets
from pretrain_gpt import (
loss_func,
model_provider,
train_valid_test_datasets_provider as standard_datasets_provider,
)
def get_batch(data_iterator):
"""Generate a batch"""
args = get_args()
retro_args = get_retro_args()
tokenizer = get_tokenizer()
# Items and their type.
keys = ['text']
datatype = torch.int64
if args.retro_add_retriever:
keys += 'neighbor_tokens',
# Broadcast data.
if data_iterator is not None:
data = next(data_iterator)
else:
data = None
data_b = tensor_parallel.broadcast_data(keys, data, datatype)
# Unpack.
tokens_ = data_b['text'].long()
labels = tokens_[:, 1:].contiguous()
tokens = tokens_[:, :-1].contiguous()
if args.retro_add_retriever:
# note: [bs * l * k, r]
# note: 2x == neighbor, continuation
neighbor_tokens = data_b['neighbor_tokens'] \
.view(-1, retro_args.retro_gpt_retrieved_length).long()
# Get the masks and postition ids.
attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
tokens,
tokenizer.eod,
args.reset_position_ids,
args.reset_attention_mask,
args.eod_mask_loss)
if args.retro_add_retriever:
_, _, neighbor_position_ids = get_ltor_masks_and_position_ids(
neighbor_tokens,
tokenizer.eod,
args.reset_position_ids,
args.reset_attention_mask,
args.eod_mask_loss)
neighbor_attention_mask = None
return tokens, labels, loss_mask, attention_mask, position_ids, \
neighbor_tokens, neighbor_attention_mask, neighbor_position_ids
else:
return tokens, labels, loss_mask, attention_mask, position_ids
def forward_step(data_iterator, model):
"""Forward step."""
args = get_args()
timers = get_timers()
# Get the batch.
timers('batch-generator').start()
if args.retro_add_retriever:
tokens, labels, loss_mask, attention_mask, position_ids, \
neighbor_tokens, neighbor_attention_mask, neighbor_position_ids = \
get_batch(data_iterator)
else:
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
data_iterator)
neighbor_tokens, neighbor_attention_mask, neighbor_position_ids = \
None, None, None
timers('batch-generator').stop()
output_tensor = model(tokens, position_ids, attention_mask,
ret_input_ids=neighbor_tokens,
ret_position_ids=neighbor_position_ids,
ret_attn_mask=neighbor_attention_mask,
labels=labels)
return output_tensor, partial(loss_func, loss_mask)
def train_valid_test_datasets_provider(train_val_test_num_samples):
"""Build train, valid, and test datasets."""
args = get_args()
if args.retro_add_retriever:
return get_retro_datasets()
else:
return standard_datasets_provider(train_val_test_num_samples)
if __name__ == "__main__":
pretrain(train_valid_test_datasets_provider, model_provider,
ModelType.encoder_or_decoder,
forward_step, args_defaults={'tokenizer_type': 'GPT2BPETokenizer'})
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from .embed import BertEmbedder, DiskDataParallelBertEmbedder
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import numpy as np
import torch
from megatron import get_args, get_tokenizer
from megatron.data.bert_dataset import build_training_sample
class BertEmbeddingDataset(torch.utils.data.Dataset):
'''Dataset to convert a text dataset to Bert tokens.'''
def __init__(self, text_dataset, max_seq_length):
super().__init__()
args = get_args()
# Dataset, tokenizer.
self.text_dataset = text_dataset
self.bert_tokenizer = get_tokenizer()
# Params to store.
self.max_seq_length = max_seq_length
self.seed = args.seed
self.masked_lm_prob = args.mask_prob
# Vocab stuff.
self.vocab_id_list = list(self.bert_tokenizer.inv_vocab.keys())
self.vocab_id_to_token_dict = self.bert_tokenizer.inv_vocab
self.cls_id = self.bert_tokenizer.cls
self.sep_id = self.bert_tokenizer.sep
self.mask_id = self.bert_tokenizer.mask
self.pad_id = self.bert_tokenizer.pad
def __len__(self):
return len(self.text_dataset)
def __getitem__(self, idx):
# Text.
text_sample = self.text_dataset[idx]
text = text_sample["text"]
text = text.replace("<|endoftext|>", "")
# Bert/Wordpiece tokens (+truncate).
bert_token_ids = self.bert_tokenizer.tokenize(text)
bert_token_ids = bert_token_ids[:self.max_seq_length - 2] # cls+sep.
if not bert_token_ids:
bert_token_ids = [ self.bert_tokenizer.pad_id ] # hack when empty seq
# Note that this rng state should be numpy and not python since
# python randint is inclusive whereas the numpy one is exclusive.
# We % 2**32 since numpy requres the seed to be between 0 and 2**32 - 1
np_rng = np.random.RandomState(seed=((self.seed + idx) % 2**32))
# Build sample.
sample = build_training_sample([bert_token_ids],
len(bert_token_ids),
len(bert_token_ids) + 2, # for cls+sep
self.vocab_id_list,
self.vocab_id_to_token_dict,
self.cls_id, self.sep_id,
self.mask_id, self.pad_id,
self.masked_lm_prob, np_rng,
binary_head=False)
sample["seq_length"] = len(sample["text"])
return sample
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from functools import partial
import numpy as np
import os
import time
import torch
from torch.utils.data import BatchSampler, DataLoader, SequentialSampler, Subset
from torch.utils.data._utils.collate import default_collate
from tqdm import tqdm
from megatron import get_args, get_tokenizer, print_rank_0
from megatron import core
from megatron.model import BertModel, ModelType
from megatron.schedules import get_forward_backward_func
from megatron.training import setup_model_and_optimizer
from .dataset import BertEmbeddingDataset
from .external_libs import h5py
from .huggingface import HuggingfaceEmbedder
from .utils import get_missing_blocks_by_rank
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
print_rank_0(" > build Bert model.")
args = get_args()
num_tokentypes = 2 if args.bert_binary_head else 0
model = BertModel(
num_tokentypes=num_tokentypes,
add_binary_head=args.bert_binary_head,
parallel_output=True,
pre_process=pre_process,
post_process=post_process)
return model
def get_batch(data_iterator):
"""Build the batch."""
# Items and their type.
keys = ['text', 'types', 'labels', 'is_random', 'loss_mask', 'padding_mask',
'seq_length']
datatype = torch.int64
# Broadcast data.
if data_iterator is not None:
data = next(data_iterator)
else:
data = None
data_b = core.tensor_parallel.broadcast_data(keys, data, datatype)
# Unpack.
tokens = data_b['text'].long()
types = data_b['types'].long()
sentence_order = data_b['is_random'].long()
loss_mask = data_b['loss_mask'].float()
lm_labels = data_b['labels'].long()
padding_mask = data_b['padding_mask'].long()
seq_lengths = data_b['seq_length'].long()
return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask, \
seq_lengths
def loss_func(loss_mask, sentence_order, seq_lengths,
output_tensor, non_loss_data):
"""Loss function. Sequence lengths returned here for progress print-outs."""
assert non_loss_data
return seq_lengths, output_tensor
def forward_step(data_iterator, model):
"""Forward step."""
args = get_args()
# Get the batch.
tokens, types, sentence_order, loss_mask, lm_labels, padding_mask, \
seq_lengths = get_batch(data_iterator)
if not args.bert_binary_head:
types = None
# Forward pass through the model.
output_tensor = model(tokens, padding_mask, tokentype_ids=types,
lm_labels=lm_labels)
return output_tensor, partial(loss_func, loss_mask, sentence_order,
seq_lengths)
def collate_batch(samples):
"""Collate samples of various lengths.
This collate function handles samples with various sequence lengths, by
padding 'text' arrays with pad_id, and other arrays with 0.
"""
n_samples = len(samples)
keys = list(samples[0].keys())
tokenizer = get_tokenizer()
# Max sample length across all samples.
max_length_map = { key:0 for key in keys }
for sample in samples:
for key in keys:
value_length = \
len(sample[key]) if isinstance(sample[key], np.ndarray) else None
max_length_map[key] = None \
if value_length is None else \
max(max_length_map[key], value_length)
# Pad samples.
padded_samples = []
for sample in samples:
padded_sample = {}
for key in keys:
padded_sample[key] = \
np.pad(
sample[key],
(0, max_length_map[key] - len(sample[key])),
mode="constant",
constant_values=tokenizer.pad_id if key == "text" else 0,
) \
if isinstance(sample[key], np.ndarray) else \
sample[key]
padded_samples.append(padded_sample)
# Build batch with padded samples.
batch = default_collate(padded_samples)
return batch
def get_data_loader(dataset, batch_size):
"""Build data loader over data subset.
Get a subset of the dataset (from start_idx -> end_idx), and wrap it in
a sequential sampler and data loader.
"""
args = get_args()
# Sequential & batch samplers.
batch_sampler = BatchSampler(
sampler=SequentialSampler(dataset),
batch_size=batch_size,
drop_last=False,
)
# Data loader.
data_loader = DataLoader(dataset,
batch_sampler=batch_sampler,
num_workers=args.num_workers,
pin_memory=True,
collate_fn=collate_batch)
return data_loader
def embed_data_loader(models, data_loader):
'''Iterate data loader and compute embeddings.'''
# Verify no model parallelism.
args = get_args()
assert args.tensor_model_parallel_size == 1 and \
args.pipeline_model_parallel_size == 1, \
"since we call forward_step directly, only tp == pp == 1 allowed."
# Data iterator.
data_iterator = iter(data_loader)
# Eval mode.
for m in models:
m.eval()
# Embed.
embeddings = []
for _ in tqdm(range(len(data_loader)), "mt embed"):
with torch.no_grad():
result = forward_step(data_iterator, models[0])
embeddings.append(result[0].detach().cpu().numpy())
# Concatenate embeddings.
embeddings = np.concatenate(embeddings, axis=0)
return embeddings
class BertEmbedder:
'''Compute Bert embeddings, from a text dataset.'''
def __init__(self, batch_size, max_bert_seq_length, embedder_type):
args = get_args()
assert args.output_bert_embeddings
self.models, optimizer, opt_param_scheduler = \
setup_model_and_optimizer(model_provider,
ModelType.encoder_or_decoder)
self.batch_size = batch_size
self.max_bert_seq_length = max_bert_seq_length
# Init Huggingface, if in use.
if embedder_type == "megatron":
self.huggingface_embedder = None
elif embedder_type == "huggingface":
self.huggingface_embedder = HuggingfaceEmbedder(batch_size,
max_bert_seq_length)
else:
raise Exception("specialize for embedder type '%s'." % embedder_type)
def embed_text_dataset(self, text_dataset):
'''Embed a text dataset.'''
# Huggingface.
if self.huggingface_embedder:
return self.huggingface_embedder.embed_text_dataset(text_dataset)
# Wrap in a BertEmbeddingDataset to tokenize samples.
bert_dataset = BertEmbeddingDataset(text_dataset,
self.max_bert_seq_length)
# Embed.
data_loader = get_data_loader(bert_dataset, self.batch_size)
embeddings = embed_data_loader(self.models, data_loader)
return embeddings
def embed_text(self, text):
'''Embed a single text string.
Primarily used for on-the-fly embeddings, particularly during
analysis or debugging. For large scale, use 'embed_text_dataset()'.
'''
class SingleTextDataset(torch.utils.data.Dataset):
'''Dataset that holds single string.'''
def __init__(self, text):
assert isinstance(text, str)
self.text = text
def __len__(self):
return 1
def __getitem__(self, i):
return {"text": self.text}
# Embed text.
text_ds = SingleTextDataset(text)
embed = self.embed_text_dataset(text_ds)[0]
return embed
class DiskDataParallelBertEmbedder:
'''Process embeddings in blocks & save to disk.'''
def __init__(self, batch_size, max_bert_seq_length, block_size,
embedder_type):
self.embedder = BertEmbedder(batch_size, max_bert_seq_length,
embedder_type)
self.block_size = block_size
def embed_text_blocks(self, name, workdir, text_dataset,
missing_embedding_blocks):
'''Process a text dataset in blocks.'''
# Iterate blocks.
for block_index, block_info in enumerate(missing_embedding_blocks):
# Missing block lists are extended with None to have equal-length
# lists. Skip the Nones.
if block_info is not None:
# Progress. (*note*: move world progress to here.)
print_rank_0("embed '%s' block %d / %d ... %s." % (
name,
block_index,
len(missing_embedding_blocks),
block_info["path"],
))
# Embed block.
sub_dataset = Subset(text_dataset, range(*block_info["range"]))
embeddings = self.embedder.embed_text_dataset(sub_dataset)
# Save embeddings.
f = h5py.File(block_info["path"], "w")
f.create_dataset("data", data=embeddings)
f.close()
# Synchronize progress across all ranks. (for easier observation)
print_rank_0(" > waiting for other ranks to finish block.")
torch.distributed.barrier()
def embed_text_dataset(self, name, workdir, text_dataset):
'''Embed a text dataset.'''
# Dataset workdir.
os.makedirs(workdir, exist_ok=True)
# Missing embedding blocks (stored on disk).
def validate(f):
assert f["data"].shape[1] == 1024
n_missing_world, missing_embedding_blocks = get_missing_blocks_by_rank(
workdir,
len(text_dataset),
self.block_size,
validate=validate)
# Prevent missing file race condition.
torch.distributed.barrier()
# Embed batches.
self.embed_text_blocks(name, workdir, text_dataset,
missing_embedding_blocks)
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import importlib
required_libs = [
"h5py",
"transformers", # for huggingface bert
]
for lib in required_libs:
try:
globals()[lib] = importlib.import_module(lib)
except ImportError as e:
raise Exception(f"Missing one or more packages required for Bert embedding: {required_libs}. Tried importing '{lib}'.")
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import numpy as np
import torch
from tqdm import tqdm
from .external_libs import transformers
class IterableTextDataset(torch.utils.data.IterableDataset):
'''Iterable over a text dataset.'''
def __init__(self, text_dataset):
self.text_dataset = text_dataset
def __iter__(self):
'''Remove 'endoftext' string.'''
for sample_idx in range(len(self.text_dataset)):
sample = self.text_dataset[sample_idx]
text = sample["text"].replace("<|endoftext|>", "")
yield text
class MyFeatureExtractionPipeline(transformers.FeatureExtractionPipeline):
def _forward(self, model_inputs):
# Embed inputs.
model_outputs = self.model(**model_inputs)
# Attention mask.
embeddings = model_outputs[0]
masks = torch.sum(model_inputs['attention_mask'], dim=1)
# Collect embeddings & check for nan.
outputs = []
for embedding, mask in zip(embeddings, masks):
output = torch.mean(embedding[1: mask - 1], dim=0)
# Nans due to empty input sequences; so only check first element.
if torch.isnan(output.view(-1)[0]).any():
output.zero_()
outputs.append(output)
# Sample.
data = {
"input" : model_inputs["input_ids"],
"output" : outputs,
}
return data
def postprocess(self, model_outputs):
# Return input for analysis.
return {
"input" : model_outputs["input"].numpy(),
"output" : model_outputs["output"].numpy(),
}
class HuggingfaceEmbedder:
def __init__(self, batch_size, max_seq_length):
# Model, tokenizer.
self.model = transformers.BertModel.from_pretrained("bert-large-cased")
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
"bert-large-cased", model_max_length=max_seq_length)
# Feature extraction pipeline.
self.pipe = MyFeatureExtractionPipeline(
model=self.model,
tokenizer=self.tokenizer,
device=torch.cuda.current_device(),
truncation=True,
max_length=max_seq_length,
)
self.batch_size = batch_size
def embed_text_dataset(self, text_dataset, verbose=True):
# Wrap dataset in iterable.
dataset = IterableTextDataset(text_dataset)
# Allocate output array.
n_samples = len(text_dataset)
embeddings = np.zeros((n_samples, 1024), dtype="f4")
start_idx = 0
# Wrap iterator in tqdm for verbose output.
_iter = self.pipe(dataset, batch_size=self.batch_size)
if verbose:
_iter = tqdm(_iter, "hf embed", total=n_samples)
# Embed dataset.
for idx, out_dict in enumerate(_iter):
inp = out_dict["input"]
out = out_dict["output"]
embeddings[start_idx] = out
start_idx += 1
return embeddings
def embed_text(self, text):
'''Embed a single text string.
Primarily used for on-the-fly embeddings, particularly during
analysis or debugging. For large scale, use 'embed_text_dataset()'.
'''
class SingleTextDataset(torch.utils.data.Dataset):
'''Dataset that holds single string.'''
def __init__(self, text):
assert isinstance(text, str)
self.text = text
def __len__(self):
return 1
def __getitem__(self, i):
return {"text": self.text}
# Embed text.
text_ds = SingleTextDataset(text)
embed = self.embed_text_dataset(text_ds, verbose=False)[0]
return embed
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from collections import defaultdict
import glob
import numpy as np
import os
import torch
from tqdm import tqdm
from megatron import print_rank_0
from megatron.core import parallel_state
from .external_libs import h5py
def save_data(data_map, *args):
'''Save map of numpy arrays to hdf5 file.'''
# Parse args.
if len(args) == 1:
path = args[0]
elif len(args) == 2:
dir_path, file_name = args
path = os.path.join(dir_path, file_name)
else:
raise Exception("specialize for len(args) == %d." % len(args))
# Save data.
if not os.path.isfile(path):
f = h5py.File(path, "w")
for k, v in data_map.items():
f.create_dataset(k, data=v)
f.close()
return path
def load_data(paths):
'''Load multiple hdf5 files to single numpy array.'''
# Read data shapes.
shape_map = defaultdict(lambda : (0, None))
for p in paths:
f = h5py.File(p, "r")
for k in f.keys():
shape = tuple(f[k].shape)
shape_map[k] = (shape_map[k][0] + shape[0], shape[1])
f.close()
# Allocate output array.
data_map = { k : np.empty(s, dtype="f4") for k, s in shape_map.items() }
start_map = { k : 0 for k in shape_map }
# Load files.
for pi, p in enumerate(tqdm(paths, "load data")):
f = h5py.File(p, "r")
for k in f.keys():
i0 = start_map[k]
i1 = i0 + len(f[k])
data_map[k][i0:i1] = f[k]
start_map[k] += len(f[k])
f.close()
return data_map
def get_missing_blocks(workdir, n_samples, block_size,
validate=lambda f : None):
'''Divide range [0, num_samples) to sequence of block ranges.
This is a core method within the concept of block processing. The idea
is to divide a range (size n_samples) into a sequence of blocks. Each
block corresponds to a file within 'workdir' with name
'{start_idx}-{end_idx}.hdf5'. This method checks for the existence of
these files, and returns a list of the ones that are missing.
'''
# Block ranges.
block_start_idxs = list(range(0, n_samples, block_size))
block_end_idxs = [ min(n_samples, i + block_size) for i in block_start_idxs ]
block_ranges = list(zip(block_start_idxs, block_end_idxs))
# All block files (existing + missing).
n_digits = int(np.ceil(np.log(n_samples) / np.log(10)) + 1)
all_blocks = [{
"range" : r,
"path" : os.path.join(
workdir,
"%s-%s.hdf5" % tuple([ str(i).zfill(n_digits) for i in r ]),
)
} for r in block_ranges]
all_block_path_set = set(block["path"] for block in all_blocks)
# Delete corrupt files.
if torch.distributed.get_rank() == 0:
existing_block_paths = [block["path"]
for block in all_blocks
if os.path.exists(block["path"])]
for index, path in enumerate(
tqdm(existing_block_paths, "validating block.")):
assert path in all_block_path_set, "unexpected filename, '%s'." % path
try:
f = h5py.File(path, "r")
except:
raise Exception("unable to open/validate '%s'." % path)
os.remove(path)
continue
try:
validate(f)
except:
raise Exception("delete block file.")
os.remove(path)
finally:
f.close()
# Wait for files to be deleted.
torch.distributed.barrier()
# Filter missing files.
missing_blocks = [block
for block in all_blocks
if not os.path.exists(block["path"])]
return missing_blocks
def get_missing_blocks_by_rank(workdir, n_samples, block_size,
validate=lambda f : None):
'''Divide missing blocks evenly across all ranks.
See 'get_missing_blocks()' above for description. The returned list of
missing blocks is split evenly across ranks via interleaving. This way,
each rank has a roughly equal number of blocks to process for a
downstream operation.
'''
missing_blocks = get_missing_blocks(workdir, n_samples, block_size,
validate)
# This rank's missing files.
data_parallel_rank = parallel_state.get_data_parallel_rank()
data_parallel_world_size = parallel_state.get_data_parallel_world_size()
rank_missing_blocks = missing_blocks[data_parallel_rank:len(missing_blocks):data_parallel_world_size]
# Extend rank's missing blocks (with None) such that all ranks have equal
# length lists. This allows for easier tracking of global progress.
n_missing_tensor = torch.cuda.LongTensor([len(rank_missing_blocks)])
torch.distributed.all_reduce(n_missing_tensor,
op=torch.distributed.ReduceOp.MAX)
max_n_missing = n_missing_tensor.item()
rank_missing_blocks += [None] * (max_n_missing - len(rank_missing_blocks))
return len(missing_blocks), rank_missing_blocks
class IdPathMap:
'''Maps indexes to the containing block path.
This class optimizing the mapping of a large number of indexes to the
path of its containing block. For example, with block_size 1M, this class
stores 1/1M as many (long) path strings, saving memory.
'''
def __init__(self, paths):
self.paths = paths
self.path_index_map = {p:i for i,p in enumerate(paths)}
self.id_index_map = {}
def __str__(self):
return "%d paths; %d ids" % (len(self.paths), len(self.id_index_map))
def add(self, id, path):
'''Map index to a path.'''
self.id_index_map[id] = self.path_index_map[path]
def __contains__(self, idx):
'''Index added to this object?'''
return idx in self.id_index_map
def __getitem__(self, idx):
'''Get path from index.'''
return self.paths[self.id_index_map[idx]]
def path_to_range(path):
'''Parse start/end indexes from block path name (e.g., 00010-00011.hdf5 ->
(10, 11).'''
return tuple([
int(i) for i in os.path.splitext(
os.path.basename(path))[0].split("-")])
def get_index_path_map(_dir):
'''Map contained indexes to block file path (on disk).'''
paths = sorted(glob.glob(_dir + "/*.hdf5"))
# Build index-path map.
idx_path_map = IdPathMap(paths)
for path in paths:
start_idx, end_idx = path_to_range(path)
for idx in range(start_idx, end_idx):
idx_path_map.add(idx, path)
return idx_path_map
This directory contains a collection of tools for building the retrieval database and pretraining neighbors for Retro. This preprocessing pipeline is broken into 3 main stages:
1. **Build retrieval chunk database** : Used for retrieving neighbors and continuation chunks, which are then passed through the retrieval encoder.
2. **Build index for similarity search** : Train and build a search index for querying chunk neighbors.
3. **Query pretraining neighbors** : For matching pretraining samples to database chunks. Neighbors are generated separately for training, validation, and test datasets.
The following overview goes into more detail on the pipeline, code structure, usage, and pretraining.
<!-- ################ contents ################ -->
# Contents
* [Quick start](#quick-start)
* [Stages](#stages)
* [Code structure](#code-structure)
* [Arguments](#arguments)
<!-- * [Pretraining](#pretraining) -->
<!-- ################ quick start ################ -->
# Quick start
See `examples/get_preprocess_cmd.sh` for example arguments.
Key files:
- `main.py` : Entry point.
- `examples/get_preprocess_cmd.sh` : Build preprocessing command (for `main.py`).
- `examples/preprocess_data.sh` : Run preprocessing (calls `get_preprocess_cmd.sh`, `main.py`).
Use `--retro-tasks` to move through the preprocessing pipeline.
- Simplest setup (builds everything): `--retro-tasks build`
- Alternatively, for tuning compute resources, run stages independently:
- Build retrieval database: `--retro-tasks db-build`
- Build search index: `--retro-tasks index-build`
- Query neighbors: `--retro-tasks pretraining-query-neighbors`
Sample code flow:
- `main.py` : Entry point (e.g., using `--retro-tasks X`).
- `db/build.py` : Build retrieval database.
- `index/build.py` : Build search index. Calls the following two files:
- `index/train.py` : Train index on subset of database.
- `index/add.py` : Add database chunks to index.
- `pretraining/query.py` : Query pretraining samples for database neighbors (saved to disk and used during pretraining).
<!-- ################ stages ################ -->
# Stages
### Build retrieval chunk database
This *database* (stored as a 2-D array, NOT a relational database) consists of a list of chunks (traditionally length 64) extracted from the original GPT token dataset. This is simply a consecutive, non-overlapping chunking of the token dataset. Chunking only takes place within a document, and therefore the final chunk of each document has length: 1 <= chunk_length <= max_chunk_length.
We discard chunks that would convert to an empty Bert sequence (rare case, happens ~1/100,000 chunks in our case), since we use Bert embeddings for building our index. Thus, the total number of chunks in the database will be slightly less than a naive calculation.
### Build index for similarity search
To match pretraining chunks to database chunks, a search index must be built to perform this querying. We use Faiss (https://github.com/facebookresearch/faiss) for training and building this index. Generally, the index is trained on a subset of all chunks in the database (specified via `--retro-nchunks-sampled`). After training, all chunks are added into the index, to be available during querying.
Indexes only accept 1-D floating point vectors for training and adding, so each chunk must first be embedded before passing to the index for either training or adding. We use Bert embeddings for this purpose, and the embeddings are generated automatically within the pipeline.
### Query pretraining neighbors
To ensure fast Retro pretraining, the database neighbors for pretraining samples are pre-computed and saved to disk, for efficient access within the Retro dataset. In this stage, the pretraining datasets (training, validation, and test) are iterated, each sample is broken into chunks, and the chunks are used for querying the index. Similar to when building the index, each chunk is embedded (via Bert) before querying the index.
The saved neighbors are labeled with unique dataset properties (i.e., seed, sequence length, number of samples, etc.) to ensure the neighbors generated during preprocessing match the neighbors requested during pretraining.
<!-- ################ code structure ################ -->
# Code structure
### `tools/retro/main.py`
This is the main entry point for Retro preprocessing. Call `main.py --help` to see arguments. Additionally, some Retro arguments are in Megatron's core arguments, so also see `add_retro_args()` section of `megatron/arguments.py` for additional arguments. Two of the most important arguments to customize are `--retro-workdir` and `--retro-tasks`.
- **`--retro-workdir`** : Set the directory in which the preprocessing pipeline saves its datasets and configuration files. This argument should remain consistent for a full pass through the pipeline, and for pretraining.
- **`--retro-tasks`** : Set the stages of preprocessing to perform. As mentioned previously, the three high-level stages are: 1) build retrieval database, 2) build search index, and 3) query pretraining neighbors. `--retro-tasks` can be used to either run the full pipeline, or run each of these stages in isolation. The latter case is useful for tuning compute resources for each stage. For example, index training utilizes GPUs and requires relatively less time, while querying neighbors uses the CPU and is a relatively slow process. Example tasks include:
- **`--retro-tasks build`** : Run entire preprocessing pipeline.
- **`--retro-tasks db-build`** : Build retrieval database.
- **`--retro-tasks index-build`** : Train and build search index.
- **`--retro-tasks pretraining-query-neighbors`** : Query pretraining neighbors.
Multiple tasks can be specified by separating with commas (e.g., `--retro-tasks db-build,index-build`). Additionally, various 'miscellaneous' tasks are currently including, primarily for validating data for each stage; these task names can be seen in `main.py`.
### `tools/retro/examples`
Example scripts for setting arguments and launch Retro preprocessing. The key files here are:
- **`get_preprocess_cmd.sh`** : Sets up arguments and command for preprocessing. **Important note**: this script assumes a few environment variables are already set before it is called. Please see the `Environment vars.` section at the top of this file. Generally, environment variables must be set to determine the location of Retro workdirs, input datasets, and GPT and Bert model information.
- **`preprocess_data.sh`** : Calls `get_preprocess_cmd.sh` to get arguments, and then calls `main.py` to launch preprocessing.
- **`pretrain_model.sh`** : Example script for pretraining on Wikipedia data, after preprocessing is complete.
### `tools/retro/db`
Build the retrieval chunk database. The key files here are:
- **`build.py`** : Entry point for building the database. This code is responsible for iterating the input datasets (i.e., `--data-path`), parsing each dataset into consecutive chunks, checking for empty Bert (Wordpiece) conversions, and storing this information to disk. Two databases are created: 1) the retrieval database, and 2) a sampled database used for training the search index.
- **`dataset.py`** : Defines database class, for iterating or accessing chunks in the database. Each chunk contains its tokens, Bert conversion length, and dataset index.
Input data:
<!-- - Token datasets, as generated by `tools/preprocess_data.py`. Each dataset should include a `.bin` and `.idx` file. Multiple datasets can be specified by using a blended configuration (see `--data-path` in `megatron/arguments.py`). -->
- Token datasets, as loaded by `gpt_dataset.py`. Multiple datasets can be specified by using a blended configuration (see `--data-path` in `megatron/arguments.py`).
Output data:
- **`<RETRO_WORKDIR>/db/merged/train.hdf5`** : The main retrieval database. (*Database* here is used to denote a list of indexed chunks, rather than a *relational database*.) The chunks in this database are added to the search index, and are used for retrieval during pretraining. This file contains a single dataset `'chunks'`, which contains 5 columns:
- `dataset_idx` : Dataset index, from list of blended indexed datasets.
- `document_idx` : Document index within dataset.
- `chunk_start_idx` : Chunk's starting token index within document.
- `chunk_end_idx` : Chunk's ending token index (exclusive) within document.
- `bert_chunk_length` : Length of Bert token sequence, after converting from GPT.
- **`<RETRO_WORKDIR>/db/merged/sampled.hdf5`** : Subset of training database that is used for training the search index. This file has the same structure as detailed above. In general, this database is significanly smaller than the `train.hdf5` database, since the search index only needs a relatively small number of samples to understand the data's structure. After training, all chunks in the main database (`train.hdf5`) are *added* to the search index.
### `tools/retro/index`
Build the search index. The key files here are:
- `build.py` : Entry point for building the search index. First, the index is trained on the sampled chunk database (see above) by calling `train.py`, and then all chunks for the full database are added to the index by calling `add.py`. Note that training requires first embedding (using Bert) all chunks (a parallel operation), and then loading these embeddings and training the index (a sequential operation), so it's best to change one's compute setup after all chunks have been embedded and saved to disk.
- `indexes/faiss_base.py` : Wrapper class for building a Faiss index, following the standard `train()` and `add()` operations.
- `indexes/faiss_par_add.py` : Similar to above, except it uses an embarrassingly parallel (multi-node, multi-process) `add()` operation. Vectors are first added to separate index copies, and then merged together.
Input data:
- **`<RETRO_WORKDIR>/db/merged/sampled.hdf5`** : Chunks used for training the search index.
- **`<RETRO_WORKDIR>/db/merged/train.hdf5`** : Chunks used for adding to the *trained* search index.
Output data:
- **`<RETRO_WORKDIR>/index/<RETRO_INDEX_TYPE>/<RETRO_INDEX_STR>/added.faissindex`** : The final index, which has been trained and has had all database chunks added to it. This index is ready for querying neighbors. Here, `RETRO_INDEX_TYPE` and `RETRO_INDEX_STR` correspond to the same-name arguments `--retro-index-type` (e.g., `faiss-par-add`) and `--retro-index-str` (e.g., `OPQ32_256,IVF4194304_HNSW32,PQ32`).
- **`<RETRO_WORKDIR>/index/<RETRO_INDEX_TYPE>/<RETRO_INDEX_STR>/empty.faissindex`** : Generally can be discarded once `added.faissindex` has been built, but this file contains the *post-training*, *pre-adding* index. Useful for debugging or building other indexes.
### `tools/retro/pretraining`
Query the pretraining datasets (training, validation, test) for their neighbors within the database. Neighbors are queried during preprocessing -- rather than during pretraining -- because querying is a fairly slow operation, so it would be a bottleneck if performed during pretraining. Queried neighbors are tagged with their unique identifying information (e.g., `train_indexmap_27662746ns_2048sl_1234s`), so as to avoid incorrect references during pretraining. The key files here are:
- **`query.py`** : Entry point for querying. The pretraining datasets are iterated, and each chunk within each sample is queried using the search index. These neighbors are filtered by discarding any database chunks that fall within the same document as any chunk within a pretraining sample.
- **`chunk_dataset.py`** : This creates an iterable 'chunk' dataset form of a pretraining dataset. This is just a light wrapper, but makes it easier to deterministically iterate and assign IDs to each chunk in a sample dataset.
- **`retro_dataset.py`** : The Retro dataset used for pretraining (not used in preprocessing). Each sample returns the sample tokens, along with neighbor tokens for each chunk within the sample.
Input data:
- Token datasets, as loaded by `gpt_dataset.py`.
- **`<RETRO_WORKDIR>/index/<RETRO_INDEX_TYPE>/<RETRO_INDEX_STR>/added.faissindex`** : The trained index, with all database chunks added to it (see previous section for details).
Output data:
- **`<RETRO_WORKDIR>/{train,valid,test}_XXns_YYsl_ZZs/WW.hdf5`** : These directories/files contain the indexes of neighbors for each chunk within each sample of the pretraining datasets. Each directory (e.g., `train_indexmap_2047435ns_2048sl_1234s`) contains a list of HDF5 files (e.g., one file might be called `0075700000-0075800000.hdf5`). Each HDF5 file contains a consecutive subset of neighbor IDs for a given chunk, for indexing into the main retrieval database. All HDF5 files taken together within a given directory, represent the entire set of neighbors for a dataset. The size of these HDF5 files is determined by the argument `--retro-block-size`. The `XX`, `YY`, `ZZ`, `WW` notation above denotes the dataset properties that are used for uniquely tagging the neighbor files, to ensure compatibility during model pretraining. These neighbor files are ultimated used by `retro_dataset.py` during pretraining, for building Retro samples.
### `tools/retro/cli`
Inspect preprocessed data. To use the CLI, open a Python terminal via the `python` command, and then load a Retro workdir with the following:
```
from tools.retro.cli import retro
retro.init("/path/to/retro/workdir")
```
This initializes Megatron, and prepares the Retro data for inspection. See the printed usage for available functions. Several routines are included for viewing data in the retrieval database and viewing pretraining samples and neighbors. For example:
```python
retro.get_db_num_indexed_datasets() # 15
retro.get_db_chunk_text(92874113) # 'research project at ... and philosophy'
retro.get_pt_sample('train', 62005) # '[16084, 26158, 25387 ..., 6898, 9568]'
```
Most methods within the CLI are prefixed to denote the data being inspected:
- **'db'** : Retrieval database (i.e., chunk tokens, document IDs, and dataset IDs)
- **'pt'** : Pretraining datasets (i.e., sample tokens and neighbor tokens)
### `tools/retro/utils.py`
A collection of utility methods. Most importantly, this contains:
- **`def get_gpt_tokenizer()`** : Get the GPT tokenizer.
- **`def get_bert_tokenizer()`** : Get the Bert tokenizer.
- **`class GPTToTextDataset`** : Wrapper class that converts GPT (BPE) samples to raw text.
### `tools/bert_embedding`
Generate Bert embeddings. The main files here are:
- **`embed.py`** : Entry point for generating embeddings, and contains the two main embedding classes, `BertEmbedder` and `DiskDataParallelBertEmbedder` (more below). This file contains code for generating Megatron embeddings, while the file below contains code for Huggingface embeddings.
- **`huggingface.py`** : Used by `embed.py` when the embedder is configured (see below) to output Huggingface embeddings.
- **`dataset.py`** : Wrapper class for converting a raw-text dataset to Bert (Wordpiece) tokens.
The Bert embeddings can be configured along two axes. The first axis is the output type:
- **`class BertEmbedder`** : This class takes a raw-text dataset as input, generates its embeddings, and returns a Numpy array. The main functions are `embed_text_dataset` (accepts a raw-text dataset) and `embed_text` (accepts a string).
- **`class DiskDataParallelBertEmbedder`** : This class wraps `BertEmbedder`, and rather than returning a Numpy array, it saves the embeddings to disk. Additionally, this class automatically splits data across data parallel ranks (using interleaving), and also processes data in a specified `block_size` (e.g., 1,000,000).
The second axis is the type of embedding model to use, controlled by the argument `--bert-embedder-type`:
- **`--bert-embedder-type megatron`** : Use Megatron's Bert model. The specific model used is dependent on the loaded checkpoint, vocab file, and tokenizer.
- **`--bert-embedder-type huggingface`** : Use Huggingface's `bert-large-cased`. (*Note*: Huggingface's inclusion is likely to be deprecated; and there is no ability to configure cased/uncased.)
### Pretraining
- **`pretrain_retro.py`** : Launch script for pretraining Retro. Similar to `pretrain_gpt.py`, except this script handles loading neighbor tokens and setting up the neighbor attention mask.
<!-- - `megatron/data/gpt_dataset.py` : ? -->
- **`megatron/model/retro_transformer.py`** : Implementation of Retro model, including the main transformer, the retrieval encoder, and chunked cross-attention layers. Note that currently, `retro_transformer.py` contains several classes that are nearly identical to `transformer.py`, except for 1 or 2 lines, due to code changes that are yet to be integrated.
- **`tools/retro/pretraining/retro_dataset.py`** : The Retro dataset used for pretraining (not used in preprocessing). Each sample returns the sample tokens, along with neighbor tokens for each chunk within the sample.
<!-- ################ arguments ################ -->
# Arguments
See `tools/retro/main.py`'s `add_retro_args()` and `megatron/arguments.py`'s `_add_retro_args()` for details and descriptions. Here we list some particularly important arguments:
- `--retro-workdir` : Mentioned previously, this argument determines the directory in which a set of Retro data is stored (during preprocessing) and loaded (during pretraining). Any change in this directory during preprocessing may result in preprocessing starting over from scratch, and any change before pretraining will result in pretraining throwing an error.
- Preprocessing
- `--retro-gpt-chunk-length` : Retro chunk length (e.g., 64 in original paper).
- `--retro-tasks` : Comma-separated list of preprocessing tasks. Generally, the `build` task is the simplest way to run the preprocessing pipeline. For finer control, individual stages can be run by using tasks (in order): `db-build`, `index-build`, and `pretraining-query-neighbors`.
- `--retro-index-str` : Faiss index string that defines the index configuration. This will vary based on data size, compute/disk setup, and user needs. For example, this string looks something like `IVF262144_HNSW32,Flat` or `OPQ32_256,IVF4194304_HNSW32,PQ32`.
- Pretraining
- `--retro-add-retriever` : Must be used to select Retro model.
- `--retro-num-neighbors` : Number of neighbors to retrieve from the retrieval database (defaults to 2).
- `--retro-num-retrieved-chunks` : For each neighbor, the number consecutive chunks to retrieve, including the initial neighbor (defaults to 2).
<!-- ################ pretraining ################ -->
<!-- # Pretraining -->
<!-- - New retro args in arguments.py (add_retro_args). -->
<!-- - Most important arg is `--retro-add-retriever`. -->
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import json
import numpy as np
import os
import torch
import types
from megatron.global_vars import set_global_variables, set_retro_args
from megatron.initialize import (
initialize_megatron,
_initialize_distributed,
_set_random_seed,
)
from tools.retro.db.utils import (
get_indexed_dataset_infos as get_db_indexed_dataset_infos,
get_merged_train_dataset as get_db_dataset,
)
from tools.retro.external_libs import h5py
from tools.retro.main import add_retro_args
from tools.retro.pretraining.retro_dataset import get_retro_datasets
from tools.retro.utils import get_args_path, get_bert_tokenizer, get_gpt_tokenizer
def shorten_str(s, n):
s = "\\n".join(s.splitlines())
return s if len(s) <= n else "%s ... %s" % (s[:n//2], s[-n//2:])
class retro:
args = None
##############################################
# initialize.
##############################################
@classmethod
def init_megatron(cls, workdir):
'''Custom initialization of Megatron.'''
# Load args.
args_path = get_args_path(workdir)
assert os.path.exists(args_path), "args.json not found in workdir."
with open(args_path) as f:
cls.args = types.SimpleNamespace(**json.load(f))
cls.args.retro_workdir = workdir # just in case workdir moved
cls.args.rank = 0 # override env
cls.args.world_size = 1 # override env
set_global_variables(cls.args)
set_retro_args(cls.args)
_initialize_distributed()
_set_random_seed(cls.args.seed, cls.args.data_parallel_random_init)
@classmethod
def init(cls, workdir):
'''Initialize Megatron, tokenizers, and datasets.'''
# Load args.
cls.init_megatron(workdir)
cls.tokenizers = types.SimpleNamespace(
gpt=get_gpt_tokenizer(),
bert=get_bert_tokenizer(),
)
# Load data.
cls.db_indexed_dataset_infos = get_db_indexed_dataset_infos()
pt_train_ds, pt_valid_ds, _ = get_retro_datasets()
cls.pt_datasets = types.SimpleNamespace(
train=pt_train_ds,
valid=pt_valid_ds,
)
# Print usage.
cls.print_usage()
##############################################
# utils.
##############################################
@classmethod
def gpt_to_text(cls, token_ids):
'''GPT tokens to text.'''
return cls.tokenizers.gpt.detokenize(token_ids)
@classmethod
def text_to_bert(cls, text):
'''Text to Bert tokens.'''
return cls.tokenizers.bert.tokenize(text)
##############################################
# chunk db.
##############################################
@classmethod
def get_db_num_indexed_datasets(cls):
'''Number of indexed datasets within blendable dataset.'''
return len(cls.db_indexed_dataset_infos)
@classmethod
def get_db_indexed_dataset_infos(cls):
'''Dataset infos, including number of training & sampled sets.'''
return [(info["ratio"], info["name"])
for info in cls.db_indexed_dataset_infos]
@classmethod
def get_db_dataset(cls):
return cls.pt_datasets.train.db_dataset
@classmethod
def get_db_num_chunks(cls):
'''Number of DB chunks.'''
return len(cls.get_db_dataset())
@classmethod
def get_db_chunk_gpt(cls, idx):
'''Get DB chunk as GPT token ids.'''
return cls.get_db_dataset()[idx]["text"].tolist()
@classmethod
def get_db_chunk_bert(cls, idx):
'''Get DB chunk as Bert token ids.'''
return cls.text_to_bert(cls.get_db_chunk_text(idx))
@classmethod
def get_db_chunk_text(cls, idx):
'''Get DB chunk as text.'''
return cls.gpt_to_text(cls.get_db_chunk_gpt(idx))
@classmethod
def get_db_chunk_and_continuation_text(cls, idx):
'''Get DB chunk along with continuation, as text.'''
# Modulus used here to match original implementation (i.e., last
# chunks continuation wraps around to first chunk).
return [
cls.get_db_chunk_text(idx),
cls.get_db_chunk_text((idx + 1) % len(cls.get_db_dataset())),
]
##############################################
# pretraining corpus.
##############################################
@classmethod
def get_pt_num_samples_and_chunks(cls, data_key):
'''Number of samples & chunks (e.g., 32*n_samples) in corpus.'''
assert hasattr(cls.pt_datasets, data_key), \
"pretraining set '%s' not found (choices: %s)." % (
data_key, ", ".join(vars(cls.pt_datasets).keys()))
chunk_dataset = getattr(cls.pt_datasets, data_key).chunk_dataset
return (
len(chunk_dataset.sample_dataset),
len(chunk_dataset),
)
@classmethod
def get_pt_num_samples(cls, data_key):
'''Number of pretraining samples.'''
return cls.get_pt_num_samples_and_chunks(data_key)[0]
@classmethod
def get_pt_num_chunks(cls, data_key):
'''Number of pretraining chunks (e.g., 32*n_samples).'''
return cls.get_pt_num_samples_and_chunks(data_key)[1]
@classmethod
def get_pt_sample(cls, data_key, idx):
return getattr(cls.pt_datasets, data_key)[idx]
##############################################
# usage.
##############################################
@classmethod
def print_usage(cls):
'''Print usage.'''
print()
print("+++++++++++++++++++++++++++++++++++++++++++++++++++")
print("examples ... [ *note*: 'db' = chunk db; 'pt' = pretraining corpus. ]")
print("+++++++++++++++++++++++++++++++++++++++++++++++++++")
print()
print("~~~~ indexed datasets ~~~~")
print("retro.get_db_num_indexed_datasets() : %s" %
cls.get_db_num_indexed_datasets())
print("retro.get_db_indexed_dataset_infos() :")
for i, (ratio,prefix) in enumerate(cls.get_db_indexed_dataset_infos()):
print(" %s(%f, %s)%s" % (
"[" if i == 0 else " ",
ratio,
prefix,
"]" if i == len(cls.db_indexed_dataset_infos) - 1 else ",",
))
print()
print("~~~~ counts ~~~~")
print("retro.get_db_num_chunks : %d." % cls.get_db_num_chunks())
print()
for sq_key in ("sample", "chunk"):
for data_key in ("train", "valid"): # test?
print("retro.get_pt_num_%ss('%s') : %d." % (
sq_key, data_key,
getattr(cls, f"get_pt_num_{sq_key}s")(data_key)))
print()
print("~~~~ tokens, text ~~~~")
print("retro.get_db_chunk_gpt(chunk_id) : %s" %
shorten_str(str(retro.get_db_chunk_gpt(0)), 50))
print("retro.get_db_chunk_bert(chunk_id) : %s" %
shorten_str(str(retro.get_db_chunk_bert(0)), 50))
print("retro.get_db_chunk_text(chunk_id) : %s" %
shorten_str(retro.get_db_chunk_text(0).strip(), 50))
print("retro.get_db_chunk_and_continuation_text(chunk_id) :")
for i, t in enumerate(retro.get_db_chunk_and_continuation_text(0)):
print(" %s'%s'%s" % (
"[" if i == 0 else " ",
shorten_str(t.strip().replace("\n", " "), 50),
"]" if i == 1 else ",",
))
sample = cls.get_pt_sample("train", 0)
print()
print("retro.get_pt_sample('train', sample_id) :")
print(" {")
for k, v in sample.items():
print(" '%s' : %s" % (k, shorten_str(str(v), 50)))
print(" }")
print()
print("(e.g., sample = retro.get_pt_sample(...))")
print()
print(" sample['text'].shape : %s" % str(sample["text"].shape))
print(" sample['neighbor_tokens'].shape : %s" % str(sample["neighbor_tokens"].shape))
print(" sample['text'] : %s" % shorten_str(str(sample["text"]), 50))
print(" sample['neighbor_tokens'][17][1] : %s" % shorten_str(str(sample["neighbor_tokens"][17][1]), 50))
print(" retro.gpt_to_text(sample['text']) : %s" % shorten_str(cls.gpt_to_text(sample["text"]), 50))
print(" retro.gpt_to_text(sample['neighbor_tokens']) : %s" % shorten_str(cls.gpt_to_text(sample["neighbor_tokens"][17][1]), 50))
print("+++++++++++++++++++++++++++++++++++++++++++++++++++")
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from . import retro
if __name__ == "__main__":
retro.init(os.environ["RETRO_WORKDIR"])
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from .build import build_db
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from collections import defaultdict
from concurrent.futures import as_completed, ProcessPoolExecutor
from functools import reduce
import glob
import json
import numpy as np
import os
from pathlib import Path
import threading
import torch
from tqdm import tqdm
import types
from megatron import get_retro_args, print_rank_0
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
from megatron.tokenizer.tokenizer import (
_BertWordPieceTokenizer,
_GPT2BPETokenizer,
)
from tools.bert_embedding.utils import get_missing_blocks_by_rank
from tools.retro.external_libs import h5py
from tools.retro.utils import get_gpt_tokenizer, get_bert_tokenizer
from .utils import (
get_individual_db,
get_individual_db_dir,
get_merged_dataset,
get_merged_db_path_map,
get_train_doc_chunk_map_dir,
save_indexed_dataset_infos,
)
def init_indexed_dataset_infos():
'''Gather meta-info about each indexed dataset.
The returned info array allows for easy access to the configuration, and
helps remove ambiguity.
'''
args = get_retro_args()
assert len(args.data_path) % 2 == 0, \
"currently, only blendable dataset is supported."
# Dataset infos.
infos = []
for i in range(0, len(args.data_path), 2):
ratio = float(args.data_path[i])
prefix = args.data_path[i + 1]
path = prefix + ".bin"
name = os.path.basename(prefix)
assert os.path.exists(path)
infos.append({
"ratio" : ratio,
"prefix" : prefix,
"path" : path,
"name" : name,
"db_dir" : get_individual_db_dir(name),
"dataset" : make_indexed_dataset(prefix, "mmap", True),
})
return infos
def build_partial_db(
dataset_idx,
n_datasets,
indexed_dataset,
block_id,
n_blocks,
block,
proc_id,
n_procs,
tokenizers,
):
'''Process a document index range of the indexed dataset.
The chunk database is built in parallel blocks, since de-tokenizing &
re-tokenizing for Bert-length computation is expensive. This method
iterates each document and extracts sequential 'chunk-length' sequences
from each document.
'''
args = get_retro_args()
# Document start/end indexes.
doc_range = block["range"]
n_docs = doc_range[1] - doc_range[0]
n_docs_per_proc = int(np.ceil(n_docs / n_procs))
doc_start_id = doc_range[0] + proc_id * n_docs_per_proc
doc_end_id = min(doc_range[1], doc_start_id + n_docs_per_proc)
# Print progress.
progress_proc_ids = set(range(n_procs)) \
if torch.distributed.get_rank() == 0 else set()
if proc_id in progress_proc_ids:
print(" > building partial chunk db, proc %d / %d, docs %d:%d / %d."%(
proc_id,
n_procs,
doc_start_id,
doc_end_id,
n_docs,
))
# Progress bars (snapshot of overall progress).
doc_id_iter = range(doc_start_id, doc_end_id)
pbar = tqdm(doc_id_iter) \
if proc_id in progress_proc_ids else \
doc_id_iter
# Iterate documents & parse chunks.
chunk_db_valid = []
chunk_db_invalid = []
for doc_id in pbar:
# Progress description.
try:
pbar.set_description("ds %d / %d, block %d / %d, proc %d / %d." % (
dataset_idx,
n_datasets,
block_id,
n_blocks,
proc_id,
n_procs))
except:
pass
# Remove EOD token.
doc = indexed_dataset.get(doc_id)
if doc[-1].item() == tokenizers.gpt.eod_id:
doc = doc[:-1]
doc_len = len(doc)
# Chunk start/end indexes.
chunk_start_idxs = list(range(0, doc_len, args.retro_gpt_chunk_length))
chunk_end_idxs = [min(doc_len, s + args.retro_gpt_chunk_length)
for s in chunk_start_idxs]
# Re-tokenize each chunk to Bert/Wordpiece (empty bert -> 'invalid').
for i, chunk_start_idx in enumerate(chunk_start_idxs):
# Re-tokenize.
chunk_end_idx = chunk_end_idxs[i]
gpt_token_ids = indexed_dataset.get(
idx=doc_id,
offset=chunk_start_idx,
length=chunk_end_idx - chunk_start_idx,
)
text = tokenizers.gpt.detokenize(gpt_token_ids)
bert_token_ids = tokenizers.bert.tokenize(text)
# 'Valid' for non-empty Bert chunks; 'invalid' otherwise.
_chunk_db = chunk_db_invalid \
if len(bert_token_ids) == 0 else \
chunk_db_valid
_chunk_db.append((
doc_id,
chunk_start_idx,
chunk_end_idx,
len(bert_token_ids),
))
return proc_id, chunk_db_valid, chunk_db_invalid
def build_individual_db(dataset_idx, n_datasets, dataset_info, tokenizers):
'''Process a single indexed dataset & extract chunks.'''
args = get_retro_args()
# Make directory.
db_dir = dataset_info["db_dir"]
os.makedirs(db_dir, exist_ok=True)
# Indexed dataset.
indexed_dataset = dataset_info["dataset"]
# Missing db blocks.
n_missing_world, missing_db_blocks = get_missing_blocks_by_rank(
db_dir,
len(indexed_dataset.doc_idx) - 1,
args.retro_doc_block_size,
validate=lambda f : f["chunks_valid"].shape[1] == 4)
# Prevent missing-path-write race condition.
torch.distributed.barrier()
if not missing_db_blocks:
return
# Num processes.
if n_missing_world == 1:
n_procs = 128
elif n_missing_world <= 2:
n_procs = 64
elif n_missing_world <= 4:
n_procs = 32
elif n_missing_world <= 8:
n_procs = 16
else:
n_procs = 8
# Process documents in parallel.
with ProcessPoolExecutor(max_workers=n_procs) as executor:
for block_idx, block in enumerate(missing_db_blocks):
if block is not None:
# Build partial dbs.
print_rank_0(' > build partial dbs.')
futures = []
for proc_id in range(n_procs): # not true process id
futures.append(executor.submit(
build_partial_db,
dataset_idx,
n_datasets,
indexed_dataset,
block_idx,
len(missing_db_blocks),
block,
proc_id,
n_procs,
tokenizers,
))
partial_chunk_dbs = []
for future in as_completed(futures):
partial_chunk_dbs.append(future.result())
# Concatenate chunks.
partial_chunk_dbs.sort(key=lambda item:item[0]) # sort by proc_id
chunk_db_valid = [item
for partial_chunk_db in partial_chunk_dbs
for item in partial_chunk_db[1]]
chunk_db_invalid = [item
for partial_chunk_db in partial_chunk_dbs
for item in partial_chunk_db[2]]
# Convert to numpy.
print_rank_0(' > converting chunk db to numpy.')
chunk_db_valid = np.array(chunk_db_valid)
chunk_db_invalid = np.array(chunk_db_invalid)
# Save DB.
print_rank_0(" > saving individual db.")
f = h5py.File(block["path"], "w")
dset = f.create_dataset("chunks_valid", data=chunk_db_valid)
dset = f.create_dataset("chunks_invalid", data=chunk_db_invalid)
f.close()
# Wait for all ranks to finish block.
print_rank_0(" > waiting for all ranks to finish block.")
torch.distributed.barrier()
print_rank_0(" > finished saving individual db.")
def build_individual_dbs(indexed_dataset_infos):
'''Iterate each indexed dataset & process its chunks.'''
args = get_retro_args()
# Tokenizers.
tokenizers = types.SimpleNamespace(
gpt=get_gpt_tokenizer(),
bert=get_bert_tokenizer(),
)
# Build individual DBs.
print_rank_0(" > build individual chunk dbs.")
for ds_idx, ds_info in enumerate(indexed_dataset_infos):
# Progress.
print_rank_0(" > building individual db, dataset %d / %d ... '%s'." % (
ds_idx,
len(indexed_dataset_infos),
ds_info["name"],
))
# Process single dataset.
build_individual_db(ds_idx, len(indexed_dataset_infos),
ds_info, tokenizers)
def update_chunk_counts(indexed_dataset_infos):
'''Set n_chunks_train & n_chunks sampled for each individual DB.'''
args = get_retro_args()
if torch.distributed.get_rank() != 0:
return
# Training split size (split at document level).
train_fraction = float(args.split.split(",")[0]) / 100
assert train_fraction > 0 and train_fraction <= 1
# Set n_chunks (including n_chunks_sampled for unambiguity).
print_rank_0(" > compute n_chunks.")
for ds_index, ds_info in \
enumerate(tqdm(indexed_dataset_infos, "count_chunks")):
db_dir = ds_info["db_dir"]
db_paths = sorted(glob.glob(db_dir + "/*.hdf5"))
# Update counts.
ds_info["n_docs"] = len(ds_info["dataset"].doc_idx) - 1
ds_info["n_docs_train"] = int(train_fraction * ds_info["n_docs"])
ds_info["n_chunks"] = 0 # previously, 'n_chunks_valid'
ds_info["n_chunks_train"] = 0
ds_info["n_chunks_invalid"] = 0
for db_path in db_paths:
with h5py.File(db_path, "r") as f:
ds_info["n_chunks"] += len(f["chunks_valid"])
ds_info["n_chunks_invalid"] += len(f["chunks_invalid"])
ds_info["n_chunks_train"] += \
(np.copy(f["chunks_valid"][:, 0]) < ds_info["n_docs_train"]) \
.sum().item()
ds_info["n_chunks_sampled"] = \
int(round(args.retro_nchunks_sampled * ds_info["ratio"]))
# Verify counts.
assert ds_info["n_chunks_train"] <= ds_info["n_chunks"], \
"n_train (%d) > n_total (%d)." % (
ds_info["n_chunks_train"], ds_info["n_chunks"])
assert ds_info["n_chunks_sampled"] <= ds_info["n_chunks_train"], \
"n_sampled (%d) > n_train (%d)." % (
ds_info["n_chunks_sampled"], ds_info["n_chunks_train"])
def merge_dbs(indexed_dataset_infos, db_type):
'''Merge individual DBs into single DB.'''
if torch.distributed.get_rank() != 0:
return
print(" > build %s chunk db." % db_type)
# Count chunks.
if db_type == "full":
raise Exception("deprecated; use 'train' or 'sampled'.")
n_chunks_key = "n_chunks"
elif db_type == "sampled":
n_chunks_key = "n_chunks_sampled"
elif db_type == "train":
n_chunks_key = "n_chunks_train"
elif db_type == "valid":
pass
else:
raise Exception("handle db_type '%s'." % db_type)
if db_type == "valid":
n_chunks = sum(m["n_chunks"] - m["n_chunks_train"]
for m in indexed_dataset_infos)
else:
n_chunks = sum(m[n_chunks_key] for m in indexed_dataset_infos)
# DB path.
db_path = get_merged_db_path_map()[db_type]
# Delete existing chunk db if incorrect size.
if os.path.exists(db_path):
try:
f = h5py.File(db_path)
n_alloc = len(f["chunks"]) # total allocated
n_written = f["n_written"][0].item() # total written
f.close()
if n_chunks != n_alloc or n_chunks != n_written:
os.remove(db_path)
except Exception as e:
if isinstance(e, OSError):
os.remove(full_db_path)
elif isinstance(e, KeyError):
f.close()
os.remove(full_db_path)
else:
raise e
# Build merged chunk db.
if not os.path.exists(db_path):
os.makedirs(os.path.dirname(db_path), exist_ok=True)
f = h5py.File(db_path, "w")
# Initialize output arrays.
merged_db = f.create_dataset("chunks", (n_chunks, 5), dtype="i8")
n_written = f.create_dataset("n_written", (1,), dtype="uint64")
n_written[0] = 0
# Iterate indexed datasets & collect chunks.
start_index = 0
for ds_idx, ds_info in enumerate(indexed_dataset_infos):
print(" > merging dbs; '%s', dataset %d / %d ... '%s'." %
(db_type, ds_idx, len(indexed_dataset_infos), ds_info["name"]))
individual_db = get_individual_db(ds_idx, ds_info)
if db_type == "valid":
individual_db = individual_db[ds_info["n_chunks_train"]:]
else:
individual_db = individual_db[:ds_info[n_chunks_key]]
merged_db[start_index:start_index+len(individual_db)] = individual_db
start_index += len(individual_db)
n_written[0] = start_index
f.close()
def get_partial_banned_chunk_map(proc_id, db_path, chunk_range_info):
'''Build partial mapping of {(dataset_id,doc_id):[chunk_ids]}.
In this method, only chunks within the range (start_chunk_id, end_chunk_id]
are processed.'''
start_chunk_id = chunk_range_info["start"]
end_chunk_id = chunk_range_info["end"]
output_path = chunk_range_info["path"]
# Skip, if output file exists.
if os.path.exists(output_path):
return
# Chunk subset.
with h5py.File(db_path) as f:
sub_chunk_db = np.copy(f["chunks"][start_chunk_id:end_chunk_id, :2])
# Map docs to chunks.
banned_chunk_map = defaultdict(list)
for rel_chunk_id, (dataset_id, doc_id) in enumerate(tqdm(
sub_chunk_db,
"map banned docs, proc %d" % proc_id,
total=sub_chunk_db.shape[0],
)):
chunk_id = start_chunk_id + rel_chunk_id
banned_chunk_map["%d,%d" % (dataset_id.item(), doc_id.item())] \
.append(chunk_id)
# Save output.
with open(output_path, "w") as f:
json.dump(banned_chunk_map, f)
def build_doc_chunk_map(indexed_dataset_infos, db_type):
'''Build mapping of {(dataset_id,doc_id):[chunk_ids]}.'''
if torch.distributed.get_rank() != 0:
return
print(" > build %s doc-chunk map." % db_type)
n_procs = 128
# Get dataset.
db_dataset = get_merged_dataset(db_type, indexed_dataset_infos)
# Sub-ranges for parallel processing.
n_chunks = db_dataset.chunks.shape[0]
n_chunks_per_proc = max(1, int(np.ceil(n_chunks / n_procs)))
chunk_id_starts = list(range(0, n_chunks, n_chunks_per_proc))
chunk_id_ranges = [(s, min(n_chunks, s + n_chunks_per_proc))
for s in chunk_id_starts]
# Wrap range info with output path.
n_digits = int(np.ceil(np.log(n_chunks) / np.log(10)) + 1)
output_dirname = get_train_doc_chunk_map_dir()
chunk_range_infos = [{
"start" : start_id,
"end" : end_id,
"path" : os.path.join(output_dirname, "%s-%s.json" % (
str(start_id).zfill(n_digits),
str(end_id).zfill(n_digits),
)),
} for start_id, end_id in chunk_id_ranges ]
# Build doc-chunk map.
print_rank_0("build doc-chunk-map.")
with ProcessPoolExecutor(max_workers=n_procs) as executor:
# Build partial chunk maps.
futures = []
for proc_id, chunk_range_info in enumerate(chunk_range_infos):
if os.path.exists(chunk_range_info["path"]):
continue
# Submit job.
futures.append(executor.submit(
get_partial_banned_chunk_map,
proc_id,
db_dataset.db_path,
chunk_range_info,
))
# Wait for processes to finish.
banned_chunk_paths = []
for finished_idx, future in enumerate(as_completed(futures)):
print("finished %d / %d." % (finished_idx, n_procs))
future.result()
def build_db():
'''Extract token chunks from each indexed dataset.
Iterate each document of each indexed dataset, extract that document's
chunks, and save to a 'DB' (hdf5 file).
'''
# Indexed dataset info.
indexed_dataset_infos = init_indexed_dataset_infos()
# Build dbs.
build_individual_dbs(indexed_dataset_infos)
# Single-process going forward.
if torch.distributed.get_rank() != 0:
return
# Update n_chunks.
update_chunk_counts(indexed_dataset_infos)
# Merge dbs.
merge_dbs(indexed_dataset_infos, "sampled")
merge_dbs(indexed_dataset_infos, "train")
merge_dbs(indexed_dataset_infos, "valid")
build_doc_chunk_map(indexed_dataset_infos, "train")
# Save (fully annotated) indexed dataset infos.
save_indexed_dataset_infos(indexed_dataset_infos)
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import json
import numpy as np
import torch
from megatron import get_args, print_rank_0
from tools.retro.external_libs import h5py
from tools.retro.utils import get_gpt_tokenizer
class DBDataset(torch.utils.data.Dataset):
'''Dataset for iterating chunks.
Requires:
- List of indexed datasets
- Chunk index array, with format:
[dataset_idx, doc_id, start_idx, end_idx, bert_length])
'''
def __init__(self, db_path, indexed_datasets, chunks, max_chunk_length):
assert chunks.shape[1] == 5, "expected 5 columns (dataset_idx, " \
"doc_idx, token_start_idx, token_end_idx, bert_chunk_length); " \
"found %d columns." % chunks.shape[1]
self.db_path = db_path
self.indexed_datasets = indexed_datasets
self.chunks = chunks
self.max_chunk_length = max_chunk_length
self.eod_token_id = get_gpt_tokenizer().eod_id
def __len__(self):
return self.chunks.shape[0]
def __getitem__(self, chunk_id):
# Chunk start/end indexes.
indexed_dataset_id, doc_id, token_start_idx, token_end_idx, _ = \
[ value.item() for value in self.chunks[chunk_id] ]
chunk_length = token_end_idx - token_start_idx
indexed_dataset = self.indexed_datasets[indexed_dataset_id]
# Chunk token ids.
token_ids = indexed_dataset.get(doc_id,
offset=token_start_idx,
length=chunk_length)
# Extend chunks to max_chunk_length by padding with EOD tokens.
if chunk_length != self.max_chunk_length:
assert chunk_length < self.max_chunk_length, "invalid chunk len."
token_ids = token_ids.tolist()
token_ids += [self.eod_token_id] * \
(self.max_chunk_length - chunk_length)
return {
"doc_id" : doc_id,
"text" : np.array(token_ids, dtype=np.int64),
}
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from collections import defaultdict
import glob
import json
import numpy as np
import os
from tqdm import tqdm
from megatron import get_retro_args, print_rank_0
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
from tools.retro.external_libs import h5py
from .dataset import DBDataset
def get_base_db_workdir():
'''Sub-directory for DB data.'''
args = get_retro_args()
return os.path.join(args.retro_workdir, "db")
def get_indexed_dataset_infos_path():
'''Path to indexed dataset meta-infos.'''
return os.path.join(get_base_db_workdir(), "indexed_dataset_infos.json")
def save_indexed_dataset_infos(indexed_dataset_infos):
'''Save dataset order & meta-info.'''
# Remove 'dataset' field.
clean_infos = []
for info in indexed_dataset_infos:
info = dict(info)
del info["dataset"]
clean_infos.append(info)
# Save.
with open(get_indexed_dataset_infos_path(), "w") as f:
json.dump(clean_infos, f, indent=4)
def get_indexed_dataset_infos():
'''Load indexed dataset meta-infos.'''
# Load json.
path = get_indexed_dataset_infos_path()
with open(path) as f:
infos = json.load(f)
# Add indexed datasets.
for info in infos:
info["dataset"] = make_indexed_dataset(info["prefix"], "mmap", True)
return infos
def get_individual_db_dir(name):
'''Individual DB's directory.'''
return os.path.join(get_base_db_workdir(), "individual", name, "db")
def get_individual_db(ds_id, ds_info):
'''Load individual dataset's chunk DB.'''
db_paths = sorted(glob.glob(ds_info["db_dir"] + "/*hdf5"))
# *Note*: convert to dataset, rather than copying to memory.
db = np.zeros((ds_info["n_chunks"], 5), dtype="i8")
db[:, 0] = ds_id
start_idx = 0
for db_path in db_paths:
f = h5py.File(db_path, "r")
n_chunks_current = f["chunks_valid"].shape[0]
db[start_idx:(start_idx+n_chunks_current), 1:] = f["chunks_valid"]
start_idx += n_chunks_current
f.close()
assert start_idx == ds_info["n_chunks"]
return db
def get_merged_db_path_map():
'''Paths to merged datasets.'''
base_dir = get_base_db_workdir()
return {
"sampled" : os.path.join(base_dir, "merged", "sampled.hdf5"),
"train" : os.path.join(base_dir, "merged", "train.hdf5"),
"valid" : os.path.join(base_dir, "merged", "valid.hdf5"),
}
def get_merged_dataset(db_type, indexed_dataset_infos=None):
'''Get merged dataset.'''
args = get_retro_args()
if not indexed_dataset_infos:
indexed_dataset_infos = get_indexed_dataset_infos()
# Load chunks.
db_path = get_merged_db_path_map()[db_type]
f = h5py.File(db_path, "r")
chunks = f["chunks"]
# DB dataset.
indexed_datasets = [ info["dataset"] for info in indexed_dataset_infos ]
dataset = DBDataset(db_path, indexed_datasets, chunks,
args.retro_gpt_chunk_length)
return dataset
def get_merged_sampled_dataset(indexed_dataset_infos=None):
return get_merged_dataset("sampled", indexed_dataset_infos)
def get_merged_train_dataset(indexed_dataset_infos=None):
return get_merged_dataset("train", indexed_dataset_infos)
def get_merged_valid_dataset(indexed_dataset_infos=None):
return get_merged_dataset("valid", indexed_dataset_infos)
def get_train_doc_chunk_map_dir():
dirname = os.path.join(get_base_db_workdir(), "merged", "train_doc_chunk_map")
os.makedirs(dirname, exist_ok=True)
return dirname
def get_train_doc_chunk_map():
paths = sorted(glob.glob(get_train_doc_chunk_map_dir() + "/*.json"))
doc_map = defaultdict(set)
for path in tqdm(paths, "load train doc maps"):
# Read file.
with open(path) as f:
crnt_doc_map = json.load(f)
# Add to doc map.
for key, chunk_ids in crnt_doc_map.items():
key = tuple(int(i) for i in key.split(","))
doc_map[key].update(chunk_ids)
return doc_map
#!/bin/bash
# Small English Wikipedia dataset (~2M chunks).
get_wiki_tiny_config() {
RETRO_INDEX_STR="IVF4096_HNSW4,Flat"
RETRO_GPT_TRAIN_SAMPLES=31250
LR_DECAY_SAMPLES=2
LR_WARMUP_SAMPLES=1
RETRO_GPT_EVAL_INTERVAL=2000
RETRO_GPT_EVAL_ITERS=100
RETRO_EF_SEARCH=4
RETRO_NPROBE=64
DATALOADER_TYPE=cyclic
}
# English Wikipedia dataset (~67M chunks).
get_wiki_config() {
RETRO_INDEX_STR="IVF262144_HNSW32,Flat"
RETRO_GPT_TRAIN_SAMPLES=2037248
LR_DECAY_SAMPLES=2
LR_WARMUP_SAMPLES=1
RETRO_GPT_EVAL_INTERVAL=2000
RETRO_GPT_EVAL_ITERS=100
RETRO_EF_SEARCH=16
RETRO_NPROBE=4096
DATALOADER_TYPE=cyclic
}
# Full corpus (~5B chunks).
get_corpus_config() {
RETRO_INDEX_STR="OPQ32_256,IVF4194304_HNSW32,PQ32"
RETRO_GPT_TRAIN_SAMPLES=192000000
LR_DECAY_SAMPLES=166400000
LR_WARMUP_SAMPLES=162761
RETRO_GPT_EVAL_INTERVAL=2000
RETRO_GPT_EVAL_ITERS=50
RETRO_EF_SEARCH=32
RETRO_NPROBE=4096
DATALOADER_TYPE=single
}
#!/bin/bash
# Build preprocessing command for Retro.
set -u
DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
################ Required environment variables. ################
# Required environment variables:
# - REPO_DIR : Root directory of Megatron codebase.
# - RETRO_WORKDIR : Root directory of this Retro project's processed data. (For
# example, this project directory might be for a blended dataset, while
# another project directory might be for just a Wikipedia dataset, and
# another for just Book Corpus data, etc.) This project directory will
# contain a complete set of processed data, including the retrieval
# database, search index, and pretraining neighbors.
# - RETRO_TASKS : One of 'build', 'db-build', 'index-build', or
# 'pretraining-query-neighbors'. See 'Retro tasks' below for task
# descriptions.
# - DATA_BLEND_SCRIPT : Path to blended dataset definition file.
# - GPT_VOCAB_FILE : GPT vocab file.
# - GPT_MERGE_FILE : GPT merge file.
# - GPT_TOKENIZER : GPT tokenizer type (e.g., GPT2BPETokenizer)
# - BERT_LOAD_PATH : Bert checkpoint directory.
# - BERT_VOCAB_FILE : Bert vocab file.
# - BERT_TOKENIZER : Bert tokenizer type (e.g., BertWordPieceLowerCase,
# BertWordPieceCase).
# - BERT_EMBEDDER_TYPE : One of 'megatron' or 'huggingface'.
# - EXTRA_ARGS : Extra arguments (else, leave empty).
################ Data blend. ################
. ${DATA_BLEND_SCRIPT}
DATA_PATH=${DATA_BLEND}
################ Retro setup. ################
RETRO_GPT_SEQ_LENGTH=2048
RETRO_GPT_CHUNK_LENGTH=64
RETRO_GPT_MICRO_BATCH_SIZE=1 # *8
RETRO_GPT_GLOBAL_BATCH_SIZE=256
RETRO_NCHUNKS_SAMPLED=300000000
################ Retro tasks. ################
# The '--retro-tasks' argument is a comma-separated list of tasks to run, in
# sequential order. For a quick start, simply set this to 'build' to run the
# entire preprocessing pipeline. For finer control, you may specify the list of
# tasks to run. This is desirable for tuning computational resources. For
# example, training the search index is relatively fast and utilizes GPUs,
# while querying the search index is relatively slow, CPU-only, and memory
# intensive (i.e., multiple populated search indexes are loaded simultaneously).
# *Note* : Once the task(s) below have been completed -- by running either
# 1) 'build', or 2) the sequential combination of 'db-build', 'index-build',
# and 'pretraining-query-neighbors' -- we are ready to pretrain Retro by
# calling pretrain_retro.py.
# ---- Option #1 : Run entire pipeline. ----
# RETRO_TASKS="build" # (*note*: default tasks)
# ---- Option #2 : Run specific stages. ----
# *Note*: Run the following stages in the given order. Optionally, tune your
# cluster setup for each stage, as described above.
# RETRO_TASKS="db-build" # ....................... run 1st
# RETRO_TASKS="index-build" # .................... run 2nd
# RETRO_TASKS="pretraining-query-neighbors" # .... run 3rd
################ Megatron args. ################
MEGATRON_ARGS=" \
--seed 1234 \
--distributed-timeout-minutes 600 \
--tokenizer-type ${BERT_TOKENIZER} \
--tensor-model-parallel-size 1 \
--pipeline-model-parallel-size 1 \
--num-layers 24 \
--hidden-size 1024 \
--num-attention-heads 16 \
--micro-batch-size ${RETRO_GPT_MICRO_BATCH_SIZE} \
--global-batch-size ${RETRO_GPT_GLOBAL_BATCH_SIZE} \
--seq-length 512 \
--max-position-embeddings 512 \
--train-samples ${RETRO_GPT_TRAIN_SAMPLES} \
--load ${BERT_LOAD_PATH} \
--exit-on-missing-checkpoint \
--no-load-optim \
--data-path ${DATA_PATH} \
--vocab-file ${BERT_VOCAB_FILE} \
--data-impl mmap \
--split 98,2,0 \
--distributed-backend nccl \
--lr 0.0001 \
--lr-decay-style linear \
--min-lr 1.0e-5 \
--lr-decay-samples ${LR_DECAY_SAMPLES} \
--lr-warmup-samples ${LR_WARMUP_SAMPLES} \
--weight-decay 1e-2 \
--clip-grad 1.0 \
--eval-interval ${RETRO_GPT_EVAL_INTERVAL} \
--eval-iters ${RETRO_GPT_EVAL_ITERS} \
--fp16 \
--DDP-impl local \
--dataloader-type ${DATALOADER_TYPE} \
--no-data-sharding \
--no-gradient-accumulation-fusion \
--no-async-tensor-model-parallel-allreduce \
"
################ Retro args. ################
RETRO_ARGS=" \
--bert-embedder-type ${BERT_EMBEDDER_TYPE} \
--output-bert-embeddings \
\
--retro-gpt-vocab-file ${GPT_VOCAB_FILE} \
--retro-gpt-merge-file ${GPT_MERGE_FILE} \
--retro-gpt-tokenizer-type ${GPT_TOKENIZER} \
--retro-gpt-seq-length ${RETRO_GPT_SEQ_LENGTH} \
--retro-gpt-chunk-length ${RETRO_GPT_CHUNK_LENGTH} \
--retro-bert-vocab-file ${BERT_VOCAB_FILE} \
--retro-bert-tokenizer-type ${BERT_TOKENIZER} \
\
--retro-tasks ${RETRO_TASKS} \
--retro-index-str ${RETRO_INDEX_STR} \
--retro-ef-search ${RETRO_EF_SEARCH} \
--retro-nprobe ${RETRO_NPROBE} \
\
--retro-workdir ${RETRO_WORKDIR} \
--retro-nchunks-sampled ${RETRO_NCHUNKS_SAMPLED} \
\
--retro-return-doc-ids \
"
################ Command. ################
RETRO_PREPROCESS_CMD=" \
./tools/retro/main.py \
${MEGATRON_ARGS} \
${RETRO_ARGS} \
${EXTRA_ARGS} \
"
#!/bin/bash
set -u
unset NCCL_DEBUG
NPROCS=8 # NPROCS must be <= number of GPUs.
set_current_dir() {
DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
}
################ Dataset configs. ################
# This script contains methods to customize arguments to specific dataset
# types. Customize this script as needed for your datasets.
set_current_dir
. $DIR/get_dataset_configs.sh
################ Environment variables. ################
# *Note*: See 'Required environment variables' in 'get_preprocess_cmd.sh' for
# a description of the required environment variables. These variables can be
# set however a user would like. In our setup, we use another bash script
# (location defined by $RETRO_ENV_VARS) that sets all the environment variables
# at once.
. $RETRO_ENV_VARS
######## Environment vars. ########
set_current_dir
. ${DIR}/get_preprocess_cmd.sh
echo "~~~~~~~~~~~~~~~~~~~~~~~~~~"
echo "DIR = '$DIR'."
echo "RETRO_PREPROCESS_CMD = '$RETRO_PREPROCESS_CMD'."
echo "~~~~~~~~~~~~~~~~~~~~~~~~~~"
######## Command. ########
FULL_CMD="\
pwd && cd ${REPO_DIR} && pwd && \
export PYTHONPATH=$PYTHONPATH:${REPO_DIR} && \
python -m torch.distributed.launch \
--nproc_per_node ${NPROCS} \
--nnodes 1 \
--node_rank ${NODE_RANK} \
--master_addr ${MASTER_ADDR} \
--master_port 6000 \
$RETRO_PREPROCESS_CMD \
"
echo "~~~~~~~~~~~~~~~~~~~~~~~~~~"
echo "FULL_CMD = '$FULL_CMD'."
echo "~~~~~~~~~~~~~~~~~~~~~~~~~~"
eval $FULL_CMD
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment