Commit 32bb4edc authored by Neel Kant's avatar Neel Kant
Browse files

Prune changes to only be related to ICT

parent 674814a5
from collections import defaultdict
import time
import pickle
import faiss
from faiss import index_factory, index_cpu_to_gpu
import numpy as np
from megatron import get_args
PCAS = [
'PCA', 'PCAR', 'PCAW', 'PCAWR'
]
# PCA to 64 dim gets "first missing" ~ 95% and "mixed" ~ 5% for all
# however, this is pretty hard since the embeds and queries are totally random, would be better to test according to a distribution
# update: Using realisitc mean and covariance helps, but then adjusting for inner product makes it unusable again
# CONCLUSION: PCA should not be used for MIPS
QUANTIZERS = [
'IVF4096_SQ16', # 'IMI2x9',
'HNSW32_SQ16', # 'IVF4096_HNSW32'
]
# IMI2x9 or any other MultiIndex doesn't support inner product so it's unusable
# IVF4096_HNSW32 doesn't support inner product either
ENCODINGS = [
'Flat',
'PQ16np', # PQ16, PQ16x12(np)
'SQ4', 'SQ8', 'SQ6', 'SQfp16',
# 'LSH', 'LSHrt', 'LSHr', 'LSHt'
]
# PQ16 is pretty slow for creating and adding - ~96s for 1e5, 105s for 1e6
# PQ16np is a bit faster but is pretty inaccurate - misses top-1 result 2/3 of time (1e6 embeds)
# PQ16x12(np) gets real slow. Uses 4096 centroids.
# SQfp16 is solid.
# LSH is inaccurate - pretty much always missing the top-1 result (1e6 embeds)
def latest(times):
return times[-1] - times[-2]
def get_embed_mean_and_cov():
embed_data = pickle.load(open('/home/dcg-adlr-nkant-data.cosmos1202/hash_data/normed4096_whitened.pkl', 'rb'))
embed_mean = embed_data['embed_mean']
whitener = embed_data['embed_whitener']
embed_cov = whitener.dot(whitener.transpose())
return embed_mean, embed_cov
def get_embeds_and_queries(mean, cov, num_embeds, num_queries):
embeds = np.random.multivariate_normal(mean, cov, num_embeds).astype('float32')
queries = np.random.multivariate_normal(mean, cov, num_queries).astype('float32')
return embeds, queries
def get_random_embeds_and_queries(d, num_embeds, num_queries):
embeds = np.random.rand(num_embeds, d).astype('float32')
queries = np.random.rand(num_queries, d).astype('float32')
return embeds, queries
def print_timing_stats(name, create_and_add, search):
print('{:20s} Create and add embeds: {:10.4f}s | Search embeds: {:10.4f}s'.format(name, create_and_add, search))
def print_accuracy_stats(name, gold_indices, estimated_indices):
gold_indices, estimated_indices = list(gold_indices), list(estimated_indices)
results = defaultdict(int)
for gold, estimated in zip(gold_indices, estimated_indices):
if gold[0] not in estimated:
results['first_missing'] += 1
elif np.array_equal(gold, estimated):
results['all_equal'] += 1
else:
results['mixed'] += 1
result_strs = ['first_missing', 'all_equal', 'mixed']
print('{:20s} First missing: {:4d} | All equal: {:4d} | Mixed: {:4d}'.format(name, *[results[s] for s in result_strs]))
def create_and_test_gold(d, k, embeds, queries):
times = [time.time()]
res = faiss.StandardGpuResources()
gold_idx = index_cpu_to_gpu(res, 0, index_factory(d, 'Flat'))
gold_idx.add(embeds)
times.append(time.time())
create_and_add = latest(times)
distances, indices = gold_idx.search(queries, k)
times.append(time.time())
print_timing_stats('Flat', create_and_add, latest(times))
print('-' * 100)
return distances, indices
def test_pca(d, k, embeds, queries, pca_dim):
distances, indices = create_and_test_gold(d, k, embeds, queries)
times = [time.time()]
all_pca_indices = []
for s in PCAS:
pca_idx = index_factory(d, s + "{},Flat".format(pca_dim), faiss.METRIC_INNER_PRODUCT)
pca_idx.train(embeds)
pca_idx.add(embeds)
times.append(time.time())
create_and_add = latest(times)
pca_distances, pca_indices = pca_idx.search(queries, k)
all_pca_indices.append(pca_indices)
times.append(time.time())
print_timing_stats(s, create_and_add, latest(times))
print('\n')
for s, pca_indices in zip(PCAS, all_pca_indices):
print_accuracy_stats(s, indices, pca_indices)
def test_quantizers(d, k, embeds, queries):
distances, indices = create_and_test_gold(d, k, embeds, queries)
times = [time.time()]
for s in QUANTIZERS:
if 'HNSW' in s:
quant_idx = index_factory(d, s, faiss.METRIC_INNER_PRODUCT)
else:
quant_idx = index_factory(d, "Flat," + s, faiss.METRIC_INNER_PRODUCT)
quant_idx.train(embeds)
quant_idx.add(embeds)
times.append(time.time())
create_and_add = latest(times)
quant_distances, quant_indices = quant_idx.search(queries, k)
times.append(time.time())
print_timing_stats(s, create_and_add, latest(times))
def test_encodings(d, k, embeds, queries):
distances, indices = create_and_test_gold(d, k, embeds, queries)
times = [time.time()]
all_encode_indices = []
for s in ENCODINGS:
encode_idx = index_factory(d, s, faiss.METRIC_INNER_PRODUCT)
encode_idx.train(embeds)
encode_idx.add(embeds)
times.append(time.time())
create_and_add = latest(times)
_, encode_indices = encode_idx.search(queries, k)
all_encode_indices.append(encode_indices)
times.append(time.time())
print_timing_stats(s, create_and_add, latest(times))
print('\n')
for s, encode_indices in zip(ENCODINGS, all_encode_indices):
print_accuracy_stats(s, indices, encode_indices)
def run_all_tests():
mean, cov = get_embed_mean_and_cov()
embeds, queries = get_embeds_and_queries(mean, cov, int(1e6), 256)
d = 128
k = 10
test_pca(d, k, embeds, queries, 96)
test_quantizers(d, k, embeds, queries)
test_encodings(d, k, embeds, queries)
if __name__ == "__main__":
run_all_tests()
import lucene
import sys
from java.nio.file import Paths
from org.apache.lucene.analysis.standard import StandardAnalyzer
from org.apache.lucene.document import Document, Field, FieldType
from org.apache.lucene.index import IndexWriter, IndexWriterConfig, IndexOptions, DirectoryReader
from org.apache.lucene.store import SimpleFSDirectory
from org.apache.lucene.search import IndexSearcher
from org.apache.lucene.queryparser.classic import QueryParser
from org.apache.lucene.search.similarities import BM25Similarity
from org.apache.lucene.util import Version
import torch
import torch.distributed as dist
from indexer import get_ict_dataset, get_one_epoch_dataloader
from megatron.initialize import initialize_megatron
from pretrain_bert_ict import get_batch
def setup():
initialize_megatron(extra_args_provider=None,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
lucene.initVM(vmargs=['-Djava.awt.headless=true'])
def run(embed_all=False):
dset = get_ict_dataset(use_titles=False, query_in_block_prob=0.1)
dataloader = iter(get_one_epoch_dataloader(dset))
index_dir = SimpleFSDirectory(Paths.get("full_wiki_index/"))
analyzer = StandardAnalyzer()
analyzer.setMaxTokenLength(1024)
config = IndexWriterConfig(analyzer)
config.setOpenMode(IndexWriterConfig.OpenMode.CREATE)
writer = IndexWriter(index_dir, config)
# field for document ID
t1 = FieldType()
t1.setStored(True)
t1.setTokenized(False)
# field for document text
t2 = FieldType()
t2.setStored(True)
t2.setTokenized(True)
t2.setIndexOptions(IndexOptions.DOCS_AND_FREQS_AND_POSITIONS)
correct = total = 0
round_correct = torch.zeros(1).cuda()
round_total = torch.zeros(1).cuda()
for round in range(100000):
with torch.no_grad():
try:
query_tokens, query_pad_mask, \
block_tokens, block_pad_mask, block_index_data = get_batch(dataloader)
except:
break
# query_tokens = query_tokens.detach().cpu().numpy()
block_tokens = block_tokens.detach().cpu().numpy()
# query_strs = [dset.decode_tokens(query_tokens[i].tolist(), hardcore=True) for i in range(query_tokens.shape[0])]
block_strs = [dset.decode_tokens(block_tokens[i].tolist(), hardcore=True) for i in range(block_tokens.shape[0])]
def add_document(text, writer, doc_id):
doc = Document()
doc.add(Field("text", text, t2))
doc.add(Field("doc_id", doc_id, t1))
writer.addDocument(doc)
# add documents to index writer
for i in range(len(block_strs)):
add_document(block_strs[i], writer, i)
# write and finalize the index
writer.commit()
# define BM25 searcher
# searcher = IndexSearcher(DirectoryReader.open(index_dir))
# searcher.setSimilarity(BM25Similarity())
# # feed queries and get scores for everything in the index
# hits_list = []
# for s in query_strs:
# query = QueryParser("text", analyzer).parse(s)
# hits = searcher.search(query, 1).scoreDocs
# hits_list.append(hits)
# for (i, hits) in enumerate(hits_list):
# doc_ids = [int(searcher.doc(hit.doc)['doc_id']) for hit in hits]
# correct += int(i in doc_ids)
# total += 1
# dist.all_reduce(round_correct)
# dist.all_reduce(round_total)
# correct += int(round_correct.item())
# total += int(round_total.item())
# round_correct -= round_correct
# round_total -= round_total
# print("Correct: {:8d} | Total: {:8d} | Fraction: {:6.5f}".format(correct, total, correct / total))
if round % 10 == 0:
print(round)
writer.close()
# Plan
# overall accuracy test:
# have index with all blocks. For BERT these are token ids, for BM25 these are tokens
#
# 1. run batch size 4096 BM25 self similarity test. For this I can just detokenize out of the dataset.
# I get the retrieval scores in the forward_step and log the results.
# 2. Create a BM25 index over all of wikipedia, have it ready for use in megatron QA.
#
# Create an index with the block embeddings with block ids
if __name__ == "__main__":
setup()
run()
import os
import sys
import time
import torch
import torch.distributed as dist
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron import get_args, get_adlr_autoresume, print_rank_0
from megatron import mpu
from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoint_name
from megatron.data.bert_dataset import get_indexed_dataset_
from megatron.data.realm_dataset import ICTDataset
from megatron.data.realm_index import detach, BlockData, FaissMIPSIndex
from megatron.data.samplers import DistributedBatchSampler
from megatron.initialize import initialize_megatron
from megatron.model import REALMRetriever
from megatron.global_vars import set_global_variables
from megatron.mpu.initialize import get_index_ready, get_index_group, get_train_group, get_data_parallel_group, get_gloo_comm_group
from megatron.mpu.initialize import set_data_parallel_group, set_model_parallel_group, init_realm_groups
from megatron.initialize import init_distributed, _init_autoresume, _set_random_seed, _write_args_to_tensorboard
from megatron.training import get_model
from megatron.utils import check_adlr_autoresume_termination
from pretrain_bert_ict import get_batch, model_provider
INDEX_READY = None
def pprint(*args):
print(*args, flush=True)
def initialize_and_run_async_megatron(extra_args_provider=None, args_defaults={},
ignore_unknown_args=False, allow_no_cuda=False):
if not allow_no_cuda:
# Make sure cuda is available.
assert torch.cuda.is_available(), 'Megatron requires CUDA.'
# Parse args, build tokenizer, and set adlr-autoresume,
# tensorboard-writer, and timers.
set_global_variables(extra_args_provider=extra_args_provider,
args_defaults=args_defaults,
ignore_unknown_args=ignore_unknown_args)
# instead of _initialize_distributed()
init_distributed()
setup_realm_groups_and_vars()
global INDEX_READY
INDEX_READY = get_index_ready()
pprint('finished setting up groups')
# Autoresume
_init_autoresume()
pprint('finished setting up autoresume')
# Random seeds for reproducibility.
args = get_args()
if args.rank == 0:
pprint('> setting random seeds to {} ...'.format(args.seed))
_set_random_seed(args.seed)
# Write arguments to tensorboard.
_write_args_to_tensorboard()
pprint('finished writing args to tensorboard')
torch.distributed.barrier()
if args.rank < args.max_training_rank:
torch.distributed.barrier(get_data_parallel_group())
pprint("All trainers ready.")
return
else:
runner = AsyncIndexBuilder(args.rank)
torch.distributed.barrier(get_data_parallel_group())
pprint("All indexers ready.")
runner.run_async()
def setup_realm_groups_and_vars():
args = get_args()
world_size = dist.get_world_size()
max_training_rank = args.max_training_rank
# assuming no model parallelism right now
set_model_parallel_group(dist.new_group([args.rank]))
init_realm_groups(max_training_rank, world_size)
if args.rank < max_training_rank:
set_data_parallel_group(get_train_group())
else:
set_data_parallel_group(get_index_group())
class IndexBuilder(object):
def __init__(self):
args = get_args()
self.debug = args.debug
self.rank = args.rank
self.model = None
self.dataloader = None
self.block_data = None
self.load_attributes()
self.is_main_builder = args.rank == 0
def load_attributes(self):
self.model = load_ict_checkpoint(only_block_model=True, no_grad=True, from_realm_chkpt=False)
self.model.eval()
self.dataloader = iter(get_one_epoch_dataloader(get_ict_dataset()))
self.block_data = BlockData()
def build_and_save_index(self):
i = 1
total = 0
while True:
with torch.no_grad():
try:
query_tokens, query_pad_mask, \
block_tokens, block_pad_mask, block_index_data = get_batch(self.dataloader)
except:
break
block_index_data = detach(block_index_data)
block_indices = block_index_data[:, 3]
block_meta = block_index_data[:, :3]
block_logits = detach(self.model(None, None, block_tokens, block_pad_mask, only_block=True))
self.block_data.add_block_data(block_indices, block_logits, block_meta)
total += block_indices.size
i += 1
if i % 1000 == 0:
print('Batch {:10d} | Total {:10d}'.format(i, total), flush=True)
if self.debug:
break
self.block_data.save_shard(self.rank)
torch.distributed.barrier(get_data_parallel_group())
del self.model
if self.is_main_builder:
self.block_data.consolidate_shards_and_save(ignore_shard=self.rank)
self.block_data.clear()
class AsyncIndexBuilder(IndexBuilder):
def __init__(self, rank):
self.rank = rank
args = get_args()
self.is_main_builder = self.rank == args.max_training_rank
self.main_builder_idx = args.max_training_rank
self.debug = args.debug
self.model = None
self.dataloader = None
self.block_data = None
self.load_attributes()
global INDEX_READY
INDEX_READY = get_index_ready()
def run_async(self):
global INDEX_READY
# synchronize for start
dist.broadcast(INDEX_READY, 0, group=get_gloo_comm_group())
while True:
print("Starting (again!)", flush=True)
self.build_and_save_index()
self.send_index_ready_signal()
while INDEX_READY == 1:
print("Waiting for new model checkpoint.", flush=True)
time.sleep(5)
self.load_attributes()
def load_attributes(self):
try:
self.model = load_ict_checkpoint(only_block_model=True, no_grad=True, from_realm_chkpt=True)
except:
print(">>>>> No realm chkpt available", flush=True)
self.model = load_ict_checkpoint(only_block_model=True, no_grad=True, from_realm_chkpt=False)
self.model.eval()
self.dataloader = iter(get_one_epoch_dataloader(get_ict_dataset()))
self.block_data = BlockData()
def send_index_ready_signal(self):
global INDEX_READY
if self.is_main_builder:
INDEX_READY = 1 - INDEX_READY
print("Switched INDEX_READY", flush=True)
torch.cuda.synchronize()
# send handle
dist.broadcast(INDEX_READY, self.main_builder_idx, group=get_gloo_comm_group(), async_op=True)
# recv handle
dist.broadcast(INDEX_READY, 0, group=get_gloo_comm_group())
torch.distributed.barrier(get_data_parallel_group())
def load_ict_checkpoint(only_query_model=False, only_block_model=False, no_grad=False, from_realm_chkpt=False):
args = get_args()
model = get_model(lambda: model_provider(only_query_model, only_block_model))
if isinstance(model, torchDDP):
model = model.module
load_path = args.load if from_realm_chkpt else args.ict_load
tracker_filename = get_checkpoint_tracker_filename(load_path)
with open(tracker_filename, 'r') as f:
iteration = int(f.read().strip())
# assert iteration > 0
checkpoint_name = get_checkpoint_name(load_path, iteration, False)
if mpu.get_data_parallel_rank() == 0:
print('global rank {} is loading checkpoint {}'.format(
torch.distributed.get_rank(), checkpoint_name))
state_dict = torch.load(checkpoint_name, map_location='cpu')
ict_state_dict = state_dict['model']
if from_realm_chkpt:
print(">>>> Attempting to get ict state dict from realm", flush=True)
ict_state_dict = ict_state_dict['retriever']['ict_model']
if only_query_model:
ict_state_dict.pop('context_model')
if only_block_model:
ict_state_dict.pop('question_model')
if no_grad:
with torch.no_grad():
model.load_state_dict(ict_state_dict)
else:
model.load_state_dict(ict_state_dict)
torch.distributed.barrier(get_data_parallel_group())
if mpu.get_data_parallel_rank() == 0:
print(' successfully loaded {}'.format(checkpoint_name))
return model
def get_ict_dataset(use_titles=True, query_in_block_prob=1):
args = get_args()
block_dataset = get_indexed_dataset_(args.data_path, 'mmap', True)
titles_dataset = get_indexed_dataset_(args.titles_data_path, 'mmap', True)
kwargs = dict(
name='full',
block_dataset=block_dataset,
title_dataset=titles_dataset,
data_prefix=args.data_path,
num_epochs=1,
max_num_samples=None,
max_seq_length=args.seq_length,
short_seq_prob=0.0001, # doesn't matter
seed=1,
query_in_block_prob=query_in_block_prob,
use_titles=use_titles
)
dataset = ICTDataset(**kwargs)
return dataset
def get_one_epoch_dataloader(dataset, batch_size=None):
args = get_args()
world_size = mpu.get_data_parallel_world_size()
rank = mpu.get_data_parallel_rank()
if batch_size is None:
batch_size = args.batch_size
global_batch_size = batch_size * world_size
num_workers = args.num_workers
sampler = torch.utils.data.SequentialSampler(dataset)
batch_sampler = DistributedBatchSampler(sampler,
batch_size=global_batch_size,
drop_last=True,
rank=rank,
world_size=world_size)
return torch.utils.data.DataLoader(dataset,
batch_sampler=batch_sampler,
num_workers=num_workers,
pin_memory=True)
if __name__ == "__main__":
initialize_megatron(extra_args_provider=None,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
index_builder = IndexBuilder()
index_builder.build_and_save_index()
......@@ -195,7 +195,6 @@ def _add_training_args(parser):
'by this value.')
group.add_argument('--tensorboard-dir', type=str, default=None,
help='Write TensorBoard logs to this directory.')
group.add_argument('--max-training-rank', type=int, default=None)
return parser
......@@ -343,14 +342,6 @@ def _add_data_args(parser):
help='Path to combined dataset to split.')
group.add_argument('--titles-data-path', type=str, default=None,
help='Path to titles dataset used for ICT')
group.add_argument('--block-data-path', type=str, default=None,
help='Path to pickled BlockData data structure')
group.add_argument('--block-index-path', type=str, default=None,
help='Path to pickled data structure for efficient block indexing')
group.add_argument('--block-top-k', type=int, default=5,
help='Number of blocks to use as top-k during retrieval')
group.add_argument('--async-indexer', action='store_true',
help='Whether the indexer job is running asynchronously with a trainer job')
group.add_argument('--split', type=str, default='969, 30, 1',
help='Comma-separated list of proportions for training,'
' validation, and test split. For example the split '
......@@ -388,7 +379,6 @@ def _add_data_args(parser):
help='Mask loss for the end of document tokens.')
group.add_argument('--query-in-block-prob', type=float, default=0.1,
help='Probability of keeping query in block for ICT dataset')
group.add_argument('--faiss-use-gpu', action='store_true')
return parser
......
......@@ -24,7 +24,6 @@ import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron import mpu
from megatron.mpu.initialize import get_train_group, get_data_parallel_group
from megatron import get_args
from megatron import print_rank_0
......@@ -45,7 +44,7 @@ def check_checkpoint_args(checkpoint_args):
_compare('num_layers')
_compare('hidden_size')
_compare('num_attention_heads')
# _compare('max_position_embeddings')
_compare('max_position_embeddings')
_compare('make_vocab_size_divisible_by')
_compare('padded_vocab_size')
_compare('tokenizer_type')
......@@ -119,14 +118,14 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
print(' successfully saved {}'.format(checkpoint_name))
# Wait so everyone is done (necessary)
torch.distributed.barrier(get_data_parallel_group())
torch.distributed.barrier()
# And update the latest iteration
if torch.distributed.get_rank() == 0:
tracker_filename = get_checkpoint_tracker_filename(args.save)
with open(tracker_filename, 'w') as f:
f.write(str(iteration))
# Wait so everyone is done (not necessary)
torch.distributed.barrier(get_data_parallel_group())
torch.distributed.barrier()
def load_checkpoint(model, optimizer, lr_scheduler):
......@@ -243,7 +242,7 @@ def load_checkpoint(model, optimizer, lr_scheduler):
'exiting ...'.format(checkpoint_name))
sys.exit()
# torch.distributed.barrier()
torch.distributed.barrier()
if mpu.get_data_parallel_rank() == 0:
print(' successfully loaded {}'.format(checkpoint_name))
......
......@@ -25,6 +25,7 @@ from torch.utils.data import Dataset
from megatron import get_tokenizer, get_args
from megatron import mpu
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
from megatron.data.dataset_utils import build_training_sample
from megatron import print_rank_0
......@@ -61,8 +62,6 @@ class BertDataset(Dataset):
self.sep_id = tokenizer.sep
self.mask_id = tokenizer.mask
self.pad_id = tokenizer.pad
from megatron.data.dataset_utils import build_training_sample
self.build_sample_fn = build_training_sample
def __len__(self):
return self.samples_mapping.shape[0]
......@@ -73,13 +72,13 @@ class BertDataset(Dataset):
# Note that this rng state should be numpy and not python since
# python randint is inclusive whereas the numpy one is exclusive.
np_rng = np.random.RandomState(seed=(self.seed + idx))
return self.build_sample_fn(sample, seq_length,
self.max_seq_length, # needed for padding
self.vocab_id_list,
self.vocab_id_to_token_dict,
self.cls_id, self.sep_id,
self.mask_id, self.pad_id,
self.masked_lm_prob, np_rng)
return build_training_sample(sample, seq_length,
self.max_seq_length, # needed for padding
self.vocab_id_list,
self.vocab_id_to_token_dict,
self.cls_id, self.sep_id,
self.mask_id, self.pad_id,
self.masked_lm_prob, np_rng)
def get_indexed_dataset_(data_prefix, data_impl, skip_warmup):
......
......@@ -25,7 +25,7 @@ import numpy as np
from megatron import print_rank_0, get_args
from megatron.data.bert_dataset import get_indexed_dataset_, get_train_valid_test_split_, BertDataset
DATASET_TYPES = ['standard_bert', 'ict', 'realm']
DATASET_TYPES = ['standard_bert', 'ict']
def compile_helper():
"""Compile helper function ar runtime. Make sure this
......@@ -388,7 +388,7 @@ def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
padding_length = max_seq_length - num_tokens
assert padding_length >= 0
assert len(tokentypes) == num_tokens
assert len(masked_positions) == len(masked_labels), (len(masked_positions), len(masked_labels))
assert len(masked_positions) == len(masked_labels)
# Tokens and token types.
filler = [pad_id] * padding_length
......@@ -456,7 +456,6 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
def build_dataset(index, name):
from megatron.data.realm_dataset import ICTDataset
from megatron.data.realm_dataset import REALMDataset
dataset = None
if splits[index + 1] > splits[index]:
# Get the pointer to the original doc-idx so we can set it later.
......@@ -486,13 +485,6 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
query_in_block_prob=args.query_in_block_prob,
**kwargs
)
elif dataset_type == 'realm':
dataset = REALMDataset(
block_dataset=indexed_dataset,
title_dataset=title_dataset,
masked_lm_prob=masked_lm_prob,
**kwargs
)
else:
dataset = BertDataset(
indexed_dataset=indexed_dataset,
......
import itertools
import random
import os
import time
import numpy as np
import torch
from torch.utils.data import Dataset
from megatron import get_tokenizer
from megatron import print_rank_0
from megatron import mpu
from megatron.data import helpers
class InverseClozeDataset(Dataset):
"""Dataset containing sentences and their blocks for an inverse cloze task."""
def __init__(self, name, block_dataset, title_dataset, data_prefix,
num_epochs, max_num_samples, max_seq_length,
query_in_block_prob, short_seq_prob, seed):
self.name = name
self.seed = seed
self.max_seq_length = max_seq_length
self.query_in_block_prob = query_in_block_prob
self.block_dataset = block_dataset
self.title_dataset = title_dataset
self.short_seq_prob = short_seq_prob
self.rng = random.Random(self.seed)
self.samples_mapping = self.get_samples_mapping(
data_prefix, num_epochs, max_num_samples)
self.tokenizer = get_tokenizer()
self.vocab_id_list = list(self.tokenizer.inv_vocab.keys())
self.vocab_id_to_token_list = self.tokenizer.inv_vocab
self.cls_id = self.tokenizer.cls
self.sep_id = self.tokenizer.sep
self.mask_id = self.tokenizer.mask
self.pad_id = self.tokenizer.pad
def __len__(self):
return self.samples_mapping.shape[0]
def __getitem__(self, idx):
start_idx, end_idx, doc_idx, block_idx = self.samples_mapping[idx]
title = list(self.title_dataset[int(doc_idx)])
block = [list(self.block_dataset[i]) for i in range(start_idx, end_idx)]
assert len(block) > 1
# avoid selecting the first or last sentence to be the query.
if len(block) == 2:
rand_sent_idx = int(self.rng.random() > 0.5)
else:
rand_sent_idx = self.rng.randint(1, len(block) - 2)
# keep the query in the context 10% of the time.
if self.rng.random() < self.query_in_block_prob:
query = block[rand_sent_idx].copy()
else:
query = block.pop(rand_sent_idx)
# still need to truncate because blocks are concluded when
# the sentence lengths have exceeded max_seq_length.
query = query[:self.max_seq_length - 2]
block = list(itertools.chain(*block))[:self.max_seq_length - (3 + len(title))]
query_tokens, query_pad_mask = self.concat_and_pad_tokens(query)
block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title)
sample = {
'query_tokens': np.array(query_tokens),
'query_pad_mask': np.array(query_pad_mask),
'block_tokens': np.array(block_tokens),
'block_pad_mask': np.array(block_pad_mask),
'block_data': np.array([start_idx, end_idx, doc_idx, block_idx]).astype(np.int64)
}
return sample
def encode_text(self, text):
return self.tokenizer.tokenize(text)
def decode_tokens(self, token_ids):
tokens = self.tokenizer.tokenizer.convert_ids_to_tokens(token_ids)
return ' '.join(token for token in tokens if token != '[PAD]')
def get_block(self, start_idx, end_idx, doc_idx):
"""Get the IDs for an evidence block plus the title of the corresponding document"""
block = [list(self.block_dataset[i]) for i in range(start_idx, end_idx)]
title = list(self.title_dataset[int(doc_idx)])
block = list(itertools.chain(*block))[:self.max_seq_length - (3 + len(title))]
block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title)
return (block_tokens, block_pad_mask)
def concat_and_pad_tokens(self, tokens, title=None):
"""concat with special tokens and pad sequence to self.max_seq_length"""
tokens = [self.cls_id] + tokens + [self.sep_id]
if title is not None:
# tokens += title + [self.sep_id]
tokens = t
assert len(tokens) <= self.max_seq_length, len(tokens)
num_pad = self.max_seq_length - len(tokens)
pad_mask = [1] * len(tokens) + [0] * num_pad
tokens += [self.pad_id] * num_pad
return tokens, pad_mask
def get_samples_mapping(self, data_prefix, num_epochs, max_num_samples):
if not num_epochs:
if not max_num_samples:
raise ValueError("Need to specify either max_num_samples "
"or num_epochs")
num_epochs = np.iinfo(np.int32).max - 1
if not max_num_samples:
max_num_samples = np.iinfo(np.int64).max - 1
# Filename of the index mapping
indexmap_filename = data_prefix
indexmap_filename += '_{}_indexmap'.format(self.name)
if num_epochs != (np.iinfo(np.int32).max - 1):
indexmap_filename += '_{}ep'.format(num_epochs)
if max_num_samples != (np.iinfo(np.int64).max - 1):
indexmap_filename += '_{}mns'.format(max_num_samples)
indexmap_filename += '_{}msl'.format(self.max_seq_length)
indexmap_filename += '_{}s'.format(self.seed)
indexmap_filename += '.npy'
# Build the indexed mapping if not exist.
if torch.distributed.get_rank() == 0 and \
not os.path.isfile(indexmap_filename):
print(' > WARNING: could not find index map file {}, building '
'the indices on rank 0 ...'.format(indexmap_filename))
# Make sure the types match the helpers input types.
assert self.block_dataset.doc_idx.dtype == np.int64
assert self.block_dataset.sizes.dtype == np.int32
# Build samples mapping
verbose = torch.distributed.get_rank() == 0
start_time = time.time()
print_rank_0(' > building samples index mapping for {} ...'.format(
self.name))
samples_mapping = helpers.build_blocks_mapping(
self.block_dataset.doc_idx,
self.block_dataset.sizes,
self.title_dataset.sizes,
num_epochs,
max_num_samples,
self.max_seq_length-3, # account for added tokens
self.seed,
verbose)
print_rank_0(' > done building samples index mapping')
np.save(indexmap_filename, samples_mapping, allow_pickle=True)
print_rank_0(' > saved the index mapping in {}'.format(
indexmap_filename))
# Make sure all the ranks have built the mapping
print_rank_0(' > elapsed time to build and save samples mapping '
'(seconds): {:4f}'.format(
time.time() - start_time))
# This should be a barrier but nccl barrier assumes
# device_index=rank which is not the case for model
# parallel case
counts = torch.cuda.LongTensor([1])
torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
assert counts[0].item() == torch.distributed.get_world_size(
group=mpu.get_data_parallel_group())
# Load indexed dataset.
print_rank_0(' > loading indexed mapping from {}'.format(
indexmap_filename))
start_time = time.time()
samples_mapping = np.load(indexmap_filename, allow_pickle=True)
print_rank_0(' loaded indexed file in {:3.3f} seconds'.format(
time.time() - start_time))
print_rank_0(' total number of samples: {}'.format(
samples_mapping.shape[0]))
return samples_mapping
import argparse
import itertools
import json
import multiprocessing
import nltk
import sys
import time
import torch
sys.path.insert(0, '../')
sys.path.insert(0, '../../')
from tokenizer.bert_tokenization import FullTokenizer
from data.indexed_dataset import make_builder
class CustomLanguageVars(nltk.tokenize.punkt.PunktLanguageVars):
_period_context_fmt = r"""
\S* # some word material
%(SentEndChars)s # a potential sentence ending
\s* # <-- THIS is what I changed
(?=(?P<after_tok>
%(NonWord)s # either other punctuation
|
(?P<next_tok>\S+) # <-- Normally you would have \s+ here
))"""
class Encoder(object):
splitter = None
tokenizer = None
def __init__(self, args):
self.args = args
def initializer(self):
# Use Encoder class as a container for global data
Encoder.tokenizer = FullTokenizer(self.args.vocab, do_lower_case=True)
spliter = nltk.load("tokenizers/punkt/english.pickle")
if self.args.keep_newlines:
# this prevents punkt from eating newlines after sentences
Encoder.splitter = nltk.tokenize.punkt.PunktSentenceTokenizer(
train_text = spliter._params,
lang_vars = CustomLanguageVars())
else:
Encoder.splitter = spliter
def encode(self, json_line):
text = json.loads(json_line)[self.args.json_key]
if not text:
text = "no text"
doc_ids = []
for sentence in Encoder.splitter.tokenize(text):
tokens = Encoder.tokenizer.tokenize(sentence)
ids = Encoder.tokenizer.convert_tokens_to_ids(tokens)
if len(ids) > 0:
doc_ids.append(ids)
else:
print("no ids!", flush=True)
tokens = Encoder.tokenizer.tokenize("no text")
ids = Encoder.tokenizer.convert_tokens_to_ids(tokens)
doc_ids.append(ids)
if self.args.flatten and len(doc_ids) > 1:
doc_ids = [list(itertools.chain(*doc_ids))]
return doc_ids, len(json_line)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--input', type=str, help='Path to input JSON')
parser.add_argument('--vocab', type=str, help='Path to vocab.txt')
parser.add_argument('--flatten', action='store_true', help='Path to input JSON')
parser.add_argument('--json-key', type=str, default='text',
help='Key to extract from json')
parser.add_argument('--output-prefix', type=str, help='Path to binary output file without suffix')
parser.add_argument('--workers', type=int, default=20,
help='Number of worker processes to launch')
parser.add_argument('--log-interval', type=int, default=100,
help='Interval between progress updates')
parser.add_argument('--keep-newlines', action='store_true',
help='Keep newlines between sentences.')
parser.add_argument('--dataset-impl', type=str, default='mmap',
choices=['lazy', 'cached', 'mmap'])
args = parser.parse_args()
args.keep_empty = False
startup_start = time.time()
print("Opening", args.input)
fin = open(args.input, 'r', encoding='utf-8')
nltk.download("punkt", quiet=True)
encoder = Encoder(args)
tokenizer = FullTokenizer(args.vocab, do_lower_case=True)
pool = multiprocessing.Pool(args.workers, initializer=encoder.initializer)
encoded_docs = pool.imap(encoder.encode, fin, 25)
print(f"Vocab size: {tokenizer.vocab_size()}")
output_bin_file = "{}.bin".format(args.output_prefix)
output_idx_file = "{}.idx".format(args.output_prefix)
builder = make_builder(output_bin_file,
impl=args.dataset_impl,
vocab_size=tokenizer.vocab_size())
startup_end = time.time()
proc_start = time.time()
total_bytes_processed = 0
print("Time to startup:", startup_end - startup_start)
for i, (doc, bytes_processed) in enumerate(encoded_docs, start=1):
total_bytes_processed += bytes_processed
for sentence in doc:
#print(sentence)
#print(tokenizer.convert_ids_to_tokens(sentence))
builder.add_item(torch.IntTensor(sentence))
builder.end_document()
if i % args.log_interval == 0:
current = time.time()
elapsed = current - proc_start
mbs = total_bytes_processed/elapsed/1024/1024
print(f"Processed {i} documents",
f"({i/elapsed} docs/s, {mbs} MB/s).",
file=sys.stderr)
builder.finalize(output_idx_file)
if __name__ == '__main__':
main()
......@@ -5,64 +5,6 @@ import numpy as np
from torch.utils.data import Dataset
from megatron import get_tokenizer
from megatron.data.realm_dataset_utils import build_realm_training_sample, get_block_samples_mapping, join_str_list
class REALMDataset(Dataset):
"""Dataset containing simple masked sentences for masked language modeling.
The dataset should yield sentences just like the regular BertDataset
However, this dataset also needs to be able to return a set of blocks
given their start and end indices.
Presumably
"""
def __init__(self, name, block_dataset, title_dataset, data_prefix,
num_epochs, max_num_samples, masked_lm_prob,
max_seq_length, short_seq_prob, seed):
self.name = name
self.seed = seed
self.max_seq_length = max_seq_length
self.masked_lm_prob = masked_lm_prob
self.block_dataset = block_dataset
self.title_dataset = title_dataset
self.short_seq_prob = short_seq_prob
self.rng = random.Random(self.seed)
self.samples_mapping = get_block_samples_mapping(
block_dataset, title_dataset, data_prefix, num_epochs,
max_num_samples, max_seq_length, seed, name)
self.tokenizer = get_tokenizer()
self.vocab_id_list = list(self.tokenizer.inv_vocab.keys())
self.vocab_id_to_token_list = self.tokenizer.inv_vocab
self.cls_id = self.tokenizer.cls
self.sep_id = self.tokenizer.sep
self.mask_id = self.tokenizer.mask
self.pad_id = self.tokenizer.pad
def __len__(self):
return self.samples_mapping.shape[0]
def __getitem__(self, idx):
start_idx, end_idx, doc_idx, block_idx = self.samples_mapping[idx]
block = [list(self.block_dataset[i]) for i in range(start_idx, end_idx)]
assert len(block) > 1
np_rng = np.random.RandomState(seed=(self.seed + idx))
sample = build_realm_training_sample(block,
self.max_seq_length,
self.vocab_id_list,
self.vocab_id_to_token_list,
self.cls_id,
self.sep_id,
self.mask_id,
self.pad_id,
self.masked_lm_prob,
np_rng)
sample.update({'query_block_indices': np.array([block_idx]).astype(np.int64)})
return sample
class ICTDataset(Dataset):
......@@ -95,6 +37,7 @@ class ICTDataset(Dataset):
return self.samples_mapping.shape[0]
def __getitem__(self, idx):
"""Get an ICT example of a pseudo-query and the block of text from which it was extracted"""
start_idx, end_idx, doc_idx, block_idx = self.samples_mapping[idx]
if self.use_titles:
title = list(self.title_dataset[int(doc_idx)])
......@@ -107,7 +50,7 @@ class ICTDataset(Dataset):
rand_sent_idx = self.rng.randint(0, len(block) - 1)
# keep the query in the context 10% of the time.
# keep the query in the context query_in_block_prob fraction of the time.
if self.rng.random() < self.query_in_block_prob:
query = block[rand_sent_idx].copy()
else:
......@@ -134,30 +77,12 @@ class ICTDataset(Dataset):
def encode_text(self, text):
return self.tokenizer.tokenize(text)
def decode_tokens(self, token_ids, hardcore=False):
def decode_tokens(self, token_ids):
"""Utility function to help with debugging mostly"""
tokens = self.tokenizer.tokenizer.convert_ids_to_tokens(token_ids)
exclude_list = ['[PAD]', '[CLS]']
if hardcore:
extra_exclude = ['[SEP]']
exclude_list.extend(extra_exclude)
non_pads = [t for t in tokens if t not in exclude_list]
joined_strs = join_str_list(non_pads)
if hardcore:
escape_chars = ['+', '-', '&', '!', '(', ')', '{', '}', '[', ']', '^', '"', '~', '*', '?', ':', '/']
skip_me = False
joined_strs = list(joined_strs)
joined_strs = [s for s in joined_strs if s != '\\']
for i, c in enumerate(joined_strs):
if skip_me:
skip_me = False
continue
if c in escape_chars:
joined_strs.insert(i, '\\')
skip_me = True
joined_strs = ''.join(joined_strs)
if len(joined_strs) < 3:
joined_strs += 'text here'
return joined_strs
def get_block(self, start_idx, end_idx, doc_idx):
"""Get the IDs for an evidence block plus the title of the corresponding document"""
......@@ -170,13 +95,14 @@ class ICTDataset(Dataset):
return (block_tokens, block_pad_mask)
def get_null_block(self):
"""Get empty block and title - used in REALM pretraining"""
block, title = [], []
block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title)
return (block_tokens, block_pad_mask)
def concat_and_pad_tokens(self, tokens, title=None):
"""concat with special tokens and pad sequence to self.max_seq_length"""
"""Concat with special tokens and pad sequence to self.max_seq_length"""
if title is None:
tokens = [self.cls_id] + tokens + [self.sep_id]
else:
......
import itertools
import os
import random
import time
import numpy as np
import spacy
import torch
from megatron.data.dataset_utils import create_masked_lm_predictions, pad_and_convert_to_numpy
from megatron import get_tokenizer, print_rank_0, mpu
SPACY_NER = spacy.load('en_core_web_lg')
def build_realm_training_sample(sample, max_seq_length,
vocab_id_list, vocab_id_to_token_dict,
cls_id, sep_id, mask_id, pad_id,
masked_lm_prob, np_rng):
tokens = list(itertools.chain(*sample))[:max_seq_length - 2]
tokens, tokentypes = create_single_tokens_and_tokentypes(tokens, cls_id, sep_id)
try:
masked_tokens, masked_positions, masked_labels = salient_span_mask(tokens, mask_id)
except TypeError:
# this means the above returned None, and None isn't iterable.
# TODO: consider coding style.
max_predictions_per_seq = masked_lm_prob * max_seq_length
masked_tokens, masked_positions, masked_labels, _ = create_masked_lm_predictions(
tokens, vocab_id_list, vocab_id_to_token_dict, masked_lm_prob,
cls_id, sep_id, mask_id, max_predictions_per_seq, np_rng)
tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np \
= pad_and_convert_to_numpy(masked_tokens, tokentypes, masked_positions,
masked_labels, pad_id, max_seq_length)
train_sample = {
'tokens': tokens_np,
'labels': labels_np,
'loss_mask': loss_mask_np,
'pad_mask': padding_mask_np
}
return train_sample
def create_single_tokens_and_tokentypes(_tokens, cls_id, sep_id):
tokens = []
tokens.append(cls_id)
tokens.extend(list(_tokens))
tokens.append(sep_id)
tokentypes = [0] * len(tokens)
return tokens, tokentypes
from megatron import print_rank_0, mpu
def join_str_list(str_list):
......@@ -63,69 +18,6 @@ def join_str_list(str_list):
return result
def id_to_str_pos_map(token_ids, tokenizer):
"""Given a list of ids, return a list of integers which correspond to the starting index
of the corresponding token in the original string (with spaces, without artifacts e.g. ##)"""
token_strs = tokenizer.tokenizer.convert_ids_to_tokens(token_ids)
pos_map = [0]
for i in range(len(token_strs) - 1):
len_prev = len(token_strs[i])
# do not add the length of the "##"
if token_strs[i].startswith("##"):
len_prev -= 2
# add the length of the space if needed
if token_strs[i + 1].startswith("##"):
pos_map.append(pos_map[-1] + len_prev)
else:
pos_map.append(pos_map[-1] + len_prev + 1)
# make sure total size is correct
offset = -2 if token_strs[-1].startswith("##") else 0
total_len = pos_map[-1] + len(token_strs[-1]) + offset
assert total_len == len(join_str_list(token_strs)) - 1, (total_len, len(join_str_list(token_strs)))
return pos_map
def salient_span_mask(tokens, mask_id):
"""Creates the predictions for the masked LM objective.
Note: Tokens here are vocab ids and not text tokens."""
tokenizer = get_tokenizer()
tokens_str = join_str_list(tokenizer.tokenizer.convert_ids_to_tokens(tokens))
# need to get all named entities
entities = SPACY_NER(tokens_str).ents
entities = [e for e in entities if e.text != "CLS"]
if len(entities) == 0:
return None
entity_idx = np.random.randint(0, len(entities))
selected_entity = entities[entity_idx]
token_pos_map = id_to_str_pos_map(tokens, tokenizer)
mask_start = mask_end = 0
set_mask_start = False
while mask_end < len(token_pos_map) and token_pos_map[mask_end] < selected_entity.end_char:
if token_pos_map[mask_start] > selected_entity.start_char:
set_mask_start = True
if not set_mask_start:
mask_start += 1
mask_end += 1
masked_positions = list(range(mask_start - 1, mask_end))
labels = []
output_tokens = tokens.copy()
for id_idx in masked_positions:
labels.append(tokens[id_idx])
output_tokens[id_idx] = mask_id
#print("-" * 100 + '\n',
# "TOKEN STR\n", tokens_str + '\n',
# "SELECTED ENTITY\n", selected_entity.text + '\n',
# "OUTPUT\n", join_str_list(tokenizer.tokenizer.convert_ids_to_tokens(output_tokens)), flush=True)
return output_tokens, masked_positions, labels
def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epochs,
max_num_samples, max_seq_length, seed, name):
if not num_epochs:
......
from collections import defaultdict
import os
import pickle
import shutil
import faiss
import numpy as np
import torch
from megatron import get_args, mpu
def detach(tensor):
return tensor.detach().cpu().numpy()
class BlockData(object):
def __init__(self):
self.embed_data = dict()
self.meta_data = dict()
self.temp_dir_name = 'temp_block_data'
def state(self):
return {
'embed_data': self.embed_data,
'meta_data': self.meta_data
}
def clear(self):
"""Clear the data structures to save memory"""
self.embed_data = dict()
self.meta_data = dict()
@classmethod
def load_from_file(cls, fname):
print("\n> Unpickling block data", flush=True)
state_dict = pickle.load(open(fname, 'rb'))
print(">> Finished unpickling block data\n", flush=True)
new_index = cls()
new_index.embed_data = state_dict['embed_data']
new_index.meta_data = state_dict['meta_data']
return new_index
def add_block_data(self, block_indices, block_embeds, block_metas, allow_overwrite=False):
for idx, embed, meta in zip(block_indices, block_embeds, block_metas):
if not allow_overwrite and idx in self.embed_data:
raise ValueError("Unexpectedly tried to overwrite block data")
self.embed_data[idx] = np.float16(embed)
self.meta_data[idx] = meta
def save_shard(self, rank):
if not os.path.isdir(self.temp_dir_name):
os.mkdir(self.temp_dir_name)
# save the data for each shard
with open('{}/{}.pkl'.format(self.temp_dir_name, rank), 'wb') as data_file:
pickle.dump(self.state(), data_file)
def consolidate_shards_and_save(self, ignore_shard=0):
"""Combine all the shards made using self.save_shard()"""
fnames = os.listdir(self.temp_dir_name)
for fname in fnames:
with open('{}/{}'.format(self.temp_dir_name, fname), 'rb') as f:
data = pickle.load(f)
old_size = len(self.embed_data)
shard_size = len(data['embed_data'])
self.embed_data.update(data['embed_data'])
self.meta_data.update(data['meta_data'])
# assert (len(self.embed_data) == old_size + shard_size) or (str(ignore_shard) in fname)
args = get_args()
with open(args.block_data_path, 'wb') as final_file:
pickle.dump(self.state(), final_file)
shutil.rmtree(self.temp_dir_name, ignore_errors=True)
class FaissMIPSIndex(object):
def __init__(self, index_type, embed_size, use_gpu=False):
self.index_type = index_type
self.embed_size = embed_size
self.use_gpu = use_gpu
self.id_map = dict()
# alsh
self.m = 5
self.u = 0.99
self.max_norm = None
self.block_mips_index = None
self._set_block_index()
def _set_block_index(self):
INDEX_TYPES = ['flat_ip']
if self.index_type not in INDEX_TYPES:
raise ValueError("Invalid index type specified")
print("\n> Building index", flush=True)
self.block_mips_index = faiss.index_factory(self.embed_size, 'Flat', faiss.METRIC_INNER_PRODUCT)
if not self.use_gpu:
self.block_mips_index = faiss.IndexIDMap(self.block_mips_index)
print(">> Finished building index", flush=True)
if self.use_gpu:
res = faiss.StandardGpuResources()
# self.block_mips_index = faiss.index_cpu_to_gpu(res, device, self.block_mips_index)
config = faiss.GpuIndexFlatConfig()
config.device = torch.cuda.current_device()
config.useFloat16 = True
self.block_mips_index = faiss.GpuIndexFlat(res, self.block_mips_index, config)
print(">>> Loaded Faiss index on GPU {}\n".format(self.block_mips_index.getDevice()), flush=True)
def reset_index(self):
self._set_block_index()
def add_block_embed_data(self, all_block_data, clear_block_data=False):
"""Add the embedding of each block to the underlying FAISS index"""
block_indices, block_embeds = zip(*all_block_data.embed_data.items())
if self.use_gpu:
for i, idx in enumerate(block_indices):
self.id_map[i] = idx
if clear_block_data:
all_block_data.clear()
if self.use_gpu:
self.block_mips_index.add(np.float32(np.array(block_embeds)))
else:
self.block_mips_index.add_with_ids(np.float32(np.array(block_embeds)), np.array(block_indices))
def search_mips_index(self, query_embeds, top_k, reconstruct=True):
"""Get the top-k blocks by the index distance metric.
:param reconstruct: if True: return a [num_queries x k x embed_dim] array of blocks
if False: return [num_queries x k] array of distances, and another for indices
"""
if self.index_type == 'flat_l2':
query_embeds = self.alsh_query_preprocess_fn(query_embeds)
query_embeds = np.float32(detach(query_embeds))
# query_embeds = query_embeds.float()
with torch.no_grad():
if reconstruct:
top_k_block_embeds = self.block_mips_index.search_and_reconstruct(query_embeds, top_k)
return top_k_block_embeds
else:
distances, block_indices = self.block_mips_index.search(query_embeds, top_k)
if self.use_gpu:
fresh_indices = np.zeros(block_indices.shape)
for i in range(block_indices.shape[0]):
for j in range(block_indices.shape[1]):
fresh_indices[i, j] = self.id_map[block_indices[i, j]]
block_indices = fresh_indices
return distances, block_indices
# functions below are for ALSH, which currently isn't being used
def get_norm_powers_and_halves_array(self, embeds):
norm = np.linalg.norm(embeds, axis=1)
norm_powers = [np.multiply(norm, norm)] # squared L2 norms of all
for i in range(self.m - 1):
norm_powers.append(np.multiply(norm_powers[-1], norm_powers[-1]))
# [num_blocks x self.m]
norm_powers = np.transpose(np.array(norm_powers))
halves_array = 0.5 * np.ones(norm_powers.shape)
return norm_powers, halves_array
def alsh_block_preprocess_fn(self, block_embeds):
block_embeds = np.array(block_embeds)
if self.max_norm is None:
self.max_norm = max(np.linalg.norm(block_embeds, axis=1))
if self.max_norm > 1:
block_embeds = self.u / self.max_norm * block_embeds
norm_powers, halves_array = self.get_norm_powers_and_halves_array(block_embeds)
# P'(S(x)) for all x in block_embeds
return np.float32(np.concatenate((block_embeds, norm_powers, halves_array), axis=1))
def alsh_query_preprocess_fn(self, query_embeds):
max_norm = max(np.linalg.norm(query_embeds, axis=1))
if max_norm > 1:
query_embeds = self.u / max_norm * query_embeds
norm_powers, halves_array = self.get_norm_powers_and_halves_array(query_embeds)
# Q'(S(x)) for all x in query_embeds
return np.float32(np.concatenate((query_embeds, halves_array, norm_powers), axis=1))
# This was the original hashing scheme, not used anymore
class RandProjectionLSHIndex(object):
"""Class for holding hashed data"""
def __init__(self, embed_size, num_buckets, whiten=True, seed=0):
np.random.seed(seed)
self.hash_data = defaultdict(list)
hash_matrix = 2 * np.random.rand(embed_size, int(num_buckets / 2)) - 1
self.hash_matrix = hash_matrix / np.linalg.norm(hash_matrix, axis=0).reshape(1, -1)
self.embed_mean = None
self.embed_whitener = None
self.whiten = whiten
def state(self):
state = {
'hash_data': self.hash_data,
'hash_matrix': self.hash_matrix,
'embed_mean': self.embed_mean,
'embed_whitener': self.embed_whitener,
}
return state
def save_to_file(self):
args = get_args()
with open(args.block_index_path, 'wb') as index_file:
pickle.dump(self.state(), index_file)
@classmethod
def load_from_file(cls, fname):
print(" > Unpickling block hash data")
state_dict = pickle.load(open(fname, 'rb'))
print(" > Finished unpickling")
hash_matrix = state_dict['hash_matrix']
new_index = cls(hash_matrix.shape[0], hash_matrix.shape[1] * 2)
new_index.hash_data = state_dict['hash_data']
new_index.embed_mean = state_dict.get('embed_mean')
new_index.embed_whitener = state_dict.get('embed_whitener')
new_index.hash_matrix = hash_matrix
return new_index
def get_block_bucket(self, hash):
return self.hash_data[hash]
def hash_embeds(self, embeds, write_block_data=None):
"""Hash a tensor of embeddings using a random projection matrix"""
embed_scores_pos = torch.matmul(embeds, torch.cuda.FloatTensor(self.hash_matrix).type(embeds.dtype))
embed_scores = torch.cat((embed_scores_pos, -embed_scores_pos), axis=1)
embed_hashes = detach(torch.argmax(embed_scores, axis=1))
if write_block_data is not None:
for hash, indices in zip(embed_hashes, write_block_data):
self.hash_data[hash].append(indices)
return embed_hashes
def hash_whitened_block_embeds(self, block_data):
"""Transform all block embeds to have zero mean and unit covariance
when treated as samples from a distribution"""
block_idx, all_embeds = zip(*block_data.embed_data.items())
arr_embeds = np.transpose(np.array(all_embeds))
mean = np.mean(arr_embeds, axis=1).reshape(-1, 1)
centered = arr_embeds - mean
inv_cov = np.linalg.inv(np.cov(arr_embeds))
whitener = np.transpose(np.linalg.cholesky(inv_cov))
whitened = np.float16(np.transpose(whitener.dot(centered)))
self.embed_mean = mean.reshape(-1)
self.embed_whitener = whitener
self.hash_data = defaultdict(list)
batch_size = 16384
i = 0
args = get_args()
with torch.no_grad():
while True:
if args.debug:
print(i, flush=True)
batch_slice = slice(i * batch_size, (i + 1) * batch_size)
batch_embed = torch.cuda.HalfTensor(whitened[batch_slice])
batch_meta = [block_data.meta_data[idx] for idx in block_idx[batch_slice]]
if len(batch_meta) == 0:
break
self.hash_embeds(batch_embed, batch_meta)
i += 1
def exact_mips_equals(self, query_embeds, all_block_data, norm_blocks):
"""For each query, determine whether the mips block is in the correct hash bucket"""
shuffled_block_idx, block_embeds = zip(*all_block_data.items())
if norm_blocks:
block_embeds = block_embeds / np.linalg.norm(block_embeds, axis=1).reshape(-1, 1)
with torch.no_grad():
query_hashes = self.hash_embeds(query_embeds)
# [num_query x num_blocks]
inner_products = torch.matmul(torch.cuda.HalfTensor(query_embeds),
torch.cuda.HalfTensor(np.transpose(np.array(block_embeds))))
max_inner_product_idxes = detach(torch.argmax(inner_products, axis=1))
best_blocks = np.array([all_block_data[shuffled_block_idx[idx]] for idx in max_inner_product_idxes])
best_block_hashes = self.hash_embeds(best_blocks)
print('Query hashes: ', query_hashes)
print('Block hashes: ', best_block_hashes)
equal_arr = np.equal(query_hashes, best_block_hashes).astype(int)
# array of zeros and ones which can be used for counting success
return equal_arr
def exact_mips_test(self, num_queries, all_block_data, norm_blocks):
if self.whiten:
if self.embed_mean is None:
self.hash_whitened_block_embeds(all_block_data)
embed_size = self.hash_matrix.shape[0]
query_embeds = np.random.multivariate_normal(np.zeros(embed_size), np.eye(embed_size), num_queries)
query_embeds = query_embeds / np.linalg.norm(query_embeds, axis=1).reshape(-1, 1)
else:
block_idx, all_embeds = zip(*all_block_data.items())
arr_embeds = np.transpose(np.array(all_embeds))
mean = np.mean(arr_embeds, axis=1).reshape(-1, 1)
cov = np.cov(arr_embeds)
query_embeds = np.random.multivariate_normal(mean, cov, num_queries)
equal_arr = self.exact_mips_equals(query_embeds, all_block_data, norm_blocks)
print("Num correct: ", sum(equal_arr), " Fraction correct: ", sum(equal_arr) / equal_arr.size)
print(equal_arr)
......@@ -19,7 +19,7 @@ import math
import torch
from .samplers import DistributedBatchSampler
from .datasets import json_dataset, csv_dataset, split_ds, ConcatDataset, SplitDataset, bert_sentencepair_dataset, GPT2Dataset, InverseClozeDataset
from .datasets import json_dataset, csv_dataset, split_ds, ConcatDataset, SplitDataset, bert_sentencepair_dataset, GPT2Dataset
from .lazy_loader import exists_lazy, make_lazy, lazy_array_loader
from .tokenization import Tokenization, CommandToken, Tokenizer, CharacterLevelTokenizer, BertWordPieceTokenizer, GPT2BPETokenizer, make_tokenizer
from . import corpora
......@@ -126,10 +126,7 @@ def make_dataset(path, seq_length, text_key, label_key, lazy=False, process_fn=N
ds = split_ds(ds, split)
if 'bert' in ds_type.lower():
presplit_sentences = kwargs['presplit_sentences'] if 'presplit_sentences' in kwargs else False
if 'ict' in ds_type.lower():
dstype = InverseClozeDataset
else:
dstype = bert_sentencepair_dataset
dstype = bert_sentencepair_dataset
ds = [dstype(d, max_seq_len=seq_length, presplit_sentences=presplit_sentences)
if d is not None else None for d in ds]
elif ds_type.lower() == 'gpt2':
......@@ -137,10 +134,7 @@ def make_dataset(path, seq_length, text_key, label_key, lazy=False, process_fn=N
else:
if 'bert' in ds_type.lower():
presplit_sentences = kwargs['presplit_sentences'] if 'presplit_sentences' in kwargs else False
if 'ict' in ds_type.lower():
dstype = InverseClozeDataset
else:
dstype = bert_sentencepair_dataset
dstype = bert_sentencepair_dataset
ds = dstype(ds, max_seq_len=seq_length, presplit_sentences=presplit_sentences)
elif ds_type.lower() == 'gpt2':
ds = GPT2Dataset(ds, max_seq_len=seq_length)
......
......@@ -46,9 +46,11 @@ class DataConfig:
def make_data_loader(dataset, batch_size, args):
if args.shuffle:
shuffle = args.shuffle
if shuffle:
sampler = data_utils.samplers.RandomSampler(
dataset, replacement=True, num_samples=batch_size*args.train_iters)
dataset, replacement=True, num_samples=batch_size * args.train_iters)
else:
sampler = torch.utils.data.SequentialSampler(dataset)
world_size = torch.distributed.get_world_size(
......
......@@ -18,7 +18,6 @@ import os
import time
from operator import itemgetter
from bisect import bisect_right
import itertools
import json
import csv
import math
......@@ -337,6 +336,7 @@ class json_dataset(data.Dataset):
all_strs (list): list of all strings from the dataset
all_labels (list): list of all labels from the dataset (if they have it)
"""
def __init__(self, path, tokenizer=None, preprocess_fn=None, binarize_sent=False,
text_key='sentence', label_key='label', loose_json=False, **kwargs):
self.is_lazy = False
......@@ -354,6 +354,9 @@ class json_dataset(data.Dataset):
self.X.append(s)
self.Y.append(j[label_key])
if binarize_sent:
self.Y = binarize_labels(self.Y, hard=binarize_sent)
def SetTokenizer(self, tokenizer):
if tokenizer is None:
self.using_tokenizer = False
......@@ -642,8 +645,10 @@ class bert_sentencepair_dataset(data.Dataset):
np_rng = np.random.RandomState(seed=[rng.randint(0, 2**32 - 1) for _ in range(16)])
# get seq length
target_seq_length = self.max_seq_len
short_seq = False
if rng.random() < self.short_seq_prob:
target_seq_length = rng.randint(2, target_seq_length)
short_seq = True
# get sentence pair and label
is_random_next = None
......@@ -817,7 +822,7 @@ class bert_sentencepair_dataset(data.Dataset):
def mask_token(self, idx, tokens, types, vocab_words, rng):
"""
helper function to mask `idx` token from `tokens` according to
section 3.1.1 of https://arxiv.org/pdf/1810.04805.pdf
section 3.3.1 of https://arxiv.org/pdf/1810.04805.pdf
"""
label = tokens[idx]
if rng.random() < 0.8:
......@@ -876,185 +881,3 @@ class bert_sentencepair_dataset(data.Dataset):
mask_labels[idx] = label
return (output_tokens, output_types), mask, mask_labels, pad_mask
class InverseClozeDataset(data.Dataset):
"""
Dataset containing sentences and various 'blocks' for an inverse cloze task.
Arguments:
ds (Dataset or array-like): data corpus to use for training
max_seq_len (int): maximum sequence length to use for an input sentence
short_seq_prob (float): Proportion of input sentences purposefully shorter than max_seq_len
dataset_size (int): number of input sentences in the dataset.
"""
def __init__(self,
ds,
max_seq_len=512,
short_seq_prob=.01,
dataset_size=None,
presplit_sentences=False,
weighted=True,
**kwargs):
self.ds = ds
self.ds_len = len(self.ds)
self.tokenizer = self.ds.GetTokenizer()
self.vocab_words = list(self.tokenizer.text_token_vocab.values())
self.ds.SetTokenizer(None)
self.max_seq_len = max_seq_len
self.short_seq_prob = short_seq_prob
self.dataset_size = dataset_size
if self.dataset_size is None:
# this is wrong
self.dataset_size = self.ds_len * (self.ds_len-1)
self.presplit_sentences = presplit_sentences
if not self.presplit_sentences:
nltk.download('punkt', download_dir="./nltk")
self.weighted = weighted
if self.weighted:
if hasattr(self.ds, 'is_lazy') and self.ds.is_lazy:
lens = np.array(self.ds.lens)
else:
lens = np.array([len(d['text']) if isinstance(d, dict) else len(d) for d in self.ds])
self.total_len = np.sum(lens)
self.weighting = list(accumulate(lens))
else:
self.weighting = None
def get_weighted_samples(self, np_rng):
if self.weighting is not None:
idx = np_rng.randint(self.total_len)
return bisect_right(self.weighting, idx)
else:
return np_rng.randint(self.ds_len - 1)
def __len__(self):
return self.dataset_size
def __getitem__(self, idx):
# get rng state corresponding to index (allows deterministic random pair)
rng = random.Random(idx + 1000)
np_rng = np.random.RandomState(seed=[rng.randint(0, 2**32-1) for _ in range(16)])
# get seq length. Save 2 tokens for beginning and end
target_seq_length = self.max_seq_len - 2
if rng.random() < self.short_seq_prob:
target_seq_length = rng.randint(5, target_seq_length)
input_data, context_data = self.get_input_and_context(target_seq_length, rng, np_rng)
input_tokens, input_token_types, input_pad_mask = input_data
context_tokens, context_token_types, context_pad_mask = context_data
sample = {
'input_text': np.array(input_tokens),
'query_types': np.array(input_token_types),
'input_pad_mask': np.array(input_pad_mask),
'context_text': np.array(context_tokens),
'block_types': np.array(context_token_types),
'context_pad_mask': np.array(context_pad_mask)
}
return sample
def get_sentence_split_doc(self, idx):
"""fetch document at index idx and split into sentences"""
document = self.ds[idx]
if isinstance(document, dict):
document = document['text']
lines = document.split('\n')
if self.presplit_sentences:
return [line for line in lines if line]
rtn = []
for line in lines:
if line != '':
rtn.extend(tokenize.sent_tokenize(line))
return rtn
def sentence_tokenize(self, sent, sentence_num=0, beginning=False, ending=False):
"""tokenize sentence and get token types"""
tokens = self.tokenizer.EncodeAsIds(sent).tokenization
str_type = 'str' + str(sentence_num)
token_types = [self.tokenizer.get_type(str_type).Id]*len(tokens)
return tokens, token_types
def get_input_and_context(self, target_seq_length, rng, np_rng):
"""fetches a sentence and its surrounding context"""
num_tries = 0
while num_tries < 20:
num_tries += 1
doc = None
while doc is None:
doc_idx = self.get_weighted_samples(np_rng)
# doc is a list of sentences
doc = self.get_sentence_split_doc(doc_idx)
if not doc:
doc = None
# set up and tokenize the entire selected document
num_sentences = len(doc)
padless_max_len = self.max_seq_len - 2
# select a random sentence from the document as input
# TODO: consider adding multiple input sentences.
input_sentence_idx = rng.randint(0, num_sentences - 1)
tokens, token_types = self.sentence_tokenize(doc[input_sentence_idx], 0)
input_tokens, input_token_types = tokens[:target_seq_length], token_types[:target_seq_length]
if not len(input_tokens) > 0:
continue
context_tokens, context_token_types = [], []
# 10% of the time, the input sentence is left in the context.
# The other 90% of the time, remove it.
if rng.random() < 0.1:
context_tokens = input_tokens.copy()
context_token_types = input_token_types.copy()
# parameters for examining sentences to add to the context
view_preceding = True
view_radius = 1
while len(context_tokens) < padless_max_len:
# keep adding sentences while the context can accommodate more.
if view_preceding:
examine_idx = input_sentence_idx - view_radius
if examine_idx >= 0:
new_tokens, new_token_types = self.sentence_tokenize(doc[examine_idx], 0)
context_tokens = new_tokens + context_tokens
context_token_types = new_token_types + context_token_types
else:
examine_idx = input_sentence_idx + view_radius
if examine_idx < num_sentences:
new_tokens, new_token_types = self.sentence_tokenize(doc[examine_idx], 0)
context_tokens += new_tokens
context_token_types += new_token_types
view_radius += 1
view_preceding = not view_preceding
if view_radius > num_sentences:
break
# assemble the tokens and token types of the context
context_tokens = context_tokens[:padless_max_len]
context_token_types = context_token_types[:padless_max_len]
if not len(context_tokens) > 0:
continue
# concatenate 'CLS' and 'SEP' tokens and add extra token types
input_tokens, input_token_types, input_pad_mask = self.concat_and_pad_tokens(
input_tokens, input_token_types)
context_tokens, context_token_types, context_pad_mask = self.concat_and_pad_tokens(
context_tokens, context_token_types)
return (input_tokens, input_token_types, input_pad_mask), \
(context_tokens, context_token_types, context_pad_mask)
else:
raise RuntimeError("Could not get a valid data point from InverseClozeDataset")
def concat_and_pad_tokens(self, tokens, token_types):
"""concat with special tokens and pad sequence to self.max_seq_len"""
tokens = [self.tokenizer.get_command('ENC').Id] + tokens + [self.tokenizer.get_command('sep').Id]
token_types = [token_types[0]] + token_types + [token_types[0]]
assert len(tokens) <= self.max_seq_len
num_pad = max(0, self.max_seq_len - len(tokens))
pad_mask = [0] * len(tokens) + [1] * num_pad
tokens += [self.tokenizer.get_command('pad').Id] * num_pad
token_types += [token_types[0]] * num_pad
return tokens, token_types, pad_mask
......@@ -164,14 +164,14 @@ class _Timer:
def start(self):
"""Start the timer."""
assert not self.started_, 'timer has already been started'
# torch.cuda.synchronize()
torch.cuda.synchronize()
self.start_time = time.time()
self.started_ = True
def stop(self):
"""Stop the timer."""
assert self.started_, 'timer is not started'
# torch.cuda.synchronize()
torch.cuda.synchronize()
self.elapsed_ += (time.time() - self.start_time)
self.started_ = False
......
......@@ -15,7 +15,6 @@
"""Megatron initialization."""
import datetime
import random
import os
......@@ -62,7 +61,8 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
_write_args_to_tensorboard()
def init_distributed():
def _initialize_distributed():
"""Initialize torch.distributed and mpu."""
args = get_args()
device_count = torch.cuda.device_count()
......@@ -102,13 +102,6 @@ def init_distributed():
world_size=args.world_size, rank=args.rank,
init_method=init_method)
def _initialize_distributed():
"""Initialize torch.distributed and mpu."""
init_distributed()
args = get_args()
device_count = torch.cuda.device_count()
# Set the model-parallel / data-parallel communicators.
if device_count > 0:
mpu.initialize_model_parallel(args.model_parallel_size)
......
......@@ -125,17 +125,12 @@ class BertModel(MegatronModule):
scaled_init_method = scaled_init_method_normal(args.init_method_std,
args.num_layers)
max_pos_embeds = None
if not add_binary_head and ict_head_size is None:
max_pos_embeds = 2 * args.seq_length
self.language_model, self._language_model_key = get_language_model(
attention_mask_func=bert_attention_mask_func,
num_tokentypes=num_tokentypes,
add_pooler=add_pooler,
init_method=init_method,
scaled_init_method=scaled_init_method,
max_pos_embeds=max_pos_embeds)
scaled_init_method=scaled_init_method)
if not self.add_ict_head:
self.lm_head = BertLMHead(
......
......@@ -56,7 +56,7 @@ class DistributedDataParallel(MegatronModule):
if not no_scale and not reduce_after:
coalesced /= dist.get_world_size(group=self.data_parallel_group)
dist.all_reduce(coalesced, group=self.data_parallel_group)
# torch.cuda.synchronize()
torch.cuda.synchronize()
if not no_scale and reduce_after:
coalesced /= dist.get_world_size(group=self.data_parallel_group)
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
......
......@@ -44,7 +44,7 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
init_method, scaled_init_method, max_pos_embeds=None):
init_method, scaled_init_method):
"""Build language model and return along with the key to save."""
args = get_args()
......@@ -60,8 +60,7 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
init_method=init_method,
output_layer_init_method=scaled_init_method,
num_tokentypes=num_tokentypes,
add_pooler=add_pooler,
max_pos_embeds=max_pos_embeds)
add_pooler=add_pooler)
# key used for checkpoints.
language_model_key = 'language_model'
......@@ -268,8 +267,7 @@ class TransformerLanguageModel(MegatronModule):
init_method,
output_layer_init_method,
num_tokentypes=0,
add_pooler=False,
max_pos_embeds=None):
add_pooler=False):
super(TransformerLanguageModel, self).__init__()
args = get_args()
......@@ -278,11 +276,10 @@ class TransformerLanguageModel(MegatronModule):
self.init_method = init_method
self.add_pooler = add_pooler
max_pos_embeds = args.max_position_embeddings if max_pos_embeds is None else max_pos_embeds
# Embeddings
self.embedding = Embedding(self.hidden_size,
args.padded_vocab_size,
max_pos_embeds,
args.max_position_embeddings,
args.hidden_dropout,
self.init_method,
self.num_tokentypes)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment