Commit 183ad176 authored by Neel Kant's avatar Neel Kant
Browse files

Refactor to build BlockData, FaissMIPSIndex, RandProjectLSHIndex

parent 0104f910
from collections import defaultdict
import os
import pickle
import shutil
import numpy as np
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
......@@ -11,7 +5,8 @@ from megatron import get_args
from megatron import mpu
from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoint_name
from megatron.data.bert_dataset import get_indexed_dataset_
from megatron.data.ict_dataset import InverseClozeDataset
from megatron.data.realm_dataset import InverseClozeDataset
from megatron.data.realm_index import BlockData, RandProjectionLSHIndex
from megatron.data.samplers import DistributedBatchSampler
from megatron.initialize import initialize_megatron
from megatron.model import REALMRetriever
......@@ -23,253 +18,6 @@ def detach(tensor):
return tensor.detach().cpu().numpy()
class HashedIndex(object):
"""Class for holding hashed data"""
def __init__(self, embed_size, num_buckets, whiten=False, seed=0):
np.random.seed(seed)
self.block_data = defaultdict(list)
self.hash_data = defaultdict(list)
hash_matrix = 2 * np.random.rand(embed_size, int(num_buckets / 2)) - 1
self.hash_matrix = hash_matrix / np.linalg.norm(hash_matrix, axis=0).reshape(1, -1)
self.embed_mean = None
self.embed_whitener = None
self.whiten = whiten
# alsh
self.m = 5
self.u = 0.99
self.max_norm = None
self.block_index = None
def state(self):
state = {
'block_data': self.block_data,
'hash_data': self.hash_data,
'hash_matrix': self.hash_matrix,
'embed_mean': self.embed_mean,
'embed_whitener': self.embed_whitener,
}
return state
def get_block_bucket(self, hash):
return self.hash_data[hash]
def get_block_embed(self, block_idx):
return self.block_data[block_idx]
def hash_embeds(self, embeds, block_data=None):
"""Hash a tensor of embeddings using a random projection matrix"""
embed_scores_pos = torch.matmul(embeds, torch.cuda.FloatTensor(self.hash_matrix))
embed_scores = torch.cat((embed_scores_pos, -embed_scores_pos), axis=1)
embed_hashes = detach(torch.argmax(embed_scores, axis=1))
if block_data is not None:
for hash, indices in zip(embed_hashes, block_data):
self.hash_data[hash].append(indices)
return embed_hashes
def assign_block_embeds(self, block_indices, block_embeds, allow_overwrite=False):
"""Assign the embeddings for each block index into a hash map"""
for idx, embed in zip(block_indices, block_embeds):
if not allow_overwrite and int(idx) in self.block_data:
raise ValueError("Attempted to overwrite a read-only HashedIndex")
self.block_data[int(idx)] = np.float16(embed)
def save_shard(self, rank):
dir_name = 'block_hash_data'
if not os.path.isdir(dir_name):
os.mkdir(dir_name)
# save the data for each shard
with open('{}/{}.pkl'.format(dir_name, rank), 'wb') as data_file:
pickle.dump(self.state(), data_file)
def consolidate_shards_and_save(self, ignore_shard=0):
"""Combine all the shards made using self.save_shard()"""
dir_name = 'block_hash_data'
fnames = os.listdir(dir_name)
for fname in fnames:
with open('{}/{}'.format(dir_name, fname), 'rb') as f:
data = pickle.load(f)
assert np.array_equal(data['hash_matrix'], self.hash_matrix)
old_size = len(self.block_data)
shard_size = len(data['block_data'])
self.block_data.update(data['block_data'])
assert (len(self.block_data) == old_size + shard_size) or (str(ignore_shard) in fname)
if not self.whiten:
for bucket, items in data['hash_data'].items():
self.hash_data[bucket].extend(items)
if self.whiten:
self.whiten_block_embeds()
args = get_args()
with open(args.hash_data_path, 'wb') as final_file:
pickle.dump(self.state(), final_file)
shutil.rmtree(dir_name, ignore_errors=True)
def clear(self):
"""Clear the data structures to save memory"""
self.block_data = dict()
self.hash_data = defaultdict(list)
def whiten_block_embeds(self):
"""Transform all block embeds to have zero mean and unit covariance
when treated as samples from a distribution"""
block_idx, all_embeds = zip(*self.block_data.items())
arr_embeds = np.transpose(np.array(all_embeds))
mean = np.mean(arr_embeds, axis=1).reshape(-1, 1)
centered = arr_embeds - mean
inv_cov = np.linalg.inv(np.cov(arr_embeds))
whitener = np.transpose(np.linalg.cholesky(inv_cov))
whitened = np.float16(np.transpose(whitener.dot(centered)))
self.embed_mean = mean.reshape(-1)
self.embed_whitener = whitener
self.block_data = dict(zip(block_idx, list(whitened)))
self.hash_data = defaultdict(list)
batch_size = 16384
i = 0
args = get_args()
with torch.no_grad():
hashing_tensor = torch.cuda.HalfTensor(self.hash_matrix)
while True:
if args.debug:
print(i, flush=True)
batch_slice = slice(i * batch_size, (i + 1) * batch_size)
batch_embed = torch.cuda.HalfTensor(whitened[batch_slice])
batch_block_idx = block_idx[batch_slice]
if len(batch_block_idx) == 0:
break
hash_scores_pos = torch.matmul(batch_embed, hashing_tensor)
embed_scores = torch.cat((hash_scores_pos, -hash_scores_pos), axis=1)
embed_hashes = detach(torch.argmax(embed_scores, axis=1))
for idx, hash in zip(batch_block_idx, list(embed_hashes)):
# [int] instead of [array<int>] since this is just for analysis rn
self.hash_data[hash].append(idx)
i += 1
def create_block_data_index(self):
import faiss
self.block_idx, block_embeds = zip(*self.block_data.items())
block_embeds = np.array(block_embeds)
alsh_preprocessed_blocks = self.alsh_block_preprocess_fn()
index = faiss.IndexFlatL2(alsh_preprocessed_blocks.shape[1])
index.add(alsh_preprocessed_blocks)
print('Total blocks in index: ', index.ntotal)
self.block_index = index
def get_norm_powers_and_halves_array(self, embeds):
norm = np.linalg.norm(embeds, axis=1)
norm_powers = [np.multiply(norm, norm)] # squared L2 norms of all
for i in range(self.m - 1):
norm_powers.append(np.multiply(norm_powers[-1], norm_powers[-1]))
# [num_blocks x self.m]
norm_powers = np.transpose(np.array(norm_powers))
halves_array = 0.5 * np.ones(norm_powers.shape)
return norm_powers, halves_array
def alsh_block_preprocess_fn(self):
block_idx, block_embeds = zip(*self.block_data.items())
block_embeds = np.array(block_embeds)
if self.max_norm is None:
self.max_norm = max(np.linalg.norm(block_embeds, axis=1))
if self.max_norm > 1:
block_embeds = self.u / self.max_norm * block_embeds
norm_powers, halves_array = self.get_norm_powers_and_halves_array(block_embeds)
# P'(S(x)) for all x in block_embeds
return np.float32(np.concatenate((block_embeds, norm_powers, halves_array), axis=1))
def alsh_query_preprocess_fn(self, query_embeds):
max_norm = max(np.linalg.norm(query_embeds, axis=1))
if max_norm > 1:
query_embeds = self.u / max_norm * query_embeds
norm_powers, halves_array = self.get_norm_powers_and_halves_array(query_embeds)
# Q'(S(x)) for all x in query_embeds
return np.float32(np.concatenate((query_embeds, halves_array, norm_powers), axis=1))
def exact_mips_equals(self, query_embeds, norm_blocks):
"""For each query, determine whether the mips block is in the correct hash bucket"""
shuffled_block_idx, block_embeds = zip(*self.block_data.items())
if norm_blocks:
block_embeds = block_embeds / np.linalg.norm(block_embeds, axis=1).reshape(-1, 1)
with torch.no_grad():
# get hashes for the queries
hash_scores_pos = torch.matmul(torch.cuda.HalfTensor(query_embeds), torch.cuda.HalfTensor(self.hash_matrix))
hash_scores = torch.cat((hash_scores_pos, -hash_scores_pos), axis=1)
query_hashes = detach(torch.argmax(hash_scores, axis=1))
# [num_query x num_blocks]
inner_products = torch.matmul(torch.cuda.HalfTensor(query_embeds),
torch.cuda.HalfTensor(np.transpose(np.array(block_embeds))))
max_inner_product_idxes = detach(torch.argmax(inner_products, axis=1))
best_blocks = [self.block_data[shuffled_block_idx[idx]] for idx in max_inner_product_idxes]
best_blocks_tensor = torch.cuda.HalfTensor(np.array(best_blocks))
# bb = best_blocks
bb_hash_scores_pos = torch.matmul(best_blocks_tensor, torch.cuda.HalfTensor(self.hash_matrix))
bb_hash_scores = torch.cat((bb_hash_scores_pos, -bb_hash_scores_pos), axis=1)
best_block_hashes = detach(torch.argmax(bb_hash_scores, axis=1))
print('Query hashes: ', query_hashes)
print('Block hashes: ', best_block_hashes)
equal_arr = np.equal(query_hashes, best_block_hashes).astype(int)
# array of zeros and ones which can be used for counting success
return equal_arr
def exact_mips_test(self, num_queries, whitened, norm_blocks, alsh):
if whitened:
if self.embed_mean is None:
self.whiten_block_embeds()
query_embeds = np.random.multivariate_normal(np.zeros(128), np.eye(128), num_queries)
query_embeds = query_embeds / np.linalg.norm(query_embeds, axis=1).reshape(-1, 1)
if alsh:
if self.block_index is None:
self.create_block_data_index()
alsh_queries = self.alsh_query_preprocess_fn(query_embeds)
neighbor_ids, distances = self.block_index.search(alsh_queries, 5)
print('DONE')
return
else:
block_idx, all_embeds = zip(*self.block_data.items())
arr_embeds = np.transpose(np.array(all_embeds))
mean = np.mean(arr_embeds, axis=1).reshape(-1, 1)
cov = np.cov(arr_embeds)
query_embeds = np.random.multivariate_normal(mean, cov, num_queries)
equal_arr = self.exact_mips_equals(query_embeds, norm_blocks)
print("Num correct: ", sum(equal_arr), " Fraction correct: ", sum(equal_arr) / equal_arr.size)
print(equal_arr)
@classmethod
def load_from_file(cls, fname):
print(" > Unpickling block hash data")
state_dict = pickle.load(open(fname, 'rb'))
print(" > Finished unpickling")
hash_matrix = state_dict['hash_matrix']
new_index = HashedIndex(hash_matrix.shape[0], hash_matrix.shape[1] * 2)
new_index.block_data = state_dict['block_data']
new_index.hash_data = state_dict['hash_data']
new_index.embed_mean = state_dict.get('embed_mean')
new_index.embed_whitener = state_dict.get('embed_whitener')
new_index.hash_matrix = hash_matrix
return new_index
def test_retriever():
initialize_megatron(extra_args_provider=None,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
......@@ -317,42 +65,42 @@ def main():
model.eval()
dataset = get_ict_dataset()
data_iter = iter(get_one_epoch_dataloader(dataset))
hashed_index = HashedIndex(embed_size=128, num_buckets=32, whiten=True)
all_block_data = BlockData()
hashed_index = RandProjectionLSHIndex(embed_size=128, num_buckets=32, whiten=True)
i = 1
total = 0
whiten = False
while True:
try:
query_tokens, query_pad_mask, \
block_tokens, block_pad_mask, block_indices = get_batch(data_iter)
block_tokens, block_pad_mask, block_index_data = get_batch(data_iter)
except:
break
block_indices = detach(block_indices)
block_logits = model(None, None, block_tokens, block_pad_mask, only_block=True)
block_index_data = detach(block_index_data)
block_indices = block_index_data[:, 3]
block_meta = block_index_data[:, :3]
# If whitened, then hashing needs to be done after whitening the block embeds
# which is done in consolidate_shards_and_save()
if not whiten:
hashed_index.hash_embeds(block_logits, block_indices)
hashed_index.assign_block_embeds(block_indices[:, 3], detach(block_logits))
block_logits = model(None, None, block_tokens, block_pad_mask, only_block=True)
all_block_data.add_block_data(block_indices, block_logits, block_meta)
total += block_indices.shape[0]
total += block_indices.size
i += 1
if i % 20 == 0:
print('Batch {:10d} | Total {:10d}'.format(i, total), flush=True)
if args.debug:
break
hashed_index.save_shard(args.rank)
all_block_data.save_shard(args.rank)
torch.distributed.barrier()
del model
if args.rank == 0:
hashed_index.consolidate_shards_and_save()
all_block_data.consolidate_shards_and_save()
hashed_index.hash_whitened_block_embeds(all_block_data)
hashed_index.save_to_file()
else:
hashed_index.clear()
all_block_data.clear()
def load_ict_checkpoint(only_query_model=False, only_block_model=False, no_grad=False):
......
import itertools
import os
import random
import time
import numpy as np
import spacy
import torch
from torch.utils.data import Dataset
from megatron import get_tokenizer
from megatron.data.bert_dataset import BertDataset, get_samples_mapping_
from megatron import get_tokenizer, print_rank_0, mpu
from megatron.data.bert_dataset import BertDataset
from megatron.data.dataset_utils import create_masked_lm_predictions, pad_and_convert_to_numpy
#qa_nlp = spacy.load('en_core_web_lg')
qa_nlp = None
qa_nlp = spacy.load('en_core_web_lg')
class RealmDataset(BertDataset):
"""Dataset containing simple masked sentences for masked language modeling.
......@@ -74,3 +79,170 @@ def spacy_ner(block_text):
answers.append(str(ent.text))
candidates['starts'] = starts
candidates['answers'] = answers
class InverseClozeDataset(Dataset):
"""Dataset containing sentences and their blocks for an inverse cloze task."""
def __init__(self, name, block_dataset, title_dataset, data_prefix,
num_epochs, max_num_samples, max_seq_length,
short_seq_prob, seed):
self.name = name
self.seed = seed
self.max_seq_length = max_seq_length
self.block_dataset = block_dataset
self.title_dataset = title_dataset
self.short_seq_prob = short_seq_prob
self.rng = random.Random(self.seed)
self.samples_mapping = self.get_samples_mapping(
data_prefix, num_epochs, max_num_samples)
self.tokenizer = get_tokenizer()
self.vocab_id_list = list(self.tokenizer.inv_vocab.keys())
self.vocab_id_to_token_list = self.tokenizer.inv_vocab
self.cls_id = self.tokenizer.cls
self.sep_id = self.tokenizer.sep
self.mask_id = self.tokenizer.mask
self.pad_id = self.tokenizer.pad
def __len__(self):
return self.samples_mapping.shape[0]
def __getitem__(self, idx):
start_idx, end_idx, doc_idx, block_idx = self.samples_mapping[idx]
title = list(self.title_dataset[int(doc_idx)])
block = [list(self.block_dataset[i]) for i in range(start_idx, end_idx)]
assert len(block) > 1
# avoid selecting the first or last sentence to be the query.
if len(block) == 2:
rand_sent_idx = int(self.rng.random() > 0.5)
else:
rand_sent_idx = self.rng.randint(1, len(block) - 2)
# keep the query in the context 10% of the time.
if self.rng.random() < 1:
query = block[rand_sent_idx].copy()
else:
query = block.pop(rand_sent_idx)
# still need to truncate because blocks are concluded when
# the sentence lengths have exceeded max_seq_length.
query = query[:self.max_seq_length - 2]
block = list(itertools.chain(*block))[:self.max_seq_length - (3 + len(title))]
query_tokens, query_pad_mask = self.concat_and_pad_tokens(query)
block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title)
sample = {
'query_tokens': np.array(query_tokens),
'query_pad_mask': np.array(query_pad_mask),
'block_tokens': np.array(block_tokens),
'block_pad_mask': np.array(block_pad_mask),
'block_data': np.array([start_idx, end_idx, doc_idx, block_idx]).astype(np.int64)
}
return sample
def encode_text(self, text):
return self.tokenizer.tokenize(text)
def decode_tokens(self, token_ids):
tokens = self.tokenizer.tokenizer.convert_ids_to_tokens(token_ids)
return ' '.join(token for token in tokens if token != '[PAD]')
def get_block(self, start_idx, end_idx, doc_idx):
"""Get the IDs for an evidence block plus the title of the corresponding document"""
block = [list(self.block_dataset[i]) for i in range(start_idx, end_idx)]
title = list(self.title_dataset[int(doc_idx)])
block = list(itertools.chain(*block))[:self.max_seq_length - (3 + len(title))]
block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title)
return (block_tokens, block_pad_mask)
def concat_and_pad_tokens(self, tokens, title=None):
"""concat with special tokens and pad sequence to self.max_seq_length"""
tokens = [self.cls_id] + tokens + [self.sep_id]
if title is not None:
tokens += title + [self.sep_id]
assert len(tokens) <= self.max_seq_length, len(tokens)
num_pad = self.max_seq_length - len(tokens)
pad_mask = [1] * len(tokens) + [0] * num_pad
tokens += [self.pad_id] * num_pad
return tokens, pad_mask
def get_samples_mapping(self, data_prefix, num_epochs, max_num_samples):
if not num_epochs:
if not max_num_samples:
raise ValueError("Need to specify either max_num_samples "
"or num_epochs")
num_epochs = np.iinfo(np.int32).max - 1
if not max_num_samples:
max_num_samples = np.iinfo(np.int64).max - 1
# Filename of the index mapping
indexmap_filename = data_prefix
indexmap_filename += '_{}_indexmap'.format(self.name)
if num_epochs != (np.iinfo(np.int32).max - 1):
indexmap_filename += '_{}ep'.format(num_epochs)
if max_num_samples != (np.iinfo(np.int64).max - 1):
indexmap_filename += '_{}mns'.format(max_num_samples)
indexmap_filename += '_{}msl'.format(self.max_seq_length)
indexmap_filename += '_{}s'.format(self.seed)
indexmap_filename += '.npy'
# Build the indexed mapping if not exist.
if torch.distributed.get_rank() == 0 and \
not os.path.isfile(indexmap_filename):
print(' > WARNING: could not find index map file {}, building '
'the indices on rank 0 ...'.format(indexmap_filename))
# Make sure the types match the helpers input types.
assert self.block_dataset.doc_idx.dtype == np.int64
assert self.block_dataset.sizes.dtype == np.int32
# Build samples mapping
verbose = torch.distributed.get_rank() == 0
start_time = time.time()
print_rank_0(' > building samples index mapping for {} ...'.format(
self.name))
from megatron.data.dataset_utils import compile_helper
compile_helper()
from megatron.data import helpers
samples_mapping = helpers.build_blocks_mapping(
self.block_dataset.doc_idx,
self.block_dataset.sizes,
self.title_dataset.sizes,
num_epochs,
max_num_samples,
self.max_seq_length-3, # account for added tokens
self.seed,
verbose)
print_rank_0(' > done building samples index mapping')
np.save(indexmap_filename, samples_mapping, allow_pickle=True)
print_rank_0(' > saved the index mapping in {}'.format(
indexmap_filename))
# Make sure all the ranks have built the mapping
print_rank_0(' > elapsed time to build and save samples mapping '
'(seconds): {:4f}'.format(
time.time() - start_time))
# This should be a barrier but nccl barrier assumes
# device_index=rank which is not the case for model
# parallel case
counts = torch.cuda.LongTensor([1])
torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
assert counts[0].item() == torch.distributed.get_world_size(
group=mpu.get_data_parallel_group())
# Load indexed dataset.
print_rank_0(' > loading indexed mapping from {}'.format(
indexmap_filename))
start_time = time.time()
samples_mapping = np.load(indexmap_filename, allow_pickle=True)
print_rank_0(' loaded indexed file in {:3.3f} seconds'.format(
time.time() - start_time))
print_rank_0(' total number of samples: {}'.format(
samples_mapping.shape[0]))
return samples_mapping
......@@ -18,7 +18,8 @@
import torch
import torch.nn.functional as F
from hashed_index import HashedIndex, load_ict_checkpoint, get_ict_dataset
from hashed_index import load_ict_checkpoint, get_ict_dataset
from megatron.data.realm_dataset import HashedIndex
from megatron import get_args
from megatron import get_timers
from megatron import mpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment