Commit 56e81e99 authored by Neel Kant's avatar Neel Kant
Browse files

Complete refactor of RandProjectLSHIndex

parent 642802e0
...@@ -6,7 +6,7 @@ from megatron import mpu ...@@ -6,7 +6,7 @@ from megatron import mpu
from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoint_name from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoint_name
from megatron.data.bert_dataset import get_indexed_dataset_ from megatron.data.bert_dataset import get_indexed_dataset_
from megatron.data.realm_dataset import InverseClozeDataset from megatron.data.realm_dataset import InverseClozeDataset
from megatron.data.realm_index import BlockData, RandProjectionLSHIndex from megatron.data.realm_index import detach, BlockData, RandProjectionLSHIndex
from megatron.data.samplers import DistributedBatchSampler from megatron.data.samplers import DistributedBatchSampler
from megatron.initialize import initialize_megatron from megatron.initialize import initialize_megatron
from megatron.model import REALMRetriever from megatron.model import REALMRetriever
...@@ -14,10 +14,6 @@ from megatron.training import get_model ...@@ -14,10 +14,6 @@ from megatron.training import get_model
from pretrain_bert_ict import get_batch, model_provider from pretrain_bert_ict import get_batch, model_provider
def detach(tensor):
return tensor.detach().cpu().numpy()
def test_retriever(): def test_retriever():
initialize_megatron(extra_args_provider=None, initialize_megatron(extra_args_provider=None,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'}) args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
...@@ -71,26 +67,27 @@ def main(): ...@@ -71,26 +67,27 @@ def main():
i = 1 i = 1
total = 0 total = 0
while True: while True:
try: with torch.no_grad():
query_tokens, query_pad_mask, \ try:
block_tokens, block_pad_mask, block_index_data = get_batch(data_iter) query_tokens, query_pad_mask, \
except: block_tokens, block_pad_mask, block_index_data = get_batch(data_iter)
break except:
block_index_data = detach(block_index_data)
block_indices = block_index_data[:, 3]
block_meta = block_index_data[:, :3]
block_logits = model(None, None, block_tokens, block_pad_mask, only_block=True)
all_block_data.add_block_data(block_indices, block_logits, block_meta)
total += block_indices.size
i += 1
if i % 20 == 0:
print('Batch {:10d} | Total {:10d}'.format(i, total), flush=True)
if args.debug:
break break
block_index_data = detach(block_index_data)
block_indices = block_index_data[:, 3]
block_meta = block_index_data[:, :3]
block_logits = detach(model(None, None, block_tokens, block_pad_mask, only_block=True))
all_block_data.add_block_data(block_indices, block_logits, block_meta)
total += block_indices.size
i += 1
if i % 20 == 0:
print('Batch {:10d} | Total {:10d}'.format(i, total), flush=True)
if args.debug:
break
all_block_data.save_shard(args.rank) all_block_data.save_shard(args.rank)
torch.distributed.barrier() torch.distributed.barrier()
del model del model
......
...@@ -24,11 +24,9 @@ from torch.utils.data import Dataset ...@@ -24,11 +24,9 @@ from torch.utils.data import Dataset
from megatron import get_tokenizer from megatron import get_tokenizer
from megatron import mpu from megatron import mpu
from megatron.data.dataset_utils import build_training_sample
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
from megatron import print_rank_0 from megatron import print_rank_0
DATASET_TYPES = ['standard_bert', 'ict', 'realm']
class BertDataset(Dataset): class BertDataset(Dataset):
...@@ -64,6 +62,7 @@ class BertDataset(Dataset): ...@@ -64,6 +62,7 @@ class BertDataset(Dataset):
self.sep_id = tokenizer.sep self.sep_id = tokenizer.sep
self.mask_id = tokenizer.mask self.mask_id = tokenizer.mask
self.pad_id = tokenizer.pad self.pad_id = tokenizer.pad
from megatron.data.dataset_utils import build_training_sample
self.build_sample_fn = build_training_sample self.build_sample_fn = build_training_sample
def __len__(self): def __len__(self):
......
...@@ -23,9 +23,9 @@ import itertools ...@@ -23,9 +23,9 @@ import itertools
import numpy as np import numpy as np
from megatron import print_rank_0 from megatron import print_rank_0
from megatron.data.bert_dataset import DATASET_TYPES, get_indexed_dataset_, get_train_valid_test_split_, BertDataset from megatron.data.bert_dataset import get_indexed_dataset_, get_train_valid_test_split_, BertDataset
from megatron.data.realm_dataset import InverseClozeDataset
DATASET_TYPES = ['standard_bert', 'ict', 'realm']
def compile_helper(): def compile_helper():
"""Compile helper function ar runtime. Make sure this """Compile helper function ar runtime. Make sure this
...@@ -454,6 +454,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -454,6 +454,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
print_split_stats('test', 2) print_split_stats('test', 2)
def build_dataset(index, name): def build_dataset(index, name):
from megatron.data.realm_dataset import InverseClozeDataset
from megatron.data.realm_dataset import RealmDataset from megatron.data.realm_dataset import RealmDataset
dataset = None dataset = None
if splits[index + 1] > splits[index]: if splits[index + 1] > splits[index]:
...@@ -502,4 +503,4 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -502,4 +503,4 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
valid_dataset = build_dataset(1, 'valid') valid_dataset = build_dataset(1, 'valid')
test_dataset = build_dataset(2, 'test') test_dataset = build_dataset(2, 'test')
return (train_dataset, valid_dataset, test_dataset) return (train_dataset, valid_dataset, test_dataset)
\ No newline at end of file
...@@ -12,7 +12,7 @@ from megatron import get_tokenizer, print_rank_0, mpu ...@@ -12,7 +12,7 @@ from megatron import get_tokenizer, print_rank_0, mpu
from megatron.data.bert_dataset import BertDataset from megatron.data.bert_dataset import BertDataset
from megatron.data.dataset_utils import create_masked_lm_predictions, pad_and_convert_to_numpy from megatron.data.dataset_utils import create_masked_lm_predictions, pad_and_convert_to_numpy
qa_nlp = spacy.load('en_core_web_lg') #qa_nlp = spacy.load('en_core_web_lg')
class RealmDataset(BertDataset): class RealmDataset(BertDataset):
......
...@@ -3,13 +3,16 @@ import os ...@@ -3,13 +3,16 @@ import os
import pickle import pickle
import shutil import shutil
from hashed_index import detach
import numpy as np import numpy as np
import torch import torch
from megatron import get_args from megatron import get_args
def detach(tensor):
return tensor.detach().cpu().numpy()
class BlockData(object): class BlockData(object):
def __init__(self): def __init__(self):
self.embed_data = dict() self.embed_data = dict()
...@@ -43,7 +46,7 @@ class BlockData(object): ...@@ -43,7 +46,7 @@ class BlockData(object):
if not allow_overwrite and idx in self.embed_data: if not allow_overwrite and idx in self.embed_data:
raise ValueError("Unexpectedly tried to overwrite block data") raise ValueError("Unexpectedly tried to overwrite block data")
self.embed_data[idx] = embed self.embed_data[idx] = np.float16(embed)
self.meta_data[idx] = meta self.meta_data[idx] = meta
def save_shard(self, rank): def save_shard(self, rank):
...@@ -213,7 +216,7 @@ class RandProjectionLSHIndex(object): ...@@ -213,7 +216,7 @@ class RandProjectionLSHIndex(object):
def hash_embeds(self, embeds, write_block_data=None): def hash_embeds(self, embeds, write_block_data=None):
"""Hash a tensor of embeddings using a random projection matrix""" """Hash a tensor of embeddings using a random projection matrix"""
embed_scores_pos = torch.matmul(embeds, torch.cuda.FloatTensor(self.hash_matrix)) embed_scores_pos = torch.matmul(embeds, torch.cuda.HalfTensor(self.hash_matrix))
embed_scores = torch.cat((embed_scores_pos, -embed_scores_pos), axis=1) embed_scores = torch.cat((embed_scores_pos, -embed_scores_pos), axis=1)
embed_hashes = detach(torch.argmax(embed_scores, axis=1)) embed_hashes = detach(torch.argmax(embed_scores, axis=1))
...@@ -226,7 +229,7 @@ class RandProjectionLSHIndex(object): ...@@ -226,7 +229,7 @@ class RandProjectionLSHIndex(object):
def hash_whitened_block_embeds(self, block_data): def hash_whitened_block_embeds(self, block_data):
"""Transform all block embeds to have zero mean and unit covariance """Transform all block embeds to have zero mean and unit covariance
when treated as samples from a distribution""" when treated as samples from a distribution"""
block_idx, all_embeds = zip(block_data.embed_data.items()) block_idx, all_embeds = zip(*block_data.embed_data.items())
arr_embeds = np.transpose(np.array(all_embeds)) arr_embeds = np.transpose(np.array(all_embeds))
mean = np.mean(arr_embeds, axis=1).reshape(-1, 1) mean = np.mean(arr_embeds, axis=1).reshape(-1, 1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment