"vscode:/vscode.git/clone" did not exist on "c76dcce59f1ad463e853871b88ec2d556f8a3939"
Commit 3aca1415 authored by liangjing's avatar liangjing
Browse files

Merge branch 'megatron-lm_dtk24.04' into 'main'

Megatron lm dtk24.04

See merge request !1
parents 0024a5c6 1005e9d3
Pipeline #1806 passed with stage
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import os
from . import retro from . import retro
......
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import json
import numpy as np
import os
import torch
import types
from megatron.global_vars import set_global_variables, set_retro_args
from megatron.initialize import (
initialize_megatron,
_initialize_distributed,
_set_random_seed,
_compile_dependencies,
)
from tools.retro.db.utils import (
get_indexed_dataset_infos as get_db_indexed_dataset_infos,
get_merged_train_dataset as get_db_dataset,
)
from tools.retro.main import add_retro_args
from tools.retro.query.retro_dataset import get_retro_datasets
from tools.retro.utils import get_args_path, get_bert_tokenizer, get_gpt_tokenizer
def shorten_str(s, n):
s = "\\n".join(s.splitlines())
return s if len(s) <= n else "%s ... %s" % (s[:n//2], s[-n//2:])
class retro:
args = None
##############################################
# initialize.
##############################################
@classmethod
def parse_dtype_str(cls, dtype_str):
return {
"torch.float16" : torch.float16,
"torch.float32" : torch.float32,
"torch.bfloat16" : torch.bfloat16,
}[dtype_str]
@classmethod
def init_megatron(cls, workdir):
'''Custom initialization of Megatron.'''
# Load args.
args_path = get_args_path(workdir)
assert os.path.exists(args_path), "args.json not found in workdir."
with open(args_path) as f:
cls.args = types.SimpleNamespace(**json.load(f))
cls.args.retro_workdir = workdir # just in case workdir moved
cls.args.rank = 0 # override env
cls.args.world_size = 1 # override env
cls.args.params_dtype = cls.parse_dtype_str(cls.args.params_dtype)
set_global_variables(cls.args)
set_retro_args(cls.args)
_initialize_distributed()
_set_random_seed(cls.args.seed, cls.args.data_parallel_random_init)
_compile_dependencies()
@classmethod
def init(cls, workdir):
'''Initialize Megatron, tokenizers, and datasets.'''
# Load args.
cls.init_megatron(workdir)
cls.tokenizers = types.SimpleNamespace(
gpt=get_gpt_tokenizer(),
bert=get_bert_tokenizer(),
)
# Load data.
cls.db_indexed_dataset_infos = get_db_indexed_dataset_infos()
cls.db_dataset = get_db_dataset()
pt_train_ds, pt_valid_ds, _ = get_retro_datasets(verify_sizes=False)
cls.pt_datasets = types.SimpleNamespace(
train=pt_train_ds,
valid=pt_valid_ds,
)
# Retrieve max saved neighbors.
for key in vars(cls.pt_datasets):
getattr(cls.pt_datasets, key).num_neighbors = \
cls.args.retro_query_num_neighbors_save
# Print usage.
cls.print_usage()
##############################################
# utils.
##############################################
@classmethod
def gpt_to_text(cls, token_ids):
'''GPT tokens to text.'''
return cls.tokenizers.gpt.detokenize(token_ids.tolist()
if isinstance(token_ids, np.ndarray)
else token_ids)
@classmethod
def text_to_bert(cls, text):
'''Text to Bert tokens.'''
return cls.tokenizers.bert.tokenize(text)
##############################################
# chunk db.
##############################################
@classmethod
def get_db_num_indexed_datasets(cls):
'''Number of indexed datasets within blendable dataset.'''
return len(cls.db_indexed_dataset_infos)
@classmethod
def get_db_indexed_dataset_infos(cls):
'''Dataset infos, including number of training & sampled sets.'''
return [(info["ratio"], info["name"])
for info in cls.db_indexed_dataset_infos]
@classmethod
def get_db_dataset(cls):
return cls.db_dataset
@classmethod
def get_db_num_chunks(cls):
'''Number of DB chunks.'''
return len(cls.get_db_dataset())
@classmethod
def get_db_chunk_gpt(cls, idx):
'''Get DB chunk as GPT token ids.'''
return cls.get_db_dataset()[idx]["text"].tolist()
@classmethod
def get_db_chunk_bert(cls, idx):
'''Get DB chunk as Bert token ids.'''
return cls.text_to_bert(cls.get_db_chunk_text(idx))
@classmethod
def get_db_chunk_text(cls, idx):
'''Get DB chunk as text.'''
return cls.gpt_to_text(cls.get_db_chunk_gpt(idx))
@classmethod
def get_db_chunk_and_continuation_text(cls, idx):
'''Get DB chunk along with continuation, as text.'''
# Modulus used here to match original implementation (i.e., last
# chunks continuation wraps around to first chunk).
return [
cls.get_db_chunk_text(idx),
cls.get_db_chunk_text((idx + 1) % len(cls.get_db_dataset())),
]
##############################################
# pretraining corpus.
##############################################
@classmethod
def get_pt_num_samples_and_chunks(cls, data_key):
'''Number of samples & chunks (e.g., 32*n_samples) in corpus.'''
assert hasattr(cls.pt_datasets, data_key), \
"pretraining set '%s' not found (choices: %s)." % (
data_key, ", ".join(vars(cls.pt_datasets).keys()))
chunk_dataset = getattr(cls.pt_datasets, data_key).chunk_dataset
return (
len(chunk_dataset.sample_dataset),
len(chunk_dataset),
)
@classmethod
def get_pt_num_samples(cls, data_key):
'''Number of pretraining samples.'''
return cls.get_pt_num_samples_and_chunks(data_key)[0]
@classmethod
def get_pt_num_chunks(cls, data_key):
'''Number of pretraining chunks (e.g., 32*n_samples).'''
return cls.get_pt_num_samples_and_chunks(data_key)[1]
@classmethod
def get_pt_dataset(cls, data_key):
return getattr(cls.pt_datasets, data_key)
@classmethod
def get_pt_sample(cls, data_key, idx):
return getattr(cls.pt_datasets, data_key)[idx]
@classmethod
def get_neighbor_tokens(cls, sample_id, chunk_id, data_key="train"):
try:
sample = cls.get_pt_sample(data_key, sample_id)
sample_token_ids = sample["text"]
chunk_length = cls.args.retro_gpt_chunk_length
chunk_start_idx = chunk_id * chunk_length
chunk_end_idx = min(sample_token_ids.shape[0],
chunk_start_idx + chunk_length)
chunk_token_ids = sample_token_ids[chunk_start_idx:chunk_end_idx]
neighbor_token_ids = sample["neighbor_tokens"][chunk_id]
return {
"chunk_tokens" : chunk_token_ids,
"neighbor_tokens" : neighbor_token_ids,
}
except:
return None
@classmethod
def print_neighbor_texts(cls, sample_id, chunk_id, data_key="train"):
tokens = cls.get_neighbor_tokens(sample_id, chunk_id, data_key)
print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
try:
print("PRETRAINING CHUNK:")
print(" - %s" % shorten_str(cls.gpt_to_text(tokens["chunk_tokens"]), 150))
print("NEIGHBOR_CHUNKS:")
for token_ids in tokens["neighbor_tokens"]:
print(" - %s" % shorten_str(cls.gpt_to_text(token_ids), 150))
except:
print("<no neighbors for sample %d>" % sample_id)
##############################################
# usage.
##############################################
@classmethod
def print_usage(cls):
'''Print usage.'''
print()
print("+++++++++++++++++++++++++++++++++++++++++++++++++++")
print("examples ... [ *note*: 'db' = chunk db; 'pt' = pretraining corpus. ]")
print("+++++++++++++++++++++++++++++++++++++++++++++++++++")
print()
print("~~~~ indexed datasets ~~~~")
print("retro.get_db_num_indexed_datasets() : %s" %
cls.get_db_num_indexed_datasets())
print("retro.get_db_indexed_dataset_infos() :")
for i, (ratio,prefix) in enumerate(cls.get_db_indexed_dataset_infos()):
print(" %s(%f, %s)%s" % (
"[" if i == 0 else " ",
ratio,
prefix,
"]" if i == len(cls.db_indexed_dataset_infos) - 1 else ",",
))
print()
print("~~~~ counts ~~~~")
print("retro.get_db_num_chunks : %d." % cls.get_db_num_chunks())
print()
for sq_key in ("sample", "chunk"):
for data_key in ("train", "valid"): # test?
print("retro.get_pt_num_%ss('%s') : %d." % (
sq_key, data_key,
getattr(cls, f"get_pt_num_{sq_key}s")(data_key)))
print()
print("~~~~ tokens, text ~~~~")
print("retro.get_db_chunk_gpt(chunk_id) : %s" %
shorten_str(str(retro.get_db_chunk_gpt(0)), 50))
print("retro.get_db_chunk_bert(chunk_id) : %s" %
shorten_str(str(retro.get_db_chunk_bert(0)), 50))
print("retro.get_db_chunk_text(chunk_id) : %s" %
shorten_str(retro.get_db_chunk_text(0).strip(), 50))
print("retro.get_db_chunk_and_continuation_text(chunk_id) :")
for i, t in enumerate(retro.get_db_chunk_and_continuation_text(0)):
print(" %s'%s'%s" % (
"[" if i == 0 else " ",
shorten_str(t.strip().replace("\n", " "), 50),
"]" if i == 1 else ",",
))
sample = cls.get_pt_sample("train", 0)
sample_chunk_id = sample["neighbor_tokens"].shape[0] // 2
sample_neighbor_id = 0
print()
print("retro.get_pt_sample('train', sample_id) :")
print(" {")
for k, v in sample.items():
print(" '%s' : %s" % (k, shorten_str(str(v), 50)))
print(" }")
print()
print("(e.g., sample = retro.get_pt_sample(...))")
print()
print(" sample['text'].shape : %s" % str(sample["text"].shape))
print(" sample['neighbor_tokens'].shape : %s" % str(sample["neighbor_tokens"].shape))
print(" sample['text'] : %s" % shorten_str(str(sample["text"]), 50))
print(" sample['neighbor_tokens'][17][1] : %s" % shorten_str(str(sample["neighbor_tokens"][sample_chunk_id][sample_neighbor_id]), 50))
print(" retro.gpt_to_text(sample['text']) : %s" % shorten_str(cls.gpt_to_text(sample["text"]), 50))
print(" retro.gpt_to_text(sample['neighbor_tokens']) : %s" % shorten_str(cls.gpt_to_text(sample["neighbor_tokens"][sample_chunk_id][sample_neighbor_id]), 50))
print("+++++++++++++++++++++++++++++++++++++++++++++++++++")
...@@ -24,11 +24,13 @@ from tools.retro.external_libs import h5py ...@@ -24,11 +24,13 @@ from tools.retro.external_libs import h5py
from tools.retro.utils import get_gpt_tokenizer, get_bert_tokenizer from tools.retro.utils import get_gpt_tokenizer, get_bert_tokenizer
from .utils import ( from .utils import (
get_individual_db, get_indexed_dataset_infos,
get_indexed_dataset_infos_path,
get_individual_db_dir, get_individual_db_dir,
get_individual_chunk_db,
get_individual_doc_offsets,
get_merged_dataset, get_merged_dataset,
get_merged_db_path_map, get_merged_db_path_map,
get_train_doc_chunk_map_dir,
save_indexed_dataset_infos, save_indexed_dataset_infos,
) )
...@@ -52,7 +54,7 @@ def init_indexed_dataset_infos(): ...@@ -52,7 +54,7 @@ def init_indexed_dataset_infos():
prefix = args.data_path[i + 1] prefix = args.data_path[i + 1]
path = prefix + ".bin" path = prefix + ".bin"
name = os.path.basename(prefix) name = os.path.basename(prefix)
assert os.path.exists(path) assert os.path.exists(path), "couldn't find '%s'." % path
infos.append({ infos.append({
"ratio" : ratio, "ratio" : ratio,
"prefix" : prefix, "prefix" : prefix,
...@@ -114,6 +116,7 @@ def build_partial_db( ...@@ -114,6 +116,7 @@ def build_partial_db(
# Iterate documents & parse chunks. # Iterate documents & parse chunks.
chunk_db_valid = [] chunk_db_valid = []
chunk_db_invalid = [] chunk_db_invalid = []
doc_size_map = {}
for doc_id in pbar: for doc_id in pbar:
# Progress description. # Progress description.
...@@ -130,7 +133,7 @@ def build_partial_db( ...@@ -130,7 +133,7 @@ def build_partial_db(
# Remove EOD token. # Remove EOD token.
doc = indexed_dataset.get(doc_id) doc = indexed_dataset.get(doc_id)
if doc[-1].item() == tokenizers.gpt.eod_id: if doc[-1].item() == tokenizers.gpt.eod:
doc = doc[:-1] doc = doc[:-1]
doc_len = len(doc) doc_len = len(doc)
...@@ -140,6 +143,7 @@ def build_partial_db( ...@@ -140,6 +143,7 @@ def build_partial_db(
for s in chunk_start_idxs] for s in chunk_start_idxs]
# Re-tokenize each chunk to Bert/Wordpiece (empty bert -> 'invalid'). # Re-tokenize each chunk to Bert/Wordpiece (empty bert -> 'invalid').
doc_size_map[doc_id] = 0
for i, chunk_start_idx in enumerate(chunk_start_idxs): for i, chunk_start_idx in enumerate(chunk_start_idxs):
# Re-tokenize. # Re-tokenize.
...@@ -149,13 +153,15 @@ def build_partial_db( ...@@ -149,13 +153,15 @@ def build_partial_db(
offset=chunk_start_idx, offset=chunk_start_idx,
length=chunk_end_idx - chunk_start_idx, length=chunk_end_idx - chunk_start_idx,
) )
text = tokenizers.gpt.detokenize(gpt_token_ids) text = tokenizers.gpt.detokenize(gpt_token_ids.tolist())
bert_token_ids = tokenizers.bert.tokenize(text) bert_token_ids = tokenizers.bert.tokenize(text)
# 'Valid' for non-empty Bert chunks; 'invalid' otherwise. # 'Valid' for non-empty Bert chunks; 'invalid' otherwise.
_chunk_db = chunk_db_invalid \ if len(bert_token_ids) == 0:
if len(bert_token_ids) == 0 else \ _chunk_db = chunk_db_invalid
chunk_db_valid else:
_chunk_db = chunk_db_valid
doc_size_map[doc_id] += 1
_chunk_db.append(( _chunk_db.append((
doc_id, doc_id,
chunk_start_idx, chunk_start_idx,
...@@ -163,7 +169,7 @@ def build_partial_db( ...@@ -163,7 +169,7 @@ def build_partial_db(
len(bert_token_ids), len(bert_token_ids),
)) ))
return proc_id, chunk_db_valid, chunk_db_invalid return proc_id, chunk_db_valid, chunk_db_invalid, doc_size_map
def build_individual_db(dataset_idx, n_datasets, dataset_info, tokenizers): def build_individual_db(dataset_idx, n_datasets, dataset_info, tokenizers):
...@@ -181,9 +187,10 @@ def build_individual_db(dataset_idx, n_datasets, dataset_info, tokenizers): ...@@ -181,9 +187,10 @@ def build_individual_db(dataset_idx, n_datasets, dataset_info, tokenizers):
# Missing db blocks. # Missing db blocks.
n_missing_world, missing_db_blocks = get_missing_blocks_by_rank( n_missing_world, missing_db_blocks = get_missing_blocks_by_rank(
db_dir, db_dir,
len(indexed_dataset.doc_idx) - 1, len(indexed_dataset),
args.retro_doc_block_size, args.retro_doc_block_size,
validate=lambda f : f["chunks_valid"].shape[1] == 4) validate=lambda f : f["chunks_valid"].shape == (0,) \
or f["chunks_valid"].shape[1] == 4)
# Prevent missing-path-write race condition. # Prevent missing-path-write race condition.
torch.distributed.barrier() torch.distributed.barrier()
...@@ -209,6 +216,8 @@ def build_individual_db(dataset_idx, n_datasets, dataset_info, tokenizers): ...@@ -209,6 +216,8 @@ def build_individual_db(dataset_idx, n_datasets, dataset_info, tokenizers):
if block is not None: if block is not None:
db_path = block["path"]
# Build partial dbs. # Build partial dbs.
print_rank_0(' > build partial dbs.') print_rank_0(' > build partial dbs.')
futures = [] futures = []
...@@ -240,15 +249,27 @@ def build_individual_db(dataset_idx, n_datasets, dataset_info, tokenizers): ...@@ -240,15 +249,27 @@ def build_individual_db(dataset_idx, n_datasets, dataset_info, tokenizers):
# Convert to numpy. # Convert to numpy.
print_rank_0(' > converting chunk db to numpy.') print_rank_0(' > converting chunk db to numpy.')
chunk_db_valid = np.array(chunk_db_valid) chunk_db_valid = np.array(chunk_db_valid, dtype="uint32")
chunk_db_invalid = np.array(chunk_db_invalid) chunk_db_invalid = np.array(chunk_db_invalid, dtype="uint32")
# Document offsets.
doc_sizes = [(d, s)
for partial_chunk_db in partial_chunk_dbs
for d, s in partial_chunk_db[3].items()]
doc_sizes.sort(key = lambda item : item[0])
doc_offsets = np.cumsum([item[1] for item in doc_sizes]) \
.astype("uint64")
doc_offsets = np.stack((
np.array([item[0] for item in doc_sizes], dtype="uint64"),
doc_offsets), axis=1)
# Save DB. # Save DB.
print_rank_0(" > saving individual db.") print_rank_0(" > saving individual db.")
f = h5py.File(block["path"], "w") with h5py.File(db_path, "w") as f:
dset = f.create_dataset("chunks_valid", data=chunk_db_valid) dset = f.create_dataset("chunks_valid", data=chunk_db_valid)
dset = f.create_dataset("chunks_invalid", data=chunk_db_invalid) dset = f.create_dataset("chunks_invalid",
f.close() data=chunk_db_invalid)
dset = f.create_dataset("doc_offsets", data=doc_offsets)
# Wait for all ranks to finish block. # Wait for all ranks to finish block.
print_rank_0(" > waiting for all ranks to finish block.") print_rank_0(" > waiting for all ranks to finish block.")
...@@ -292,14 +313,16 @@ def update_chunk_counts(indexed_dataset_infos): ...@@ -292,14 +313,16 @@ def update_chunk_counts(indexed_dataset_infos):
if torch.distributed.get_rank() != 0: if torch.distributed.get_rank() != 0:
return return
# Data ratio sum (for setting index training chunks).
data_ratio_sum = sum([ d["ratio"] for d in indexed_dataset_infos ])
# Training split size (split at document level). # Training split size (split at document level).
train_fraction = float(args.split.split(",")[0]) / 100 train_fraction = float(args.split.split(",")[0]) / 100
assert train_fraction > 0 and train_fraction <= 1 assert train_fraction > 0 and train_fraction <= 1
# Set n_chunks (including n_chunks_sampled for unambiguity). # Set n_chunks (including n_chunks_sampled for unambiguity).
print_rank_0(" > compute n_chunks.") print_rank_0(" > compute n_chunks.")
for ds_index, ds_info in \ for ds_index, ds_info in enumerate(indexed_dataset_infos):
enumerate(tqdm(indexed_dataset_infos, "count_chunks")):
db_dir = ds_info["db_dir"] db_dir = ds_info["db_dir"]
db_paths = sorted(glob.glob(db_dir + "/*.hdf5")) db_paths = sorted(glob.glob(db_dir + "/*.hdf5"))
...@@ -310,16 +333,17 @@ def update_chunk_counts(indexed_dataset_infos): ...@@ -310,16 +333,17 @@ def update_chunk_counts(indexed_dataset_infos):
ds_info["n_chunks"] = 0 # previously, 'n_chunks_valid' ds_info["n_chunks"] = 0 # previously, 'n_chunks_valid'
ds_info["n_chunks_train"] = 0 ds_info["n_chunks_train"] = 0
ds_info["n_chunks_invalid"] = 0 ds_info["n_chunks_invalid"] = 0
for db_path in db_paths: for db_path in tqdm(db_paths, "%d/%d, %s" % (
with h5py.File(db_path, "r") as f: ds_index, len(indexed_dataset_infos), ds_info["name"])):
with h5py.File(db_path, "r") as f:
ds_info["n_chunks"] += len(f["chunks_valid"]) ds_info["n_chunks"] += len(f["chunks_valid"])
ds_info["n_chunks_invalid"] += len(f["chunks_invalid"]) ds_info["n_chunks_invalid"] += len(f["chunks_invalid"])
ds_info["n_chunks_train"] += \ ds_info["n_chunks_train"] += \
(np.copy(f["chunks_valid"][:, 0]) < ds_info["n_docs_train"]) \ (np.copy(f["chunks_valid"][:, 0]) < ds_info["n_docs_train"]) \
.sum().item() .sum().item()
ds_info["n_chunks_sampled"] = \ ds_info["n_chunks_sampled"] = int(args.retro_index_ntrain *
int(round(args.retro_nchunks_sampled * ds_info["ratio"])) ds_info["ratio"] / data_ratio_sum)
# Verify counts. # Verify counts.
assert ds_info["n_chunks_train"] <= ds_info["n_chunks"], \ assert ds_info["n_chunks_train"] <= ds_info["n_chunks"], \
...@@ -339,15 +363,14 @@ def merge_dbs(indexed_dataset_infos, db_type): ...@@ -339,15 +363,14 @@ def merge_dbs(indexed_dataset_infos, db_type):
print(" > build %s chunk db." % db_type) print(" > build %s chunk db." % db_type)
# Count chunks. # Count chunks.
if db_type == "full": if db_type == "sampled":
raise Exception("deprecated; use 'train' or 'sampled'.")
n_chunks_key = "n_chunks"
elif db_type == "sampled":
n_chunks_key = "n_chunks_sampled" n_chunks_key = "n_chunks_sampled"
n_docs_key = None
elif db_type == "train": elif db_type == "train":
n_chunks_key = "n_chunks_train" n_chunks_key = "n_chunks_train"
n_docs_key = "n_docs_train"
elif db_type == "valid": elif db_type == "valid":
pass n_docs_key = None
else: else:
raise Exception("handle db_type '%s'." % db_type) raise Exception("handle db_type '%s'." % db_type)
...@@ -356,6 +379,8 @@ def merge_dbs(indexed_dataset_infos, db_type): ...@@ -356,6 +379,8 @@ def merge_dbs(indexed_dataset_infos, db_type):
for m in indexed_dataset_infos) for m in indexed_dataset_infos)
else: else:
n_chunks = sum(m[n_chunks_key] for m in indexed_dataset_infos) n_chunks = sum(m[n_chunks_key] for m in indexed_dataset_infos)
n_docs = None if n_docs_key is None else \
sum(m[n_docs_key] for m in indexed_dataset_infos)
# DB path. # DB path.
db_path = get_merged_db_path_map()[db_type] db_path = get_merged_db_path_map()[db_type]
...@@ -375,10 +400,10 @@ def merge_dbs(indexed_dataset_infos, db_type): ...@@ -375,10 +400,10 @@ def merge_dbs(indexed_dataset_infos, db_type):
except Exception as e: except Exception as e:
if isinstance(e, OSError): if isinstance(e, OSError):
os.remove(full_db_path) os.remove(db_path)
elif isinstance(e, KeyError): elif isinstance(e, KeyError):
f.close() f.close()
os.remove(full_db_path) os.remove(db_path)
else: else:
raise e raise e
...@@ -389,121 +414,60 @@ def merge_dbs(indexed_dataset_infos, db_type): ...@@ -389,121 +414,60 @@ def merge_dbs(indexed_dataset_infos, db_type):
f = h5py.File(db_path, "w") f = h5py.File(db_path, "w")
# Initialize output arrays. # Initialize output arrays.
merged_db = f.create_dataset("chunks", (n_chunks, 5), dtype="i8") merged_chunk_db = \
f.create_dataset("chunks", (n_chunks, 5), dtype="uint32")
merged_doc_offsets = None if n_docs_key is None else \
f.create_dataset("doc_offsets", (n_docs, 3), dtype="uint64")
n_written = f.create_dataset("n_written", (1,), dtype="uint64") n_written = f.create_dataset("n_written", (1,), dtype="uint64")
n_written[0] = 0 n_written[0] = 0
# Iterate indexed datasets & collect chunks. # Iterate indexed datasets & collect chunks.
start_index = 0 chunk_start_index = 0
doc_start_index = 0
doc_start_offset = 0
for ds_idx, ds_info in enumerate(indexed_dataset_infos): for ds_idx, ds_info in enumerate(indexed_dataset_infos):
print(" > merging dbs; '%s', dataset %d / %d ... '%s'." % print(" > merging dbs; '%s', dataset %d / %d ... '%s'." %
(db_type, ds_idx, len(indexed_dataset_infos), ds_info["name"])) (db_type, ds_idx, len(indexed_dataset_infos), ds_info["name"]))
individual_db = get_individual_db(ds_idx, ds_info) individual_chunk_db = get_individual_chunk_db(ds_idx, ds_info)
individual_doc_offsets = None if n_docs_key is None else \
get_individual_doc_offsets(ds_idx, ds_info)
if db_type == "valid": if db_type == "valid":
individual_db = individual_db[ds_info["n_chunks_train"]:] individual_chunk_db = \
individual_chunk_db[ds_info["n_chunks_train"]:]
if n_docs_key is None:
individual_doc_offsets = None
else:
train_doc_offset = \
individual_doc_offsets[ds_info["n_docs_train"] - 1, 2]
individual_doc_offsets = \
np.copy(individual_doc_offsets[ds_info["n_docs_train"]:])
individual_doc_offsets[:, 2] -= train_doc_offset
print("~~~")
print(individual_doc_offsets)
print(train_doc_offset)
raise Exception("test me.")
else: else:
individual_db = individual_db[:ds_info[n_chunks_key]] individual_chunk_db = \
individual_chunk_db[:ds_info[n_chunks_key]]
merged_db[start_index:start_index+len(individual_db)] = individual_db individual_doc_offsets = None if n_docs_key is None else \
start_index += len(individual_db) np.copy(individual_doc_offsets[:ds_info[n_docs_key]])
n_written[0] = start_index
merged_chunk_db[chunk_start_index:chunk_start_index+len(individual_chunk_db)] = individual_chunk_db
chunk_start_index += len(individual_chunk_db)
n_written[0] = chunk_start_index
if n_docs_key is not None:
individual_doc_offsets[:, 2] += doc_start_offset
doc_end_index = doc_start_index + individual_doc_offsets.shape[0]
merged_doc_offsets[doc_start_index:doc_end_index] = \
individual_doc_offsets
doc_start_index = doc_end_index
doc_start_offset = individual_doc_offsets[-1, 2].item()
f.close() f.close()
def get_partial_banned_chunk_map(proc_id, db_path, chunk_range_info):
'''Build partial mapping of {(dataset_id,doc_id):[chunk_ids]}.
In this method, only chunks within the range (start_chunk_id, end_chunk_id]
are processed.'''
start_chunk_id = chunk_range_info["start"]
end_chunk_id = chunk_range_info["end"]
output_path = chunk_range_info["path"]
# Skip, if output file exists.
if os.path.exists(output_path):
return
# Chunk subset.
with h5py.File(db_path) as f:
sub_chunk_db = np.copy(f["chunks"][start_chunk_id:end_chunk_id, :2])
# Map docs to chunks.
banned_chunk_map = defaultdict(list)
for rel_chunk_id, (dataset_id, doc_id) in enumerate(tqdm(
sub_chunk_db,
"map banned docs, proc %d" % proc_id,
total=sub_chunk_db.shape[0],
)):
chunk_id = start_chunk_id + rel_chunk_id
banned_chunk_map["%d,%d" % (dataset_id.item(), doc_id.item())] \
.append(chunk_id)
# Save output.
with open(output_path, "w") as f:
json.dump(banned_chunk_map, f)
def build_doc_chunk_map(indexed_dataset_infos, db_type):
'''Build mapping of {(dataset_id,doc_id):[chunk_ids]}.'''
if torch.distributed.get_rank() != 0:
return
print(" > build %s doc-chunk map." % db_type)
n_procs = 128
# Get dataset.
db_dataset = get_merged_dataset(db_type, indexed_dataset_infos)
# Sub-ranges for parallel processing.
n_chunks = db_dataset.chunks.shape[0]
n_chunks_per_proc = max(1, int(np.ceil(n_chunks / n_procs)))
chunk_id_starts = list(range(0, n_chunks, n_chunks_per_proc))
chunk_id_ranges = [(s, min(n_chunks, s + n_chunks_per_proc))
for s in chunk_id_starts]
# Wrap range info with output path.
n_digits = int(np.ceil(np.log(n_chunks) / np.log(10)) + 1)
output_dirname = get_train_doc_chunk_map_dir()
chunk_range_infos = [{
"start" : start_id,
"end" : end_id,
"path" : os.path.join(output_dirname, "%s-%s.json" % (
str(start_id).zfill(n_digits),
str(end_id).zfill(n_digits),
)),
} for start_id, end_id in chunk_id_ranges ]
# Build doc-chunk map.
print_rank_0("build doc-chunk-map.")
with ProcessPoolExecutor(max_workers=n_procs) as executor:
# Build partial chunk maps.
futures = []
for proc_id, chunk_range_info in enumerate(chunk_range_infos):
if os.path.exists(chunk_range_info["path"]):
continue
# Submit job.
futures.append(executor.submit(
get_partial_banned_chunk_map,
proc_id,
db_dataset.db_path,
chunk_range_info,
))
# Wait for processes to finish.
banned_chunk_paths = []
for finished_idx, future in enumerate(as_completed(futures)):
print("finished %d / %d." % (finished_idx, n_procs))
future.result()
def build_db(): def build_db():
'''Extract token chunks from each indexed dataset. '''Extract token chunks from each indexed dataset.
...@@ -521,14 +485,13 @@ def build_db(): ...@@ -521,14 +485,13 @@ def build_db():
if torch.distributed.get_rank() != 0: if torch.distributed.get_rank() != 0:
return return
# Update n_chunks. # Update n_chunks & save indexed dataset infos.
update_chunk_counts(indexed_dataset_infos) if not os.path.exists(get_indexed_dataset_infos_path()):
update_chunk_counts(indexed_dataset_infos)
save_indexed_dataset_infos(indexed_dataset_infos)
indexed_dataset_infos = get_indexed_dataset_infos()
# Merge dbs. # Merge dbs.
merge_dbs(indexed_dataset_infos, "sampled") merge_dbs(indexed_dataset_infos, "sampled")
merge_dbs(indexed_dataset_infos, "train") merge_dbs(indexed_dataset_infos, "train")
merge_dbs(indexed_dataset_infos, "valid") merge_dbs(indexed_dataset_infos, "valid")
build_doc_chunk_map(indexed_dataset_infos, "train")
# Save (fully annotated) indexed dataset infos.
save_indexed_dataset_infos(indexed_dataset_infos)
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import json import json
import numpy as np import numpy as np
import torch import torch
from tqdm import tqdm
from megatron import get_args, print_rank_0 from megatron import get_args, print_rank_0
from tools.retro.external_libs import h5py from tools.retro.external_libs import h5py
...@@ -27,9 +28,10 @@ class DBDataset(torch.utils.data.Dataset): ...@@ -27,9 +28,10 @@ class DBDataset(torch.utils.data.Dataset):
self.db_path = db_path self.db_path = db_path
self.indexed_datasets = indexed_datasets self.indexed_datasets = indexed_datasets
self.chunks = chunks self.chunks = chunks
self.doc_chunk_map = None
self.max_chunk_length = max_chunk_length self.max_chunk_length = max_chunk_length
self.eod_token_id = get_gpt_tokenizer().eod_id self.eod_token_id = get_gpt_tokenizer().eod
def __len__(self): def __len__(self):
return self.chunks.shape[0] return self.chunks.shape[0]
...@@ -58,3 +60,15 @@ class DBDataset(torch.utils.data.Dataset): ...@@ -58,3 +60,15 @@ class DBDataset(torch.utils.data.Dataset):
"doc_id" : doc_id, "doc_id" : doc_id,
"text" : np.array(token_ids, dtype=np.int64), "text" : np.array(token_ids, dtype=np.int64),
} }
def load_doc_tuples(self):
'''Load the dataset & document ids.
Load the dataset id & document id of each chunk in the database, to
be used for causality filtering during querying.
'''
self.doc_tuples = np.zeros(shape=(len(self), 2), dtype="uint32")
block_size = int(1e6)
for start_idx in tqdm(range(0, len(self), block_size)):
end_idx = min(len(self), start_idx + block_size)
self.doc_tuples[start_idx:end_idx]=self.chunks[start_idx:end_idx,:2]
...@@ -57,14 +57,14 @@ def get_indexed_dataset_infos(): ...@@ -57,14 +57,14 @@ def get_indexed_dataset_infos():
def get_individual_db_dir(name): def get_individual_db_dir(name):
'''Individual DB's directory.''' '''Individual DB's directory.'''
return os.path.join(get_base_db_workdir(), "individual", name, "db") return os.path.join(get_base_db_workdir(), "individual", name)
def get_individual_db(ds_id, ds_info): def get_individual_chunk_db(ds_id, ds_info):
'''Load individual dataset's chunk DB.''' '''Load individual dataset's chunk DB.'''
db_paths = sorted(glob.glob(ds_info["db_dir"] + "/*hdf5")) db_paths = sorted(glob.glob(ds_info["db_dir"] + "/*hdf5"))
# *Note*: convert to dataset, rather than copying to memory. # *Note*: convert to dataset, rather than copying to memory.
db = np.zeros((ds_info["n_chunks"], 5), dtype="i8") db = np.zeros((ds_info["n_chunks"], 5), dtype="uint32")
db[:, 0] = ds_id db[:, 0] = ds_id
start_idx = 0 start_idx = 0
for db_path in db_paths: for db_path in db_paths:
...@@ -79,6 +79,27 @@ def get_individual_db(ds_id, ds_info): ...@@ -79,6 +79,27 @@ def get_individual_db(ds_id, ds_info):
return db return db
def get_individual_doc_offsets(ds_id, ds_info):
'''Load individual dataset's chunk DB.'''
paths = sorted(glob.glob(ds_info["db_dir"] + "/*hdf5"))
# *Note*: convert to dataset, rather than copying to memory.
doc_offsets = np.zeros((ds_info["n_docs"], 3), dtype="uint64")
doc_offsets[:, 0] = ds_id
start_idx = 0
start_offset = 0
for path in paths:
with h5py.File(path) as f:
current_doc_offsets = np.copy(f["doc_offsets"])
current_doc_offsets[:, 1] += start_offset
current_ndocs = current_doc_offsets.shape[0]
doc_offsets[start_idx:(start_idx+current_ndocs), 1:] = \
current_doc_offsets
start_idx += current_ndocs
start_offset = current_doc_offsets[-1, 1].item()
return doc_offsets
def get_merged_db_path_map(): def get_merged_db_path_map():
'''Paths to merged datasets.''' '''Paths to merged datasets.'''
base_dir = get_base_db_workdir() base_dir = get_base_db_workdir()
...@@ -120,28 +141,3 @@ def get_merged_train_dataset(indexed_dataset_infos=None): ...@@ -120,28 +141,3 @@ def get_merged_train_dataset(indexed_dataset_infos=None):
def get_merged_valid_dataset(indexed_dataset_infos=None): def get_merged_valid_dataset(indexed_dataset_infos=None):
return get_merged_dataset("valid", indexed_dataset_infos) return get_merged_dataset("valid", indexed_dataset_infos)
def get_train_doc_chunk_map_dir():
dirname = os.path.join(get_base_db_workdir(), "merged", "train_doc_chunk_map")
os.makedirs(dirname, exist_ok=True)
return dirname
def get_train_doc_chunk_map():
paths = sorted(glob.glob(get_train_doc_chunk_map_dir() + "/*.json"))
doc_map = defaultdict(set)
for path in tqdm(paths, "load train doc maps"):
# Read file.
with open(path) as f:
crnt_doc_map = json.load(f)
# Add to doc map.
for key, chunk_ids in crnt_doc_map.items():
key = tuple(int(i) for i in key.split(","))
doc_map[key].update(chunk_ids)
return doc_map
#!/bin/bash
# Small English Wikipedia dataset (~2M chunks).
get_wiki_tiny_config() {
RETRO_INDEX_STR="IVF4096_HNSW4,Flat"
RETRO_GPT_TRAIN_SAMPLES=31250
LR_DECAY_SAMPLES=2
LR_WARMUP_SAMPLES=1
RETRO_GPT_EVAL_INTERVAL=2000
RETRO_GPT_EVAL_ITERS=100
RETRO_EF_SEARCH=4
RETRO_NPROBE=64
DATALOADER_TYPE=cyclic
}
# English Wikipedia dataset (~67M chunks).
get_wiki_config() {
RETRO_INDEX_STR="IVF262144_HNSW32,Flat"
RETRO_GPT_TRAIN_SAMPLES=2037248
LR_DECAY_SAMPLES=2
LR_WARMUP_SAMPLES=1
RETRO_GPT_EVAL_INTERVAL=2000
RETRO_GPT_EVAL_ITERS=100
RETRO_EF_SEARCH=16
RETRO_NPROBE=4096
DATALOADER_TYPE=cyclic
}
# Full corpus (~5B chunks).
get_corpus_config() {
RETRO_INDEX_STR="OPQ32_256,IVF4194304_HNSW32,PQ32"
RETRO_GPT_TRAIN_SAMPLES=192000000
LR_DECAY_SAMPLES=166400000
LR_WARMUP_SAMPLES=162761
RETRO_GPT_EVAL_INTERVAL=2000
RETRO_GPT_EVAL_ITERS=50
RETRO_EF_SEARCH=32
RETRO_NPROBE=4096
DATALOADER_TYPE=single
}
#!/bin/bash
# Build preprocessing command for Retro.
set -u
DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
################ Required environment variables. ################
# Required environment variables:
# - REPO_DIR : Root directory of Megatron codebase.
# - RETRO_WORKDIR : Root directory of this Retro project's processed data. (For
# example, this project directory might be for a blended dataset, while
# another project directory might be for just a Wikipedia dataset, and
# another for just Book Corpus data, etc.) This project directory will
# contain a complete set of processed data, including the retrieval
# database, search index, and pretraining neighbors.
# - RETRO_TASKS : One of 'build', 'db-build', 'index-build', or
# 'pretraining-query-neighbors'. See 'Retro tasks' below for task
# descriptions.
# - DATA_BLEND_SCRIPT : Path to blended dataset definition file.
# - GPT_VOCAB_FILE : GPT vocab file.
# - GPT_MERGE_FILE : GPT merge file.
# - GPT_TOKENIZER : GPT tokenizer type (e.g., GPT2BPETokenizer)
# - BERT_LOAD_PATH : Bert checkpoint directory.
# - BERT_VOCAB_FILE : Bert vocab file.
# - BERT_TOKENIZER : Bert tokenizer type (e.g., BertWordPieceLowerCase,
# BertWordPieceCase).
# - BERT_EMBEDDER_TYPE : One of 'megatron' or 'huggingface'.
# - EXTRA_ARGS : Extra arguments (else, leave empty).
################ Data blend. ################
. ${DATA_BLEND_SCRIPT}
DATA_PATH=${DATA_BLEND}
################ Retro setup. ################
RETRO_GPT_SEQ_LENGTH=2048
RETRO_GPT_CHUNK_LENGTH=64
RETRO_GPT_MICRO_BATCH_SIZE=1 # *8
RETRO_GPT_GLOBAL_BATCH_SIZE=256
RETRO_NCHUNKS_SAMPLED=300000000
################ Retro tasks. ################
# The '--retro-tasks' argument is a comma-separated list of tasks to run, in
# sequential order. For a quick start, simply set this to 'build' to run the
# entire preprocessing pipeline. For finer control, you may specify the list of
# tasks to run. This is desirable for tuning computational resources. For
# example, training the search index is relatively fast and utilizes GPUs,
# while querying the search index is relatively slow, CPU-only, and memory
# intensive (i.e., multiple populated search indexes are loaded simultaneously).
# *Note* : Once the task(s) below have been completed -- by running either
# 1) 'build', or 2) the sequential combination of 'db-build', 'index-build',
# and 'pretraining-query-neighbors' -- we are ready to pretrain Retro by
# calling pretrain_retro.py.
# ---- Option #1 : Run entire pipeline. ----
# RETRO_TASKS="build" # (*note*: default tasks)
# ---- Option #2 : Run specific stages. ----
# *Note*: Run the following stages in the given order. Optionally, tune your
# cluster setup for each stage, as described above.
# RETRO_TASKS="db-build" # ....................... run 1st
# RETRO_TASKS="index-build" # .................... run 2nd
# RETRO_TASKS="pretraining-query-neighbors" # .... run 3rd
################ Megatron args. ################
MEGATRON_ARGS=" \
--seed 1234 \
--distributed-timeout-minutes 600 \
--tokenizer-type ${BERT_TOKENIZER} \
--tensor-model-parallel-size 1 \
--pipeline-model-parallel-size 1 \
--num-layers 24 \
--hidden-size 1024 \
--num-attention-heads 16 \
--micro-batch-size ${RETRO_GPT_MICRO_BATCH_SIZE} \
--global-batch-size ${RETRO_GPT_GLOBAL_BATCH_SIZE} \
--seq-length 512 \
--max-position-embeddings 512 \
--train-samples ${RETRO_GPT_TRAIN_SAMPLES} \
--load ${BERT_LOAD_PATH} \
--exit-on-missing-checkpoint \
--no-load-optim \
--data-path ${DATA_PATH} \
--vocab-file ${BERT_VOCAB_FILE} \
--data-impl mmap \
--split 98,2,0 \
--distributed-backend nccl \
--lr 0.0001 \
--lr-decay-style linear \
--min-lr 1.0e-5 \
--lr-decay-samples ${LR_DECAY_SAMPLES} \
--lr-warmup-samples ${LR_WARMUP_SAMPLES} \
--weight-decay 1e-2 \
--clip-grad 1.0 \
--eval-interval ${RETRO_GPT_EVAL_INTERVAL} \
--eval-iters ${RETRO_GPT_EVAL_ITERS} \
--fp16 \
--DDP-impl local \
--dataloader-type ${DATALOADER_TYPE} \
--no-data-sharding \
--no-gradient-accumulation-fusion \
--no-async-tensor-model-parallel-allreduce \
"
################ Retro args. ################
RETRO_ARGS=" \
--bert-embedder-type ${BERT_EMBEDDER_TYPE} \
--output-bert-embeddings \
\
--retro-gpt-vocab-file ${GPT_VOCAB_FILE} \
--retro-gpt-merge-file ${GPT_MERGE_FILE} \
--retro-gpt-tokenizer-type ${GPT_TOKENIZER} \
--retro-gpt-seq-length ${RETRO_GPT_SEQ_LENGTH} \
--retro-gpt-chunk-length ${RETRO_GPT_CHUNK_LENGTH} \
--retro-bert-vocab-file ${BERT_VOCAB_FILE} \
--retro-bert-tokenizer-type ${BERT_TOKENIZER} \
\
--retro-tasks ${RETRO_TASKS} \
--retro-index-str ${RETRO_INDEX_STR} \
--retro-ef-search ${RETRO_EF_SEARCH} \
--retro-nprobe ${RETRO_NPROBE} \
\
--retro-workdir ${RETRO_WORKDIR} \
--retro-nchunks-sampled ${RETRO_NCHUNKS_SAMPLED} \
\
--retro-return-doc-ids \
"
################ Command. ################
RETRO_PREPROCESS_CMD=" \
./tools/retro/main.py \
${MEGATRON_ARGS} \
${RETRO_ARGS} \
${EXTRA_ARGS} \
"
#!/bin/bash #!/bin/bash
set -u set -u
unset NCCL_DEBUG unset NCCL_DEBUG
NPROCS=8 # NPROCS must be <= number of GPUs. ######## Megatron, Retro dirs. ########
set_current_dir() { REPO_DIR="<path/to/megatron/repo>"
DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) RETRO_WORKDIR="<path/to/retro/data/directory>"
}
################ Dataset configs. ################ ######## Task (e.g., db, index, query). ########
# This script contains methods to customize arguments to specific dataset
# types. Customize this script as needed for your datasets.
set_current_dir
. $DIR/get_dataset_configs.sh
################ Environment variables. ################ RETRO_TASKS="db-build"
# *Note*: See 'Required environment variables' in 'get_preprocess_cmd.sh' for # RETRO_TASKS="index-train"
# a description of the required environment variables. These variables can be # RETRO_TASKS="index-add"
# set however a user would like. In our setup, we use another bash script # RETRO_TASKS="query-pretraining-neighbors"
# (location defined by $RETRO_ENV_VARS) that sets all the environment variables
# at once.
. $RETRO_ENV_VARS
######## Environment vars. ######## ######## Data. ########
set_current_dir
. ${DIR}/get_preprocess_cmd.sh
echo "~~~~~~~~~~~~~~~~~~~~~~~~~~" DATA_BLEND="<see --data-path in arguments.py>"
echo "DIR = '$DIR'."
echo "RETRO_PREPROCESS_CMD = '$RETRO_PREPROCESS_CMD'." ######## Index. ########
echo "~~~~~~~~~~~~~~~~~~~~~~~~~~"
RETRO_INDEX_STR="OPQ32_64,IVF65536_HNSW8,PQ32"
RETRO_INDEX_NTRAIN=1000000
RETRO_INDEX_TRAIN_LOAD_FRACTION=0.97
RETRO_INDEX_ADD_LOAD_FRACTION=0.95
######## GPT. ########
RETRO_GPT_SEED=1234
RETRO_GPT_SPLIT="98,2,0"
RETRO_GPT_DATA_PATH=${DATA_BLEND}
RETRO_GPT_DATA_IMPL=mmap
RETRO_GPT_DATALOADER_TYPE=single
RETRO_GPT_EVAL_INTERVAL=2000
RETRO_GPT_EVAL_ITERS=50
RETRO_GPT_TRAIN_SAMPLES=200000
RETRO_GPT_LR_DECAY_SAMPLES=175000
RETRO_GPT_LR_WARMUP_SAMPLES=10000
RETRO_GPT_SEQ_LENGTH=512
RETRO_GPT_GLOBAL_BATCH_SIZE=256
RETRO_GPT_CHUNK_LENGTH=64
######## Query. ########
RETRO_QUERY_NUM_NEIGHBORS_QUERY=200 RETRO_QUERY_NUM_NEIGHBORS_SAVE=20
RETRO_QUERY_EF_SEARCH=32
RETRO_QUERY_NPROBE=4096
######## Args. ########
ARGS=" \
--distributed-timeout-minutes 600 \
--tensor-model-parallel-size 1 \
--pipeline-model-parallel-size 1 \
--num-layers 24 \
--hidden-size 1024 \
--num-attention-heads 16 \
--micro-batch-size 1 \
--global-batch-size ${RETRO_GPT_GLOBAL_BATCH_SIZE} \
--seq-length 512 \
--max-position-embeddings 512 \
--load <path/to/bert/checkpoint> \
--exit-on-missing-checkpoint \
--no-load-optim \
--data-path ${RETRO_GPT_DATA_PATH} \
--tokenizer-type BertWordPieceLowerCase \
--vocab-file <path/to/bert/vocab> \
--data-impl ${RETRO_GPT_DATA_IMPL} \
--split ${RETRO_GPT_SPLIT} \
--distributed-backend nccl \
--lr 0.0001 \
--lr-decay-style linear \
--min-lr 1.0e-5 \
--train-samples ${RETRO_GPT_TRAIN_SAMPLES} \
--lr-decay-samples ${RETRO_GPT_LR_DECAY_SAMPLES} \
--lr-warmup-samples ${RETRO_GPT_LR_WARMUP_SAMPLES} \
--weight-decay 1e-2 \
--clip-grad 1.0 \
--eval-interval ${RETRO_GPT_EVAL_INTERVAL} \
--eval-iters ${RETRO_GPT_EVAL_ITERS} \
--fp16 \
--DDP-impl local \
--dataloader-type ${RETRO_GPT_DATALOADER_TYPE} \
--no-data-sharding \
--no-gradient-accumulation-fusion \
--no-async-tensor-model-parallel-allreduce \
--bert-embedder-type megatron \
--output-bert-embeddings \
\
--retro-workdir ${RETRO_WORKDIR} \
--retro-tasks ${RETRO_TASKS} \
--retro-return-doc-ids \
--retro-bert-vocab-file <path/to/bert/vocab> \
--retro-bert-tokenizer-type BertWordPieceLowerCase \
--retro-gpt-seed ${RETRO_GPT_SEED} \
--retro-gpt-tokenizer-type GPTSentencePieceTokenizer \
--retro-gpt-tokenizer-model <path/to/gpt/tokenizer/model> \
--retro-gpt-seq-length ${RETRO_GPT_SEQ_LENGTH} \
--retro-gpt-chunk-length ${RETRO_GPT_CHUNK_LENGTH} \
--retro-gpt-global-batch-size ${RETRO_GPT_GLOBAL_BATCH_SIZE} \
--retro-gpt-eval-interval ${RETRO_GPT_EVAL_INTERVAL} \
--retro-gpt-eval-iters ${RETRO_GPT_EVAL_ITERS} \
--retro-gpt-split ${RETRO_GPT_SPLIT} \
--retro-gpt-data-impl ${RETRO_GPT_DATA_IMPL} \
--retro-gpt-data-path ${RETRO_GPT_DATA_PATH} \
--retro-index-str ${RETRO_INDEX_STR} \
--retro-index-ntrain ${RETRO_INDEX_NTRAIN} \
--retro-index-train-load-fraction ${RETRO_INDEX_TRAIN_LOAD_FRACTION} \
--retro-index-add-load-fraction ${RETRO_INDEX_ADD_LOAD_FRACTION} \
--retro-index-no-delete-training-embeddings \
--retro-index-no-delete-added-codes \
--retro-query-num-neighbors-query ${RETRO_QUERY_NUM_NEIGHBORS_QUERY} \
--retro-query-num-neighbors-save ${RETRO_QUERY_NUM_NEIGHBORS_SAVE} \
--retro-query-ef-search ${RETRO_QUERY_EF_SEARCH} \
--retro-query-nprobe ${RETRO_QUERY_NPROBE} \
"
######## Command. ######## ######## Command. ########
FULL_CMD="\
pwd && cd ${REPO_DIR} && pwd && \ NPROCS=8 # Number of GPUs.
CMD="\
cd ${REPO_DIR} && pwd && \
export PYTHONPATH=$PYTHONPATH:${REPO_DIR} && \ export PYTHONPATH=$PYTHONPATH:${REPO_DIR} && \
python -m torch.distributed.launch \ python -m torch.distributed.run \
--nproc_per_node ${NPROCS} \ --nproc_per_node ${NPROCS} \
--nnodes 1 \ --nnodes 1 \
--node_rank ${NODE_RANK} \ --node_rank ${NODE_RANK} \
--master_addr ${MASTER_ADDR} \ --master_addr ${MASTER_ADDR} \
--master_port 6000 \ --master_port 6000 \
$RETRO_PREPROCESS_CMD \ tools/retro/main.py ${ARGS} \
" "
echo "~~~~~~~~~~~~~~~~~~~~~~~~~~" echo "~~~~~~~~~~~~~~~~~~~~~~~~~~"
echo "FULL_CMD = '$FULL_CMD'." echo "CMD = '$CMD'."
echo "~~~~~~~~~~~~~~~~~~~~~~~~~~" echo "~~~~~~~~~~~~~~~~~~~~~~~~~~"
eval $FULL_CMD eval $CMD
#!/bin/bash #!/bin/bash
##################################################
# Example script for pretraining Retro.
##################################################
set -u set -u
unset NCCL_DEBUG unset NCCL_DEBUG
export CUDA_DEVICE_MAX_CONNECTIONS=1 export CUDA_DEVICE_MAX_CONNECTIONS=1
NPROCS=8 # NPROCS must be <= number of GPUs. ######## GPT or Retro?. ########
# 0 : GPT.
# 1 : Retro
ADD_RETRIEVER=1
################ Dataset configs. ################ ######## Megatron, Retro dirs. ########
# This script contains methods to customize arguments to specific dataset
# types. Customize this script as needed for your datasets.
DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
. $DIR/get_dataset_configs.sh
################ Environment variables. ################ REPO_DIR="<path/to/megatron/repo>"
# *Note*: See 'Required environment variables' in 'get_preprocess_cmd.sh' for RETRO_WORKDIR="<path/to/retro/data/directory>"
# a description of the required environment variables. These variables can be
# set however a user would like. In our setup, we use another bash script
# (location defined by $RETRO_ENV_VARS) that sets all the environment variables
# at once.
. $RETRO_ENV_VARS
################ Data blend. ################ ######## Data. ########
. ${DATA_BLEND_SCRIPT}
DATA_PATH=${DATA_BLEND}
######## Retro setup. ######## DATA_BLEND="<see --data-path in arguments.py>"
RETRO_ADD_RETRIEVER=1
RETRO_CYCLIC_TRAIN_ITERS=750000 ######## Args. ########
RETRO_NUM_NEIGHBORS=2
######## Arguments. ########
CHECKPOINT_DIR=${RETRO_WORKDIR}/checkpoints/${RETRO_ADD_RETRIEVER}
TENSORBOARD_DIR="${CHECKPOINT_DIR}/tensorboard"
mkdir -p ${TENSORBOARD_DIR}
ARGS=" \ ARGS=" \
--save-interval 1000 \ --log-interval 1 \
--save ${CHECKPOINT_DIR} \ --use-flash-attn \
--load ${CHECKPOINT_DIR} \ --apply-layernorm-1p \
--tensorboard-dir ${TENSORBOARD_DIR} \ --untie-embeddings-and-output-weights \
--log-interval 5 \ --disable-bias-linear \
--no-position-embedding \
--use-rotary-position-embeddings \
--rotary-percent 0.5 \
--swiglu \
--attention-dropout 0.0 \
--hidden-dropout 0.0 \
--exit-duration-in-mins 220 \
--tensor-model-parallel-size 1 \ --tensor-model-parallel-size 1 \
--pipeline-model-parallel-size 1 \ --pipeline-model-parallel-size 1 \
--num-layers 12 \ --num-layers 24 \
--hidden-size 768 \ --hidden-size 1024 \
--num-attention-heads 12 \ --num-attention-heads 16 \
--seq-length 2048 \ --seq-length 512 \
--max-position-embeddings 2048 \ --max-position-embeddings 512 \
--micro-batch-size 4 \ --micro-batch-size 16 \
--global-batch-size 256 \ --global-batch-size 256 \
--train-samples ${RETRO_GPT_TRAIN_SAMPLES} \ --train-samples 200000 \
--lr-decay-samples ${LR_DECAY_SAMPLES} \ --lr-decay-samples 175000 \
--lr-warmup-samples ${LR_WARMUP_SAMPLES} \ --lr-warmup-samples 10000 \
--lr 6.0e-4 \ --lr 2.5e-5 \
--min-lr 6.0e-5 \ --min-lr 2.5e-6 \
--lr-decay-style cosine \ --lr-decay-style cosine \
--eval-interval ${RETRO_GPT_EVAL_INTERVAL} \ --eval-iters 50 \
--eval-iters ${RETRO_GPT_EVAL_ITERS} \ --eval-interval 2000 \
--data-path ${DATA_PATH} \ --tokenizer-type GPTSentencePieceTokenizer \
--vocab-file ${GPT_VOCAB_FILE} \ --tokenizer-model <path/to/gpt/tokenizer/model> \
--merge-file ${GPT_MERGE_FILE} \ --data-path ${DATA_BLEND} \
--split 98,2,0 \ --split 98,2,0 \
--clip-grad 1.0 \ --clip-grad 1.0 \
--weight-decay 0.1 \ --weight-decay 0.1 \
--adam-beta1 0.9 \ --adam-beta1 0.9 \
--adam-beta2 0.95 \ --adam-beta2 0.95 \
--init-method-std 0.023 \ --init-method-std 0.007 \
--log-params-norm \ --log-params-norm \
--log-num-zeros-in-grad \ --log-num-zeros-in-grad \
--fp16 \ --bf16 \
--DDP-impl local \ --DDP-impl local \
--dataloader-type ${DATALOADER_TYPE} \
--no-data-sharding \
--no-gradient-accumulation-fusion \
" "
if [ "$RETRO_ADD_RETRIEVER" = "0" ]; then ######## Retro. ########
if [ "$ADD_RETRIEVER" = "0" ]; then
SCRIPT=pretrain_gpt.py SCRIPT=pretrain_gpt.py
else else
ARGS="${ARGS} \ ARGS="${ARGS} \
--retro-add-retriever \
--retro-workdir ${RETRO_WORKDIR} \ --retro-workdir ${RETRO_WORKDIR} \
--retro-cyclic-train-iters ${RETRO_CYCLIC_TRAIN_ITERS} \ --retro-add-retriever \
--retro-num-neighbors ${RETRO_NUM_NEIGHBORS} \
" "
SCRIPT=pretrain_retro.py SCRIPT=pretrain_retro.py
fi fi
echo "~~~~~~~~~~~~~~~~~~~~~~~~~~" ######## Command. ########
echo "ARGS = '$ARGS'."
echo "~~~~~~~~~~~~~~~~~~~~~~~~~~"
python -m torch.distributed.launch \ NPROCS=8
CMD="\
pwd && cd ${REPO_DIR} && pwd && \
export PYTHONPATH=$PYTHONPATH:${REPO_DIR} && \
python -m torch.distributed.run \
--nproc_per_node ${NPROCS} \ --nproc_per_node ${NPROCS} \
--nnodes 1 \ --nnodes 1 \
--node_rank 0 \ --node_rank ${NODE_RANK} \
--master_addr localhost \ --master_addr ${MASTER_ADDR} \
--master_port 6000 \ --master_port 6000 \
${SCRIPT} \ ${SCRIPT} ${ARGS} \
${ARGS} \ "
echo "~~~~~~~~~~~~~~~~~~~~~~~~~~"
echo "CMD = '$CMD'."
echo "~~~~~~~~~~~~~~~~~~~~~~~~~~"
eval $CMD
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from .index import Index from .build import add_to_index, build_index, train_index
# from .index import Index
...@@ -18,8 +18,10 @@ from tools.retro.index.factory import IndexFactory ...@@ -18,8 +18,10 @@ from tools.retro.index.factory import IndexFactory
from tools.retro.utils import GPTToTextDataset from tools.retro.utils import GPTToTextDataset
from .utils import ( from .utils import (
get_training_data_dir, get_training_data_block_dir,
get_training_data_merged, get_training_data_block_paths,
get_training_data_merged_path,
get_training_data_root_dir,
) )
...@@ -36,6 +38,43 @@ def get_empty_index_path(): ...@@ -36,6 +38,43 @@ def get_empty_index_path():
return empty_index_path return empty_index_path
def get_block_nload(block_path, load_fraction):
with h5py.File(block_path) as fi:
return int(load_fraction * fi["data"].shape[0])
def merge_embedding_blocks():
if torch.distributed.get_rank() != 0:
return
args = get_retro_args()
# Get block, merged paths.
load_fraction = args.retro_index_train_load_fraction
block_paths = get_training_data_block_paths()
bin_path = get_training_data_merged_path()
# Skip, if already built.
if os.path.exists(bin_path):
return
# Merge blocks.
with open(bin_path, "wb") as fo:
byte_offset = 0
for block_idx, block_path in \
enumerate(tqdm(block_paths, "merge train embeddings")):
with h5py.File(block_path) as fi:
nload = get_block_nload(block_path, load_fraction)
block = np.array(fi["data"][:nload], copy = False)
fo.write(block.tobytes())
byte_offset += block.size * block.itemsize
fo.seek(byte_offset)
def embed_db(): def embed_db():
'''Embed DB chunks. '''Embed DB chunks.
...@@ -45,6 +84,10 @@ def embed_db(): ...@@ -45,6 +84,10 @@ def embed_db():
args = get_retro_args() args = get_retro_args()
merged_train_data_path = get_training_data_merged_path()
if os.path.exists(merged_train_data_path):
return
# Get db dataset. # Get db dataset.
gpt_dataset = get_merged_sampled_dataset() gpt_dataset = get_merged_sampled_dataset()
text_dataset = GPTToTextDataset(gpt_dataset) text_dataset = GPTToTextDataset(gpt_dataset)
...@@ -54,14 +97,19 @@ def embed_db(): ...@@ -54,14 +97,19 @@ def embed_db():
args.retro_bert_max_chunk_length, args.retro_bert_max_chunk_length,
args.retro_block_size, args.retro_block_size,
args.bert_embedder_type) args.bert_embedder_type)
embedder.embed_text_dataset("index", get_training_data_dir(), text_dataset) embedder.embed_text_dataset("index",
get_training_data_block_dir(),
text_dataset)
# Merge embeddings.
merge_embedding_blocks()
def train_on_embeddings(): def train_on_embeddings():
'''Train index on embedded DB chunks.''' '''Train index on embedded DB chunks.'''
args = get_retro_args() args = get_retro_args()
index = IndexFactory.get_index(args.retro_index_type) index = IndexFactory.get_index(args.retro_index_type)
index.train(get_training_data_merged) index.train()
def remove_embeddings(): def remove_embeddings():
...@@ -71,7 +119,7 @@ def remove_embeddings(): ...@@ -71,7 +119,7 @@ def remove_embeddings():
return return
empty_index_path = get_empty_index_path() empty_index_path = get_empty_index_path()
assert os.path.isfile(empty_index_path) assert os.path.isfile(empty_index_path)
shutil.rmtree(get_training_data_dir(), ignore_errors=True) shutil.rmtree(get_training_data_root_dir(), ignore_errors=True)
def train_index(): def train_index():
...@@ -92,7 +140,7 @@ def train_index(): ...@@ -92,7 +140,7 @@ def train_index():
torch.distributed.barrier() torch.distributed.barrier()
# Remove embeddings. # Remove embeddings.
if args.retro_delete_index_training_embeddings: if args.retro_index_delete_training_embeddings:
remove_embeddings() remove_embeddings()
......
...@@ -5,6 +5,7 @@ import numpy as np ...@@ -5,6 +5,7 @@ import numpy as np
import os import os
import torch import torch
from megatron import get_retro_args
from tools.retro.external_libs import faiss from tools.retro.external_libs import faiss
from .utils import get_index_dir from .utils import get_index_dir
...@@ -30,13 +31,24 @@ class Index(abc.ABC): ...@@ -30,13 +31,24 @@ class Index(abc.ABC):
faiss.ParameterSpace().set_index_parameter(index, "verbose", v) faiss.ParameterSpace().set_index_parameter(index, "verbose", v)
def get_empty_index_path(self): def get_empty_index_path(self):
return os.path.join(get_index_dir(), "empty.faissindex") args = get_retro_args()
return os.path.join(
get_index_dir(),
"empty_%.3f.faissindex" % args.retro_index_train_load_fraction,
)
def get_empty_index(self): def get_empty_index(self):
return faiss.read_index(self.get_empty_index_path()) return faiss.read_index(self.get_empty_index_path())
def get_added_index_path(self): def get_added_index_path(self):
return os.path.join(get_index_dir(), "added.faissindex") args = get_retro_args()
return os.path.join(
get_index_dir(),
"added_%.3f_%.3f.faissindex" % (
args.retro_index_train_load_fraction,
args.retro_index_add_load_fraction,
),
)
def get_added_index(self): def get_added_index(self):
return faiss.read_index(self.get_added_index_path()) return faiss.read_index(self.get_added_index_path())
......
...@@ -8,6 +8,7 @@ inherit from this class (see FaissParAddIndex, for an example). ...@@ -8,6 +8,7 @@ inherit from this class (see FaissParAddIndex, for an example).
""" """
from datetime import timedelta from datetime import timedelta
import numpy as np
import os import os
import torch import torch
from tqdm import tqdm from tqdm import tqdm
...@@ -15,13 +16,16 @@ from tqdm import tqdm ...@@ -15,13 +16,16 @@ from tqdm import tqdm
from megatron import get_retro_args, print_rank_0 from megatron import get_retro_args, print_rank_0
from tools.bert_embedding import BertEmbedder from tools.bert_embedding import BertEmbedder
from tools.retro.external_libs import faiss from tools.retro.external_libs import faiss
from tools.retro.index import Index from tools.retro.index.index import Index
from tools.retro.index.utils import num_samples_to_block_ranges from tools.retro.index.utils import (
get_training_data_merged_path,
num_samples_to_block_ranges,
)
class FaissBaseIndex(Index): class FaissBaseIndex(Index):
def _train(self, input_data_loader): def _train(self):
'''Train index (rank 0's method).''' '''Train index (rank 0's method).'''
args = get_retro_args() args = get_retro_args()
...@@ -40,17 +44,24 @@ class FaissBaseIndex(Index): ...@@ -40,17 +44,24 @@ class FaissBaseIndex(Index):
return return
# Load data. # Load data.
inp = input_data_loader() merged_path = get_training_data_merged_path()
inp = np.memmap(
merged_path,
dtype = "f4",
mode = "r",
).reshape((-1, args.hidden_size))
# Init index. # Init index.
index = faiss.index_factory(args.retro_index_nfeats, index = faiss.index_factory(args.retro_index_nfeats,
args.retro_index_str) args.retro_index_str)
# Move to GPU. # Move to GPU.
print("> move faiss index to gpu.")
index_ivf = faiss.extract_index_ivf(index) index_ivf = faiss.extract_index_ivf(index)
clustering_index = \ clustering_index = \
faiss.index_cpu_to_all_gpus(faiss.IndexFlatL2(index_ivf.d)) faiss.index_cpu_to_all_gpus(faiss.IndexFlatL2(index_ivf.d))
index_ivf.clustering_index = clustering_index index_ivf.clustering_index = clustering_index
print("> finished moving to gpu.")
self.c_verbose(index, True) self.c_verbose(index, True)
self.c_verbose(index_ivf, True) self.c_verbose(index_ivf, True)
self.c_verbose(index_ivf.quantizer, True) self.c_verbose(index_ivf.quantizer, True)
...@@ -62,12 +73,12 @@ class FaissBaseIndex(Index): ...@@ -62,12 +73,12 @@ class FaissBaseIndex(Index):
# Save index. # Save index.
faiss.write_index(index, empty_index_path) faiss.write_index(index, empty_index_path)
def train(self, input_data_loader): def train(self):
'''Train index.''' '''Train index.'''
# Single process only. # Single process only.
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
self._train(input_data_loader) self._train()
torch.distributed.barrier() torch.distributed.barrier()
......
...@@ -10,6 +10,7 @@ the vast majority of the computational effort is embarrassingly parallel. ...@@ -10,6 +10,7 @@ the vast majority of the computational effort is embarrassingly parallel.
import numpy as np import numpy as np
import os import os
import psutil
import shutil import shutil
import torch import torch
from tqdm import tqdm from tqdm import tqdm
...@@ -104,6 +105,8 @@ class FaissParallelAddIndex(FaissBaseIndex): ...@@ -104,6 +105,8 @@ class FaissParallelAddIndex(FaissBaseIndex):
if os.path.exists(added_index_path): if os.path.exists(added_index_path):
return return
args = get_retro_args()
# Index. # Index.
print_rank_0("read empty index.") print_rank_0("read empty index.")
index = self.get_empty_index() index = self.get_empty_index()
...@@ -112,10 +115,19 @@ class FaissParallelAddIndex(FaissBaseIndex): ...@@ -112,10 +115,19 @@ class FaissParallelAddIndex(FaissBaseIndex):
# Add codes. # Add codes.
print_rank_0("add codes.") print_rank_0("add codes.")
code_paths = get_added_code_paths() code_paths = get_added_code_paths()
for code_path in tqdm(code_paths, "add codes"): pbar = tqdm(code_paths)
for code_path in pbar:
pbar.set_description("add codes, mem %.3f gb, %.1f%%" % (
psutil.virtual_memory()[3] / 1024**3,
psutil.virtual_memory()[2],
))
with h5py.File(code_path) as f: with h5py.File(code_path) as f:
codes = np.copy(f["data"])
index_ivf.add_sa_codes(codes) nload = int(args.retro_index_add_load_fraction*f["data"].shape[0])
offset = int(os.path.basename(code_path).split("-")[0])
xids = np.arange(offset, offset + nload)
codes = np.copy(f["data"][:nload])
index_ivf.add_sa_codes(codes, xids)
# Update index's ntotal. # Update index's ntotal.
index.ntotal = index_ivf.ntotal index.ntotal = index_ivf.ntotal
...@@ -129,18 +141,19 @@ class FaissParallelAddIndex(FaissBaseIndex): ...@@ -129,18 +141,19 @@ class FaissParallelAddIndex(FaissBaseIndex):
if torch.distributed.get_rank() != 0: if torch.distributed.get_rank() != 0:
return return
assert os.path.isfile(self.get_added_index_path()) assert os.path.isfile(self.get_added_index_path())
shutil.rmtree(get_added_codes_dir(), ignore_errors=True)
def add(self, text_dataset): args = get_retro_args()
if args.retro_index_delete_added_codes:
raise Exception("remove?")
shutil.rmtree(get_added_codes_dir(), ignore_errors=True)
# Check if index already exists. def add(self, text_dataset):
if not os.path.isfile(self.get_added_index_path()):
# Encode chunks. # Encode chunks.
self.encode(text_dataset) self.encode(text_dataset)
# Add codes to index. # Add codes to index.
self.add_codes() self.add_codes()
# Wait for (single-process) adding to complete. # Wait for (single-process) adding to complete.
torch.distributed.barrier() torch.distributed.barrier()
......
...@@ -45,128 +45,28 @@ def num_samples_to_block_ranges(num_samples): ...@@ -45,128 +45,28 @@ def num_samples_to_block_ranges(num_samples):
return ranges return ranges
def get_training_data_dir(): def get_training_data_root_dir():
return os.path.join(get_index_dir(), "train_tmp")
def get_training_data_paths():
return sorted(glob.glob(get_training_data_dir() + "/*.hdf5"))
def get_added_codes_dir():
return os.path.join(get_index_dir(), "add_tmp")
def get_added_code_paths():
return sorted(glob.glob(get_added_codes_dir() + "/*.hdf5"))
def get_training_data_group_infos():
args = get_retro_args() args = get_retro_args()
return os.path.join(args.retro_workdir, "index", "train_emb")
block_paths = get_training_data_paths()
max_group_size = args.retro_index_train_block_size
groups = [] def get_training_data_block_dir():
group = [] return os.path.join(get_training_data_root_dir(), "blocks")
group_size = 0
for block_path in block_paths:
with h5py.File(block_path) as f:
block_size = f["data"].shape[0]
group.append(block_path)
group_size += block_size
if group_size >= max_group_size:
groups.append({
"paths" : group,
"size" : group_size,
})
group = []
group_size = 0
if group:
groups.append({
"paths" : group,
"size" : group_size,
})
return groups def get_training_data_block_paths():
return sorted(glob.glob(get_training_data_block_dir() + "/*.hdf5"))
def load_training_block(path, load_fraction): def get_training_data_merged_path():
with h5py.File(path) as f: args = get_retro_args()
n_load = int(load_fraction * f["data"].shape[0]) return os.path.join(get_training_data_root_dir(),
return np.copy(f["data"][:n_load]) "train_%.3f.bin" % args.retro_index_train_load_fraction)
def load_training_group(executor, group_info, load_fraction):
# Launch threads to load block data.
futures = []
for path in group_info["paths"]:
futures.append(executor.submit(load_training_block, path, load_fraction))
# Collect block data.
block_datas = []
for future in futures:
block_datas.append(future.result())
# Concatenate blocks.
group_data = np.concatenate(block_datas, axis=0)
# Garbage collect.
for d in block_datas:
del d
gc.collect()
return group_data
def get_training_data_merged(): def get_added_codes_dir():
'''Merge embeddings into single dataset.''' return os.path.join(get_index_dir(), "add_codes")
args = get_retro_args()
# Setup. def get_added_code_paths():
ds_infos = get_indexed_dataset_infos() return sorted(glob.glob(get_added_codes_dir() + "/*.hdf5"))
n_chunks_sampled = sum(d["n_chunks_sampled"] for d in ds_infos)
load_fraction = args.retro_index_train_load_fraction
# Initialize merged data.
print("allocate training data array.")
t = time.time()
data = np.empty((n_chunks_sampled, args.retro_index_nfeats), dtype="f4")
print(" time : %.3f sec." % (time.time() - t))
# Data groups (minimizing fragmentation).
group_infos = get_training_data_group_infos()
# Load data blocks.
n_threads = max(len(group["paths"]) for group in group_infos)
with concurrent.futures.ThreadPoolExecutor(max_workers=n_threads) as executor:
# Load data blocks.
print("load training data blocks.")
start_idx = 0
pbar = tqdm(group_infos)
for group_info in pbar:
pbar.set_description("mem %.0f gb, %.1f%%" % (
psutil.virtual_memory()[3] / 1024**3,
psutil.virtual_memory()[2],
))
# Load group data.
group_data = load_training_group(executor, group_info, load_fraction)
data[start_idx:(start_idx+len(group_data))] = group_data
start_idx += len(group_data)
# Garbage collect.
del group_data
gc.collect()
# Handle load ratio <1.
data = data[:start_idx]
print("> training block data.shape = %s." % str(data.shape))
return data
...@@ -15,8 +15,8 @@ import torch ...@@ -15,8 +15,8 @@ import torch
from megatron import get_args, initialize_megatron, print_rank_0 from megatron import get_args, initialize_megatron, print_rank_0
from megatron.global_vars import set_retro_args from megatron.global_vars import set_retro_args
from tools.retro.db import build_db from tools.retro.db import build_db
from tools.retro.index.build import add_to_index, build_index, train_index from tools.retro.index import add_to_index, build_index, train_index
from tools.retro.pretraining.query import query_pretraining_neighbors from tools.retro.query import query_pretraining_neighbors
from tools.retro.utils import get_args_path from tools.retro.utils import get_args_path
...@@ -31,16 +31,69 @@ def add_retro_args(parser): ...@@ -31,16 +31,69 @@ def add_retro_args(parser):
group = parser.add_argument_group(title="Retro preprocessing.") group = parser.add_argument_group(title="Retro preprocessing.")
group.add_argument("--retro-gpt-vocab-file", required=True, # Basic args.
help="GPT vocab file.") group.add_argument("--retro-tasks", default="build",
group.add_argument("--retro-gpt-merge-file", required=True, help="Comma-separated list of tasks to run. Run entire "
help="GPT merge file.") "preprocesing pipeline by using '--retro-tasks build'. "
"Alternatively, run individual stages with tasks (in "
"this order) 'db-build', 'index-build', or "
"'query-pretraining-neighbors'. For example, "
"'--retro-tasks db-build,index-build,"
"query-pretraining-neighbors' is equivalent to "
"'--retro-tasks build'; or the argument can contain "
"a subset of these tasks. Stages must always be run "
"in the correct order (listed above).")
group.add_argument("--retro-block-size", type=int, default=100000,
help="Number of chunks to process at a time when "
"generating Bert embeddings and querying the search "
"index. Partial results for each block are generally "
"saved to disk in separate files.")
group.add_argument("--retro-doc-block-size", type=int, default=100000,
help="Number of documents to processe at time when "
"processing token datasets into chunk databases. The "
"partial chunk database for each block is saved into "
"a separate file.")
# GPT args.
group.add_argument('--retro-gpt-seed', type=int, default=1234,
help='Random seed used for python, numpy, '
'pytorch, and cuda.')
group.add_argument('--retro-gpt-data-impl', type=str, default='infer',
choices=['lazy', 'cached', 'mmap', 'infer'],
help='Implementation of indexed datasets.')
group.add_argument('--retro-gpt-data-path', nargs='*', required=True,
help='Path to the training dataset. Accepted format:'
'1) a single data path, 2) multiple datasets in the'
'form: dataset1-weight dataset1-path dataset2-weight '
'dataset2-path ... It is used with --split when a '
'single dataset used for all three: train, valid '
'and test. It is exclusive to the other '
'--*-data-path args')
group.add_argument('--retro-gpt-split', type=str, default='969,30,1',
help='Comma-separated list of proportions for training,'
' validation, and test split. For example the split '
'`90,5,5` will use 90%% of data for training, 5%% for '
'validation and 5%% for test.')
group.add_argument('--retro-gpt-mmap-warmup', action='store_true',
help='Warm up mmap files.')
group.add_argument("--retro-gpt-eval-interval", type=int, required=True,
help="GPT evaluation interval.")
group.add_argument("--retro-gpt-eval-iters", type=int, required=True,
help="GPT evaluation iterations.")
group.add_argument("--retro-gpt-tokenizer-type", required=True, group.add_argument("--retro-gpt-tokenizer-type", required=True,
help="GPT tokenizer type.") help="GPT tokenizer type.")
group.add_argument("--retro-gpt-seq-length", type=int, default=2048, group.add_argument("--retro-gpt-vocab-file", help="GPT vocab file.")
group.add_argument("--retro-gpt-merge-file", help="GPT merge file.")
group.add_argument("--retro-gpt-tokenizer-model",
help="GPT tokenizer model file.")
group.add_argument("--retro-gpt-seq-length", type=int, required=True,
help="GPT sequence length.") help="GPT sequence length.")
group.add_argument("--retro-gpt-global-batch-size", type=int, required=True,
help="GPT global batch size.")
group.add_argument("--retro-gpt-chunk-length", type=int, default=64, group.add_argument("--retro-gpt-chunk-length", type=int, default=64,
help="GPT chunk length.") help="GPT chunk length.")
# Bert args.
group.add_argument("--retro-bert-vocab-file", required=True, group.add_argument("--retro-bert-vocab-file", required=True,
help="Bert vocab file.") help="Bert vocab file.")
group.add_argument("--retro-bert-tokenizer-type", required=True, group.add_argument("--retro-bert-tokenizer-type", required=True,
...@@ -52,17 +105,8 @@ def add_retro_args(parser): ...@@ -52,17 +105,8 @@ def add_retro_args(parser):
help="Maximum sequence length for Bert embeddings. " help="Maximum sequence length for Bert embeddings. "
"(Named 'chunk' here in reference to these Bert " "(Named 'chunk' here in reference to these Bert "
"sequences being converted from GPT chunks.)") "sequences being converted from GPT chunks.)")
group.add_argument("--retro-tasks", default="build",
help="Comma-separated list of tasks to run. Run entire " # Index args.
"preprocesing pipeline by using '--retro-tasks build'. "
"Alternatively, run individual stages with tasks (in "
"this order) 'db-build', 'index-build', or "
"'pretraining-query-neighbors'. For example, "
"'--retro-tasks db-build,index-build,"
"pretraining-query-neighbors' is equivalent to "
"'--retro-tasks build'; or the argument can contain "
"a subset of these tasks. Stages must always be run "
"in the correct order (listed above).")
group.add_argument("--retro-index-nfeats", "-f", type=int, default=1024, group.add_argument("--retro-index-nfeats", "-f", type=int, default=1024,
help="Dimension of Bert embeddings. Bert-large is " help="Dimension of Bert embeddings. Bert-large is "
"commonly used, so this value defaults to 1024.") "commonly used, so this value defaults to 1024.")
...@@ -78,34 +122,10 @@ def add_retro_args(parser): ...@@ -78,34 +122,10 @@ def add_retro_args(parser):
"faiss.index_factory(). For example, " "faiss.index_factory(). For example, "
"'IVF262144_HNSW32,Flat' or " "'IVF262144_HNSW32,Flat' or "
"'OPQ32_256,IVF4194304_HNSW32,PQ32'.") "'OPQ32_256,IVF4194304_HNSW32,PQ32'.")
group.add_argument("--retro-ef-search", type=int, default=256, group.add_argument("--retro-index-ntrain", type=int, required=True,
help="Index ef-search parameter for HNSW during "
"querying.")
group.add_argument("--retro-nprobe", type=int, default=65536,
help="Index nprobe parameter for IVF during "
"querying.")
group.add_argument("--retro-nchunks-sampled", type=int, required=True,
help="Number of database chunks to use for training " help="Number of database chunks to use for training "
"the index. This value must be less or equal to the " "the index. This value must be less or equal to the "
"total number of chunks in the database.") "total number of chunks in the database.")
group.add_argument("--retro-doc-block-size", type=int, default=100000,
help="Number of documents to processe at time when "
"processing token datasets into chunk databases. The "
"partial chunk database for each block is saved into "
"a separate file.")
group.add_argument("--retro-block-size", type=int, default=100000,
help="Number of chunks to process at a time when "
"generating Bert embeddings and querying the search "
"index. Partial results for each block are generally "
"saved to disk in separate files.")
group.add_argument("--retro-index-train-block-size",
type=int, default=3750000,
help="As a memory fragmentation optimization, when "
"loading training data for training the search index, "
"enough data blocks loaded at a time until they reach "
"retro_index_train_block_size, and then this "
"data block is copied into the full training data "
"array.")
group.add_argument("--retro-index-train-load-fraction", group.add_argument("--retro-index-train-load-fraction",
type=float, default=1., type=float, default=1.,
help="Fraction of sampled chunks to use for training " help="Fraction of sampled chunks to use for training "
...@@ -113,19 +133,36 @@ def add_retro_args(parser): ...@@ -113,19 +133,36 @@ def add_retro_args(parser):
"use too much memory; lowering the load fraction is " "use too much memory; lowering the load fraction is "
"less costly than re-embedding a new sampled dataset " "less costly than re-embedding a new sampled dataset "
"from scratch.") "from scratch.")
group.add_argument("--retro-num-neighbors-query", type=int, default=2000, group.add_argument("--retro-index-add-load-fraction",
type=float, default=1.,
help="Fraction of database chunks to use for adding to "
"the index. Useful when our total index size would "
"use too much memory; lowering the load fraction is "
"less costly than re-designing our token datasets.")
group.add_argument("--retro-index-no-delete-training-embeddings",
action='store_false',
dest="retro_index_delete_training_embeddings",
help="Skip deleting training embeddings for the search "
"index. Useful for debugging.")
group.add_argument("--retro-index-no-delete-added-codes",
action='store_false',
dest="retro_index_delete_added_codes",
help="Skip deleting added codes for the search "
"index. Useful for debugging.")
# Query args.
group.add_argument("--retro-query-ef-search", type=int, default=256,
help="Index ef-search parameter for HNSW during querying.")
group.add_argument("--retro-query-nprobe", type=int, default=65536,
help="Index nprobe parameter for IVF during querying.")
group.add_argument("--retro-query-num-neighbors-query", type=int, default=200,
help="Number of neighbors to retrieve when calling " help="Number of neighbors to retrieve when calling "
"index.search().") "index.search().")
group.add_argument("--retro-num-neighbors-target", type=int, default=200, group.add_argument("--retro-query-num-neighbors-save", type=int, default=20,
help="Number of neighbors to save to disk after " help="Number of neighbors to save to disk after "
"the index's returned neighbors. If longer than target " "the index's returned neighbors. If longer than target "
"value, neighbors truncated; and if shorter than target " "value, neighbors truncated; and if shorter than target "
"value, neighbors are padded with -1's.") "value, neighbors are padded with -1's.")
group.add_argument("--retro-no-delete-index-training-embeddings",
action='store_false',
dest="retro_delete_index_training_embeddings",
help="Skip deleting training embeddings for the search "
"index. Useful for debugging.")
# Enforce argument naming convention. # Enforce argument naming convention.
for action in group._group_actions: for action in group._group_actions:
...@@ -140,10 +177,16 @@ def add_retro_args(parser): ...@@ -140,10 +177,16 @@ def add_retro_args(parser):
def save_args(args): def save_args(args):
'''Save copy of args within retro workdir.''' '''Save copy of args within retro workdir.'''
def default_dump(obj):
if isinstance(obj, torch.dtype):
return str(obj)
else:
raise Exception("specialize for <%s>." % type(obj).__name__)
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
args_path = get_args_path(args.retro_workdir) args_path = get_args_path(args.retro_workdir)
with open(args_path, "w") as f: with open(args_path, "w") as f:
json.dump(vars(args), f, indent=4, default=lambda o : "<skipped>") json.dump(vars(args), f, indent=4, default=default_dump)
torch.distributed.barrier() torch.distributed.barrier()
...@@ -188,7 +231,7 @@ if __name__ == "__main__": ...@@ -188,7 +231,7 @@ if __name__ == "__main__":
add_to_index() # add only add_to_index() # add only
# Pretraining. # Pretraining.
elif task == "pretraining-query-neighbors": elif task == "query-pretraining-neighbors":
query_pretraining_neighbors() query_pretraining_neighbors()
else: else:
......
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import os
from megatron import get_retro_args
def get_pretraining_workdir():
args = get_retro_args()
return os.path.join(args.retro_workdir, "pretraining")
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from .query import query_pretraining_neighbors
...@@ -4,15 +4,16 @@ import os ...@@ -4,15 +4,16 @@ import os
import torch import torch
from megatron import get_retro_args, print_rank_0 from megatron import get_retro_args, print_rank_0
from megatron.data.gpt_dataset import build_train_valid_test_datasets from megatron.data.gpt_dataset import build_train_valid_test_datasets \
as build_gpt_train_valid_test_datasets
from megatron.training import ( from megatron.training import (
build_train_valid_test_data_loaders, build_train_valid_test_datasets as build_pretraining_train_valid_test_datasets,
update_train_iters, update_train_iters,
) )
from tools.retro.db.utils import get_indexed_dataset_infos from tools.retro.db.utils import get_indexed_dataset_infos
from tools.retro.utils import get_num_chunks_per_sample from tools.retro.utils import get_num_chunks_per_sample
from .utils import get_pretraining_workdir from .utils import get_neighbor_dirname, get_query_workdir
class ChunkDataset(torch.utils.data.Dataset): class ChunkDataset(torch.utils.data.Dataset):
...@@ -86,14 +87,14 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): ...@@ -86,14 +87,14 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
print_rank_0('> building train, validation, and test datasets ' print_rank_0('> building train, validation, and test datasets '
'for GPT ...') 'for GPT ...')
train_ds, valid_ds, test_ds = build_train_valid_test_datasets( train_ds, valid_ds, test_ds = build_gpt_train_valid_test_datasets(
data_prefix=args.data_path, data_prefix=args.retro_gpt_data_path,
data_impl=args.data_impl, data_impl=args.retro_gpt_data_impl,
splits_string=args.split, splits_string=args.retro_gpt_split,
train_valid_test_num_samples=train_val_test_num_samples, train_valid_test_num_samples=train_val_test_num_samples,
seq_length=args.retro_gpt_seq_length, seq_length=args.retro_gpt_seq_length,
seed=args.seed, seed=args.retro_gpt_seed,
skip_warmup=(not args.mmap_warmup), skip_warmup=(not args.retro_gpt_mmap_warmup),
return_doc_ids=args.retro_return_doc_ids) return_doc_ids=args.retro_return_doc_ids)
print_rank_0("> finished creating pretrained GPT datasets ...") print_rank_0("> finished creating pretrained GPT datasets ...")
...@@ -115,28 +116,23 @@ def get_chunk_dataset_map(): ...@@ -115,28 +116,23 @@ def get_chunk_dataset_map():
verify_indexed_dataset_order() verify_indexed_dataset_order()
# Datasets. # Datasets.
print_rank_0(" > data loader.") print_rank_0(" > datasets.")
train_data_loader, valid_data_loader, test_data_loader \ train_ds, valid_ds, test_ds = build_pretraining_train_valid_test_datasets(
= build_train_valid_test_data_loaders( train_valid_test_datasets_provider)
train_valid_test_datasets_provider)
sample_dataset_map = {
data_loader_map = { "train" : train_ds,
"train" : train_data_loader, "valid" : valid_ds,
"valid" : valid_data_loader, "test" : test_ds,
"test" : test_data_loader,
} }
# Info dict. # Info dict.
workdir = get_pretraining_workdir() chunk_dataset_map = {
dataset_map = {
key : { key : {
"neighbor_dir" : os.path.join( "neighbor_dir" : get_neighbor_dirname(key, sample_ds),
workdir, "data" : ChunkDataset(sample_ds, args.retro_gpt_chunk_length),
os.path.basename(loader.dataset.datasets[0].index_prefix),
),
"data" : ChunkDataset(loader.dataset, args.retro_gpt_chunk_length),
} }
for key, loader in data_loader_map.items() if loader for key, sample_ds in sample_dataset_map.items() if sample_ds
} }
return dataset_map return chunk_dataset_map
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from collections import defaultdict
import numpy as np import numpy as np
import os import os
import psutil
import time import time
import torch import torch
from tqdm import tqdm from tqdm import tqdm
from megatron import get_retro_args, mpu, print_rank_0 from megatron import get_retro_args, print_rank_0
from tools.bert_embedding import BertEmbedder from tools.bert_embedding import BertEmbedder
from tools.bert_embedding.utils import get_missing_blocks_by_rank from tools.bert_embedding.utils import get_missing_blocks_by_rank
from tools.retro.db.utils import ( from tools.retro.db.utils import \
get_merged_train_dataset as get_db_merged_train_dataset, get_merged_train_dataset as get_db_merged_train_dataset
get_train_doc_chunk_map,
)
from tools.retro.external_libs import faiss, h5py from tools.retro.external_libs import faiss, h5py
from tools.retro.index.factory import IndexFactory from tools.retro.index.factory import IndexFactory
from tools.retro.index.utils import get_index_dir, num_samples_to_block_ranges from tools.retro.index.utils import get_index_dir
from tools.retro.utils import GPTToTextDataset from tools.retro.utils import GPTToTextDataset
from .chunk_dataset import get_chunk_dataset_map from .chunk_dataset import get_chunk_dataset_map as get_query_dataset_map
def get_index(chunk_db_dataset, ondisk=False): def get_index(ondisk=False):
'''Read index from disk.''' '''Read index from disk.'''
args = get_retro_args() args = get_retro_args()
# Chunk db block ranges.
n_db_chunks = len(chunk_db_dataset)
dataset_block_ranges = num_samples_to_block_ranges(n_db_chunks)
# Load index. # Load index.
index_wrapper = IndexFactory.get_index(args.retro_index_type) index_wrapper = IndexFactory.get_index(args.retro_index_type)
index_dir = get_index_dir() index_dir = get_index_dir()
...@@ -42,9 +36,9 @@ def get_index(chunk_db_dataset, ondisk=False): ...@@ -42,9 +36,9 @@ def get_index(chunk_db_dataset, ondisk=False):
# Search parameters. # Search parameters.
faiss.ParameterSpace().set_index_parameter(index, "efSearch", faiss.ParameterSpace().set_index_parameter(index, "efSearch",
args.retro_ef_search) args.retro_query_ef_search)
faiss.ParameterSpace().set_index_parameter(index, "nprobe", faiss.ParameterSpace().set_index_parameter(index, "nprobe",
args.retro_nprobe) args.retro_query_nprobe)
return index return index
...@@ -58,8 +52,9 @@ def embed_block(gpt_dataset, block, embedder): ...@@ -58,8 +52,9 @@ def embed_block(gpt_dataset, block, embedder):
return embedder.embed_text_dataset(text_block_dataset) return embedder.embed_text_dataset(text_block_dataset)
def query_embeddings(index, banned_chunk_map, chunk_id_range, def query_embeddings(db_dataset, index,
embeddings, sample_map, n_chunks_per_sample, embeddings, chunk_id_range,
sample_map, n_chunks_per_sample,
verbose=True): verbose=True):
'''Query neighbors of a block of embeddings.''' '''Query neighbors of a block of embeddings.'''
...@@ -70,24 +65,13 @@ def query_embeddings(index, banned_chunk_map, chunk_id_range, ...@@ -70,24 +65,13 @@ def query_embeddings(index, banned_chunk_map, chunk_id_range,
t = time.time() t = time.time()
assert index.ntotal > 0, "check we don't accidentally have an empty index." assert index.ntotal > 0, "check we don't accidentally have an empty index."
_, query_neighbor_ids = \ _, query_neighbor_ids = \
index.search(embeddings, args.retro_num_neighbors_query) index.search(embeddings, args.retro_query_num_neighbors_query)
if verbose: print_rank_0(" time : %.3f sec." % (time.time() - t)) if verbose: print_rank_0(" time : %.3f sec." % (time.time() - t))
# Banned neighbor ids.
if verbose: print_rank_0("get banned neighbor ids.")
sample_banned_chunk_id_map = {}
for sample_id, sample in sample_map.items():
dataset_idx = sample["dataset_idx"].item()
doc_ids = sample["doc_ids"].tolist()
banned_chunk_ids = set()
for doc_id in doc_ids:
banned_chunk_ids.update(banned_chunk_map[(dataset_idx, doc_id)])
sample_banned_chunk_id_map[sample_id] = banned_chunk_ids
# Filter banned neighbor ids. # Filter banned neighbor ids.
if verbose: print_rank_0("filter banned neighbor ids.") if verbose: print_rank_0("filter banned neighbor ids.")
filtered_neighbor_ids = np.full( filtered_neighbor_ids = np.full(
shape=(len(query_neighbor_ids), args.retro_num_neighbors_target), shape=(len(query_neighbor_ids), args.retro_query_num_neighbors_save),
fill_value=-1, fill_value=-1,
dtype="int64", dtype="int64",
) )
...@@ -95,24 +79,30 @@ def query_embeddings(index, banned_chunk_map, chunk_id_range, ...@@ -95,24 +79,30 @@ def query_embeddings(index, banned_chunk_map, chunk_id_range,
for chunk_id in range(min_chunk_id, max_chunk_id): for chunk_id in range(min_chunk_id, max_chunk_id):
sample_id = chunk_id // n_chunks_per_sample sample_id = chunk_id // n_chunks_per_sample
sample = sample_map[sample_id]
sample_dataset_idx = sample["dataset_idx"].item()
sample_doc_ids = sample["doc_ids"].tolist()
sample_doc_tuples = [(sample_dataset_idx, d) for d in sample_doc_ids]
# Get valid neighbors (!= -1). # Get valid neighbors (!= -1).
query_row = [ i for i in query_neighbor_ids[chunk_id-min_chunk_id] query_row = [ i for i in query_neighbor_ids[chunk_id-min_chunk_id]
if i >= 0 ] if i >= 0 ]
# Filter row. # Filter row.
filtered_row = [i for i in query_row filtered_row = [ i for i in query_row
if i not in sample_banned_chunk_id_map[sample_id]] if tuple(db_dataset.doc_tuples[i].tolist())
filtered_row = filtered_row[:args.retro_num_neighbors_target] not in sample_doc_tuples ]
filtered_row = filtered_row[:args.retro_query_num_neighbors_save]
filtered_row += \ filtered_row += \
[-1] * (args.retro_num_neighbors_target - len(filtered_row)) [-1] * (args.retro_query_num_neighbors_save - len(filtered_row))
filtered_neighbor_ids[chunk_id-min_chunk_id] = filtered_row filtered_neighbor_ids[chunk_id-min_chunk_id] = filtered_row
return query_neighbor_ids, filtered_neighbor_ids return query_neighbor_ids, filtered_neighbor_ids
def query_embedding_block(index, banned_chunk_map, chunk_id_range, def query_embedding_block(db_dataset, index,
embeddings, sample_map, n_chunks_per_sample): embeddings, chunk_id_range,
sample_map, n_chunks_per_sample):
query_neighbor_ids = [] query_neighbor_ids = []
filtered_neighbor_ids = [] filtered_neighbor_ids = []
...@@ -131,8 +121,9 @@ def query_embedding_block(index, banned_chunk_map, chunk_id_range, ...@@ -131,8 +121,9 @@ def query_embedding_block(index, banned_chunk_map, chunk_id_range,
chunk_id_range[0] + partial_end_idx, chunk_id_range[0] + partial_end_idx,
) )
partial_query_neighbor_ids, partial_filtered_neighbor_ids = \ partial_query_neighbor_ids, partial_filtered_neighbor_ids = \
query_embeddings(index, banned_chunk_map, partial_chunk_id_range, query_embeddings(db_dataset, index,
partial_embeddings, sample_map, n_chunks_per_sample, partial_embeddings, partial_chunk_id_range,
sample_map, n_chunks_per_sample,
verbose=False) verbose=False)
query_neighbor_ids.append(partial_query_neighbor_ids) query_neighbor_ids.append(partial_query_neighbor_ids)
filtered_neighbor_ids.append(partial_filtered_neighbor_ids) filtered_neighbor_ids.append(partial_filtered_neighbor_ids)
...@@ -144,26 +135,33 @@ def query_embedding_block(index, banned_chunk_map, chunk_id_range, ...@@ -144,26 +135,33 @@ def query_embedding_block(index, banned_chunk_map, chunk_id_range,
return query_neighbor_ids, filtered_neighbor_ids return query_neighbor_ids, filtered_neighbor_ids
def query_block_neighbors(index, banned_chunk_map, chunk_dataset, def query_block_neighbors(db_dataset, query_dataset,
block, embedder): index, embedder,
block):
'''Query neighbors of a dataset block (i.e., range).''' '''Query neighbors of a dataset block (i.e., range).'''
args = get_retro_args() args = get_retro_args()
n_chunks_per_sample = chunk_dataset.n_chunks_per_sample n_chunks_per_sample = query_dataset.n_chunks_per_sample
# Sample map. # Sample map.
sample_ids = sorted(list(set(chunk_id // n_chunks_per_sample sample_ids = sorted(list(set(chunk_id // n_chunks_per_sample
for chunk_id in range(*block["range"])))) for chunk_id in range(*block["range"]))))
sample_map = {i:chunk_dataset.sample_dataset[i] for i in sample_ids} sample_map = {}
for i in sample_ids:
sample = query_dataset.sample_dataset[i]
sample_map[i] = {
"dataset_idx" : sample["dataset_idx"],
"doc_ids" : sample["doc_ids"],
}
# Embed block. # Embed block.
embeddings = embed_block(chunk_dataset, block, embedder) embeddings = embed_block(query_dataset, block, embedder)
# Query embeddings. # Query embeddings.
_, filtered_neighbor_ids = query_embedding_block( _, filtered_neighbor_ids = query_embedding_block(
index, banned_chunk_map, block["range"], db_dataset, index,
embeddings, sample_map, embeddings, block["range"],
n_chunks_per_sample) sample_map, n_chunks_per_sample)
# Save neighbors. # Save neighbors.
print_rank_0("save neighbors.") print_rank_0("save neighbors.")
...@@ -173,22 +171,22 @@ def query_block_neighbors(index, banned_chunk_map, chunk_dataset, ...@@ -173,22 +171,22 @@ def query_block_neighbors(index, banned_chunk_map, chunk_dataset,
f.close() f.close()
def query_dataset_neighbors(index, banned_chunk_map, def query_dataset_neighbors(db_dataset, query_dataset,
prefix, chunk_dataset, neighbor_dir, prefix, neighbor_dir,
embedder): index, embedder):
'''Query neighbors of each chunk within a dataset.''' '''Query neighbors of each chunk within a dataset.'''
args = get_retro_args() args = get_retro_args()
def validate(f): def validate(f):
assert f["neighbors"].shape[1] == args.retro_num_neighbors_target, \ assert f["neighbors"].shape[1] == args.retro_query_num_neighbors_save, \
"neighbors.shape == %s; num_neighbors_target == %d." % ( "neighbors.shape == %s; num_neighbors_target == %d." % (
str(f["neighbors"].shape), str(f["neighbors"].shape),
args.retro_num_neighbors_target, args.retro_num_neighbors_target,
) )
n_missing_blocks, missing_neighbor_blocks = get_missing_blocks_by_rank( n_missing_blocks, missing_neighbor_blocks = get_missing_blocks_by_rank(
neighbor_dir, neighbor_dir,
len(chunk_dataset), len(query_dataset),
args.retro_block_size, args.retro_block_size,
validate=validate, validate=validate,
) )
...@@ -199,16 +197,19 @@ def query_dataset_neighbors(index, banned_chunk_map, ...@@ -199,16 +197,19 @@ def query_dataset_neighbors(index, banned_chunk_map,
if block is not None: if block is not None:
# Progress. # Progress.
print_rank_0("query '%s' block %d / %d ... %s." % ( print_rank_0("query '%s' block %d / %d ... %s ... mem %.3f gb, %.1f%%." % (
prefix, prefix,
block_index, block_index,
len(missing_neighbor_blocks), len(missing_neighbor_blocks),
block["path"], os.path.basename(block["path"]),
psutil.virtual_memory()[3] / 1024**3,
psutil.virtual_memory()[2],
)) ))
# Query block neighbors. # Query block neighbors.
query_block_neighbors(index, banned_chunk_map, query_block_neighbors(db_dataset, query_dataset,
chunk_dataset, block, embedder) index, embedder,
block)
# Synchronize progress across all ranks. (for easier observation) # Synchronize progress across all ranks. (for easier observation)
print_rank_0(" > waiting for other ranks to finish block.") print_rank_0(" > waiting for other ranks to finish block.")
...@@ -225,17 +226,16 @@ def query_pretraining_neighbors(): ...@@ -225,17 +226,16 @@ def query_pretraining_neighbors():
# Load chunk db dataset. # Load chunk db dataset.
print_rank_0("load chunk db dataset.") print_rank_0("load chunk db dataset.")
chunk_db_dataset = get_db_merged_train_dataset() db_dataset = get_db_merged_train_dataset()
db_dataset.load_doc_tuples()
# Load index, banned chunk ids, datasets. # Load index.
print_rank_0(" > get index.") print_rank_0(" > get index.")
index = get_index(chunk_db_dataset) index = get_index()
print_rank_0(" > get banned doc-chunk id map.")
banned_chunk_map = get_train_doc_chunk_map()
# Load datasets.
print_rank_0(" > get dataset map.") print_rank_0(" > get dataset map.")
chunk_dataset_map = get_chunk_dataset_map() query_dataset_map = get_query_dataset_map()
# Bert embedder. # Bert embedder.
embedder = BertEmbedder(args.retro_bert_batch_size, embedder = BertEmbedder(args.retro_bert_batch_size,
...@@ -244,9 +244,9 @@ def query_pretraining_neighbors(): ...@@ -244,9 +244,9 @@ def query_pretraining_neighbors():
# Query each (i.e., train, valid, test) dataset. # Query each (i.e., train, valid, test) dataset.
print_rank_0(" > query.") print_rank_0(" > query.")
for prefix, info in chunk_dataset_map.items(): for prefix, info in query_dataset_map.items():
print_rank_0(" > query '%s' dataset ... %d samples." % print_rank_0(" > query '%s' dataset ... %d samples." %
(prefix, len(info["data"]))) (prefix, len(info["data"])))
query_dataset_neighbors(index, banned_chunk_map, query_dataset_neighbors(db_dataset, info["data"],
prefix, info["data"], info["neighbor_dir"], prefix, info["neighbor_dir"],
embedder) index, embedder)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment