Commit 9a617f6c authored by Neel Kant's avatar Neel Kant
Browse files

Add REALMRetriever and some misc

parent 5235ed87
......@@ -218,9 +218,13 @@ class BertModel(MegatronModule):
class REALMBertModel(MegatronModule):
def __init__(self, ict_model, block_hash_data_path):
# consider adding dataset as an argument to constructor
# self.dataset = dataset
# or add a callback
super(REALMBertModel, self).__init__()
bert_args = dict(
num_tokentypes=2,
num_tokentypes=1,
add_binary_head=False,
parallel_output=True
)
......@@ -265,8 +269,49 @@ class REALMBertModel(MegatronModule):
retrieval_scores = query_logits.matmul(torch.transpose(batch_block_embeds, 1, 2))
# [batch_size x max bucket_pop]
retrieval_scores = retrieval_scores.squeeze()
# top 5 block indices for each query
top5_vals, top5_indices = torch.topk(retrieval_scores, k=5)
# TODO
# go to dataset, get the blocks
# re-embed the blocks
class REALMRetriever(MegatronModule):
"""Retriever which uses a pretrained ICTBertModel and a hashed_index"""
def __init__(self, ict_model, ict_dataset, hashed_index, top_k=5):
super(REALMRetriever, self).__init__()
self.ict_model = ict_model
self.ict_dataset = ict_dataset
self.hashed_index = hashed_index
def retrieve_evidence_blocks_text(self, query_text):
"""Get the top k evidence blocks for query_text in text form"""
print("-" * 100)
print("Query: ", query_text)
padless_max_len = self.ict_dataset.max_seq_length - 2
query_tokens = self.ict_dataset.encode_text(query_text)[:padless_max_len]
query_tokens, query_pad_mask = self.ict_dataset.concat_and_pad_tokens(query_tokens)
query_tokens = torch.cuda.IntTensor(np.array(query_tokens).reshape(1, -1))
query_pad_mask = torch.cuda.IntTensor(np.array(query_pad_mask).reshape(1, -1))
query_embed = self.ict_model.embed_query(query_tokens, query_pad_mask)
query_hash = self.hashed_index.hash_embeds(query_embed)
assert query_hash.size == 1
block_bucket = self.hashed_index.get_block_bucket(query_hash[0])
block_embeds = [self.hashed_index.get_block_embed[idx] for idx in block_bucket[:, 3]]
block_embed_tensor = torch.cuda.HalfTensor(np.array(block_embeds))
retrieval_scores = query_embed.matmul(torch.transpose(block_embed_tensor, 0, 1))
top5_vals, top5_indices = torch.topk(retrieval_scores, k=5, sorted=True)
top5_start_end_doc = [block_bucket[idx][:3] for idx in top5_indices]
top5_blocks = [(self.ict_dataset.get_block(*indices)) for indices in top5_start_end_doc]
for i, (block, _) in enumerate(top5_blocks):
block_text = self.ict_dataset.decode_tokens(block)
print(' > Block {}: {}'.format(i, block_text))
class ICTBertModel(MegatronModule):
......
......@@ -178,14 +178,6 @@ class FullTokenizer(object):
return convert_by_vocab(self.vocab, tokens)
def convert_ids_to_tokens(self, ids):
train_sample = {
'text': tokens_np,
'types': tokentypes_np,
'labels': labels_np,
'is_random': int(is_next_random),
'loss_mask': loss_mask_np,
'padding_mask': padding_mask_np,
'truncated': int(truncated)}
return convert_by_vocab(self.inv_vocab, ids)
def vocab_size(self):
......
......@@ -32,7 +32,7 @@ def build_tokenizer(args):
assert args.vocab_file is not None
if args.tokenizer_type == 'BertWordPieceLowerCase':
tokenizer = _BertWordPieceTokenizer(vocab_file=args.vocab_file,
lower_case=True)
lower_case=True)
elif args.tokenizer_type == 'GPT2BPETokenizer':
assert args.merge_file is not None
tokenizer = _GPT2BPETokenizer(args.vocab_file, args.merge_file)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment