Commit 0024a5c6 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'main' of https://github.com/NVIDIA/Megatron-LM

parents b004456b 3db2063b
Pipeline #229 failed with stages
in 0 seconds
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import json
import time
import sys
if __name__ == '__main__':
url_filename = sys.argv[1]
data_filename = sys.argv[2]
output_filename = sys.argv[3]
urls = set()
with open(url_filename, 'r') as f:
for line in f:
myjson = json.loads(line)
for key in myjson:
this_urls = myjson[key]
for i in range(1, len(this_urls)):
urls.add(this_urls[i])
print('will be removing {} urls'.format(len(urls)), flush=True)
written_docs = 0
removed_docs = 0
removed_chars = 0
start_time = time.time()
with open(output_filename, 'wb') as fout:
with open(data_filename, 'r') as fin:
for line in fin:
try:
myjson = json.loads(line)
url = myjson['url']
if url in urls:
print('removing', myjson)
removed_docs += 1
removed_chars += len(myjson['text'])
continue
myjson = json.dumps(myjson, ensure_ascii=False)
fout.write(myjson.encode('utf-8'))
fout.write('\n'.encode('utf-8'))
written_docs += 1
if written_docs % 10000 == 0:
print(' [PROCESSED] time (s): {:.2f} | written: {} '
'| removed: {} (char: {})'.format(
time.time() - start_time,
written_docs, removed_docs, removed_chars))
except Exception as e:
print('[SKIPPING]', line, e)
print(' [PROCESSED] time (s): {:.2f} | written: {} '
'| removed: {} (char: {})'.format(
time.time() - start_time,
written_docs, removed_docs, removed_chars))
print('done :-)')
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Processing data for pretraining."""
import argparse
import json
import multiprocessing
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
os.path.pardir)))
import time
import torch
try:
import nltk
nltk_available = True
except ImportError:
nltk_available = False
from megatron.tokenizer import build_tokenizer
from megatron.data import indexed_dataset
# https://stackoverflow.com/questions/33139531/preserve-empty-lines-with-nltks-punkt-tokenizer
class CustomLanguageVars(nltk.tokenize.punkt.PunktLanguageVars):
_period_context_fmt = r"""
\S* # some word material
%(SentEndChars)s # a potential sentence ending
\s* # <-- THIS is what I changed
(?=(?P<after_tok>
%(NonWord)s # either other punctuation
|
(?P<next_tok>\S+) # <-- Normally you would have \s+ here
))"""
class IdentitySplitter(object):
def tokenize(self, *text):
return text
class Encoder(object):
def __init__(self, args):
self.args = args
def initializer(self):
# Use Encoder class as a container for global data
Encoder.tokenizer = build_tokenizer(self.args)
if self.args.split_sentences:
if not nltk_available:
print("NLTK is not available to split sentences.")
exit()
library = "tokenizers/punkt/{}.pickle".format(self.args.lang)
print("loading: " + library)
splitter = nltk.load(library)
if self.args.keep_newlines:
# this prevents punkt from eating newlines after sentences
Encoder.splitter = nltk.tokenize.punkt.PunktSentenceTokenizer(
train_text=splitter._params,
lang_vars=CustomLanguageVars())
else:
Encoder.splitter = splitter
else:
Encoder.splitter = IdentitySplitter()
def encode(self, json_line):
data = json.loads(json_line)
ids = {}
for key in self.args.json_keys:
text = data[key]
doc_ids = []
for sentence in Encoder.splitter.tokenize(text):
sentence_ids = Encoder.tokenizer.tokenize(sentence)
if len(sentence_ids) > 0:
doc_ids.append(sentence_ids)
if len(doc_ids) > 0 and self.args.append_eod:
doc_ids[-1].append(Encoder.tokenizer.eod)
ids[key] = doc_ids
return ids, len(json_line)
def get_args():
parser = argparse.ArgumentParser()
group = parser.add_argument_group(title='input data')
group.add_argument('--input', type=str, required=True,
help='Path to input JSON')
group.add_argument('--json-keys', nargs='+', default=['text'],
help='space separate listed of keys to extract from json')
group.add_argument('--split-sentences', action='store_true',
help='Split documents into sentences.')
group.add_argument('--keep-newlines', action='store_true',
help='Keep newlines between sentences when splitting.')
group = parser.add_argument_group(title='tokenizer')
group.add_argument('--tokenizer-type', type=str, required=True,
choices=['BertWordPieceLowerCase','BertWordPieceCase',
'GPT2BPETokenizer', 'SentencePieceTokenizer', 'GPTSentencePieceTokenizer'],
help='What type of tokenizer to use.')
group.add_argument('--vocab-file', type=str, default=None,
help='Path to the vocab file')
group.add_argument('--merge-file', type=str, default=None,
help='Path to the BPE merge file (if necessary).')
group.add_argument('--append-eod', action='store_true',
help='Append an <eod> token to the end of a document.')
group.add_argument('--lang', type=str, default='english',
help='Language to use for NLTK-powered sentence splitting.')
group.add_argument('--tokenizer-model', type=str, default=None,
help='sentencepeice tokenizer model.')
group = parser.add_argument_group(title='output data')
group.add_argument('--output-prefix', type=str, required=True,
help='Path to binary output file without suffix')
group.add_argument('--dataset-impl', type=str, default='mmap',
choices=['lazy', 'cached', 'mmap'])
group = parser.add_argument_group(title='runtime')
group.add_argument('--workers', type=int, required=True,
help='Number of worker processes to launch')
group.add_argument('--chunk-size', type=int, required=True,
help='Chunk size assigned to each worker process')
group.add_argument('--log-interval', type=int, default=100,
help='Interval between progress updates')
args = parser.parse_args()
args.keep_empty = False
if args.tokenizer_type.lower().startswith('bert'):
if not args.split_sentences:
print("Bert tokenizer detected, are you sure you don't want to split sentences?")
# some default/dummy values for the tokenizer
args.rank = 0
args.make_vocab_size_divisible_by = 128
args.tensor_model_parallel_size = 1
args.vocab_extra_ids = 0
return args
def main():
args = get_args()
startup_start = time.time()
print("Opening", args.input)
fin = open(args.input, 'r', encoding='utf-8')
if nltk_available and args.split_sentences:
nltk.download("punkt", quiet=True)
encoder = Encoder(args)
tokenizer = build_tokenizer(args)
pool = multiprocessing.Pool(args.workers, initializer=encoder.initializer)
encoded_docs = pool.imap(encoder.encode, fin, args.chunk_size)
#encoded_docs = map(encoder.encode, fin)
level = "document"
if args.split_sentences:
level = "sentence"
print(f"Vocab size: {tokenizer.vocab_size}")
print(f"Output prefix: {args.output_prefix}")
output_bin_files = {}
output_idx_files = {}
builders = {}
for key in args.json_keys:
output_bin_files[key] = "{}_{}_{}.bin".format(args.output_prefix,
key, level)
output_idx_files[key] = "{}_{}_{}.idx".format(args.output_prefix,
key, level)
builders[key] = indexed_dataset.make_builder(output_bin_files[key],
impl=args.dataset_impl,
vocab_size=tokenizer.vocab_size)
startup_end = time.time()
proc_start = time.time()
total_bytes_processed = 0
print("Time to startup:", startup_end - startup_start)
for i, (doc, bytes_processed) in enumerate(encoded_docs, start=1):
total_bytes_processed += bytes_processed
for key, sentences in doc.items():
if len(sentences) == 0:
continue
for sentence in sentences:
builders[key].add_item(torch.IntTensor(sentence))
builders[key].end_document()
if i % args.log_interval == 0:
current = time.time()
elapsed = current - proc_start
mbs = total_bytes_processed/elapsed/1024/1024
print(f"Processed {i} documents",
f"({i/elapsed} docs/s, {mbs} MB/s).",
file=sys.stderr)
print("Done! Now finalizing.")
for key in args.json_keys:
builders[key].finalize(output_idx_files[key])
if __name__ == '__main__':
main()
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Processing nmt data for finetuning."""
import argparse
import json
import multiprocessing
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
os.path.pardir)))
import time
import torch
from megatron.tokenizer import build_tokenizer
from megatron.data import indexed_dataset
class Encoder(object):
def __init__(self, args):
self.args = args
def initializer(self):
# Use Encoder class as a container for global data
Encoder.tokenizer = build_tokenizer(self.args)
def encode(self, text):
ids = {}
ids = Encoder.tokenizer.tokenize(text)
assert len(ids) > 0
return ids, len(text)
def get_args():
parser = argparse.ArgumentParser()
group = parser.add_argument_group(title='input data')
group.add_argument('--input', type=str, required=True,
help='Path to input JSON')
group = parser.add_argument_group(title='tokenizer')
group.add_argument('--tokenizer-type', type=str, default='YTTMTokenizer',
choices=['BertWordPieceLowerCase','BertWordPieceCase',
'GPT2BPETokenizer', 'SentencePieceTokenizer'],
help='What type of tokenizer to use.')
group.add_argument('--vocab-file', type=str, default=None,
help='Path to the vocab file')
group.add_argument('--merge-file', type=str, default=None,
help='Path to the BPE merge file (if necessary).')
group = parser.add_argument_group(title='output data')
group.add_argument('--output-prefix', type=str, required=True,
help='Path to binary output file without suffix')
group.add_argument('--dataset-impl', type=str, default='mmap',
choices=['lazy', 'cached', 'mmap'])
group = parser.add_argument_group(title='runtime')
group.add_argument('--workers', type=int, default=1,
help='Number of worker processes to launch')
group.add_argument('--log-interval', type=int, default=100,
help='Interval between progress updates')
args = parser.parse_args()
args.keep_empty = False
# some default/dummy values for the tokenizer
args.rank = 0
args.make_vocab_size_divisible_by = 128
args.tensor_model_parallel_size = 1
args.vocab_extra_ids = 0
return args
def main():
args = get_args()
startup_start = time.time()
print("Opening", args.input)
fin = open(args.input, 'r', encoding='utf-8')
encoder = Encoder(args)
tokenizer = build_tokenizer(args)
pool = multiprocessing.Pool(args.workers, initializer=encoder.initializer)
encoded_sentences = pool.imap(encoder.encode, fin, 25)
print(f"Vocab size: {tokenizer.vocab_size}")
print(f"Output prefix: {args.output_prefix}")
output_bin_file = "{}.bin".format(args.output_prefix)
output_idx_file = "{}.idx".format(args.output_prefix)
builder = indexed_dataset.make_builder(output_bin_file,
impl=args.dataset_impl,
vocab_size=tokenizer.vocab_size)
startup_end = time.time()
proc_start = time.time()
total_bytes_processed = 0
print("Time to startup:", startup_end - startup_start)
for i, (sentence, bytes_processed) in enumerate(encoded_sentences, start=1):
total_bytes_processed += bytes_processed
builder.add_item(torch.IntTensor(sentence))
# documents contain only one sentence.
builder.end_document()
if i % args.log_interval == 0:
current = time.time()
elapsed = current - proc_start
mbs = total_bytes_processed/elapsed/1024/1024
print(f"Processed {i} sentences",
f"({i/elapsed} sentences/s, {mbs} MB/s).",
file=sys.stderr)
builder.finalize(output_idx_file)
if __name__ == '__main__':
main()
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Processing large data for pretraining."""
import argparse
import math
import json
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
os.path.pardir)))
import time
import gzip
import glob
import torch
import numpy as np
import multiprocessing
try:
import nltk
nltk_available = True
except ImportError:
nltk_available = False
from megatron.tokenizer import build_tokenizer
from megatron.data import indexed_dataset
# https://stackoverflow.com/questions/33139531/preserve-empty-lines-with-nltks-punkt-tokenizer
class CustomLanguageVars(nltk.tokenize.punkt.PunktLanguageVars):
_period_context_fmt = r"""
\S* # some word material
%(SentEndChars)s # a potential sentence ending
\s* # <-- THIS is what I changed
(?=(?P<after_tok>
%(NonWord)s # either other punctuation
|
(?P<next_tok>\S+) # <-- Normally you would have \s+ here
))"""
class IdentitySplitter(object):
def tokenize(self, *text):
return text
class Encoder(object):
def __init__(self, args):
self.args = args
def initializer(self):
# Use Encoder class as a container for global data
Encoder.tokenizer = build_tokenizer(self.args)
if self.args.split_sentences:
if not nltk_available:
print("NLTK is not available to split sentences.")
exit()
library = "tokenizers/punkt/{}.pickle".format(self.args.lang)
splitter = nltk.load(library)
if self.args.keep_newlines:
# this prevents punkt from eating newlines after sentences
Encoder.splitter = nltk.tokenize.punkt.PunktSentenceTokenizer(
train_text = splitter._params,
lang_vars = CustomLanguageVars())
else:
Encoder.splitter = splitter
else:
Encoder.splitter = IdentitySplitter()
def split(self, json_line):
data = json.loads(json_line)
output = {}
for key in self.args.json_keys:
text = data[key]
max_len = 1000000
tokens_list = [Encoder.splitter.tokenize(text[i:i+max_len]) for i in range(0, len(text), max_len)]
output[key] = [tokens for partial in tokens_list for tokens in partial]
return json.dumps(output), len(json_line)
def encode(self, json_line):
data = json.loads(json_line)
ids = {}
lens = {}
for key in self.args.json_keys:
text = data[key]
if isinstance(text, list):
sentences = text
else:
sentences = [text]
doc_ids = []
sentence_lens = []
for sentence in sentences:
sentence_ids = Encoder.tokenizer.tokenize(sentence)
if len(sentence_ids) > 0:
doc_ids.extend(sentence_ids)
sentence_lens.append(len(sentence_ids))
if len(doc_ids) > 0 and self.args.append_eod:
doc_ids.append(Encoder.tokenizer.eod)
ids[key] = doc_ids
lens[key] = sentence_lens
return ids, lens, len(json_line)
class Partition(object):
def __init__(self, args, workers):
self.args = args
self.workers = workers
def print_processing_stats(self, count, proc_start, total_bytes_processed):
if count % self.args.log_interval == 0:
current = time.time()
elapsed = current - proc_start
mbs = total_bytes_processed/elapsed/1024/1024
print(f"Processed {count} documents",
f"({count/elapsed} docs/s, {mbs} MB/s).",
file=sys.stderr)
def split_sentences(self, file_name):
input_file_name, output_file_name = file_name
print("Opening", input_file_name)
fin = open(input_file_name, 'r', encoding='utf-8')
fout = open(output_file_name, 'w')
encoder = Encoder(self.args)
pool = multiprocessing.Pool(self.workers, initializer=encoder.initializer)
split_docs = pool.imap(encoder.split, fin, 32)
proc_start = time.time()
total_bytes_processed = 0
for i, (doc, bytes_processed) in enumerate(split_docs, start=1):
total_bytes_processed += bytes_processed
fout.write(doc + "\n")
self.print_processing_stats(i, proc_start, total_bytes_processed)
fin.close()
fout.close()
def process_json_file(self, file_name):
input_file_name, output_prefix = file_name
print("Opening", input_file_name)
fin = open(input_file_name, 'r', encoding='utf-8')
startup_start = time.time()
encoder = Encoder(self.args)
tokenizer = build_tokenizer(self.args)
pool = multiprocessing.Pool(self.workers, initializer=encoder.initializer)
encoded_docs = pool.imap(encoder.encode, fin, 32)
level = "document"
if self.args.split_sentences:
level = "sentence"
output_bin_files = {}
output_idx_files = {}
builders = {}
for key in self.args.json_keys:
output_bin_files[key] = "{}_{}_{}.bin".format(output_prefix,
key, level)
output_idx_files[key] = "{}_{}_{}.idx".format(output_prefix,
key, level)
builders[key] = indexed_dataset.make_builder(output_bin_files[key],
impl=self.args.dataset_impl,
vocab_size=tokenizer.vocab_size)
startup_end = time.time()
proc_start = time.time()
total_bytes_processed = 0
print("Time to startup:", startup_end - startup_start)
for i, (doc, sentence_lens, bytes_processed) in enumerate(encoded_docs, start=1):
total_bytes_processed += bytes_processed
for key in doc.keys():
builders[key].add_doc(doc[key], sentence_lens[key])
self.print_processing_stats(i, proc_start, total_bytes_processed)
fin.close()
builders[key].finalize(output_idx_files[key])
def get_args():
parser = argparse.ArgumentParser()
group = parser.add_argument_group(title='input data')
group.add_argument('--input', type=str, required=True,
help='Path to input JSON')
group.add_argument('--json-keys', nargs='+', default=['text'],
help='space separate listed of keys to extract from json')
group.add_argument('--split-sentences', action='store_true',
help='Split documents into sentences.')
group.add_argument('--keep-newlines', action='store_true',
help='Keep newlines between sentences when splitting.')
group = parser.add_argument_group(title='tokenizer')
group.add_argument('--tokenizer-type', type=str, required=True,
choices=['BertWordPieceLowerCase','BertWordPieceCase',
'GPT2BPETokenizer', 'SentencePieceTokenizer', 'GPTSentencePieceTokenizer'],
help='What type of tokenizer to use.')
group.add_argument('--tokenizer-model', type=str, default=None,
help='YTTM tokenizer model.')
group.add_argument('--vocab-file', type=str, default=None,
help='Path to the vocab file')
group.add_argument('--merge-file', type=str, default=None,
help='Path to the BPE merge file (if necessary).')
group.add_argument('--append-eod', action='store_true',
help='Append an <eod> token to the end of a document.')
group.add_argument('--lang', type=str, default='english',
help='Language to use for NLTK-powered sentence splitting.')
group = parser.add_argument_group(title='output data')
group.add_argument('--output-prefix', type=str, required=True,
help='Path to binary output file without suffix')
group.add_argument('--dataset-impl', type=str, default='mmap',
choices=['lazy', 'cached', 'mmap'])
group = parser.add_argument_group(title='runtime')
group.add_argument('--workers', type=int, default=1,
help='Number of worker processes to launch')
group.add_argument('--partitions', type=int, default=1,
help='Number of file partitions')
group.add_argument('--log-interval', type=int, default=1000,
help='Interval between progress updates')
args = parser.parse_args()
args.keep_empty = False
if args.tokenizer_type.lower().startswith('bert') and not args.split_sentences:
print("Are you sure you don't want to split sentences?")
# some default/dummy values for the tokenizer
args.rank = 1
args.make_vocab_size_divisible_by = 128
args.tensor_model_parallel_size = 1
args.vocab_extra_ids = 0
return args
def get_file_name(args, file_id):
file_name, extension = os.path.splitext(args.input)
input_file_name = file_name + "_" + str(file_id) + extension
sentence_split_file = file_name + "_ss_" + str(file_id) + extension
output_prefix = args.output_prefix + "_" + str(file_id)
file_names = {
'partition': input_file_name,
'sentence_split': sentence_split_file,
'output_prefix': output_prefix}
return file_names
def check_files_exist(in_ss_out_names, key, num_partitions):
for i in range(num_partitions):
if not os.path.exists(in_ss_out_names[i][key]):
return False
return True
def main():
args = get_args()
if args.split_sentences:
if nltk_available:
nltk.download("punkt", quiet=True)
else:
raise Exception(
"nltk library required for sentence splitting is not available.")
in_ss_out_names = []
if args.partitions == 1:
file_name, extension = os.path.splitext(args.input)
sentence_split_file = file_name + "_ss" + extension
file_names = {
'partition': args.input,
'sentence_split': sentence_split_file,
'output_prefix': args.output_prefix}
in_ss_out_names.append(file_names)
else:
in_file_names = glob.glob(args.input)
# create .jsonl parition files
for idx in range(args.partitions):
in_ss_out_name = get_file_name(args, idx)
in_ss_out_names.append(in_ss_out_name)
# check to see if paritions were already created
partitions_present = check_files_exist(in_ss_out_names, 'partition', args.partitions)
# check to see if paritions with split sentences already created
split_sentences_present = check_files_exist(in_ss_out_names, 'sentence_split', args.partitions)
if not partitions_present and not split_sentences_present:
# populate .jsonl partition files from parent files
partitioned_input_files = []
for idx in range(args.partitions):
partitioned_input_file = open(in_ss_out_names[idx]['partition'], 'w')
partitioned_input_files.append(partitioned_input_file)
index = 0
for in_file_name in in_file_names:
# support for gzip files
if in_file_name.endswith(".gz"):
fin = gzip.open(in_file_name, 'rt')
else:
fin = open(in_file_name, 'r', encoding='utf-8')
for line in fin:
partitioned_input_files[index].write(line)
index = (index + 1)%args.partitions
fin.close()
for idx in range(args.partitions):
partitioned_input_files[idx].close()
assert args.workers % args.partitions == 0
partition = Partition(args, args.workers//args.partitions)
# check to see if paritions with split sentences already created
split_sentences_present = check_files_exist(in_ss_out_names, 'sentence_split', args.partitions)
# split sentences in partition files
if args.split_sentences and not split_sentences_present:
processes = []
for name in in_ss_out_names:
p = multiprocessing.Process(target=partition.split_sentences,
args=((name['partition'], name['sentence_split']),))
p.start()
processes.append(p)
for p in processes:
p.join()
if args.partitions == 1:
return
# encode partition files in parallel
processes = []
input_key = 'sentence_split' if args.split_sentences else 'partition'
for name in in_ss_out_names:
p = multiprocessing.Process(target=partition.process_json_file,
args=((name[input_key], name['output_prefix']),))
p.start()
processes.append(p)
for p in processes:
p.join()
# merge bin/idx partitions
level = "document"
if args.split_sentences:
level = "sentence"
output_bin_files = {}
output_idx_files = {}
builders = {}
tokenizer = build_tokenizer(args)
for key in args.json_keys:
output_bin_files[key] = "{}_{}_{}.bin".format(args.output_prefix,
key, level)
output_idx_files[key] = "{}_{}_{}.idx".format(args.output_prefix,
key, level)
builders[key] = indexed_dataset.make_builder(output_bin_files[key],
impl=args.dataset_impl,
vocab_size=tokenizer.vocab_size)
for name in in_ss_out_names:
parition_output_prefix = name['output_prefix']
full_partition_output_prefix = "{}_{}_{}".format(parition_output_prefix,
key, level)
builders[key].merge_file_(full_partition_output_prefix)
builders[key].finalize(output_idx_files[key])
if __name__ == '__main__':
main()
This directory contains a collection of tools for building the retrieval database and pretraining neighbors for Retro. This preprocessing pipeline is broken into 3 main stages:
1. **Build retrieval chunk database** : Used for retrieving neighbors and continuation chunks, which are then passed through the retrieval encoder.
2. **Build index for similarity search** : Train and build a search index for querying chunk neighbors.
3. **Query pretraining neighbors** : For matching pretraining samples to database chunks. Neighbors are generated separately for training, validation, and test datasets.
The following overview goes into more detail on the pipeline, code structure, usage, and pretraining.
<!-- ################ contents ################ -->
# Contents
* [Quick start](#quick-start)
* [Stages](#stages)
* [Code structure](#code-structure)
* [Arguments](#arguments)
<!-- * [Pretraining](#pretraining) -->
<!-- ################ quick start ################ -->
# Quick start
See `examples/get_preprocess_cmd.sh` for example arguments.
Key files:
- `main.py` : Entry point.
- `examples/get_preprocess_cmd.sh` : Build preprocessing command (for `main.py`).
- `examples/preprocess_data.sh` : Run preprocessing (calls `get_preprocess_cmd.sh`, `main.py`).
Use `--retro-tasks` to move through the preprocessing pipeline.
- Simplest setup (builds everything): `--retro-tasks build`
- Alternatively, for tuning compute resources, run stages independently:
- Build retrieval database: `--retro-tasks db-build`
- Build search index: `--retro-tasks index-build`
- Query neighbors: `--retro-tasks pretraining-query-neighbors`
Sample code flow:
- `main.py` : Entry point (e.g., using `--retro-tasks X`).
- `db/build.py` : Build retrieval database.
- `index/build.py` : Build search index. Calls the following two files:
- `index/train.py` : Train index on subset of database.
- `index/add.py` : Add database chunks to index.
- `pretraining/query.py` : Query pretraining samples for database neighbors (saved to disk and used during pretraining).
<!-- ################ stages ################ -->
# Stages
### Build retrieval chunk database
This *database* (stored as a 2-D array, NOT a relational database) consists of a list of chunks (traditionally length 64) extracted from the original GPT token dataset. This is simply a consecutive, non-overlapping chunking of the token dataset. Chunking only takes place within a document, and therefore the final chunk of each document has length: 1 <= chunk_length <= max_chunk_length.
We discard chunks that would convert to an empty Bert sequence (rare case, happens ~1/100,000 chunks in our case), since we use Bert embeddings for building our index. Thus, the total number of chunks in the database will be slightly less than a naive calculation.
### Build index for similarity search
To match pretraining chunks to database chunks, a search index must be built to perform this querying. We use Faiss (https://github.com/facebookresearch/faiss) for training and building this index. Generally, the index is trained on a subset of all chunks in the database (specified via `--retro-nchunks-sampled`). After training, all chunks are added into the index, to be available during querying.
Indexes only accept 1-D floating point vectors for training and adding, so each chunk must first be embedded before passing to the index for either training or adding. We use Bert embeddings for this purpose, and the embeddings are generated automatically within the pipeline.
### Query pretraining neighbors
To ensure fast Retro pretraining, the database neighbors for pretraining samples are pre-computed and saved to disk, for efficient access within the Retro dataset. In this stage, the pretraining datasets (training, validation, and test) are iterated, each sample is broken into chunks, and the chunks are used for querying the index. Similar to when building the index, each chunk is embedded (via Bert) before querying the index.
The saved neighbors are labeled with unique dataset properties (i.e., seed, sequence length, number of samples, etc.) to ensure the neighbors generated during preprocessing match the neighbors requested during pretraining.
<!-- ################ code structure ################ -->
# Code structure
### `tools/retro/main.py`
This is the main entry point for Retro preprocessing. Call `main.py --help` to see arguments. Additionally, some Retro arguments are in Megatron's core arguments, so also see `add_retro_args()` section of `megatron/arguments.py` for additional arguments. Two of the most important arguments to customize are `--retro-workdir` and `--retro-tasks`.
- **`--retro-workdir`** : Set the directory in which the preprocessing pipeline saves its datasets and configuration files. This argument should remain consistent for a full pass through the pipeline, and for pretraining.
- **`--retro-tasks`** : Set the stages of preprocessing to perform. As mentioned previously, the three high-level stages are: 1) build retrieval database, 2) build search index, and 3) query pretraining neighbors. `--retro-tasks` can be used to either run the full pipeline, or run each of these stages in isolation. The latter case is useful for tuning compute resources for each stage. For example, index training utilizes GPUs and requires relatively less time, while querying neighbors uses the CPU and is a relatively slow process. Example tasks include:
- **`--retro-tasks build`** : Run entire preprocessing pipeline.
- **`--retro-tasks db-build`** : Build retrieval database.
- **`--retro-tasks index-build`** : Train and build search index.
- **`--retro-tasks pretraining-query-neighbors`** : Query pretraining neighbors.
Multiple tasks can be specified by separating with commas (e.g., `--retro-tasks db-build,index-build`). Additionally, various 'miscellaneous' tasks are currently including, primarily for validating data for each stage; these task names can be seen in `main.py`.
### `tools/retro/examples`
Example scripts for setting arguments and launch Retro preprocessing. The key files here are:
- **`get_preprocess_cmd.sh`** : Sets up arguments and command for preprocessing. **Important note**: this script assumes a few environment variables are already set before it is called. Please see the `Environment vars.` section at the top of this file. Generally, environment variables must be set to determine the location of Retro workdirs, input datasets, and GPT and Bert model information.
- **`preprocess_data.sh`** : Calls `get_preprocess_cmd.sh` to get arguments, and then calls `main.py` to launch preprocessing.
- **`pretrain_model.sh`** : Example script for pretraining on Wikipedia data, after preprocessing is complete.
### `tools/retro/db`
Build the retrieval chunk database. The key files here are:
- **`build.py`** : Entry point for building the database. This code is responsible for iterating the input datasets (i.e., `--data-path`), parsing each dataset into consecutive chunks, checking for empty Bert (Wordpiece) conversions, and storing this information to disk. Two databases are created: 1) the retrieval database, and 2) a sampled database used for training the search index.
- **`dataset.py`** : Defines database class, for iterating or accessing chunks in the database. Each chunk contains its tokens, Bert conversion length, and dataset index.
Input data:
<!-- - Token datasets, as generated by `tools/preprocess_data.py`. Each dataset should include a `.bin` and `.idx` file. Multiple datasets can be specified by using a blended configuration (see `--data-path` in `megatron/arguments.py`). -->
- Token datasets, as loaded by `gpt_dataset.py`. Multiple datasets can be specified by using a blended configuration (see `--data-path` in `megatron/arguments.py`).
Output data:
- **`<RETRO_WORKDIR>/db/merged/train.hdf5`** : The main retrieval database. (*Database* here is used to denote a list of indexed chunks, rather than a *relational database*.) The chunks in this database are added to the search index, and are used for retrieval during pretraining. This file contains a single dataset `'chunks'`, which contains 5 columns:
- `dataset_idx` : Dataset index, from list of blended indexed datasets.
- `document_idx` : Document index within dataset.
- `chunk_start_idx` : Chunk's starting token index within document.
- `chunk_end_idx` : Chunk's ending token index (exclusive) within document.
- `bert_chunk_length` : Length of Bert token sequence, after converting from GPT.
- **`<RETRO_WORKDIR>/db/merged/sampled.hdf5`** : Subset of training database that is used for training the search index. This file has the same structure as detailed above. In general, this database is significanly smaller than the `train.hdf5` database, since the search index only needs a relatively small number of samples to understand the data's structure. After training, all chunks in the main database (`train.hdf5`) are *added* to the search index.
### `tools/retro/index`
Build the search index. The key files here are:
- `build.py` : Entry point for building the search index. First, the index is trained on the sampled chunk database (see above) by calling `train.py`, and then all chunks for the full database are added to the index by calling `add.py`. Note that training requires first embedding (using Bert) all chunks (a parallel operation), and then loading these embeddings and training the index (a sequential operation), so it's best to change one's compute setup after all chunks have been embedded and saved to disk.
- `indexes/faiss_base.py` : Wrapper class for building a Faiss index, following the standard `train()` and `add()` operations.
- `indexes/faiss_par_add.py` : Similar to above, except it uses an embarrassingly parallel (multi-node, multi-process) `add()` operation. Vectors are first added to separate index copies, and then merged together.
Input data:
- **`<RETRO_WORKDIR>/db/merged/sampled.hdf5`** : Chunks used for training the search index.
- **`<RETRO_WORKDIR>/db/merged/train.hdf5`** : Chunks used for adding to the *trained* search index.
Output data:
- **`<RETRO_WORKDIR>/index/<RETRO_INDEX_TYPE>/<RETRO_INDEX_STR>/added.faissindex`** : The final index, which has been trained and has had all database chunks added to it. This index is ready for querying neighbors. Here, `RETRO_INDEX_TYPE` and `RETRO_INDEX_STR` correspond to the same-name arguments `--retro-index-type` (e.g., `faiss-par-add`) and `--retro-index-str` (e.g., `OPQ32_256,IVF4194304_HNSW32,PQ32`).
- **`<RETRO_WORKDIR>/index/<RETRO_INDEX_TYPE>/<RETRO_INDEX_STR>/empty.faissindex`** : Generally can be discarded once `added.faissindex` has been built, but this file contains the *post-training*, *pre-adding* index. Useful for debugging or building other indexes.
### `tools/retro/pretraining`
Query the pretraining datasets (training, validation, test) for their neighbors within the database. Neighbors are queried during preprocessing -- rather than during pretraining -- because querying is a fairly slow operation, so it would be a bottleneck if performed during pretraining. Queried neighbors are tagged with their unique identifying information (e.g., `train_indexmap_27662746ns_2048sl_1234s`), so as to avoid incorrect references during pretraining. The key files here are:
- **`query.py`** : Entry point for querying. The pretraining datasets are iterated, and each chunk within each sample is queried using the search index. These neighbors are filtered by discarding any database chunks that fall within the same document as any chunk within a pretraining sample.
- **`chunk_dataset.py`** : This creates an iterable 'chunk' dataset form of a pretraining dataset. This is just a light wrapper, but makes it easier to deterministically iterate and assign IDs to each chunk in a sample dataset.
- **`retro_dataset.py`** : The Retro dataset used for pretraining (not used in preprocessing). Each sample returns the sample tokens, along with neighbor tokens for each chunk within the sample.
Input data:
- Token datasets, as loaded by `gpt_dataset.py`.
- **`<RETRO_WORKDIR>/index/<RETRO_INDEX_TYPE>/<RETRO_INDEX_STR>/added.faissindex`** : The trained index, with all database chunks added to it (see previous section for details).
Output data:
- **`<RETRO_WORKDIR>/{train,valid,test}_XXns_YYsl_ZZs/WW.hdf5`** : These directories/files contain the indexes of neighbors for each chunk within each sample of the pretraining datasets. Each directory (e.g., `train_indexmap_2047435ns_2048sl_1234s`) contains a list of HDF5 files (e.g., one file might be called `0075700000-0075800000.hdf5`). Each HDF5 file contains a consecutive subset of neighbor IDs for a given chunk, for indexing into the main retrieval database. All HDF5 files taken together within a given directory, represent the entire set of neighbors for a dataset. The size of these HDF5 files is determined by the argument `--retro-block-size`. The `XX`, `YY`, `ZZ`, `WW` notation above denotes the dataset properties that are used for uniquely tagging the neighbor files, to ensure compatibility during model pretraining. These neighbor files are ultimated used by `retro_dataset.py` during pretraining, for building Retro samples.
### `tools/retro/cli`
Inspect preprocessed data. To use the CLI, open a Python terminal via the `python` command, and then load a Retro workdir with the following:
```
from tools.retro.cli import retro
retro.init("/path/to/retro/workdir")
```
This initializes Megatron, and prepares the Retro data for inspection. See the printed usage for available functions. Several routines are included for viewing data in the retrieval database and viewing pretraining samples and neighbors. For example:
```python
retro.get_db_num_indexed_datasets() # 15
retro.get_db_chunk_text(92874113) # 'research project at ... and philosophy'
retro.get_pt_sample('train', 62005) # '[16084, 26158, 25387 ..., 6898, 9568]'
```
Most methods within the CLI are prefixed to denote the data being inspected:
- **'db'** : Retrieval database (i.e., chunk tokens, document IDs, and dataset IDs)
- **'pt'** : Pretraining datasets (i.e., sample tokens and neighbor tokens)
### `tools/retro/utils.py`
A collection of utility methods. Most importantly, this contains:
- **`def get_gpt_tokenizer()`** : Get the GPT tokenizer.
- **`def get_bert_tokenizer()`** : Get the Bert tokenizer.
- **`class GPTToTextDataset`** : Wrapper class that converts GPT (BPE) samples to raw text.
### `tools/bert_embedding`
Generate Bert embeddings. The main files here are:
- **`embed.py`** : Entry point for generating embeddings, and contains the two main embedding classes, `BertEmbedder` and `DiskDataParallelBertEmbedder` (more below). This file contains code for generating Megatron embeddings, while the file below contains code for Huggingface embeddings.
- **`huggingface.py`** : Used by `embed.py` when the embedder is configured (see below) to output Huggingface embeddings.
- **`dataset.py`** : Wrapper class for converting a raw-text dataset to Bert (Wordpiece) tokens.
The Bert embeddings can be configured along two axes. The first axis is the output type:
- **`class BertEmbedder`** : This class takes a raw-text dataset as input, generates its embeddings, and returns a Numpy array. The main functions are `embed_text_dataset` (accepts a raw-text dataset) and `embed_text` (accepts a string).
- **`class DiskDataParallelBertEmbedder`** : This class wraps `BertEmbedder`, and rather than returning a Numpy array, it saves the embeddings to disk. Additionally, this class automatically splits data across data parallel ranks (using interleaving), and also processes data in a specified `block_size` (e.g., 1,000,000).
The second axis is the type of embedding model to use, controlled by the argument `--bert-embedder-type`:
- **`--bert-embedder-type megatron`** : Use Megatron's Bert model. The specific model used is dependent on the loaded checkpoint, vocab file, and tokenizer.
- **`--bert-embedder-type huggingface`** : Use Huggingface's `bert-large-cased`. (*Note*: Huggingface's inclusion is likely to be deprecated; and there is no ability to configure cased/uncased.)
### Pretraining
- **`pretrain_retro.py`** : Launch script for pretraining Retro. Similar to `pretrain_gpt.py`, except this script handles loading neighbor tokens and setting up the neighbor attention mask.
<!-- - `megatron/data/gpt_dataset.py` : ? -->
- **`megatron/model/retro_transformer.py`** : Implementation of Retro model, including the main transformer, the retrieval encoder, and chunked cross-attention layers. Note that currently, `retro_transformer.py` contains several classes that are nearly identical to `transformer.py`, except for 1 or 2 lines, due to code changes that are yet to be integrated.
- **`tools/retro/pretraining/retro_dataset.py`** : The Retro dataset used for pretraining (not used in preprocessing). Each sample returns the sample tokens, along with neighbor tokens for each chunk within the sample.
<!-- ################ arguments ################ -->
# Arguments
See `tools/retro/main.py`'s `add_retro_args()` and `megatron/arguments.py`'s `_add_retro_args()` for details and descriptions. Here we list some particularly important arguments:
- `--retro-workdir` : Mentioned previously, this argument determines the directory in which a set of Retro data is stored (during preprocessing) and loaded (during pretraining). Any change in this directory during preprocessing may result in preprocessing starting over from scratch, and any change before pretraining will result in pretraining throwing an error.
- Preprocessing
- `--retro-gpt-chunk-length` : Retro chunk length (e.g., 64 in original paper).
- `--retro-tasks` : Comma-separated list of preprocessing tasks. Generally, the `build` task is the simplest way to run the preprocessing pipeline. For finer control, individual stages can be run by using tasks (in order): `db-build`, `index-build`, and `pretraining-query-neighbors`.
- `--retro-index-str` : Faiss index string that defines the index configuration. This will vary based on data size, compute/disk setup, and user needs. For example, this string looks something like `IVF262144_HNSW32,Flat` or `OPQ32_256,IVF4194304_HNSW32,PQ32`.
- Pretraining
- `--retro-add-retriever` : Must be used to select Retro model.
- `--retro-num-neighbors` : Number of neighbors to retrieve from the retrieval database (defaults to 2).
- `--retro-num-retrieved-chunks` : For each neighbor, the number consecutive chunks to retrieve, including the initial neighbor (defaults to 2).
<!-- ################ pretraining ################ -->
<!-- # Pretraining -->
<!-- - New retro args in arguments.py (add_retro_args). -->
<!-- - Most important arg is `--retro-add-retriever`. -->
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import json
import numpy as np
import os
import torch
import types
from megatron.global_vars import set_global_variables, set_retro_args
from megatron.initialize import (
initialize_megatron,
_initialize_distributed,
_set_random_seed,
)
from tools.retro.db.utils import (
get_indexed_dataset_infos as get_db_indexed_dataset_infos,
get_merged_train_dataset as get_db_dataset,
)
from tools.retro.external_libs import h5py
from tools.retro.main import add_retro_args
from tools.retro.pretraining.retro_dataset import get_retro_datasets
from tools.retro.utils import get_args_path, get_bert_tokenizer, get_gpt_tokenizer
def shorten_str(s, n):
s = "\\n".join(s.splitlines())
return s if len(s) <= n else "%s ... %s" % (s[:n//2], s[-n//2:])
class retro:
args = None
##############################################
# initialize.
##############################################
@classmethod
def init_megatron(cls, workdir):
'''Custom initialization of Megatron.'''
# Load args.
args_path = get_args_path(workdir)
assert os.path.exists(args_path), "args.json not found in workdir."
with open(args_path) as f:
cls.args = types.SimpleNamespace(**json.load(f))
cls.args.retro_workdir = workdir # just in case workdir moved
cls.args.rank = 0 # override env
cls.args.world_size = 1 # override env
set_global_variables(cls.args)
set_retro_args(cls.args)
_initialize_distributed()
_set_random_seed(cls.args.seed, cls.args.data_parallel_random_init)
@classmethod
def init(cls, workdir):
'''Initialize Megatron, tokenizers, and datasets.'''
# Load args.
cls.init_megatron(workdir)
cls.tokenizers = types.SimpleNamespace(
gpt=get_gpt_tokenizer(),
bert=get_bert_tokenizer(),
)
# Load data.
cls.db_indexed_dataset_infos = get_db_indexed_dataset_infos()
pt_train_ds, pt_valid_ds, _ = get_retro_datasets()
cls.pt_datasets = types.SimpleNamespace(
train=pt_train_ds,
valid=pt_valid_ds,
)
# Print usage.
cls.print_usage()
##############################################
# utils.
##############################################
@classmethod
def gpt_to_text(cls, token_ids):
'''GPT tokens to text.'''
return cls.tokenizers.gpt.detokenize(token_ids)
@classmethod
def text_to_bert(cls, text):
'''Text to Bert tokens.'''
return cls.tokenizers.bert.tokenize(text)
##############################################
# chunk db.
##############################################
@classmethod
def get_db_num_indexed_datasets(cls):
'''Number of indexed datasets within blendable dataset.'''
return len(cls.db_indexed_dataset_infos)
@classmethod
def get_db_indexed_dataset_infos(cls):
'''Dataset infos, including number of training & sampled sets.'''
return [(info["ratio"], info["name"])
for info in cls.db_indexed_dataset_infos]
@classmethod
def get_db_dataset(cls):
return cls.pt_datasets.train.db_dataset
@classmethod
def get_db_num_chunks(cls):
'''Number of DB chunks.'''
return len(cls.get_db_dataset())
@classmethod
def get_db_chunk_gpt(cls, idx):
'''Get DB chunk as GPT token ids.'''
return cls.get_db_dataset()[idx]["text"].tolist()
@classmethod
def get_db_chunk_bert(cls, idx):
'''Get DB chunk as Bert token ids.'''
return cls.text_to_bert(cls.get_db_chunk_text(idx))
@classmethod
def get_db_chunk_text(cls, idx):
'''Get DB chunk as text.'''
return cls.gpt_to_text(cls.get_db_chunk_gpt(idx))
@classmethod
def get_db_chunk_and_continuation_text(cls, idx):
'''Get DB chunk along with continuation, as text.'''
# Modulus used here to match original implementation (i.e., last
# chunks continuation wraps around to first chunk).
return [
cls.get_db_chunk_text(idx),
cls.get_db_chunk_text((idx + 1) % len(cls.get_db_dataset())),
]
##############################################
# pretraining corpus.
##############################################
@classmethod
def get_pt_num_samples_and_chunks(cls, data_key):
'''Number of samples & chunks (e.g., 32*n_samples) in corpus.'''
assert hasattr(cls.pt_datasets, data_key), \
"pretraining set '%s' not found (choices: %s)." % (
data_key, ", ".join(vars(cls.pt_datasets).keys()))
chunk_dataset = getattr(cls.pt_datasets, data_key).chunk_dataset
return (
len(chunk_dataset.sample_dataset),
len(chunk_dataset),
)
@classmethod
def get_pt_num_samples(cls, data_key):
'''Number of pretraining samples.'''
return cls.get_pt_num_samples_and_chunks(data_key)[0]
@classmethod
def get_pt_num_chunks(cls, data_key):
'''Number of pretraining chunks (e.g., 32*n_samples).'''
return cls.get_pt_num_samples_and_chunks(data_key)[1]
@classmethod
def get_pt_sample(cls, data_key, idx):
return getattr(cls.pt_datasets, data_key)[idx]
##############################################
# usage.
##############################################
@classmethod
def print_usage(cls):
'''Print usage.'''
print()
print("+++++++++++++++++++++++++++++++++++++++++++++++++++")
print("examples ... [ *note*: 'db' = chunk db; 'pt' = pretraining corpus. ]")
print("+++++++++++++++++++++++++++++++++++++++++++++++++++")
print()
print("~~~~ indexed datasets ~~~~")
print("retro.get_db_num_indexed_datasets() : %s" %
cls.get_db_num_indexed_datasets())
print("retro.get_db_indexed_dataset_infos() :")
for i, (ratio,prefix) in enumerate(cls.get_db_indexed_dataset_infos()):
print(" %s(%f, %s)%s" % (
"[" if i == 0 else " ",
ratio,
prefix,
"]" if i == len(cls.db_indexed_dataset_infos) - 1 else ",",
))
print()
print("~~~~ counts ~~~~")
print("retro.get_db_num_chunks : %d." % cls.get_db_num_chunks())
print()
for sq_key in ("sample", "chunk"):
for data_key in ("train", "valid"): # test?
print("retro.get_pt_num_%ss('%s') : %d." % (
sq_key, data_key,
getattr(cls, f"get_pt_num_{sq_key}s")(data_key)))
print()
print("~~~~ tokens, text ~~~~")
print("retro.get_db_chunk_gpt(chunk_id) : %s" %
shorten_str(str(retro.get_db_chunk_gpt(0)), 50))
print("retro.get_db_chunk_bert(chunk_id) : %s" %
shorten_str(str(retro.get_db_chunk_bert(0)), 50))
print("retro.get_db_chunk_text(chunk_id) : %s" %
shorten_str(retro.get_db_chunk_text(0).strip(), 50))
print("retro.get_db_chunk_and_continuation_text(chunk_id) :")
for i, t in enumerate(retro.get_db_chunk_and_continuation_text(0)):
print(" %s'%s'%s" % (
"[" if i == 0 else " ",
shorten_str(t.strip().replace("\n", " "), 50),
"]" if i == 1 else ",",
))
sample = cls.get_pt_sample("train", 0)
print()
print("retro.get_pt_sample('train', sample_id) :")
print(" {")
for k, v in sample.items():
print(" '%s' : %s" % (k, shorten_str(str(v), 50)))
print(" }")
print()
print("(e.g., sample = retro.get_pt_sample(...))")
print()
print(" sample['text'].shape : %s" % str(sample["text"].shape))
print(" sample['neighbor_tokens'].shape : %s" % str(sample["neighbor_tokens"].shape))
print(" sample['text'] : %s" % shorten_str(str(sample["text"]), 50))
print(" sample['neighbor_tokens'][17][1] : %s" % shorten_str(str(sample["neighbor_tokens"][17][1]), 50))
print(" retro.gpt_to_text(sample['text']) : %s" % shorten_str(cls.gpt_to_text(sample["text"]), 50))
print(" retro.gpt_to_text(sample['neighbor_tokens']) : %s" % shorten_str(cls.gpt_to_text(sample["neighbor_tokens"][17][1]), 50))
print("+++++++++++++++++++++++++++++++++++++++++++++++++++")
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from . import retro
if __name__ == "__main__":
retro.init(os.environ["RETRO_WORKDIR"])
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from .build import build_db
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from collections import defaultdict
from concurrent.futures import as_completed, ProcessPoolExecutor
from functools import reduce
import glob
import json
import numpy as np
import os
from pathlib import Path
import threading
import torch
from tqdm import tqdm
import types
from megatron import get_retro_args, print_rank_0
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
from megatron.tokenizer.tokenizer import (
_BertWordPieceTokenizer,
_GPT2BPETokenizer,
)
from tools.bert_embedding.utils import get_missing_blocks_by_rank
from tools.retro.external_libs import h5py
from tools.retro.utils import get_gpt_tokenizer, get_bert_tokenizer
from .utils import (
get_individual_db,
get_individual_db_dir,
get_merged_dataset,
get_merged_db_path_map,
get_train_doc_chunk_map_dir,
save_indexed_dataset_infos,
)
def init_indexed_dataset_infos():
'''Gather meta-info about each indexed dataset.
The returned info array allows for easy access to the configuration, and
helps remove ambiguity.
'''
args = get_retro_args()
assert len(args.data_path) % 2 == 0, \
"currently, only blendable dataset is supported."
# Dataset infos.
infos = []
for i in range(0, len(args.data_path), 2):
ratio = float(args.data_path[i])
prefix = args.data_path[i + 1]
path = prefix + ".bin"
name = os.path.basename(prefix)
assert os.path.exists(path)
infos.append({
"ratio" : ratio,
"prefix" : prefix,
"path" : path,
"name" : name,
"db_dir" : get_individual_db_dir(name),
"dataset" : make_indexed_dataset(prefix, "mmap", True),
})
return infos
def build_partial_db(
dataset_idx,
n_datasets,
indexed_dataset,
block_id,
n_blocks,
block,
proc_id,
n_procs,
tokenizers,
):
'''Process a document index range of the indexed dataset.
The chunk database is built in parallel blocks, since de-tokenizing &
re-tokenizing for Bert-length computation is expensive. This method
iterates each document and extracts sequential 'chunk-length' sequences
from each document.
'''
args = get_retro_args()
# Document start/end indexes.
doc_range = block["range"]
n_docs = doc_range[1] - doc_range[0]
n_docs_per_proc = int(np.ceil(n_docs / n_procs))
doc_start_id = doc_range[0] + proc_id * n_docs_per_proc
doc_end_id = min(doc_range[1], doc_start_id + n_docs_per_proc)
# Print progress.
progress_proc_ids = set(range(n_procs)) \
if torch.distributed.get_rank() == 0 else set()
if proc_id in progress_proc_ids:
print(" > building partial chunk db, proc %d / %d, docs %d:%d / %d."%(
proc_id,
n_procs,
doc_start_id,
doc_end_id,
n_docs,
))
# Progress bars (snapshot of overall progress).
doc_id_iter = range(doc_start_id, doc_end_id)
pbar = tqdm(doc_id_iter) \
if proc_id in progress_proc_ids else \
doc_id_iter
# Iterate documents & parse chunks.
chunk_db_valid = []
chunk_db_invalid = []
for doc_id in pbar:
# Progress description.
try:
pbar.set_description("ds %d / %d, block %d / %d, proc %d / %d." % (
dataset_idx,
n_datasets,
block_id,
n_blocks,
proc_id,
n_procs))
except:
pass
# Remove EOD token.
doc = indexed_dataset.get(doc_id)
if doc[-1].item() == tokenizers.gpt.eod_id:
doc = doc[:-1]
doc_len = len(doc)
# Chunk start/end indexes.
chunk_start_idxs = list(range(0, doc_len, args.retro_gpt_chunk_length))
chunk_end_idxs = [min(doc_len, s + args.retro_gpt_chunk_length)
for s in chunk_start_idxs]
# Re-tokenize each chunk to Bert/Wordpiece (empty bert -> 'invalid').
for i, chunk_start_idx in enumerate(chunk_start_idxs):
# Re-tokenize.
chunk_end_idx = chunk_end_idxs[i]
gpt_token_ids = indexed_dataset.get(
idx=doc_id,
offset=chunk_start_idx,
length=chunk_end_idx - chunk_start_idx,
)
text = tokenizers.gpt.detokenize(gpt_token_ids)
bert_token_ids = tokenizers.bert.tokenize(text)
# 'Valid' for non-empty Bert chunks; 'invalid' otherwise.
_chunk_db = chunk_db_invalid \
if len(bert_token_ids) == 0 else \
chunk_db_valid
_chunk_db.append((
doc_id,
chunk_start_idx,
chunk_end_idx,
len(bert_token_ids),
))
return proc_id, chunk_db_valid, chunk_db_invalid
def build_individual_db(dataset_idx, n_datasets, dataset_info, tokenizers):
'''Process a single indexed dataset & extract chunks.'''
args = get_retro_args()
# Make directory.
db_dir = dataset_info["db_dir"]
os.makedirs(db_dir, exist_ok=True)
# Indexed dataset.
indexed_dataset = dataset_info["dataset"]
# Missing db blocks.
n_missing_world, missing_db_blocks = get_missing_blocks_by_rank(
db_dir,
len(indexed_dataset.doc_idx) - 1,
args.retro_doc_block_size,
validate=lambda f : f["chunks_valid"].shape[1] == 4)
# Prevent missing-path-write race condition.
torch.distributed.barrier()
if not missing_db_blocks:
return
# Num processes.
if n_missing_world == 1:
n_procs = 128
elif n_missing_world <= 2:
n_procs = 64
elif n_missing_world <= 4:
n_procs = 32
elif n_missing_world <= 8:
n_procs = 16
else:
n_procs = 8
# Process documents in parallel.
with ProcessPoolExecutor(max_workers=n_procs) as executor:
for block_idx, block in enumerate(missing_db_blocks):
if block is not None:
# Build partial dbs.
print_rank_0(' > build partial dbs.')
futures = []
for proc_id in range(n_procs): # not true process id
futures.append(executor.submit(
build_partial_db,
dataset_idx,
n_datasets,
indexed_dataset,
block_idx,
len(missing_db_blocks),
block,
proc_id,
n_procs,
tokenizers,
))
partial_chunk_dbs = []
for future in as_completed(futures):
partial_chunk_dbs.append(future.result())
# Concatenate chunks.
partial_chunk_dbs.sort(key=lambda item:item[0]) # sort by proc_id
chunk_db_valid = [item
for partial_chunk_db in partial_chunk_dbs
for item in partial_chunk_db[1]]
chunk_db_invalid = [item
for partial_chunk_db in partial_chunk_dbs
for item in partial_chunk_db[2]]
# Convert to numpy.
print_rank_0(' > converting chunk db to numpy.')
chunk_db_valid = np.array(chunk_db_valid)
chunk_db_invalid = np.array(chunk_db_invalid)
# Save DB.
print_rank_0(" > saving individual db.")
f = h5py.File(block["path"], "w")
dset = f.create_dataset("chunks_valid", data=chunk_db_valid)
dset = f.create_dataset("chunks_invalid", data=chunk_db_invalid)
f.close()
# Wait for all ranks to finish block.
print_rank_0(" > waiting for all ranks to finish block.")
torch.distributed.barrier()
print_rank_0(" > finished saving individual db.")
def build_individual_dbs(indexed_dataset_infos):
'''Iterate each indexed dataset & process its chunks.'''
args = get_retro_args()
# Tokenizers.
tokenizers = types.SimpleNamespace(
gpt=get_gpt_tokenizer(),
bert=get_bert_tokenizer(),
)
# Build individual DBs.
print_rank_0(" > build individual chunk dbs.")
for ds_idx, ds_info in enumerate(indexed_dataset_infos):
# Progress.
print_rank_0(" > building individual db, dataset %d / %d ... '%s'." % (
ds_idx,
len(indexed_dataset_infos),
ds_info["name"],
))
# Process single dataset.
build_individual_db(ds_idx, len(indexed_dataset_infos),
ds_info, tokenizers)
def update_chunk_counts(indexed_dataset_infos):
'''Set n_chunks_train & n_chunks sampled for each individual DB.'''
args = get_retro_args()
if torch.distributed.get_rank() != 0:
return
# Training split size (split at document level).
train_fraction = float(args.split.split(",")[0]) / 100
assert train_fraction > 0 and train_fraction <= 1
# Set n_chunks (including n_chunks_sampled for unambiguity).
print_rank_0(" > compute n_chunks.")
for ds_index, ds_info in \
enumerate(tqdm(indexed_dataset_infos, "count_chunks")):
db_dir = ds_info["db_dir"]
db_paths = sorted(glob.glob(db_dir + "/*.hdf5"))
# Update counts.
ds_info["n_docs"] = len(ds_info["dataset"].doc_idx) - 1
ds_info["n_docs_train"] = int(train_fraction * ds_info["n_docs"])
ds_info["n_chunks"] = 0 # previously, 'n_chunks_valid'
ds_info["n_chunks_train"] = 0
ds_info["n_chunks_invalid"] = 0
for db_path in db_paths:
with h5py.File(db_path, "r") as f:
ds_info["n_chunks"] += len(f["chunks_valid"])
ds_info["n_chunks_invalid"] += len(f["chunks_invalid"])
ds_info["n_chunks_train"] += \
(np.copy(f["chunks_valid"][:, 0]) < ds_info["n_docs_train"]) \
.sum().item()
ds_info["n_chunks_sampled"] = \
int(round(args.retro_nchunks_sampled * ds_info["ratio"]))
# Verify counts.
assert ds_info["n_chunks_train"] <= ds_info["n_chunks"], \
"n_train (%d) > n_total (%d)." % (
ds_info["n_chunks_train"], ds_info["n_chunks"])
assert ds_info["n_chunks_sampled"] <= ds_info["n_chunks_train"], \
"n_sampled (%d) > n_train (%d)." % (
ds_info["n_chunks_sampled"], ds_info["n_chunks_train"])
def merge_dbs(indexed_dataset_infos, db_type):
'''Merge individual DBs into single DB.'''
if torch.distributed.get_rank() != 0:
return
print(" > build %s chunk db." % db_type)
# Count chunks.
if db_type == "full":
raise Exception("deprecated; use 'train' or 'sampled'.")
n_chunks_key = "n_chunks"
elif db_type == "sampled":
n_chunks_key = "n_chunks_sampled"
elif db_type == "train":
n_chunks_key = "n_chunks_train"
elif db_type == "valid":
pass
else:
raise Exception("handle db_type '%s'." % db_type)
if db_type == "valid":
n_chunks = sum(m["n_chunks"] - m["n_chunks_train"]
for m in indexed_dataset_infos)
else:
n_chunks = sum(m[n_chunks_key] for m in indexed_dataset_infos)
# DB path.
db_path = get_merged_db_path_map()[db_type]
# Delete existing chunk db if incorrect size.
if os.path.exists(db_path):
try:
f = h5py.File(db_path)
n_alloc = len(f["chunks"]) # total allocated
n_written = f["n_written"][0].item() # total written
f.close()
if n_chunks != n_alloc or n_chunks != n_written:
os.remove(db_path)
except Exception as e:
if isinstance(e, OSError):
os.remove(full_db_path)
elif isinstance(e, KeyError):
f.close()
os.remove(full_db_path)
else:
raise e
# Build merged chunk db.
if not os.path.exists(db_path):
os.makedirs(os.path.dirname(db_path), exist_ok=True)
f = h5py.File(db_path, "w")
# Initialize output arrays.
merged_db = f.create_dataset("chunks", (n_chunks, 5), dtype="i8")
n_written = f.create_dataset("n_written", (1,), dtype="uint64")
n_written[0] = 0
# Iterate indexed datasets & collect chunks.
start_index = 0
for ds_idx, ds_info in enumerate(indexed_dataset_infos):
print(" > merging dbs; '%s', dataset %d / %d ... '%s'." %
(db_type, ds_idx, len(indexed_dataset_infos), ds_info["name"]))
individual_db = get_individual_db(ds_idx, ds_info)
if db_type == "valid":
individual_db = individual_db[ds_info["n_chunks_train"]:]
else:
individual_db = individual_db[:ds_info[n_chunks_key]]
merged_db[start_index:start_index+len(individual_db)] = individual_db
start_index += len(individual_db)
n_written[0] = start_index
f.close()
def get_partial_banned_chunk_map(proc_id, db_path, chunk_range_info):
'''Build partial mapping of {(dataset_id,doc_id):[chunk_ids]}.
In this method, only chunks within the range (start_chunk_id, end_chunk_id]
are processed.'''
start_chunk_id = chunk_range_info["start"]
end_chunk_id = chunk_range_info["end"]
output_path = chunk_range_info["path"]
# Skip, if output file exists.
if os.path.exists(output_path):
return
# Chunk subset.
with h5py.File(db_path) as f:
sub_chunk_db = np.copy(f["chunks"][start_chunk_id:end_chunk_id, :2])
# Map docs to chunks.
banned_chunk_map = defaultdict(list)
for rel_chunk_id, (dataset_id, doc_id) in enumerate(tqdm(
sub_chunk_db,
"map banned docs, proc %d" % proc_id,
total=sub_chunk_db.shape[0],
)):
chunk_id = start_chunk_id + rel_chunk_id
banned_chunk_map["%d,%d" % (dataset_id.item(), doc_id.item())] \
.append(chunk_id)
# Save output.
with open(output_path, "w") as f:
json.dump(banned_chunk_map, f)
def build_doc_chunk_map(indexed_dataset_infos, db_type):
'''Build mapping of {(dataset_id,doc_id):[chunk_ids]}.'''
if torch.distributed.get_rank() != 0:
return
print(" > build %s doc-chunk map." % db_type)
n_procs = 128
# Get dataset.
db_dataset = get_merged_dataset(db_type, indexed_dataset_infos)
# Sub-ranges for parallel processing.
n_chunks = db_dataset.chunks.shape[0]
n_chunks_per_proc = max(1, int(np.ceil(n_chunks / n_procs)))
chunk_id_starts = list(range(0, n_chunks, n_chunks_per_proc))
chunk_id_ranges = [(s, min(n_chunks, s + n_chunks_per_proc))
for s in chunk_id_starts]
# Wrap range info with output path.
n_digits = int(np.ceil(np.log(n_chunks) / np.log(10)) + 1)
output_dirname = get_train_doc_chunk_map_dir()
chunk_range_infos = [{
"start" : start_id,
"end" : end_id,
"path" : os.path.join(output_dirname, "%s-%s.json" % (
str(start_id).zfill(n_digits),
str(end_id).zfill(n_digits),
)),
} for start_id, end_id in chunk_id_ranges ]
# Build doc-chunk map.
print_rank_0("build doc-chunk-map.")
with ProcessPoolExecutor(max_workers=n_procs) as executor:
# Build partial chunk maps.
futures = []
for proc_id, chunk_range_info in enumerate(chunk_range_infos):
if os.path.exists(chunk_range_info["path"]):
continue
# Submit job.
futures.append(executor.submit(
get_partial_banned_chunk_map,
proc_id,
db_dataset.db_path,
chunk_range_info,
))
# Wait for processes to finish.
banned_chunk_paths = []
for finished_idx, future in enumerate(as_completed(futures)):
print("finished %d / %d." % (finished_idx, n_procs))
future.result()
def build_db():
'''Extract token chunks from each indexed dataset.
Iterate each document of each indexed dataset, extract that document's
chunks, and save to a 'DB' (hdf5 file).
'''
# Indexed dataset info.
indexed_dataset_infos = init_indexed_dataset_infos()
# Build dbs.
build_individual_dbs(indexed_dataset_infos)
# Single-process going forward.
if torch.distributed.get_rank() != 0:
return
# Update n_chunks.
update_chunk_counts(indexed_dataset_infos)
# Merge dbs.
merge_dbs(indexed_dataset_infos, "sampled")
merge_dbs(indexed_dataset_infos, "train")
merge_dbs(indexed_dataset_infos, "valid")
build_doc_chunk_map(indexed_dataset_infos, "train")
# Save (fully annotated) indexed dataset infos.
save_indexed_dataset_infos(indexed_dataset_infos)
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import json
import numpy as np
import torch
from megatron import get_args, print_rank_0
from tools.retro.external_libs import h5py
from tools.retro.utils import get_gpt_tokenizer
class DBDataset(torch.utils.data.Dataset):
'''Dataset for iterating chunks.
Requires:
- List of indexed datasets
- Chunk index array, with format:
[dataset_idx, doc_id, start_idx, end_idx, bert_length])
'''
def __init__(self, db_path, indexed_datasets, chunks, max_chunk_length):
assert chunks.shape[1] == 5, "expected 5 columns (dataset_idx, " \
"doc_idx, token_start_idx, token_end_idx, bert_chunk_length); " \
"found %d columns." % chunks.shape[1]
self.db_path = db_path
self.indexed_datasets = indexed_datasets
self.chunks = chunks
self.max_chunk_length = max_chunk_length
self.eod_token_id = get_gpt_tokenizer().eod_id
def __len__(self):
return self.chunks.shape[0]
def __getitem__(self, chunk_id):
# Chunk start/end indexes.
indexed_dataset_id, doc_id, token_start_idx, token_end_idx, _ = \
[ value.item() for value in self.chunks[chunk_id] ]
chunk_length = token_end_idx - token_start_idx
indexed_dataset = self.indexed_datasets[indexed_dataset_id]
# Chunk token ids.
token_ids = indexed_dataset.get(doc_id,
offset=token_start_idx,
length=chunk_length)
# Extend chunks to max_chunk_length by padding with EOD tokens.
if chunk_length != self.max_chunk_length:
assert chunk_length < self.max_chunk_length, "invalid chunk len."
token_ids = token_ids.tolist()
token_ids += [self.eod_token_id] * \
(self.max_chunk_length - chunk_length)
return {
"doc_id" : doc_id,
"text" : np.array(token_ids, dtype=np.int64),
}
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from collections import defaultdict
import glob
import json
import numpy as np
import os
from tqdm import tqdm
from megatron import get_retro_args, print_rank_0
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
from tools.retro.external_libs import h5py
from .dataset import DBDataset
def get_base_db_workdir():
'''Sub-directory for DB data.'''
args = get_retro_args()
return os.path.join(args.retro_workdir, "db")
def get_indexed_dataset_infos_path():
'''Path to indexed dataset meta-infos.'''
return os.path.join(get_base_db_workdir(), "indexed_dataset_infos.json")
def save_indexed_dataset_infos(indexed_dataset_infos):
'''Save dataset order & meta-info.'''
# Remove 'dataset' field.
clean_infos = []
for info in indexed_dataset_infos:
info = dict(info)
del info["dataset"]
clean_infos.append(info)
# Save.
with open(get_indexed_dataset_infos_path(), "w") as f:
json.dump(clean_infos, f, indent=4)
def get_indexed_dataset_infos():
'''Load indexed dataset meta-infos.'''
# Load json.
path = get_indexed_dataset_infos_path()
with open(path) as f:
infos = json.load(f)
# Add indexed datasets.
for info in infos:
info["dataset"] = make_indexed_dataset(info["prefix"], "mmap", True)
return infos
def get_individual_db_dir(name):
'''Individual DB's directory.'''
return os.path.join(get_base_db_workdir(), "individual", name, "db")
def get_individual_db(ds_id, ds_info):
'''Load individual dataset's chunk DB.'''
db_paths = sorted(glob.glob(ds_info["db_dir"] + "/*hdf5"))
# *Note*: convert to dataset, rather than copying to memory.
db = np.zeros((ds_info["n_chunks"], 5), dtype="i8")
db[:, 0] = ds_id
start_idx = 0
for db_path in db_paths:
f = h5py.File(db_path, "r")
n_chunks_current = f["chunks_valid"].shape[0]
db[start_idx:(start_idx+n_chunks_current), 1:] = f["chunks_valid"]
start_idx += n_chunks_current
f.close()
assert start_idx == ds_info["n_chunks"]
return db
def get_merged_db_path_map():
'''Paths to merged datasets.'''
base_dir = get_base_db_workdir()
return {
"sampled" : os.path.join(base_dir, "merged", "sampled.hdf5"),
"train" : os.path.join(base_dir, "merged", "train.hdf5"),
"valid" : os.path.join(base_dir, "merged", "valid.hdf5"),
}
def get_merged_dataset(db_type, indexed_dataset_infos=None):
'''Get merged dataset.'''
args = get_retro_args()
if not indexed_dataset_infos:
indexed_dataset_infos = get_indexed_dataset_infos()
# Load chunks.
db_path = get_merged_db_path_map()[db_type]
f = h5py.File(db_path, "r")
chunks = f["chunks"]
# DB dataset.
indexed_datasets = [ info["dataset"] for info in indexed_dataset_infos ]
dataset = DBDataset(db_path, indexed_datasets, chunks,
args.retro_gpt_chunk_length)
return dataset
def get_merged_sampled_dataset(indexed_dataset_infos=None):
return get_merged_dataset("sampled", indexed_dataset_infos)
def get_merged_train_dataset(indexed_dataset_infos=None):
return get_merged_dataset("train", indexed_dataset_infos)
def get_merged_valid_dataset(indexed_dataset_infos=None):
return get_merged_dataset("valid", indexed_dataset_infos)
def get_train_doc_chunk_map_dir():
dirname = os.path.join(get_base_db_workdir(), "merged", "train_doc_chunk_map")
os.makedirs(dirname, exist_ok=True)
return dirname
def get_train_doc_chunk_map():
paths = sorted(glob.glob(get_train_doc_chunk_map_dir() + "/*.json"))
doc_map = defaultdict(set)
for path in tqdm(paths, "load train doc maps"):
# Read file.
with open(path) as f:
crnt_doc_map = json.load(f)
# Add to doc map.
for key, chunk_ids in crnt_doc_map.items():
key = tuple(int(i) for i in key.split(","))
doc_map[key].update(chunk_ids)
return doc_map
#!/bin/bash
# Small English Wikipedia dataset (~2M chunks).
get_wiki_tiny_config() {
RETRO_INDEX_STR="IVF4096_HNSW4,Flat"
RETRO_GPT_TRAIN_SAMPLES=31250
LR_DECAY_SAMPLES=2
LR_WARMUP_SAMPLES=1
RETRO_GPT_EVAL_INTERVAL=2000
RETRO_GPT_EVAL_ITERS=100
RETRO_EF_SEARCH=4
RETRO_NPROBE=64
DATALOADER_TYPE=cyclic
}
# English Wikipedia dataset (~67M chunks).
get_wiki_config() {
RETRO_INDEX_STR="IVF262144_HNSW32,Flat"
RETRO_GPT_TRAIN_SAMPLES=2037248
LR_DECAY_SAMPLES=2
LR_WARMUP_SAMPLES=1
RETRO_GPT_EVAL_INTERVAL=2000
RETRO_GPT_EVAL_ITERS=100
RETRO_EF_SEARCH=16
RETRO_NPROBE=4096
DATALOADER_TYPE=cyclic
}
# Full corpus (~5B chunks).
get_corpus_config() {
RETRO_INDEX_STR="OPQ32_256,IVF4194304_HNSW32,PQ32"
RETRO_GPT_TRAIN_SAMPLES=192000000
LR_DECAY_SAMPLES=166400000
LR_WARMUP_SAMPLES=162761
RETRO_GPT_EVAL_INTERVAL=2000
RETRO_GPT_EVAL_ITERS=50
RETRO_EF_SEARCH=32
RETRO_NPROBE=4096
DATALOADER_TYPE=single
}
#!/bin/bash
# Build preprocessing command for Retro.
set -u
DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
################ Required environment variables. ################
# Required environment variables:
# - REPO_DIR : Root directory of Megatron codebase.
# - RETRO_WORKDIR : Root directory of this Retro project's processed data. (For
# example, this project directory might be for a blended dataset, while
# another project directory might be for just a Wikipedia dataset, and
# another for just Book Corpus data, etc.) This project directory will
# contain a complete set of processed data, including the retrieval
# database, search index, and pretraining neighbors.
# - RETRO_TASKS : One of 'build', 'db-build', 'index-build', or
# 'pretraining-query-neighbors'. See 'Retro tasks' below for task
# descriptions.
# - DATA_BLEND_SCRIPT : Path to blended dataset definition file.
# - GPT_VOCAB_FILE : GPT vocab file.
# - GPT_MERGE_FILE : GPT merge file.
# - GPT_TOKENIZER : GPT tokenizer type (e.g., GPT2BPETokenizer)
# - BERT_LOAD_PATH : Bert checkpoint directory.
# - BERT_VOCAB_FILE : Bert vocab file.
# - BERT_TOKENIZER : Bert tokenizer type (e.g., BertWordPieceLowerCase,
# BertWordPieceCase).
# - BERT_EMBEDDER_TYPE : One of 'megatron' or 'huggingface'.
# - EXTRA_ARGS : Extra arguments (else, leave empty).
################ Data blend. ################
. ${DATA_BLEND_SCRIPT}
DATA_PATH=${DATA_BLEND}
################ Retro setup. ################
RETRO_GPT_SEQ_LENGTH=2048
RETRO_GPT_CHUNK_LENGTH=64
RETRO_GPT_MICRO_BATCH_SIZE=1 # *8
RETRO_GPT_GLOBAL_BATCH_SIZE=256
RETRO_NCHUNKS_SAMPLED=300000000
################ Retro tasks. ################
# The '--retro-tasks' argument is a comma-separated list of tasks to run, in
# sequential order. For a quick start, simply set this to 'build' to run the
# entire preprocessing pipeline. For finer control, you may specify the list of
# tasks to run. This is desirable for tuning computational resources. For
# example, training the search index is relatively fast and utilizes GPUs,
# while querying the search index is relatively slow, CPU-only, and memory
# intensive (i.e., multiple populated search indexes are loaded simultaneously).
# *Note* : Once the task(s) below have been completed -- by running either
# 1) 'build', or 2) the sequential combination of 'db-build', 'index-build',
# and 'pretraining-query-neighbors' -- we are ready to pretrain Retro by
# calling pretrain_retro.py.
# ---- Option #1 : Run entire pipeline. ----
# RETRO_TASKS="build" # (*note*: default tasks)
# ---- Option #2 : Run specific stages. ----
# *Note*: Run the following stages in the given order. Optionally, tune your
# cluster setup for each stage, as described above.
# RETRO_TASKS="db-build" # ....................... run 1st
# RETRO_TASKS="index-build" # .................... run 2nd
# RETRO_TASKS="pretraining-query-neighbors" # .... run 3rd
################ Megatron args. ################
MEGATRON_ARGS=" \
--seed 1234 \
--distributed-timeout-minutes 600 \
--tokenizer-type ${BERT_TOKENIZER} \
--tensor-model-parallel-size 1 \
--pipeline-model-parallel-size 1 \
--num-layers 24 \
--hidden-size 1024 \
--num-attention-heads 16 \
--micro-batch-size ${RETRO_GPT_MICRO_BATCH_SIZE} \
--global-batch-size ${RETRO_GPT_GLOBAL_BATCH_SIZE} \
--seq-length 512 \
--max-position-embeddings 512 \
--train-samples ${RETRO_GPT_TRAIN_SAMPLES} \
--load ${BERT_LOAD_PATH} \
--exit-on-missing-checkpoint \
--no-load-optim \
--data-path ${DATA_PATH} \
--vocab-file ${BERT_VOCAB_FILE} \
--data-impl mmap \
--split 98,2,0 \
--distributed-backend nccl \
--lr 0.0001 \
--lr-decay-style linear \
--min-lr 1.0e-5 \
--lr-decay-samples ${LR_DECAY_SAMPLES} \
--lr-warmup-samples ${LR_WARMUP_SAMPLES} \
--weight-decay 1e-2 \
--clip-grad 1.0 \
--eval-interval ${RETRO_GPT_EVAL_INTERVAL} \
--eval-iters ${RETRO_GPT_EVAL_ITERS} \
--fp16 \
--DDP-impl local \
--dataloader-type ${DATALOADER_TYPE} \
--no-data-sharding \
--no-gradient-accumulation-fusion \
--no-async-tensor-model-parallel-allreduce \
"
################ Retro args. ################
RETRO_ARGS=" \
--bert-embedder-type ${BERT_EMBEDDER_TYPE} \
--output-bert-embeddings \
\
--retro-gpt-vocab-file ${GPT_VOCAB_FILE} \
--retro-gpt-merge-file ${GPT_MERGE_FILE} \
--retro-gpt-tokenizer-type ${GPT_TOKENIZER} \
--retro-gpt-seq-length ${RETRO_GPT_SEQ_LENGTH} \
--retro-gpt-chunk-length ${RETRO_GPT_CHUNK_LENGTH} \
--retro-bert-vocab-file ${BERT_VOCAB_FILE} \
--retro-bert-tokenizer-type ${BERT_TOKENIZER} \
\
--retro-tasks ${RETRO_TASKS} \
--retro-index-str ${RETRO_INDEX_STR} \
--retro-ef-search ${RETRO_EF_SEARCH} \
--retro-nprobe ${RETRO_NPROBE} \
\
--retro-workdir ${RETRO_WORKDIR} \
--retro-nchunks-sampled ${RETRO_NCHUNKS_SAMPLED} \
\
--retro-return-doc-ids \
"
################ Command. ################
RETRO_PREPROCESS_CMD=" \
./tools/retro/main.py \
${MEGATRON_ARGS} \
${RETRO_ARGS} \
${EXTRA_ARGS} \
"
#!/bin/bash
set -u
unset NCCL_DEBUG
NPROCS=8 # NPROCS must be <= number of GPUs.
set_current_dir() {
DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
}
################ Dataset configs. ################
# This script contains methods to customize arguments to specific dataset
# types. Customize this script as needed for your datasets.
set_current_dir
. $DIR/get_dataset_configs.sh
################ Environment variables. ################
# *Note*: See 'Required environment variables' in 'get_preprocess_cmd.sh' for
# a description of the required environment variables. These variables can be
# set however a user would like. In our setup, we use another bash script
# (location defined by $RETRO_ENV_VARS) that sets all the environment variables
# at once.
. $RETRO_ENV_VARS
######## Environment vars. ########
set_current_dir
. ${DIR}/get_preprocess_cmd.sh
echo "~~~~~~~~~~~~~~~~~~~~~~~~~~"
echo "DIR = '$DIR'."
echo "RETRO_PREPROCESS_CMD = '$RETRO_PREPROCESS_CMD'."
echo "~~~~~~~~~~~~~~~~~~~~~~~~~~"
######## Command. ########
FULL_CMD="\
pwd && cd ${REPO_DIR} && pwd && \
export PYTHONPATH=$PYTHONPATH:${REPO_DIR} && \
python -m torch.distributed.launch \
--nproc_per_node ${NPROCS} \
--nnodes 1 \
--node_rank ${NODE_RANK} \
--master_addr ${MASTER_ADDR} \
--master_port 6000 \
$RETRO_PREPROCESS_CMD \
"
echo "~~~~~~~~~~~~~~~~~~~~~~~~~~"
echo "FULL_CMD = '$FULL_CMD'."
echo "~~~~~~~~~~~~~~~~~~~~~~~~~~"
eval $FULL_CMD
#!/bin/bash
##################################################
# Example script for pretraining Retro.
##################################################
set -u
unset NCCL_DEBUG
export CUDA_DEVICE_MAX_CONNECTIONS=1
NPROCS=8 # NPROCS must be <= number of GPUs.
################ Dataset configs. ################
# This script contains methods to customize arguments to specific dataset
# types. Customize this script as needed for your datasets.
DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
. $DIR/get_dataset_configs.sh
################ Environment variables. ################
# *Note*: See 'Required environment variables' in 'get_preprocess_cmd.sh' for
# a description of the required environment variables. These variables can be
# set however a user would like. In our setup, we use another bash script
# (location defined by $RETRO_ENV_VARS) that sets all the environment variables
# at once.
. $RETRO_ENV_VARS
################ Data blend. ################
. ${DATA_BLEND_SCRIPT}
DATA_PATH=${DATA_BLEND}
######## Retro setup. ########
RETRO_ADD_RETRIEVER=1
RETRO_CYCLIC_TRAIN_ITERS=750000
RETRO_NUM_NEIGHBORS=2
######## Arguments. ########
CHECKPOINT_DIR=${RETRO_WORKDIR}/checkpoints/${RETRO_ADD_RETRIEVER}
TENSORBOARD_DIR="${CHECKPOINT_DIR}/tensorboard"
mkdir -p ${TENSORBOARD_DIR}
ARGS=" \
--save-interval 1000 \
--save ${CHECKPOINT_DIR} \
--load ${CHECKPOINT_DIR} \
--tensorboard-dir ${TENSORBOARD_DIR} \
--log-interval 5 \
--tensor-model-parallel-size 1 \
--pipeline-model-parallel-size 1 \
--num-layers 12 \
--hidden-size 768 \
--num-attention-heads 12 \
--seq-length 2048 \
--max-position-embeddings 2048 \
--micro-batch-size 4 \
--global-batch-size 256 \
--train-samples ${RETRO_GPT_TRAIN_SAMPLES} \
--lr-decay-samples ${LR_DECAY_SAMPLES} \
--lr-warmup-samples ${LR_WARMUP_SAMPLES} \
--lr 6.0e-4 \
--min-lr 6.0e-5 \
--lr-decay-style cosine \
--eval-interval ${RETRO_GPT_EVAL_INTERVAL} \
--eval-iters ${RETRO_GPT_EVAL_ITERS} \
--data-path ${DATA_PATH} \
--vocab-file ${GPT_VOCAB_FILE} \
--merge-file ${GPT_MERGE_FILE} \
--split 98,2,0 \
--clip-grad 1.0 \
--weight-decay 0.1 \
--adam-beta1 0.9 \
--adam-beta2 0.95 \
--init-method-std 0.023 \
--log-params-norm \
--log-num-zeros-in-grad \
--fp16 \
--DDP-impl local \
--dataloader-type ${DATALOADER_TYPE} \
--no-data-sharding \
--no-gradient-accumulation-fusion \
"
if [ "$RETRO_ADD_RETRIEVER" = "0" ]; then
SCRIPT=pretrain_gpt.py
else
ARGS="${ARGS} \
--retro-add-retriever \
--retro-workdir ${RETRO_WORKDIR} \
--retro-cyclic-train-iters ${RETRO_CYCLIC_TRAIN_ITERS} \
--retro-num-neighbors ${RETRO_NUM_NEIGHBORS} \
"
SCRIPT=pretrain_retro.py
fi
echo "~~~~~~~~~~~~~~~~~~~~~~~~~~"
echo "ARGS = '$ARGS'."
echo "~~~~~~~~~~~~~~~~~~~~~~~~~~"
python -m torch.distributed.launch \
--nproc_per_node ${NPROCS} \
--nnodes 1 \
--node_rank 0 \
--master_addr localhost \
--master_port 6000 \
${SCRIPT} \
${ARGS} \
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import importlib
required_libs = [
"faiss",
"h5py",
"transformers", # for huggingface bert
]
for lib in required_libs:
try:
globals()[lib] = importlib.import_module(lib)
except ImportError as e:
raise Exception(f"Missing one or more packages required for Retro preprocessing: {required_libs}. Tried importing '{lib}'.")
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from .index import Index
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import numpy as np
import os
import shutil
import torch
from tqdm import tqdm
from megatron import get_retro_args, print_rank_0
from tools.bert_embedding import DiskDataParallelBertEmbedder
from tools.retro.db.utils import (
get_indexed_dataset_infos,
get_merged_sampled_dataset,
get_merged_train_dataset,
)
from tools.retro.external_libs import h5py
from tools.retro.index.factory import IndexFactory
from tools.retro.utils import GPTToTextDataset
from .utils import (
get_training_data_dir,
get_training_data_merged,
)
##################################################
# Train index.
##################################################
def get_empty_index_path():
'''Path of empty index.'''
args = get_retro_args()
index = IndexFactory.get_index(args.retro_index_type)
empty_index_path = index.get_empty_index_path()
return empty_index_path
def embed_db():
'''Embed DB chunks.
Store chunks in blocks on disk. These blocks will later be merged into
a single dataset for training the index.
'''
args = get_retro_args()
# Get db dataset.
gpt_dataset = get_merged_sampled_dataset()
text_dataset = GPTToTextDataset(gpt_dataset)
# Embed dataset.
embedder = DiskDataParallelBertEmbedder(args.retro_bert_batch_size,
args.retro_bert_max_chunk_length,
args.retro_block_size,
args.bert_embedder_type)
embedder.embed_text_dataset("index", get_training_data_dir(), text_dataset)
def train_on_embeddings():
'''Train index on embedded DB chunks.'''
args = get_retro_args()
index = IndexFactory.get_index(args.retro_index_type)
index.train(get_training_data_merged)
def remove_embeddings():
'''Remove embeddings after training.'''
torch.distributed.barrier()
if torch.distributed.get_rank() != 0:
return
empty_index_path = get_empty_index_path()
assert os.path.isfile(empty_index_path)
shutil.rmtree(get_training_data_dir(), ignore_errors=True)
def train_index():
'''Train index on DB chunks.'''
args = get_retro_args()
# Check if trained index already exists.
if not os.path.isfile(get_empty_index_path()):
# Embed training chunks.
embed_db()
# Train index on embeddings.
train_on_embeddings()
# Wait for (single-process) training to complete.
torch.distributed.barrier()
# Remove embeddings.
if args.retro_delete_index_training_embeddings:
remove_embeddings()
##################################################
# Add to index.
##################################################
def add_to_index():
'''Add DB chunks to index.'''
args = get_retro_args()
# Get index.
index = IndexFactory.get_index(args.retro_index_type)
# Get text dataset.
gpt_dataset = get_merged_train_dataset()
text_dataset = GPTToTextDataset(gpt_dataset)
# Add to index.
output_index_path = index.add(text_dataset)
return output_index_path
##################################################
# Build index (train + add).
##################################################
def build_index():
'''Build index.
Building index involves sequentially running stages above:
- Train index (on sampled training chunks).
- Add to index (on all training chunks).
'''
# Train index.
train_index()
# Add to index.
add_to_index()
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from .indexes import FaissBaseIndex, FaissParallelAddIndex
class IndexFactory:
'''Get index.
Index type generally read from argument '--retro-index-ty'.
'''
@classmethod
def get_index_class(cls, index_type):
return {
"faiss-base" : FaissBaseIndex,
"faiss-par-add" : FaissParallelAddIndex,
}[index_type]
@classmethod
def get_index(cls, index_type):
index_class = cls.get_index_class(index_type)
index = index_class()
return index
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import abc
import numpy as np
import os
import torch
from tools.retro.external_libs import faiss
from .utils import get_index_dir
class Index(abc.ABC):
'''Abstract base class for indexes.
*Note* : While currently only Faiss-based classes are implemented, in the
future, this class will be extended with other types of indexes that have
different performance-accuracy trade-offs.
The primary methods to override are:
- train() : Train index on the sampled training chunks.
- add() : Add all training chunks to index.
'''
@classmethod
def c_verbose(cls, index, v):
'''Make index object verbose.'''
assert isinstance(v, bool)
faiss.ParameterSpace().set_index_parameter(index, "verbose", v)
def get_empty_index_path(self):
return os.path.join(get_index_dir(), "empty.faissindex")
def get_empty_index(self):
return faiss.read_index(self.get_empty_index_path())
def get_added_index_path(self):
return os.path.join(get_index_dir(), "added.faissindex")
def get_added_index(self):
return faiss.read_index(self.get_added_index_path())
@abc.abstractmethod
def train(self, *args):
pass
@abc.abstractmethod
def add(self, *args):
pass
def embed_text_dataset_block(self, embedder, text_dataset, _range):
'''Embed a range of a text dataset.'''
sub_dataset = torch.utils.data.Subset(text_dataset, range(*_range))
return embedder.embed_text_dataset(sub_dataset)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment