Commit 256eb6ed authored by Neel Kant's avatar Neel Kant
Browse files

Enhance hashed_index and more improvements elsewhere

parent 9f9b2cf8
......@@ -17,6 +17,10 @@ from megatron.training import get_model
from pretrain_bert_ict import get_batch, model_provider
def detach(tensor):
return tensor.detach().cpu().numpy()
def embed_docs():
initialize_megatron(extra_args_provider=None,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
......@@ -45,12 +49,13 @@ def embed_docs():
block_hash_pos = torch.matmul(block_logits, hash_matrix)
block_hash_full = torch.cat((block_hash_pos, -block_hash_pos), axis=1)
block_hashes = torch.argmax(block_hash_full, axis=1).detach().cpu().numpy()
block_hashes = detach(torch.argmax(block_hash_full, axis=1))
for hash, indices_array in zip(block_hashes, block_indices):
hash_data[int(hash)].append(indices_array.detach().cpu().numpy())
hash_data[int(hash)].append(detach(indices_array))
block_logits = block_logits.detach().cpu().numpy()
block_indices = block_indices.detach().cpu().numpy()[:, 3]
block_logits = detach(block_logits)
# originally this has [start_idx, end_idx, doc_idx, block_idx]
block_indices = detach(block_indices)[:, 3]
for logits, idx in zip(block_logits, block_indices):
block_data[int(idx)] = logits
......@@ -68,6 +73,10 @@ def embed_docs():
torch.distributed.barrier()
all_data.clear()
del all_data
del model
if mpu.get_data_parallel_rank() == 0:
all_block_data = defaultdict(dict)
dir_name = 'block_hash_data'
......@@ -80,9 +89,7 @@ def embed_docs():
with open('block_hash_data.pkl', 'wb') as final_file:
pickle.dump(all_block_data, final_file)
os.rmdir(dir_name)
return
def load_checkpoint():
......
......@@ -15,6 +15,8 @@
"""BERT model."""
import pickle
import numpy as np
import torch
......@@ -215,7 +217,7 @@ class BertModel(MegatronModule):
class REALMBertModel(MegatronModule):
def __init__(self, ict_model_path, block_hash_data_path):
def __init__(self, ict_model, block_hash_data_path):
super(REALMBertModel, self).__init__()
bert_args = dict(
num_tokentypes=2,
......@@ -226,17 +228,21 @@ class REALMBertModel(MegatronModule):
self._lm_key = 'realm_lm'
self.ict_model = ict_model
self.ict_dataset = ict_dataset
self.block_hash_data = block_hash_data
with open(block_hash_data_path, 'rb') as data_file:
data = pickle.load(data_file)
# {block_idx: block_embed} - the main index
self.block_data = data['block_data']
# {hash_num: [start, end, doc, block]} - the hash table
self.hash_data = data['hash_data']
# [embed_size x num_buckets / 2] - the projection matrix used for hashing
self.hash_matrix = self.hash_data['matrix']
def forward(self, tokens, attention_mask, token_types):
# [batch_size x embed_size]
query_logits = self.ict_model.embed_query(tokens, attention_mask, token_types)
hash_matrix_pos = self.hash_data['matrix']
# [batch_size, num_buckets / 2]
query_hash_pos = torch.matmul(query_logits, hash_matrix_pos)
# [batch_size x num_buckets / 2]
query_hash_pos = torch.matmul(query_logits, self.hash_matrix)
query_hash_full = torch.cat((query_hash_pos, -query_hash_pos), axis=1)
# [batch_size]
......@@ -247,15 +253,19 @@ class REALMBertModel(MegatronModule):
# TODO: this should be made into a single np.array in preprocessing
bucket_blocks = self.hash_data[hash]
block_indices = bucket_blocks[:, 3]
# [bucket_pop, embed_size]
# [bucket_pop x embed_size]
block_embeds = [self.block_data[idx] for idx in block_indices]
# will become [batch_size, bucket_pop, embed_size]
# will become [batch_size x bucket_pop x embed_size]
# will require padding to do tensor multiplication
batch_block_embeds.append(block_embeds)
# [batch_size x max bucket_pop x embed_size]
batch_block_embeds = np.array(batch_block_embeds)
retrieval_scores = query_logits.matmul(torch.transpose(batch_block_embeds, 0, 1))
# [batch_size x 1 x max bucket_pop]
retrieval_scores = query_logits.matmul(torch.transpose(batch_block_embeds, 1, 2))
# [batch_size x max bucket_pop]
retrieval_scores = retrieval_scores.squeeze()
top5_vals, top5_indices = torch.topk(retrieval_scores, k=5)
......
......@@ -17,24 +17,49 @@
import torch
import torch.nn.functional as F
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron import get_args
from megatron import get_timers
from megatron import mpu
from megatron import print_rank_0
from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoint_name
from megatron.data.bert_dataset import build_train_valid_test_datasets
from megatron.model import ICTBertModel, REALMBertModel
from megatron.training import pretrain
from megatron.training import get_model, pretrain
from megatron.utils import reduce_losses
from pretrain_bert_ict import model_provider as ict_model_provider
num_batches = 0
def model_provider():
"""Build the model."""
args = get_args()
print_rank_0('building BERT models ...')
realm_model = REALMBertModel(args.ict_model_path,
ict_model = get_model(ict_model_provider)
if isinstance(ict_model, torchDDP):
model = ict_model.module
tracker_filename = get_checkpoint_tracker_filename(args.load)
with open(tracker_filename, 'r') as f:
iteration = int(f.read().strip())
assert iteration > 0
checkpoint_name = get_checkpoint_name(args.load, iteration, False)
if mpu.get_data_parallel_rank() == 0:
print('global rank {} is loading checkpoint {}'.format(
torch.distributed.get_rank(), checkpoint_name))
state_dict = torch.load(checkpoint_name, map_location='cpu')
ict_model.load_state_dict(state_dict['model'])
torch.distributed.barrier()
if mpu.get_data_parallel_rank() == 0:
print(' successfully loaded {}'.format(checkpoint_name))
realm_model = REALMBertModel(ict_model,
args.block_hash_data_path)
return ict_model
......@@ -107,8 +132,7 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
masked_lm_prob=args.mask_prob,
short_seq_prob=args.short_seq_prob,
seed=args.seed,
skip_warmup=(not args.mmap_warmup),
ict_dataset=True)
skip_warmup=(not args.mmap_warmup))
print_rank_0("> finished creating BERT ICT datasets ...")
return train_ds, valid_ds, test_ds
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment