Commit 3aca1415 authored by liangjing's avatar liangjing
Browse files

Merge branch 'megatron-lm_dtk24.04' into 'main'

Megatron lm dtk24.04

See merge request !1
parents 0024a5c6 1005e9d3
Pipeline #1806 passed with stage
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import os
from . import retro
......
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import json
import numpy as np
import os
import torch
import types
from megatron.global_vars import set_global_variables, set_retro_args
from megatron.initialize import (
initialize_megatron,
_initialize_distributed,
_set_random_seed,
_compile_dependencies,
)
from tools.retro.db.utils import (
get_indexed_dataset_infos as get_db_indexed_dataset_infos,
get_merged_train_dataset as get_db_dataset,
)
from tools.retro.main import add_retro_args
from tools.retro.query.retro_dataset import get_retro_datasets
from tools.retro.utils import get_args_path, get_bert_tokenizer, get_gpt_tokenizer
def shorten_str(s, n):
s = "\\n".join(s.splitlines())
return s if len(s) <= n else "%s ... %s" % (s[:n//2], s[-n//2:])
class retro:
args = None
##############################################
# initialize.
##############################################
@classmethod
def parse_dtype_str(cls, dtype_str):
return {
"torch.float16" : torch.float16,
"torch.float32" : torch.float32,
"torch.bfloat16" : torch.bfloat16,
}[dtype_str]
@classmethod
def init_megatron(cls, workdir):
'''Custom initialization of Megatron.'''
# Load args.
args_path = get_args_path(workdir)
assert os.path.exists(args_path), "args.json not found in workdir."
with open(args_path) as f:
cls.args = types.SimpleNamespace(**json.load(f))
cls.args.retro_workdir = workdir # just in case workdir moved
cls.args.rank = 0 # override env
cls.args.world_size = 1 # override env
cls.args.params_dtype = cls.parse_dtype_str(cls.args.params_dtype)
set_global_variables(cls.args)
set_retro_args(cls.args)
_initialize_distributed()
_set_random_seed(cls.args.seed, cls.args.data_parallel_random_init)
_compile_dependencies()
@classmethod
def init(cls, workdir):
'''Initialize Megatron, tokenizers, and datasets.'''
# Load args.
cls.init_megatron(workdir)
cls.tokenizers = types.SimpleNamespace(
gpt=get_gpt_tokenizer(),
bert=get_bert_tokenizer(),
)
# Load data.
cls.db_indexed_dataset_infos = get_db_indexed_dataset_infos()
cls.db_dataset = get_db_dataset()
pt_train_ds, pt_valid_ds, _ = get_retro_datasets(verify_sizes=False)
cls.pt_datasets = types.SimpleNamespace(
train=pt_train_ds,
valid=pt_valid_ds,
)
# Retrieve max saved neighbors.
for key in vars(cls.pt_datasets):
getattr(cls.pt_datasets, key).num_neighbors = \
cls.args.retro_query_num_neighbors_save
# Print usage.
cls.print_usage()
##############################################
# utils.
##############################################
@classmethod
def gpt_to_text(cls, token_ids):
'''GPT tokens to text.'''
return cls.tokenizers.gpt.detokenize(token_ids.tolist()
if isinstance(token_ids, np.ndarray)
else token_ids)
@classmethod
def text_to_bert(cls, text):
'''Text to Bert tokens.'''
return cls.tokenizers.bert.tokenize(text)
##############################################
# chunk db.
##############################################
@classmethod
def get_db_num_indexed_datasets(cls):
'''Number of indexed datasets within blendable dataset.'''
return len(cls.db_indexed_dataset_infos)
@classmethod
def get_db_indexed_dataset_infos(cls):
'''Dataset infos, including number of training & sampled sets.'''
return [(info["ratio"], info["name"])
for info in cls.db_indexed_dataset_infos]
@classmethod
def get_db_dataset(cls):
return cls.db_dataset
@classmethod
def get_db_num_chunks(cls):
'''Number of DB chunks.'''
return len(cls.get_db_dataset())
@classmethod
def get_db_chunk_gpt(cls, idx):
'''Get DB chunk as GPT token ids.'''
return cls.get_db_dataset()[idx]["text"].tolist()
@classmethod
def get_db_chunk_bert(cls, idx):
'''Get DB chunk as Bert token ids.'''
return cls.text_to_bert(cls.get_db_chunk_text(idx))
@classmethod
def get_db_chunk_text(cls, idx):
'''Get DB chunk as text.'''
return cls.gpt_to_text(cls.get_db_chunk_gpt(idx))
@classmethod
def get_db_chunk_and_continuation_text(cls, idx):
'''Get DB chunk along with continuation, as text.'''
# Modulus used here to match original implementation (i.e., last
# chunks continuation wraps around to first chunk).
return [
cls.get_db_chunk_text(idx),
cls.get_db_chunk_text((idx + 1) % len(cls.get_db_dataset())),
]
##############################################
# pretraining corpus.
##############################################
@classmethod
def get_pt_num_samples_and_chunks(cls, data_key):
'''Number of samples & chunks (e.g., 32*n_samples) in corpus.'''
assert hasattr(cls.pt_datasets, data_key), \
"pretraining set '%s' not found (choices: %s)." % (
data_key, ", ".join(vars(cls.pt_datasets).keys()))
chunk_dataset = getattr(cls.pt_datasets, data_key).chunk_dataset
return (
len(chunk_dataset.sample_dataset),
len(chunk_dataset),
)
@classmethod
def get_pt_num_samples(cls, data_key):
'''Number of pretraining samples.'''
return cls.get_pt_num_samples_and_chunks(data_key)[0]
@classmethod
def get_pt_num_chunks(cls, data_key):
'''Number of pretraining chunks (e.g., 32*n_samples).'''
return cls.get_pt_num_samples_and_chunks(data_key)[1]
@classmethod
def get_pt_dataset(cls, data_key):
return getattr(cls.pt_datasets, data_key)
@classmethod
def get_pt_sample(cls, data_key, idx):
return getattr(cls.pt_datasets, data_key)[idx]
@classmethod
def get_neighbor_tokens(cls, sample_id, chunk_id, data_key="train"):
try:
sample = cls.get_pt_sample(data_key, sample_id)
sample_token_ids = sample["text"]
chunk_length = cls.args.retro_gpt_chunk_length
chunk_start_idx = chunk_id * chunk_length
chunk_end_idx = min(sample_token_ids.shape[0],
chunk_start_idx + chunk_length)
chunk_token_ids = sample_token_ids[chunk_start_idx:chunk_end_idx]
neighbor_token_ids = sample["neighbor_tokens"][chunk_id]
return {
"chunk_tokens" : chunk_token_ids,
"neighbor_tokens" : neighbor_token_ids,
}
except:
return None
@classmethod
def print_neighbor_texts(cls, sample_id, chunk_id, data_key="train"):
tokens = cls.get_neighbor_tokens(sample_id, chunk_id, data_key)
print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
try:
print("PRETRAINING CHUNK:")
print(" - %s" % shorten_str(cls.gpt_to_text(tokens["chunk_tokens"]), 150))
print("NEIGHBOR_CHUNKS:")
for token_ids in tokens["neighbor_tokens"]:
print(" - %s" % shorten_str(cls.gpt_to_text(token_ids), 150))
except:
print("<no neighbors for sample %d>" % sample_id)
##############################################
# usage.
##############################################
@classmethod
def print_usage(cls):
'''Print usage.'''
print()
print("+++++++++++++++++++++++++++++++++++++++++++++++++++")
print("examples ... [ *note*: 'db' = chunk db; 'pt' = pretraining corpus. ]")
print("+++++++++++++++++++++++++++++++++++++++++++++++++++")
print()
print("~~~~ indexed datasets ~~~~")
print("retro.get_db_num_indexed_datasets() : %s" %
cls.get_db_num_indexed_datasets())
print("retro.get_db_indexed_dataset_infos() :")
for i, (ratio,prefix) in enumerate(cls.get_db_indexed_dataset_infos()):
print(" %s(%f, %s)%s" % (
"[" if i == 0 else " ",
ratio,
prefix,
"]" if i == len(cls.db_indexed_dataset_infos) - 1 else ",",
))
print()
print("~~~~ counts ~~~~")
print("retro.get_db_num_chunks : %d." % cls.get_db_num_chunks())
print()
for sq_key in ("sample", "chunk"):
for data_key in ("train", "valid"): # test?
print("retro.get_pt_num_%ss('%s') : %d." % (
sq_key, data_key,
getattr(cls, f"get_pt_num_{sq_key}s")(data_key)))
print()
print("~~~~ tokens, text ~~~~")
print("retro.get_db_chunk_gpt(chunk_id) : %s" %
shorten_str(str(retro.get_db_chunk_gpt(0)), 50))
print("retro.get_db_chunk_bert(chunk_id) : %s" %
shorten_str(str(retro.get_db_chunk_bert(0)), 50))
print("retro.get_db_chunk_text(chunk_id) : %s" %
shorten_str(retro.get_db_chunk_text(0).strip(), 50))
print("retro.get_db_chunk_and_continuation_text(chunk_id) :")
for i, t in enumerate(retro.get_db_chunk_and_continuation_text(0)):
print(" %s'%s'%s" % (
"[" if i == 0 else " ",
shorten_str(t.strip().replace("\n", " "), 50),
"]" if i == 1 else ",",
))
sample = cls.get_pt_sample("train", 0)
sample_chunk_id = sample["neighbor_tokens"].shape[0] // 2
sample_neighbor_id = 0
print()
print("retro.get_pt_sample('train', sample_id) :")
print(" {")
for k, v in sample.items():
print(" '%s' : %s" % (k, shorten_str(str(v), 50)))
print(" }")
print()
print("(e.g., sample = retro.get_pt_sample(...))")
print()
print(" sample['text'].shape : %s" % str(sample["text"].shape))
print(" sample['neighbor_tokens'].shape : %s" % str(sample["neighbor_tokens"].shape))
print(" sample['text'] : %s" % shorten_str(str(sample["text"]), 50))
print(" sample['neighbor_tokens'][17][1] : %s" % shorten_str(str(sample["neighbor_tokens"][sample_chunk_id][sample_neighbor_id]), 50))
print(" retro.gpt_to_text(sample['text']) : %s" % shorten_str(cls.gpt_to_text(sample["text"]), 50))
print(" retro.gpt_to_text(sample['neighbor_tokens']) : %s" % shorten_str(cls.gpt_to_text(sample["neighbor_tokens"][sample_chunk_id][sample_neighbor_id]), 50))
print("+++++++++++++++++++++++++++++++++++++++++++++++++++")
......@@ -24,11 +24,13 @@ from tools.retro.external_libs import h5py
from tools.retro.utils import get_gpt_tokenizer, get_bert_tokenizer
from .utils import (
get_individual_db,
get_indexed_dataset_infos,
get_indexed_dataset_infos_path,
get_individual_db_dir,
get_individual_chunk_db,
get_individual_doc_offsets,
get_merged_dataset,
get_merged_db_path_map,
get_train_doc_chunk_map_dir,
save_indexed_dataset_infos,
)
......@@ -52,7 +54,7 @@ def init_indexed_dataset_infos():
prefix = args.data_path[i + 1]
path = prefix + ".bin"
name = os.path.basename(prefix)
assert os.path.exists(path)
assert os.path.exists(path), "couldn't find '%s'." % path
infos.append({
"ratio" : ratio,
"prefix" : prefix,
......@@ -114,6 +116,7 @@ def build_partial_db(
# Iterate documents & parse chunks.
chunk_db_valid = []
chunk_db_invalid = []
doc_size_map = {}
for doc_id in pbar:
# Progress description.
......@@ -130,7 +133,7 @@ def build_partial_db(
# Remove EOD token.
doc = indexed_dataset.get(doc_id)
if doc[-1].item() == tokenizers.gpt.eod_id:
if doc[-1].item() == tokenizers.gpt.eod:
doc = doc[:-1]
doc_len = len(doc)
......@@ -140,6 +143,7 @@ def build_partial_db(
for s in chunk_start_idxs]
# Re-tokenize each chunk to Bert/Wordpiece (empty bert -> 'invalid').
doc_size_map[doc_id] = 0
for i, chunk_start_idx in enumerate(chunk_start_idxs):
# Re-tokenize.
......@@ -149,13 +153,15 @@ def build_partial_db(
offset=chunk_start_idx,
length=chunk_end_idx - chunk_start_idx,
)
text = tokenizers.gpt.detokenize(gpt_token_ids)
text = tokenizers.gpt.detokenize(gpt_token_ids.tolist())
bert_token_ids = tokenizers.bert.tokenize(text)
# 'Valid' for non-empty Bert chunks; 'invalid' otherwise.
_chunk_db = chunk_db_invalid \
if len(bert_token_ids) == 0 else \
chunk_db_valid
if len(bert_token_ids) == 0:
_chunk_db = chunk_db_invalid
else:
_chunk_db = chunk_db_valid
doc_size_map[doc_id] += 1
_chunk_db.append((
doc_id,
chunk_start_idx,
......@@ -163,7 +169,7 @@ def build_partial_db(
len(bert_token_ids),
))
return proc_id, chunk_db_valid, chunk_db_invalid
return proc_id, chunk_db_valid, chunk_db_invalid, doc_size_map
def build_individual_db(dataset_idx, n_datasets, dataset_info, tokenizers):
......@@ -181,9 +187,10 @@ def build_individual_db(dataset_idx, n_datasets, dataset_info, tokenizers):
# Missing db blocks.
n_missing_world, missing_db_blocks = get_missing_blocks_by_rank(
db_dir,
len(indexed_dataset.doc_idx) - 1,
len(indexed_dataset),
args.retro_doc_block_size,
validate=lambda f : f["chunks_valid"].shape[1] == 4)
validate=lambda f : f["chunks_valid"].shape == (0,) \
or f["chunks_valid"].shape[1] == 4)
# Prevent missing-path-write race condition.
torch.distributed.barrier()
......@@ -209,6 +216,8 @@ def build_individual_db(dataset_idx, n_datasets, dataset_info, tokenizers):
if block is not None:
db_path = block["path"]
# Build partial dbs.
print_rank_0(' > build partial dbs.')
futures = []
......@@ -240,15 +249,27 @@ def build_individual_db(dataset_idx, n_datasets, dataset_info, tokenizers):
# Convert to numpy.
print_rank_0(' > converting chunk db to numpy.')
chunk_db_valid = np.array(chunk_db_valid)
chunk_db_invalid = np.array(chunk_db_invalid)
chunk_db_valid = np.array(chunk_db_valid, dtype="uint32")
chunk_db_invalid = np.array(chunk_db_invalid, dtype="uint32")
# Document offsets.
doc_sizes = [(d, s)
for partial_chunk_db in partial_chunk_dbs
for d, s in partial_chunk_db[3].items()]
doc_sizes.sort(key = lambda item : item[0])
doc_offsets = np.cumsum([item[1] for item in doc_sizes]) \
.astype("uint64")
doc_offsets = np.stack((
np.array([item[0] for item in doc_sizes], dtype="uint64"),
doc_offsets), axis=1)
# Save DB.
print_rank_0(" > saving individual db.")
f = h5py.File(block["path"], "w")
dset = f.create_dataset("chunks_valid", data=chunk_db_valid)
dset = f.create_dataset("chunks_invalid", data=chunk_db_invalid)
f.close()
with h5py.File(db_path, "w") as f:
dset = f.create_dataset("chunks_valid", data=chunk_db_valid)
dset = f.create_dataset("chunks_invalid",
data=chunk_db_invalid)
dset = f.create_dataset("doc_offsets", data=doc_offsets)
# Wait for all ranks to finish block.
print_rank_0(" > waiting for all ranks to finish block.")
......@@ -292,14 +313,16 @@ def update_chunk_counts(indexed_dataset_infos):
if torch.distributed.get_rank() != 0:
return
# Data ratio sum (for setting index training chunks).
data_ratio_sum = sum([ d["ratio"] for d in indexed_dataset_infos ])
# Training split size (split at document level).
train_fraction = float(args.split.split(",")[0]) / 100
assert train_fraction > 0 and train_fraction <= 1
# Set n_chunks (including n_chunks_sampled for unambiguity).
print_rank_0(" > compute n_chunks.")
for ds_index, ds_info in \
enumerate(tqdm(indexed_dataset_infos, "count_chunks")):
for ds_index, ds_info in enumerate(indexed_dataset_infos):
db_dir = ds_info["db_dir"]
db_paths = sorted(glob.glob(db_dir + "/*.hdf5"))
......@@ -310,16 +333,17 @@ def update_chunk_counts(indexed_dataset_infos):
ds_info["n_chunks"] = 0 # previously, 'n_chunks_valid'
ds_info["n_chunks_train"] = 0
ds_info["n_chunks_invalid"] = 0
for db_path in db_paths:
with h5py.File(db_path, "r") as f:
for db_path in tqdm(db_paths, "%d/%d, %s" % (
ds_index, len(indexed_dataset_infos), ds_info["name"])):
with h5py.File(db_path, "r") as f:
ds_info["n_chunks"] += len(f["chunks_valid"])
ds_info["n_chunks_invalid"] += len(f["chunks_invalid"])
ds_info["n_chunks_train"] += \
(np.copy(f["chunks_valid"][:, 0]) < ds_info["n_docs_train"]) \
.sum().item()
ds_info["n_chunks_sampled"] = \
int(round(args.retro_nchunks_sampled * ds_info["ratio"]))
ds_info["n_chunks_sampled"] = int(args.retro_index_ntrain *
ds_info["ratio"] / data_ratio_sum)
# Verify counts.
assert ds_info["n_chunks_train"] <= ds_info["n_chunks"], \
......@@ -339,15 +363,14 @@ def merge_dbs(indexed_dataset_infos, db_type):
print(" > build %s chunk db." % db_type)
# Count chunks.
if db_type == "full":
raise Exception("deprecated; use 'train' or 'sampled'.")
n_chunks_key = "n_chunks"
elif db_type == "sampled":
if db_type == "sampled":
n_chunks_key = "n_chunks_sampled"
n_docs_key = None
elif db_type == "train":
n_chunks_key = "n_chunks_train"
n_docs_key = "n_docs_train"
elif db_type == "valid":
pass
n_docs_key = None
else:
raise Exception("handle db_type '%s'." % db_type)
......@@ -356,6 +379,8 @@ def merge_dbs(indexed_dataset_infos, db_type):
for m in indexed_dataset_infos)
else:
n_chunks = sum(m[n_chunks_key] for m in indexed_dataset_infos)
n_docs = None if n_docs_key is None else \
sum(m[n_docs_key] for m in indexed_dataset_infos)
# DB path.
db_path = get_merged_db_path_map()[db_type]
......@@ -375,10 +400,10 @@ def merge_dbs(indexed_dataset_infos, db_type):
except Exception as e:
if isinstance(e, OSError):
os.remove(full_db_path)
os.remove(db_path)
elif isinstance(e, KeyError):
f.close()
os.remove(full_db_path)
os.remove(db_path)
else:
raise e
......@@ -389,121 +414,60 @@ def merge_dbs(indexed_dataset_infos, db_type):
f = h5py.File(db_path, "w")
# Initialize output arrays.
merged_db = f.create_dataset("chunks", (n_chunks, 5), dtype="i8")
merged_chunk_db = \
f.create_dataset("chunks", (n_chunks, 5), dtype="uint32")
merged_doc_offsets = None if n_docs_key is None else \
f.create_dataset("doc_offsets", (n_docs, 3), dtype="uint64")
n_written = f.create_dataset("n_written", (1,), dtype="uint64")
n_written[0] = 0
# Iterate indexed datasets & collect chunks.
start_index = 0
chunk_start_index = 0
doc_start_index = 0
doc_start_offset = 0
for ds_idx, ds_info in enumerate(indexed_dataset_infos):
print(" > merging dbs; '%s', dataset %d / %d ... '%s'." %
(db_type, ds_idx, len(indexed_dataset_infos), ds_info["name"]))
individual_db = get_individual_db(ds_idx, ds_info)
individual_chunk_db = get_individual_chunk_db(ds_idx, ds_info)
individual_doc_offsets = None if n_docs_key is None else \
get_individual_doc_offsets(ds_idx, ds_info)
if db_type == "valid":
individual_db = individual_db[ds_info["n_chunks_train"]:]
individual_chunk_db = \
individual_chunk_db[ds_info["n_chunks_train"]:]
if n_docs_key is None:
individual_doc_offsets = None
else:
train_doc_offset = \
individual_doc_offsets[ds_info["n_docs_train"] - 1, 2]
individual_doc_offsets = \
np.copy(individual_doc_offsets[ds_info["n_docs_train"]:])
individual_doc_offsets[:, 2] -= train_doc_offset
print("~~~")
print(individual_doc_offsets)
print(train_doc_offset)
raise Exception("test me.")
else:
individual_db = individual_db[:ds_info[n_chunks_key]]
merged_db[start_index:start_index+len(individual_db)] = individual_db
start_index += len(individual_db)
n_written[0] = start_index
individual_chunk_db = \
individual_chunk_db[:ds_info[n_chunks_key]]
individual_doc_offsets = None if n_docs_key is None else \
np.copy(individual_doc_offsets[:ds_info[n_docs_key]])
merged_chunk_db[chunk_start_index:chunk_start_index+len(individual_chunk_db)] = individual_chunk_db
chunk_start_index += len(individual_chunk_db)
n_written[0] = chunk_start_index
if n_docs_key is not None:
individual_doc_offsets[:, 2] += doc_start_offset
doc_end_index = doc_start_index + individual_doc_offsets.shape[0]
merged_doc_offsets[doc_start_index:doc_end_index] = \
individual_doc_offsets
doc_start_index = doc_end_index
doc_start_offset = individual_doc_offsets[-1, 2].item()
f.close()
def get_partial_banned_chunk_map(proc_id, db_path, chunk_range_info):
'''Build partial mapping of {(dataset_id,doc_id):[chunk_ids]}.
In this method, only chunks within the range (start_chunk_id, end_chunk_id]
are processed.'''
start_chunk_id = chunk_range_info["start"]
end_chunk_id = chunk_range_info["end"]
output_path = chunk_range_info["path"]
# Skip, if output file exists.
if os.path.exists(output_path):
return
# Chunk subset.
with h5py.File(db_path) as f:
sub_chunk_db = np.copy(f["chunks"][start_chunk_id:end_chunk_id, :2])
# Map docs to chunks.
banned_chunk_map = defaultdict(list)
for rel_chunk_id, (dataset_id, doc_id) in enumerate(tqdm(
sub_chunk_db,
"map banned docs, proc %d" % proc_id,
total=sub_chunk_db.shape[0],
)):
chunk_id = start_chunk_id + rel_chunk_id
banned_chunk_map["%d,%d" % (dataset_id.item(), doc_id.item())] \
.append(chunk_id)
# Save output.
with open(output_path, "w") as f:
json.dump(banned_chunk_map, f)
def build_doc_chunk_map(indexed_dataset_infos, db_type):
'''Build mapping of {(dataset_id,doc_id):[chunk_ids]}.'''
if torch.distributed.get_rank() != 0:
return
print(" > build %s doc-chunk map." % db_type)
n_procs = 128
# Get dataset.
db_dataset = get_merged_dataset(db_type, indexed_dataset_infos)
# Sub-ranges for parallel processing.
n_chunks = db_dataset.chunks.shape[0]
n_chunks_per_proc = max(1, int(np.ceil(n_chunks / n_procs)))
chunk_id_starts = list(range(0, n_chunks, n_chunks_per_proc))
chunk_id_ranges = [(s, min(n_chunks, s + n_chunks_per_proc))
for s in chunk_id_starts]
# Wrap range info with output path.
n_digits = int(np.ceil(np.log(n_chunks) / np.log(10)) + 1)
output_dirname = get_train_doc_chunk_map_dir()
chunk_range_infos = [{
"start" : start_id,
"end" : end_id,
"path" : os.path.join(output_dirname, "%s-%s.json" % (
str(start_id).zfill(n_digits),
str(end_id).zfill(n_digits),
)),
} for start_id, end_id in chunk_id_ranges ]
# Build doc-chunk map.
print_rank_0("build doc-chunk-map.")
with ProcessPoolExecutor(max_workers=n_procs) as executor:
# Build partial chunk maps.
futures = []
for proc_id, chunk_range_info in enumerate(chunk_range_infos):
if os.path.exists(chunk_range_info["path"]):
continue
# Submit job.
futures.append(executor.submit(
get_partial_banned_chunk_map,
proc_id,
db_dataset.db_path,
chunk_range_info,
))
# Wait for processes to finish.
banned_chunk_paths = []
for finished_idx, future in enumerate(as_completed(futures)):
print("finished %d / %d." % (finished_idx, n_procs))
future.result()
def build_db():
'''Extract token chunks from each indexed dataset.
......@@ -521,14 +485,13 @@ def build_db():
if torch.distributed.get_rank() != 0:
return
# Update n_chunks.
update_chunk_counts(indexed_dataset_infos)
# Update n_chunks & save indexed dataset infos.
if not os.path.exists(get_indexed_dataset_infos_path()):
update_chunk_counts(indexed_dataset_infos)
save_indexed_dataset_infos(indexed_dataset_infos)
indexed_dataset_infos = get_indexed_dataset_infos()
# Merge dbs.
merge_dbs(indexed_dataset_infos, "sampled")
merge_dbs(indexed_dataset_infos, "train")
merge_dbs(indexed_dataset_infos, "valid")
build_doc_chunk_map(indexed_dataset_infos, "train")
# Save (fully annotated) indexed dataset infos.
save_indexed_dataset_infos(indexed_dataset_infos)
......@@ -3,6 +3,7 @@
import json
import numpy as np
import torch
from tqdm import tqdm
from megatron import get_args, print_rank_0
from tools.retro.external_libs import h5py
......@@ -27,9 +28,10 @@ class DBDataset(torch.utils.data.Dataset):
self.db_path = db_path
self.indexed_datasets = indexed_datasets
self.chunks = chunks
self.doc_chunk_map = None
self.max_chunk_length = max_chunk_length
self.eod_token_id = get_gpt_tokenizer().eod_id
self.eod_token_id = get_gpt_tokenizer().eod
def __len__(self):
return self.chunks.shape[0]
......@@ -58,3 +60,15 @@ class DBDataset(torch.utils.data.Dataset):
"doc_id" : doc_id,
"text" : np.array(token_ids, dtype=np.int64),
}
def load_doc_tuples(self):
'''Load the dataset & document ids.
Load the dataset id & document id of each chunk in the database, to
be used for causality filtering during querying.
'''
self.doc_tuples = np.zeros(shape=(len(self), 2), dtype="uint32")
block_size = int(1e6)
for start_idx in tqdm(range(0, len(self), block_size)):
end_idx = min(len(self), start_idx + block_size)
self.doc_tuples[start_idx:end_idx]=self.chunks[start_idx:end_idx,:2]
......@@ -57,14 +57,14 @@ def get_indexed_dataset_infos():
def get_individual_db_dir(name):
'''Individual DB's directory.'''
return os.path.join(get_base_db_workdir(), "individual", name, "db")
return os.path.join(get_base_db_workdir(), "individual", name)
def get_individual_db(ds_id, ds_info):
def get_individual_chunk_db(ds_id, ds_info):
'''Load individual dataset's chunk DB.'''
db_paths = sorted(glob.glob(ds_info["db_dir"] + "/*hdf5"))
# *Note*: convert to dataset, rather than copying to memory.
db = np.zeros((ds_info["n_chunks"], 5), dtype="i8")
db = np.zeros((ds_info["n_chunks"], 5), dtype="uint32")
db[:, 0] = ds_id
start_idx = 0
for db_path in db_paths:
......@@ -79,6 +79,27 @@ def get_individual_db(ds_id, ds_info):
return db
def get_individual_doc_offsets(ds_id, ds_info):
'''Load individual dataset's chunk DB.'''
paths = sorted(glob.glob(ds_info["db_dir"] + "/*hdf5"))
# *Note*: convert to dataset, rather than copying to memory.
doc_offsets = np.zeros((ds_info["n_docs"], 3), dtype="uint64")
doc_offsets[:, 0] = ds_id
start_idx = 0
start_offset = 0
for path in paths:
with h5py.File(path) as f:
current_doc_offsets = np.copy(f["doc_offsets"])
current_doc_offsets[:, 1] += start_offset
current_ndocs = current_doc_offsets.shape[0]
doc_offsets[start_idx:(start_idx+current_ndocs), 1:] = \
current_doc_offsets
start_idx += current_ndocs
start_offset = current_doc_offsets[-1, 1].item()
return doc_offsets
def get_merged_db_path_map():
'''Paths to merged datasets.'''
base_dir = get_base_db_workdir()
......@@ -120,28 +141,3 @@ def get_merged_train_dataset(indexed_dataset_infos=None):
def get_merged_valid_dataset(indexed_dataset_infos=None):
return get_merged_dataset("valid", indexed_dataset_infos)
def get_train_doc_chunk_map_dir():
dirname = os.path.join(get_base_db_workdir(), "merged", "train_doc_chunk_map")
os.makedirs(dirname, exist_ok=True)
return dirname
def get_train_doc_chunk_map():
paths = sorted(glob.glob(get_train_doc_chunk_map_dir() + "/*.json"))
doc_map = defaultdict(set)
for path in tqdm(paths, "load train doc maps"):
# Read file.
with open(path) as f:
crnt_doc_map = json.load(f)
# Add to doc map.
for key, chunk_ids in crnt_doc_map.items():
key = tuple(int(i) for i in key.split(","))
doc_map[key].update(chunk_ids)
return doc_map
#!/bin/bash
# Small English Wikipedia dataset (~2M chunks).
get_wiki_tiny_config() {
RETRO_INDEX_STR="IVF4096_HNSW4,Flat"
RETRO_GPT_TRAIN_SAMPLES=31250
LR_DECAY_SAMPLES=2
LR_WARMUP_SAMPLES=1
RETRO_GPT_EVAL_INTERVAL=2000
RETRO_GPT_EVAL_ITERS=100
RETRO_EF_SEARCH=4
RETRO_NPROBE=64
DATALOADER_TYPE=cyclic
}
# English Wikipedia dataset (~67M chunks).
get_wiki_config() {
RETRO_INDEX_STR="IVF262144_HNSW32,Flat"
RETRO_GPT_TRAIN_SAMPLES=2037248
LR_DECAY_SAMPLES=2
LR_WARMUP_SAMPLES=1
RETRO_GPT_EVAL_INTERVAL=2000
RETRO_GPT_EVAL_ITERS=100
RETRO_EF_SEARCH=16
RETRO_NPROBE=4096
DATALOADER_TYPE=cyclic
}
# Full corpus (~5B chunks).
get_corpus_config() {
RETRO_INDEX_STR="OPQ32_256,IVF4194304_HNSW32,PQ32"
RETRO_GPT_TRAIN_SAMPLES=192000000
LR_DECAY_SAMPLES=166400000
LR_WARMUP_SAMPLES=162761
RETRO_GPT_EVAL_INTERVAL=2000
RETRO_GPT_EVAL_ITERS=50
RETRO_EF_SEARCH=32
RETRO_NPROBE=4096
DATALOADER_TYPE=single
}
#!/bin/bash
# Build preprocessing command for Retro.
set -u
DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
################ Required environment variables. ################
# Required environment variables:
# - REPO_DIR : Root directory of Megatron codebase.
# - RETRO_WORKDIR : Root directory of this Retro project's processed data. (For
# example, this project directory might be for a blended dataset, while
# another project directory might be for just a Wikipedia dataset, and
# another for just Book Corpus data, etc.) This project directory will
# contain a complete set of processed data, including the retrieval
# database, search index, and pretraining neighbors.
# - RETRO_TASKS : One of 'build', 'db-build', 'index-build', or
# 'pretraining-query-neighbors'. See 'Retro tasks' below for task
# descriptions.
# - DATA_BLEND_SCRIPT : Path to blended dataset definition file.
# - GPT_VOCAB_FILE : GPT vocab file.
# - GPT_MERGE_FILE : GPT merge file.
# - GPT_TOKENIZER : GPT tokenizer type (e.g., GPT2BPETokenizer)
# - BERT_LOAD_PATH : Bert checkpoint directory.
# - BERT_VOCAB_FILE : Bert vocab file.
# - BERT_TOKENIZER : Bert tokenizer type (e.g., BertWordPieceLowerCase,
# BertWordPieceCase).
# - BERT_EMBEDDER_TYPE : One of 'megatron' or 'huggingface'.
# - EXTRA_ARGS : Extra arguments (else, leave empty).
################ Data blend. ################
. ${DATA_BLEND_SCRIPT}
DATA_PATH=${DATA_BLEND}
################ Retro setup. ################
RETRO_GPT_SEQ_LENGTH=2048
RETRO_GPT_CHUNK_LENGTH=64
RETRO_GPT_MICRO_BATCH_SIZE=1 # *8
RETRO_GPT_GLOBAL_BATCH_SIZE=256
RETRO_NCHUNKS_SAMPLED=300000000
################ Retro tasks. ################
# The '--retro-tasks' argument is a comma-separated list of tasks to run, in
# sequential order. For a quick start, simply set this to 'build' to run the
# entire preprocessing pipeline. For finer control, you may specify the list of
# tasks to run. This is desirable for tuning computational resources. For
# example, training the search index is relatively fast and utilizes GPUs,
# while querying the search index is relatively slow, CPU-only, and memory
# intensive (i.e., multiple populated search indexes are loaded simultaneously).
# *Note* : Once the task(s) below have been completed -- by running either
# 1) 'build', or 2) the sequential combination of 'db-build', 'index-build',
# and 'pretraining-query-neighbors' -- we are ready to pretrain Retro by
# calling pretrain_retro.py.
# ---- Option #1 : Run entire pipeline. ----
# RETRO_TASKS="build" # (*note*: default tasks)
# ---- Option #2 : Run specific stages. ----
# *Note*: Run the following stages in the given order. Optionally, tune your
# cluster setup for each stage, as described above.
# RETRO_TASKS="db-build" # ....................... run 1st
# RETRO_TASKS="index-build" # .................... run 2nd
# RETRO_TASKS="pretraining-query-neighbors" # .... run 3rd
################ Megatron args. ################
MEGATRON_ARGS=" \
--seed 1234 \
--distributed-timeout-minutes 600 \
--tokenizer-type ${BERT_TOKENIZER} \
--tensor-model-parallel-size 1 \
--pipeline-model-parallel-size 1 \
--num-layers 24 \
--hidden-size 1024 \
--num-attention-heads 16 \
--micro-batch-size ${RETRO_GPT_MICRO_BATCH_SIZE} \
--global-batch-size ${RETRO_GPT_GLOBAL_BATCH_SIZE} \
--seq-length 512 \
--max-position-embeddings 512 \
--train-samples ${RETRO_GPT_TRAIN_SAMPLES} \
--load ${BERT_LOAD_PATH} \
--exit-on-missing-checkpoint \
--no-load-optim \
--data-path ${DATA_PATH} \
--vocab-file ${BERT_VOCAB_FILE} \
--data-impl mmap \
--split 98,2,0 \
--distributed-backend nccl \
--lr 0.0001 \
--lr-decay-style linear \
--min-lr 1.0e-5 \
--lr-decay-samples ${LR_DECAY_SAMPLES} \
--lr-warmup-samples ${LR_WARMUP_SAMPLES} \
--weight-decay 1e-2 \
--clip-grad 1.0 \
--eval-interval ${RETRO_GPT_EVAL_INTERVAL} \
--eval-iters ${RETRO_GPT_EVAL_ITERS} \
--fp16 \
--DDP-impl local \
--dataloader-type ${DATALOADER_TYPE} \
--no-data-sharding \
--no-gradient-accumulation-fusion \
--no-async-tensor-model-parallel-allreduce \
"
################ Retro args. ################
RETRO_ARGS=" \
--bert-embedder-type ${BERT_EMBEDDER_TYPE} \
--output-bert-embeddings \
\
--retro-gpt-vocab-file ${GPT_VOCAB_FILE} \
--retro-gpt-merge-file ${GPT_MERGE_FILE} \
--retro-gpt-tokenizer-type ${GPT_TOKENIZER} \
--retro-gpt-seq-length ${RETRO_GPT_SEQ_LENGTH} \
--retro-gpt-chunk-length ${RETRO_GPT_CHUNK_LENGTH} \
--retro-bert-vocab-file ${BERT_VOCAB_FILE} \
--retro-bert-tokenizer-type ${BERT_TOKENIZER} \
\
--retro-tasks ${RETRO_TASKS} \
--retro-index-str ${RETRO_INDEX_STR} \
--retro-ef-search ${RETRO_EF_SEARCH} \
--retro-nprobe ${RETRO_NPROBE} \
\
--retro-workdir ${RETRO_WORKDIR} \
--retro-nchunks-sampled ${RETRO_NCHUNKS_SAMPLED} \
\
--retro-return-doc-ids \
"
################ Command. ################
RETRO_PREPROCESS_CMD=" \
./tools/retro/main.py \
${MEGATRON_ARGS} \
${RETRO_ARGS} \
${EXTRA_ARGS} \
"
#!/bin/bash
set -u
unset NCCL_DEBUG
NPROCS=8 # NPROCS must be <= number of GPUs.
######## Megatron, Retro dirs. ########
set_current_dir() {
DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
}
REPO_DIR="<path/to/megatron/repo>"
RETRO_WORKDIR="<path/to/retro/data/directory>"
################ Dataset configs. ################
# This script contains methods to customize arguments to specific dataset
# types. Customize this script as needed for your datasets.
set_current_dir
. $DIR/get_dataset_configs.sh
######## Task (e.g., db, index, query). ########
################ Environment variables. ################
# *Note*: See 'Required environment variables' in 'get_preprocess_cmd.sh' for
# a description of the required environment variables. These variables can be
# set however a user would like. In our setup, we use another bash script
# (location defined by $RETRO_ENV_VARS) that sets all the environment variables
# at once.
. $RETRO_ENV_VARS
RETRO_TASKS="db-build"
# RETRO_TASKS="index-train"
# RETRO_TASKS="index-add"
# RETRO_TASKS="query-pretraining-neighbors"
######## Environment vars. ########
set_current_dir
. ${DIR}/get_preprocess_cmd.sh
######## Data. ########
echo "~~~~~~~~~~~~~~~~~~~~~~~~~~"
echo "DIR = '$DIR'."
echo "RETRO_PREPROCESS_CMD = '$RETRO_PREPROCESS_CMD'."
echo "~~~~~~~~~~~~~~~~~~~~~~~~~~"
DATA_BLEND="<see --data-path in arguments.py>"
######## Index. ########
RETRO_INDEX_STR="OPQ32_64,IVF65536_HNSW8,PQ32"
RETRO_INDEX_NTRAIN=1000000
RETRO_INDEX_TRAIN_LOAD_FRACTION=0.97
RETRO_INDEX_ADD_LOAD_FRACTION=0.95
######## GPT. ########
RETRO_GPT_SEED=1234
RETRO_GPT_SPLIT="98,2,0"
RETRO_GPT_DATA_PATH=${DATA_BLEND}
RETRO_GPT_DATA_IMPL=mmap
RETRO_GPT_DATALOADER_TYPE=single
RETRO_GPT_EVAL_INTERVAL=2000
RETRO_GPT_EVAL_ITERS=50
RETRO_GPT_TRAIN_SAMPLES=200000
RETRO_GPT_LR_DECAY_SAMPLES=175000
RETRO_GPT_LR_WARMUP_SAMPLES=10000
RETRO_GPT_SEQ_LENGTH=512
RETRO_GPT_GLOBAL_BATCH_SIZE=256
RETRO_GPT_CHUNK_LENGTH=64
######## Query. ########
RETRO_QUERY_NUM_NEIGHBORS_QUERY=200 RETRO_QUERY_NUM_NEIGHBORS_SAVE=20
RETRO_QUERY_EF_SEARCH=32
RETRO_QUERY_NPROBE=4096
######## Args. ########
ARGS=" \
--distributed-timeout-minutes 600 \
--tensor-model-parallel-size 1 \
--pipeline-model-parallel-size 1 \
--num-layers 24 \
--hidden-size 1024 \
--num-attention-heads 16 \
--micro-batch-size 1 \
--global-batch-size ${RETRO_GPT_GLOBAL_BATCH_SIZE} \
--seq-length 512 \
--max-position-embeddings 512 \
--load <path/to/bert/checkpoint> \
--exit-on-missing-checkpoint \
--no-load-optim \
--data-path ${RETRO_GPT_DATA_PATH} \
--tokenizer-type BertWordPieceLowerCase \
--vocab-file <path/to/bert/vocab> \
--data-impl ${RETRO_GPT_DATA_IMPL} \
--split ${RETRO_GPT_SPLIT} \
--distributed-backend nccl \
--lr 0.0001 \
--lr-decay-style linear \
--min-lr 1.0e-5 \
--train-samples ${RETRO_GPT_TRAIN_SAMPLES} \
--lr-decay-samples ${RETRO_GPT_LR_DECAY_SAMPLES} \
--lr-warmup-samples ${RETRO_GPT_LR_WARMUP_SAMPLES} \
--weight-decay 1e-2 \
--clip-grad 1.0 \
--eval-interval ${RETRO_GPT_EVAL_INTERVAL} \
--eval-iters ${RETRO_GPT_EVAL_ITERS} \
--fp16 \
--DDP-impl local \
--dataloader-type ${RETRO_GPT_DATALOADER_TYPE} \
--no-data-sharding \
--no-gradient-accumulation-fusion \
--no-async-tensor-model-parallel-allreduce \
--bert-embedder-type megatron \
--output-bert-embeddings \
\
--retro-workdir ${RETRO_WORKDIR} \
--retro-tasks ${RETRO_TASKS} \
--retro-return-doc-ids \
--retro-bert-vocab-file <path/to/bert/vocab> \
--retro-bert-tokenizer-type BertWordPieceLowerCase \
--retro-gpt-seed ${RETRO_GPT_SEED} \
--retro-gpt-tokenizer-type GPTSentencePieceTokenizer \
--retro-gpt-tokenizer-model <path/to/gpt/tokenizer/model> \
--retro-gpt-seq-length ${RETRO_GPT_SEQ_LENGTH} \
--retro-gpt-chunk-length ${RETRO_GPT_CHUNK_LENGTH} \
--retro-gpt-global-batch-size ${RETRO_GPT_GLOBAL_BATCH_SIZE} \
--retro-gpt-eval-interval ${RETRO_GPT_EVAL_INTERVAL} \
--retro-gpt-eval-iters ${RETRO_GPT_EVAL_ITERS} \
--retro-gpt-split ${RETRO_GPT_SPLIT} \
--retro-gpt-data-impl ${RETRO_GPT_DATA_IMPL} \
--retro-gpt-data-path ${RETRO_GPT_DATA_PATH} \
--retro-index-str ${RETRO_INDEX_STR} \
--retro-index-ntrain ${RETRO_INDEX_NTRAIN} \
--retro-index-train-load-fraction ${RETRO_INDEX_TRAIN_LOAD_FRACTION} \
--retro-index-add-load-fraction ${RETRO_INDEX_ADD_LOAD_FRACTION} \
--retro-index-no-delete-training-embeddings \
--retro-index-no-delete-added-codes \
--retro-query-num-neighbors-query ${RETRO_QUERY_NUM_NEIGHBORS_QUERY} \
--retro-query-num-neighbors-save ${RETRO_QUERY_NUM_NEIGHBORS_SAVE} \
--retro-query-ef-search ${RETRO_QUERY_EF_SEARCH} \
--retro-query-nprobe ${RETRO_QUERY_NPROBE} \
"
######## Command. ########
FULL_CMD="\
pwd && cd ${REPO_DIR} && pwd && \
NPROCS=8 # Number of GPUs.
CMD="\
cd ${REPO_DIR} && pwd && \
export PYTHONPATH=$PYTHONPATH:${REPO_DIR} && \
python -m torch.distributed.launch \
python -m torch.distributed.run \
--nproc_per_node ${NPROCS} \
--nnodes 1 \
--node_rank ${NODE_RANK} \
--master_addr ${MASTER_ADDR} \
--master_port 6000 \
$RETRO_PREPROCESS_CMD \
tools/retro/main.py ${ARGS} \
"
echo "~~~~~~~~~~~~~~~~~~~~~~~~~~"
echo "FULL_CMD = '$FULL_CMD'."
echo "CMD = '$CMD'."
echo "~~~~~~~~~~~~~~~~~~~~~~~~~~"
eval $FULL_CMD
eval $CMD
#!/bin/bash
##################################################
# Example script for pretraining Retro.
##################################################
set -u
unset NCCL_DEBUG
export CUDA_DEVICE_MAX_CONNECTIONS=1
NPROCS=8 # NPROCS must be <= number of GPUs.
######## GPT or Retro?. ########
# 0 : GPT.
# 1 : Retro
ADD_RETRIEVER=1
################ Dataset configs. ################
# This script contains methods to customize arguments to specific dataset
# types. Customize this script as needed for your datasets.
DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
. $DIR/get_dataset_configs.sh
######## Megatron, Retro dirs. ########
################ Environment variables. ################
# *Note*: See 'Required environment variables' in 'get_preprocess_cmd.sh' for
# a description of the required environment variables. These variables can be
# set however a user would like. In our setup, we use another bash script
# (location defined by $RETRO_ENV_VARS) that sets all the environment variables
# at once.
. $RETRO_ENV_VARS
REPO_DIR="<path/to/megatron/repo>"
RETRO_WORKDIR="<path/to/retro/data/directory>"
################ Data blend. ################
. ${DATA_BLEND_SCRIPT}
DATA_PATH=${DATA_BLEND}
######## Data. ########
######## Retro setup. ########
RETRO_ADD_RETRIEVER=1
RETRO_CYCLIC_TRAIN_ITERS=750000
RETRO_NUM_NEIGHBORS=2
DATA_BLEND="<see --data-path in arguments.py>"
######## Args. ########
######## Arguments. ########
CHECKPOINT_DIR=${RETRO_WORKDIR}/checkpoints/${RETRO_ADD_RETRIEVER}
TENSORBOARD_DIR="${CHECKPOINT_DIR}/tensorboard"
mkdir -p ${TENSORBOARD_DIR}
ARGS=" \
--save-interval 1000 \
--save ${CHECKPOINT_DIR} \
--load ${CHECKPOINT_DIR} \
--tensorboard-dir ${TENSORBOARD_DIR} \
--log-interval 5 \
--log-interval 1 \
--use-flash-attn \
--apply-layernorm-1p \
--untie-embeddings-and-output-weights \
--disable-bias-linear \
--no-position-embedding \
--use-rotary-position-embeddings \
--rotary-percent 0.5 \
--swiglu \
--attention-dropout 0.0 \
--hidden-dropout 0.0 \
--exit-duration-in-mins 220 \
--tensor-model-parallel-size 1 \
--pipeline-model-parallel-size 1 \
--num-layers 12 \
--hidden-size 768 \
--num-attention-heads 12 \
--seq-length 2048 \
--max-position-embeddings 2048 \
--micro-batch-size 4 \
--num-layers 24 \
--hidden-size 1024 \
--num-attention-heads 16 \
--seq-length 512 \
--max-position-embeddings 512 \
--micro-batch-size 16 \
--global-batch-size 256 \
--train-samples ${RETRO_GPT_TRAIN_SAMPLES} \
--lr-decay-samples ${LR_DECAY_SAMPLES} \
--lr-warmup-samples ${LR_WARMUP_SAMPLES} \
--lr 6.0e-4 \
--min-lr 6.0e-5 \
--train-samples 200000 \
--lr-decay-samples 175000 \
--lr-warmup-samples 10000 \
--lr 2.5e-5 \
--min-lr 2.5e-6 \
--lr-decay-style cosine \
--eval-interval ${RETRO_GPT_EVAL_INTERVAL} \
--eval-iters ${RETRO_GPT_EVAL_ITERS} \
--data-path ${DATA_PATH} \
--vocab-file ${GPT_VOCAB_FILE} \
--merge-file ${GPT_MERGE_FILE} \
--eval-iters 50 \
--eval-interval 2000 \
--tokenizer-type GPTSentencePieceTokenizer \
--tokenizer-model <path/to/gpt/tokenizer/model> \
--data-path ${DATA_BLEND} \
--split 98,2,0 \
--clip-grad 1.0 \
--weight-decay 0.1 \
--adam-beta1 0.9 \
--adam-beta2 0.95 \
--init-method-std 0.023 \
--init-method-std 0.007 \
--log-params-norm \
--log-num-zeros-in-grad \
--fp16 \
--bf16 \
--DDP-impl local \
--dataloader-type ${DATALOADER_TYPE} \
--no-data-sharding \
--no-gradient-accumulation-fusion \
"
if [ "$RETRO_ADD_RETRIEVER" = "0" ]; then
######## Retro. ########
if [ "$ADD_RETRIEVER" = "0" ]; then
SCRIPT=pretrain_gpt.py
else
ARGS="${ARGS} \
--retro-add-retriever \
--retro-workdir ${RETRO_WORKDIR} \
--retro-cyclic-train-iters ${RETRO_CYCLIC_TRAIN_ITERS} \
--retro-num-neighbors ${RETRO_NUM_NEIGHBORS} \
--retro-add-retriever \
"
SCRIPT=pretrain_retro.py
fi
echo "~~~~~~~~~~~~~~~~~~~~~~~~~~"
echo "ARGS = '$ARGS'."
echo "~~~~~~~~~~~~~~~~~~~~~~~~~~"
######## Command. ########
python -m torch.distributed.launch \
NPROCS=8
CMD="\
pwd && cd ${REPO_DIR} && pwd && \
export PYTHONPATH=$PYTHONPATH:${REPO_DIR} && \
python -m torch.distributed.run \
--nproc_per_node ${NPROCS} \
--nnodes 1 \
--node_rank 0 \
--master_addr localhost \
--node_rank ${NODE_RANK} \
--master_addr ${MASTER_ADDR} \
--master_port 6000 \
${SCRIPT} \
${ARGS} \
${SCRIPT} ${ARGS} \
"
echo "~~~~~~~~~~~~~~~~~~~~~~~~~~"
echo "CMD = '$CMD'."
echo "~~~~~~~~~~~~~~~~~~~~~~~~~~"
eval $CMD
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from .index import Index
from .build import add_to_index, build_index, train_index
# from .index import Index
......@@ -18,8 +18,10 @@ from tools.retro.index.factory import IndexFactory
from tools.retro.utils import GPTToTextDataset
from .utils import (
get_training_data_dir,
get_training_data_merged,
get_training_data_block_dir,
get_training_data_block_paths,
get_training_data_merged_path,
get_training_data_root_dir,
)
......@@ -36,6 +38,43 @@ def get_empty_index_path():
return empty_index_path
def get_block_nload(block_path, load_fraction):
with h5py.File(block_path) as fi:
return int(load_fraction * fi["data"].shape[0])
def merge_embedding_blocks():
if torch.distributed.get_rank() != 0:
return
args = get_retro_args()
# Get block, merged paths.
load_fraction = args.retro_index_train_load_fraction
block_paths = get_training_data_block_paths()
bin_path = get_training_data_merged_path()
# Skip, if already built.
if os.path.exists(bin_path):
return
# Merge blocks.
with open(bin_path, "wb") as fo:
byte_offset = 0
for block_idx, block_path in \
enumerate(tqdm(block_paths, "merge train embeddings")):
with h5py.File(block_path) as fi:
nload = get_block_nload(block_path, load_fraction)
block = np.array(fi["data"][:nload], copy = False)
fo.write(block.tobytes())
byte_offset += block.size * block.itemsize
fo.seek(byte_offset)
def embed_db():
'''Embed DB chunks.
......@@ -45,6 +84,10 @@ def embed_db():
args = get_retro_args()
merged_train_data_path = get_training_data_merged_path()
if os.path.exists(merged_train_data_path):
return
# Get db dataset.
gpt_dataset = get_merged_sampled_dataset()
text_dataset = GPTToTextDataset(gpt_dataset)
......@@ -54,14 +97,19 @@ def embed_db():
args.retro_bert_max_chunk_length,
args.retro_block_size,
args.bert_embedder_type)
embedder.embed_text_dataset("index", get_training_data_dir(), text_dataset)
embedder.embed_text_dataset("index",
get_training_data_block_dir(),
text_dataset)
# Merge embeddings.
merge_embedding_blocks()
def train_on_embeddings():
'''Train index on embedded DB chunks.'''
args = get_retro_args()
index = IndexFactory.get_index(args.retro_index_type)
index.train(get_training_data_merged)
index.train()
def remove_embeddings():
......@@ -71,7 +119,7 @@ def remove_embeddings():
return
empty_index_path = get_empty_index_path()
assert os.path.isfile(empty_index_path)
shutil.rmtree(get_training_data_dir(), ignore_errors=True)
shutil.rmtree(get_training_data_root_dir(), ignore_errors=True)
def train_index():
......@@ -92,7 +140,7 @@ def train_index():
torch.distributed.barrier()
# Remove embeddings.
if args.retro_delete_index_training_embeddings:
if args.retro_index_delete_training_embeddings:
remove_embeddings()
......
......@@ -5,6 +5,7 @@ import numpy as np
import os
import torch
from megatron import get_retro_args
from tools.retro.external_libs import faiss
from .utils import get_index_dir
......@@ -30,13 +31,24 @@ class Index(abc.ABC):
faiss.ParameterSpace().set_index_parameter(index, "verbose", v)
def get_empty_index_path(self):
return os.path.join(get_index_dir(), "empty.faissindex")
args = get_retro_args()
return os.path.join(
get_index_dir(),
"empty_%.3f.faissindex" % args.retro_index_train_load_fraction,
)
def get_empty_index(self):
return faiss.read_index(self.get_empty_index_path())
def get_added_index_path(self):
return os.path.join(get_index_dir(), "added.faissindex")
args = get_retro_args()
return os.path.join(
get_index_dir(),
"added_%.3f_%.3f.faissindex" % (
args.retro_index_train_load_fraction,
args.retro_index_add_load_fraction,
),
)
def get_added_index(self):
return faiss.read_index(self.get_added_index_path())
......
......@@ -8,6 +8,7 @@ inherit from this class (see FaissParAddIndex, for an example).
"""
from datetime import timedelta
import numpy as np
import os
import torch
from tqdm import tqdm
......@@ -15,13 +16,16 @@ from tqdm import tqdm
from megatron import get_retro_args, print_rank_0
from tools.bert_embedding import BertEmbedder
from tools.retro.external_libs import faiss
from tools.retro.index import Index
from tools.retro.index.utils import num_samples_to_block_ranges
from tools.retro.index.index import Index
from tools.retro.index.utils import (
get_training_data_merged_path,
num_samples_to_block_ranges,
)
class FaissBaseIndex(Index):
def _train(self, input_data_loader):
def _train(self):
'''Train index (rank 0's method).'''
args = get_retro_args()
......@@ -40,17 +44,24 @@ class FaissBaseIndex(Index):
return
# Load data.
inp = input_data_loader()
merged_path = get_training_data_merged_path()
inp = np.memmap(
merged_path,
dtype = "f4",
mode = "r",
).reshape((-1, args.hidden_size))
# Init index.
index = faiss.index_factory(args.retro_index_nfeats,
args.retro_index_str)
# Move to GPU.
print("> move faiss index to gpu.")
index_ivf = faiss.extract_index_ivf(index)
clustering_index = \
faiss.index_cpu_to_all_gpus(faiss.IndexFlatL2(index_ivf.d))
index_ivf.clustering_index = clustering_index
print("> finished moving to gpu.")
self.c_verbose(index, True)
self.c_verbose(index_ivf, True)
self.c_verbose(index_ivf.quantizer, True)
......@@ -62,12 +73,12 @@ class FaissBaseIndex(Index):
# Save index.
faiss.write_index(index, empty_index_path)
def train(self, input_data_loader):
def train(self):
'''Train index.'''
# Single process only.
if torch.distributed.get_rank() == 0:
self._train(input_data_loader)
self._train()
torch.distributed.barrier()
......
......@@ -10,6 +10,7 @@ the vast majority of the computational effort is embarrassingly parallel.
import numpy as np
import os
import psutil
import shutil
import torch
from tqdm import tqdm
......@@ -104,6 +105,8 @@ class FaissParallelAddIndex(FaissBaseIndex):
if os.path.exists(added_index_path):
return
args = get_retro_args()
# Index.
print_rank_0("read empty index.")
index = self.get_empty_index()
......@@ -112,10 +115,19 @@ class FaissParallelAddIndex(FaissBaseIndex):
# Add codes.
print_rank_0("add codes.")
code_paths = get_added_code_paths()
for code_path in tqdm(code_paths, "add codes"):
pbar = tqdm(code_paths)
for code_path in pbar:
pbar.set_description("add codes, mem %.3f gb, %.1f%%" % (
psutil.virtual_memory()[3] / 1024**3,
psutil.virtual_memory()[2],
))
with h5py.File(code_path) as f:
codes = np.copy(f["data"])
index_ivf.add_sa_codes(codes)
nload = int(args.retro_index_add_load_fraction*f["data"].shape[0])
offset = int(os.path.basename(code_path).split("-")[0])
xids = np.arange(offset, offset + nload)
codes = np.copy(f["data"][:nload])
index_ivf.add_sa_codes(codes, xids)
# Update index's ntotal.
index.ntotal = index_ivf.ntotal
......@@ -129,18 +141,19 @@ class FaissParallelAddIndex(FaissBaseIndex):
if torch.distributed.get_rank() != 0:
return
assert os.path.isfile(self.get_added_index_path())
shutil.rmtree(get_added_codes_dir(), ignore_errors=True)
def add(self, text_dataset):
args = get_retro_args()
if args.retro_index_delete_added_codes:
raise Exception("remove?")
shutil.rmtree(get_added_codes_dir(), ignore_errors=True)
# Check if index already exists.
if not os.path.isfile(self.get_added_index_path()):
def add(self, text_dataset):
# Encode chunks.
self.encode(text_dataset)
# Encode chunks.
self.encode(text_dataset)
# Add codes to index.
self.add_codes()
# Add codes to index.
self.add_codes()
# Wait for (single-process) adding to complete.
torch.distributed.barrier()
......
......@@ -45,128 +45,28 @@ def num_samples_to_block_ranges(num_samples):
return ranges
def get_training_data_dir():
return os.path.join(get_index_dir(), "train_tmp")
def get_training_data_paths():
return sorted(glob.glob(get_training_data_dir() + "/*.hdf5"))
def get_added_codes_dir():
return os.path.join(get_index_dir(), "add_tmp")
def get_added_code_paths():
return sorted(glob.glob(get_added_codes_dir() + "/*.hdf5"))
def get_training_data_group_infos():
def get_training_data_root_dir():
args = get_retro_args()
return os.path.join(args.retro_workdir, "index", "train_emb")
block_paths = get_training_data_paths()
max_group_size = args.retro_index_train_block_size
groups = []
group = []
group_size = 0
for block_path in block_paths:
with h5py.File(block_path) as f:
block_size = f["data"].shape[0]
group.append(block_path)
group_size += block_size
def get_training_data_block_dir():
return os.path.join(get_training_data_root_dir(), "blocks")
if group_size >= max_group_size:
groups.append({
"paths" : group,
"size" : group_size,
})
group = []
group_size = 0
if group:
groups.append({
"paths" : group,
"size" : group_size,
})
return groups
def get_training_data_block_paths():
return sorted(glob.glob(get_training_data_block_dir() + "/*.hdf5"))
def load_training_block(path, load_fraction):
with h5py.File(path) as f:
n_load = int(load_fraction * f["data"].shape[0])
return np.copy(f["data"][:n_load])
def load_training_group(executor, group_info, load_fraction):
# Launch threads to load block data.
futures = []
for path in group_info["paths"]:
futures.append(executor.submit(load_training_block, path, load_fraction))
# Collect block data.
block_datas = []
for future in futures:
block_datas.append(future.result())
# Concatenate blocks.
group_data = np.concatenate(block_datas, axis=0)
# Garbage collect.
for d in block_datas:
del d
gc.collect()
return group_data
def get_training_data_merged_path():
args = get_retro_args()
return os.path.join(get_training_data_root_dir(),
"train_%.3f.bin" % args.retro_index_train_load_fraction)
def get_training_data_merged():
'''Merge embeddings into single dataset.'''
def get_added_codes_dir():
return os.path.join(get_index_dir(), "add_codes")
args = get_retro_args()
# Setup.
ds_infos = get_indexed_dataset_infos()
n_chunks_sampled = sum(d["n_chunks_sampled"] for d in ds_infos)
load_fraction = args.retro_index_train_load_fraction
# Initialize merged data.
print("allocate training data array.")
t = time.time()
data = np.empty((n_chunks_sampled, args.retro_index_nfeats), dtype="f4")
print(" time : %.3f sec." % (time.time() - t))
# Data groups (minimizing fragmentation).
group_infos = get_training_data_group_infos()
# Load data blocks.
n_threads = max(len(group["paths"]) for group in group_infos)
with concurrent.futures.ThreadPoolExecutor(max_workers=n_threads) as executor:
# Load data blocks.
print("load training data blocks.")
start_idx = 0
pbar = tqdm(group_infos)
for group_info in pbar:
pbar.set_description("mem %.0f gb, %.1f%%" % (
psutil.virtual_memory()[3] / 1024**3,
psutil.virtual_memory()[2],
))
# Load group data.
group_data = load_training_group(executor, group_info, load_fraction)
data[start_idx:(start_idx+len(group_data))] = group_data
start_idx += len(group_data)
# Garbage collect.
del group_data
gc.collect()
# Handle load ratio <1.
data = data[:start_idx]
print("> training block data.shape = %s." % str(data.shape))
return data
def get_added_code_paths():
return sorted(glob.glob(get_added_codes_dir() + "/*.hdf5"))
......@@ -15,8 +15,8 @@ import torch
from megatron import get_args, initialize_megatron, print_rank_0
from megatron.global_vars import set_retro_args
from tools.retro.db import build_db
from tools.retro.index.build import add_to_index, build_index, train_index
from tools.retro.pretraining.query import query_pretraining_neighbors
from tools.retro.index import add_to_index, build_index, train_index
from tools.retro.query import query_pretraining_neighbors
from tools.retro.utils import get_args_path
......@@ -31,16 +31,69 @@ def add_retro_args(parser):
group = parser.add_argument_group(title="Retro preprocessing.")
group.add_argument("--retro-gpt-vocab-file", required=True,
help="GPT vocab file.")
group.add_argument("--retro-gpt-merge-file", required=True,
help="GPT merge file.")
# Basic args.
group.add_argument("--retro-tasks", default="build",
help="Comma-separated list of tasks to run. Run entire "
"preprocesing pipeline by using '--retro-tasks build'. "
"Alternatively, run individual stages with tasks (in "
"this order) 'db-build', 'index-build', or "
"'query-pretraining-neighbors'. For example, "
"'--retro-tasks db-build,index-build,"
"query-pretraining-neighbors' is equivalent to "
"'--retro-tasks build'; or the argument can contain "
"a subset of these tasks. Stages must always be run "
"in the correct order (listed above).")
group.add_argument("--retro-block-size", type=int, default=100000,
help="Number of chunks to process at a time when "
"generating Bert embeddings and querying the search "
"index. Partial results for each block are generally "
"saved to disk in separate files.")
group.add_argument("--retro-doc-block-size", type=int, default=100000,
help="Number of documents to processe at time when "
"processing token datasets into chunk databases. The "
"partial chunk database for each block is saved into "
"a separate file.")
# GPT args.
group.add_argument('--retro-gpt-seed', type=int, default=1234,
help='Random seed used for python, numpy, '
'pytorch, and cuda.')
group.add_argument('--retro-gpt-data-impl', type=str, default='infer',
choices=['lazy', 'cached', 'mmap', 'infer'],
help='Implementation of indexed datasets.')
group.add_argument('--retro-gpt-data-path', nargs='*', required=True,
help='Path to the training dataset. Accepted format:'
'1) a single data path, 2) multiple datasets in the'
'form: dataset1-weight dataset1-path dataset2-weight '
'dataset2-path ... It is used with --split when a '
'single dataset used for all three: train, valid '
'and test. It is exclusive to the other '
'--*-data-path args')
group.add_argument('--retro-gpt-split', type=str, default='969,30,1',
help='Comma-separated list of proportions for training,'
' validation, and test split. For example the split '
'`90,5,5` will use 90%% of data for training, 5%% for '
'validation and 5%% for test.')
group.add_argument('--retro-gpt-mmap-warmup', action='store_true',
help='Warm up mmap files.')
group.add_argument("--retro-gpt-eval-interval", type=int, required=True,
help="GPT evaluation interval.")
group.add_argument("--retro-gpt-eval-iters", type=int, required=True,
help="GPT evaluation iterations.")
group.add_argument("--retro-gpt-tokenizer-type", required=True,
help="GPT tokenizer type.")
group.add_argument("--retro-gpt-seq-length", type=int, default=2048,
group.add_argument("--retro-gpt-vocab-file", help="GPT vocab file.")
group.add_argument("--retro-gpt-merge-file", help="GPT merge file.")
group.add_argument("--retro-gpt-tokenizer-model",
help="GPT tokenizer model file.")
group.add_argument("--retro-gpt-seq-length", type=int, required=True,
help="GPT sequence length.")
group.add_argument("--retro-gpt-global-batch-size", type=int, required=True,
help="GPT global batch size.")
group.add_argument("--retro-gpt-chunk-length", type=int, default=64,
help="GPT chunk length.")
# Bert args.
group.add_argument("--retro-bert-vocab-file", required=True,
help="Bert vocab file.")
group.add_argument("--retro-bert-tokenizer-type", required=True,
......@@ -52,17 +105,8 @@ def add_retro_args(parser):
help="Maximum sequence length for Bert embeddings. "
"(Named 'chunk' here in reference to these Bert "
"sequences being converted from GPT chunks.)")
group.add_argument("--retro-tasks", default="build",
help="Comma-separated list of tasks to run. Run entire "
"preprocesing pipeline by using '--retro-tasks build'. "
"Alternatively, run individual stages with tasks (in "
"this order) 'db-build', 'index-build', or "
"'pretraining-query-neighbors'. For example, "
"'--retro-tasks db-build,index-build,"
"pretraining-query-neighbors' is equivalent to "
"'--retro-tasks build'; or the argument can contain "
"a subset of these tasks. Stages must always be run "
"in the correct order (listed above).")
# Index args.
group.add_argument("--retro-index-nfeats", "-f", type=int, default=1024,
help="Dimension of Bert embeddings. Bert-large is "
"commonly used, so this value defaults to 1024.")
......@@ -78,34 +122,10 @@ def add_retro_args(parser):
"faiss.index_factory(). For example, "
"'IVF262144_HNSW32,Flat' or "
"'OPQ32_256,IVF4194304_HNSW32,PQ32'.")
group.add_argument("--retro-ef-search", type=int, default=256,
help="Index ef-search parameter for HNSW during "
"querying.")
group.add_argument("--retro-nprobe", type=int, default=65536,
help="Index nprobe parameter for IVF during "
"querying.")
group.add_argument("--retro-nchunks-sampled", type=int, required=True,
group.add_argument("--retro-index-ntrain", type=int, required=True,
help="Number of database chunks to use for training "
"the index. This value must be less or equal to the "
"total number of chunks in the database.")
group.add_argument("--retro-doc-block-size", type=int, default=100000,
help="Number of documents to processe at time when "
"processing token datasets into chunk databases. The "
"partial chunk database for each block is saved into "
"a separate file.")
group.add_argument("--retro-block-size", type=int, default=100000,
help="Number of chunks to process at a time when "
"generating Bert embeddings and querying the search "
"index. Partial results for each block are generally "
"saved to disk in separate files.")
group.add_argument("--retro-index-train-block-size",
type=int, default=3750000,
help="As a memory fragmentation optimization, when "
"loading training data for training the search index, "
"enough data blocks loaded at a time until they reach "
"retro_index_train_block_size, and then this "
"data block is copied into the full training data "
"array.")
group.add_argument("--retro-index-train-load-fraction",
type=float, default=1.,
help="Fraction of sampled chunks to use for training "
......@@ -113,19 +133,36 @@ def add_retro_args(parser):
"use too much memory; lowering the load fraction is "
"less costly than re-embedding a new sampled dataset "
"from scratch.")
group.add_argument("--retro-num-neighbors-query", type=int, default=2000,
group.add_argument("--retro-index-add-load-fraction",
type=float, default=1.,
help="Fraction of database chunks to use for adding to "
"the index. Useful when our total index size would "
"use too much memory; lowering the load fraction is "
"less costly than re-designing our token datasets.")
group.add_argument("--retro-index-no-delete-training-embeddings",
action='store_false',
dest="retro_index_delete_training_embeddings",
help="Skip deleting training embeddings for the search "
"index. Useful for debugging.")
group.add_argument("--retro-index-no-delete-added-codes",
action='store_false',
dest="retro_index_delete_added_codes",
help="Skip deleting added codes for the search "
"index. Useful for debugging.")
# Query args.
group.add_argument("--retro-query-ef-search", type=int, default=256,
help="Index ef-search parameter for HNSW during querying.")
group.add_argument("--retro-query-nprobe", type=int, default=65536,
help="Index nprobe parameter for IVF during querying.")
group.add_argument("--retro-query-num-neighbors-query", type=int, default=200,
help="Number of neighbors to retrieve when calling "
"index.search().")
group.add_argument("--retro-num-neighbors-target", type=int, default=200,
group.add_argument("--retro-query-num-neighbors-save", type=int, default=20,
help="Number of neighbors to save to disk after "
"the index's returned neighbors. If longer than target "
"value, neighbors truncated; and if shorter than target "
"value, neighbors are padded with -1's.")
group.add_argument("--retro-no-delete-index-training-embeddings",
action='store_false',
dest="retro_delete_index_training_embeddings",
help="Skip deleting training embeddings for the search "
"index. Useful for debugging.")
# Enforce argument naming convention.
for action in group._group_actions:
......@@ -140,10 +177,16 @@ def add_retro_args(parser):
def save_args(args):
'''Save copy of args within retro workdir.'''
def default_dump(obj):
if isinstance(obj, torch.dtype):
return str(obj)
else:
raise Exception("specialize for <%s>." % type(obj).__name__)
if torch.distributed.get_rank() == 0:
args_path = get_args_path(args.retro_workdir)
with open(args_path, "w") as f:
json.dump(vars(args), f, indent=4, default=lambda o : "<skipped>")
json.dump(vars(args), f, indent=4, default=default_dump)
torch.distributed.barrier()
......@@ -188,7 +231,7 @@ if __name__ == "__main__":
add_to_index() # add only
# Pretraining.
elif task == "pretraining-query-neighbors":
elif task == "query-pretraining-neighbors":
query_pretraining_neighbors()
else:
......
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import os
from megatron import get_retro_args
def get_pretraining_workdir():
args = get_retro_args()
return os.path.join(args.retro_workdir, "pretraining")
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from .query import query_pretraining_neighbors
......@@ -4,15 +4,16 @@ import os
import torch
from megatron import get_retro_args, print_rank_0
from megatron.data.gpt_dataset import build_train_valid_test_datasets
from megatron.data.gpt_dataset import build_train_valid_test_datasets \
as build_gpt_train_valid_test_datasets
from megatron.training import (
build_train_valid_test_data_loaders,
build_train_valid_test_datasets as build_pretraining_train_valid_test_datasets,
update_train_iters,
)
from tools.retro.db.utils import get_indexed_dataset_infos
from tools.retro.utils import get_num_chunks_per_sample
from .utils import get_pretraining_workdir
from .utils import get_neighbor_dirname, get_query_workdir
class ChunkDataset(torch.utils.data.Dataset):
......@@ -86,14 +87,14 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
print_rank_0('> building train, validation, and test datasets '
'for GPT ...')
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
data_prefix=args.data_path,
data_impl=args.data_impl,
splits_string=args.split,
train_ds, valid_ds, test_ds = build_gpt_train_valid_test_datasets(
data_prefix=args.retro_gpt_data_path,
data_impl=args.retro_gpt_data_impl,
splits_string=args.retro_gpt_split,
train_valid_test_num_samples=train_val_test_num_samples,
seq_length=args.retro_gpt_seq_length,
seed=args.seed,
skip_warmup=(not args.mmap_warmup),
seed=args.retro_gpt_seed,
skip_warmup=(not args.retro_gpt_mmap_warmup),
return_doc_ids=args.retro_return_doc_ids)
print_rank_0("> finished creating pretrained GPT datasets ...")
......@@ -115,28 +116,23 @@ def get_chunk_dataset_map():
verify_indexed_dataset_order()
# Datasets.
print_rank_0(" > data loader.")
train_data_loader, valid_data_loader, test_data_loader \
= build_train_valid_test_data_loaders(
train_valid_test_datasets_provider)
data_loader_map = {
"train" : train_data_loader,
"valid" : valid_data_loader,
"test" : test_data_loader,
print_rank_0(" > datasets.")
train_ds, valid_ds, test_ds = build_pretraining_train_valid_test_datasets(
train_valid_test_datasets_provider)
sample_dataset_map = {
"train" : train_ds,
"valid" : valid_ds,
"test" : test_ds,
}
# Info dict.
workdir = get_pretraining_workdir()
dataset_map = {
chunk_dataset_map = {
key : {
"neighbor_dir" : os.path.join(
workdir,
os.path.basename(loader.dataset.datasets[0].index_prefix),
),
"data" : ChunkDataset(loader.dataset, args.retro_gpt_chunk_length),
"neighbor_dir" : get_neighbor_dirname(key, sample_ds),
"data" : ChunkDataset(sample_ds, args.retro_gpt_chunk_length),
}
for key, loader in data_loader_map.items() if loader
for key, sample_ds in sample_dataset_map.items() if sample_ds
}
return dataset_map
return chunk_dataset_map
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from collections import defaultdict
import numpy as np
import os
import psutil
import time
import torch
from tqdm import tqdm
from megatron import get_retro_args, mpu, print_rank_0
from megatron import get_retro_args, print_rank_0
from tools.bert_embedding import BertEmbedder
from tools.bert_embedding.utils import get_missing_blocks_by_rank
from tools.retro.db.utils import (
get_merged_train_dataset as get_db_merged_train_dataset,
get_train_doc_chunk_map,
)
from tools.retro.db.utils import \
get_merged_train_dataset as get_db_merged_train_dataset
from tools.retro.external_libs import faiss, h5py
from tools.retro.index.factory import IndexFactory
from tools.retro.index.utils import get_index_dir, num_samples_to_block_ranges
from tools.retro.index.utils import get_index_dir
from tools.retro.utils import GPTToTextDataset
from .chunk_dataset import get_chunk_dataset_map
from .chunk_dataset import get_chunk_dataset_map as get_query_dataset_map
def get_index(chunk_db_dataset, ondisk=False):
def get_index(ondisk=False):
'''Read index from disk.'''
args = get_retro_args()
# Chunk db block ranges.
n_db_chunks = len(chunk_db_dataset)
dataset_block_ranges = num_samples_to_block_ranges(n_db_chunks)
# Load index.
index_wrapper = IndexFactory.get_index(args.retro_index_type)
index_dir = get_index_dir()
......@@ -42,9 +36,9 @@ def get_index(chunk_db_dataset, ondisk=False):
# Search parameters.
faiss.ParameterSpace().set_index_parameter(index, "efSearch",
args.retro_ef_search)
args.retro_query_ef_search)
faiss.ParameterSpace().set_index_parameter(index, "nprobe",
args.retro_nprobe)
args.retro_query_nprobe)
return index
......@@ -58,8 +52,9 @@ def embed_block(gpt_dataset, block, embedder):
return embedder.embed_text_dataset(text_block_dataset)
def query_embeddings(index, banned_chunk_map, chunk_id_range,
embeddings, sample_map, n_chunks_per_sample,
def query_embeddings(db_dataset, index,
embeddings, chunk_id_range,
sample_map, n_chunks_per_sample,
verbose=True):
'''Query neighbors of a block of embeddings.'''
......@@ -70,24 +65,13 @@ def query_embeddings(index, banned_chunk_map, chunk_id_range,
t = time.time()
assert index.ntotal > 0, "check we don't accidentally have an empty index."
_, query_neighbor_ids = \
index.search(embeddings, args.retro_num_neighbors_query)
index.search(embeddings, args.retro_query_num_neighbors_query)
if verbose: print_rank_0(" time : %.3f sec." % (time.time() - t))
# Banned neighbor ids.
if verbose: print_rank_0("get banned neighbor ids.")
sample_banned_chunk_id_map = {}
for sample_id, sample in sample_map.items():
dataset_idx = sample["dataset_idx"].item()
doc_ids = sample["doc_ids"].tolist()
banned_chunk_ids = set()
for doc_id in doc_ids:
banned_chunk_ids.update(banned_chunk_map[(dataset_idx, doc_id)])
sample_banned_chunk_id_map[sample_id] = banned_chunk_ids
# Filter banned neighbor ids.
if verbose: print_rank_0("filter banned neighbor ids.")
filtered_neighbor_ids = np.full(
shape=(len(query_neighbor_ids), args.retro_num_neighbors_target),
shape=(len(query_neighbor_ids), args.retro_query_num_neighbors_save),
fill_value=-1,
dtype="int64",
)
......@@ -95,24 +79,30 @@ def query_embeddings(index, banned_chunk_map, chunk_id_range,
for chunk_id in range(min_chunk_id, max_chunk_id):
sample_id = chunk_id // n_chunks_per_sample
sample = sample_map[sample_id]
sample_dataset_idx = sample["dataset_idx"].item()
sample_doc_ids = sample["doc_ids"].tolist()
sample_doc_tuples = [(sample_dataset_idx, d) for d in sample_doc_ids]
# Get valid neighbors (!= -1).
query_row = [ i for i in query_neighbor_ids[chunk_id-min_chunk_id]
if i >= 0 ]
# Filter row.
filtered_row = [i for i in query_row
if i not in sample_banned_chunk_id_map[sample_id]]
filtered_row = filtered_row[:args.retro_num_neighbors_target]
filtered_row = [ i for i in query_row
if tuple(db_dataset.doc_tuples[i].tolist())
not in sample_doc_tuples ]
filtered_row = filtered_row[:args.retro_query_num_neighbors_save]
filtered_row += \
[-1] * (args.retro_num_neighbors_target - len(filtered_row))
[-1] * (args.retro_query_num_neighbors_save - len(filtered_row))
filtered_neighbor_ids[chunk_id-min_chunk_id] = filtered_row
return query_neighbor_ids, filtered_neighbor_ids
def query_embedding_block(index, banned_chunk_map, chunk_id_range,
embeddings, sample_map, n_chunks_per_sample):
def query_embedding_block(db_dataset, index,
embeddings, chunk_id_range,
sample_map, n_chunks_per_sample):
query_neighbor_ids = []
filtered_neighbor_ids = []
......@@ -131,8 +121,9 @@ def query_embedding_block(index, banned_chunk_map, chunk_id_range,
chunk_id_range[0] + partial_end_idx,
)
partial_query_neighbor_ids, partial_filtered_neighbor_ids = \
query_embeddings(index, banned_chunk_map, partial_chunk_id_range,
partial_embeddings, sample_map, n_chunks_per_sample,
query_embeddings(db_dataset, index,
partial_embeddings, partial_chunk_id_range,
sample_map, n_chunks_per_sample,
verbose=False)
query_neighbor_ids.append(partial_query_neighbor_ids)
filtered_neighbor_ids.append(partial_filtered_neighbor_ids)
......@@ -144,26 +135,33 @@ def query_embedding_block(index, banned_chunk_map, chunk_id_range,
return query_neighbor_ids, filtered_neighbor_ids
def query_block_neighbors(index, banned_chunk_map, chunk_dataset,
block, embedder):
def query_block_neighbors(db_dataset, query_dataset,
index, embedder,
block):
'''Query neighbors of a dataset block (i.e., range).'''
args = get_retro_args()
n_chunks_per_sample = chunk_dataset.n_chunks_per_sample
n_chunks_per_sample = query_dataset.n_chunks_per_sample
# Sample map.
sample_ids = sorted(list(set(chunk_id // n_chunks_per_sample
for chunk_id in range(*block["range"]))))
sample_map = {i:chunk_dataset.sample_dataset[i] for i in sample_ids}
sample_map = {}
for i in sample_ids:
sample = query_dataset.sample_dataset[i]
sample_map[i] = {
"dataset_idx" : sample["dataset_idx"],
"doc_ids" : sample["doc_ids"],
}
# Embed block.
embeddings = embed_block(chunk_dataset, block, embedder)
embeddings = embed_block(query_dataset, block, embedder)
# Query embeddings.
_, filtered_neighbor_ids = query_embedding_block(
index, banned_chunk_map, block["range"],
embeddings, sample_map,
n_chunks_per_sample)
db_dataset, index,
embeddings, block["range"],
sample_map, n_chunks_per_sample)
# Save neighbors.
print_rank_0("save neighbors.")
......@@ -173,22 +171,22 @@ def query_block_neighbors(index, banned_chunk_map, chunk_dataset,
f.close()
def query_dataset_neighbors(index, banned_chunk_map,
prefix, chunk_dataset, neighbor_dir,
embedder):
def query_dataset_neighbors(db_dataset, query_dataset,
prefix, neighbor_dir,
index, embedder):
'''Query neighbors of each chunk within a dataset.'''
args = get_retro_args()
def validate(f):
assert f["neighbors"].shape[1] == args.retro_num_neighbors_target, \
assert f["neighbors"].shape[1] == args.retro_query_num_neighbors_save, \
"neighbors.shape == %s; num_neighbors_target == %d." % (
str(f["neighbors"].shape),
args.retro_num_neighbors_target,
)
n_missing_blocks, missing_neighbor_blocks = get_missing_blocks_by_rank(
neighbor_dir,
len(chunk_dataset),
len(query_dataset),
args.retro_block_size,
validate=validate,
)
......@@ -199,16 +197,19 @@ def query_dataset_neighbors(index, banned_chunk_map,
if block is not None:
# Progress.
print_rank_0("query '%s' block %d / %d ... %s." % (
print_rank_0("query '%s' block %d / %d ... %s ... mem %.3f gb, %.1f%%." % (
prefix,
block_index,
len(missing_neighbor_blocks),
block["path"],
os.path.basename(block["path"]),
psutil.virtual_memory()[3] / 1024**3,
psutil.virtual_memory()[2],
))
# Query block neighbors.
query_block_neighbors(index, banned_chunk_map,
chunk_dataset, block, embedder)
query_block_neighbors(db_dataset, query_dataset,
index, embedder,
block)
# Synchronize progress across all ranks. (for easier observation)
print_rank_0(" > waiting for other ranks to finish block.")
......@@ -225,17 +226,16 @@ def query_pretraining_neighbors():
# Load chunk db dataset.
print_rank_0("load chunk db dataset.")
chunk_db_dataset = get_db_merged_train_dataset()
db_dataset = get_db_merged_train_dataset()
db_dataset.load_doc_tuples()
# Load index, banned chunk ids, datasets.
# Load index.
print_rank_0(" > get index.")
index = get_index(chunk_db_dataset)
print_rank_0(" > get banned doc-chunk id map.")
banned_chunk_map = get_train_doc_chunk_map()
index = get_index()
# Load datasets.
print_rank_0(" > get dataset map.")
chunk_dataset_map = get_chunk_dataset_map()
query_dataset_map = get_query_dataset_map()
# Bert embedder.
embedder = BertEmbedder(args.retro_bert_batch_size,
......@@ -244,9 +244,9 @@ def query_pretraining_neighbors():
# Query each (i.e., train, valid, test) dataset.
print_rank_0(" > query.")
for prefix, info in chunk_dataset_map.items():
for prefix, info in query_dataset_map.items():
print_rank_0(" > query '%s' dataset ... %d samples." %
(prefix, len(info["data"])))
query_dataset_neighbors(index, banned_chunk_map,
prefix, info["data"], info["neighbor_dir"],
embedder)
query_dataset_neighbors(db_dataset, info["data"],
prefix, info["neighbor_dir"],
index, embedder)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment