Commit f42b4d24 authored by Neel Kant's avatar Neel Kant
Browse files

Revise REALMBertModel and REALMRetriever

parent 24034e03
......@@ -19,6 +19,7 @@ import pickle
import numpy as np
import torch
import torch.nn.functional as F
from megatron import get_args
from megatron.module import MegatronModule
......@@ -217,11 +218,7 @@ class BertModel(MegatronModule):
class REALMBertModel(MegatronModule):
def __init__(self, ict_model, block_hash_data_path):
# consider adding dataset as an argument to constructor
# self.dataset = dataset
# or add a callback
def __init__(self, retriever):
super(REALMBertModel, self).__init__()
bert_args = dict(
num_tokentypes=1,
......@@ -231,50 +228,38 @@ class REALMBertModel(MegatronModule):
self.lm_model = BertModel(**bert_args)
self._lm_key = 'realm_lm'
self.ict_model = ict_model
with open(block_hash_data_path, 'rb') as data_file:
data = pickle.load(data_file)
# {block_idx: block_embed} - the main index
self.block_data = data['block_data']
# {hash_num: [start, end, doc, block]} - the hash table
self.hash_data = data['hash_data']
# [embed_size x num_buckets / 2] - the projection matrix used for hashing
self.hash_matrix = self.hash_data['matrix']
def forward(self, tokens, attention_mask, token_types):
# [batch_size x embed_size]
query_logits = self.ict_model.embed_query(tokens, attention_mask, token_types)
# [batch_size x num_buckets / 2]
query_hash_pos = torch.matmul(query_logits, self.hash_matrix)
query_hash_full = torch.cat((query_hash_pos, -query_hash_pos), axis=1)
# [batch_size]
query_hashes = torch.argmax(query_hash_full, axis=1)
batch_block_embeds = []
for hash in query_hashes:
# TODO: this should be made into a single np.array in preprocessing
bucket_blocks = self.hash_data[hash]
block_indices = bucket_blocks[:, 3]
# [bucket_pop x embed_size]
block_embeds = [self.block_data[idx] for idx in block_indices]
# will become [batch_size x bucket_pop x embed_size]
# will require padding to do tensor multiplication
batch_block_embeds.append(block_embeds)
# [batch_size x max bucket_pop x embed_size]
batch_block_embeds = np.array(batch_block_embeds)
# [batch_size x 1 x max bucket_pop]
retrieval_scores = query_logits.matmul(torch.transpose(batch_block_embeds, 1, 2))
# [batch_size x max bucket_pop]
retrieval_scores = retrieval_scores.squeeze()
# top 5 block indices for each query
top5_vals, top5_indices = torch.topk(retrieval_scores, k=5)
# TODO
# go to dataset, get the blocks
# re-embed the blocks
self.retriever = retriever
self._retriever_key = 'retriever'
def forward(self, tokens, attention_mask):
# [batch_size x 5 x seq_length]
top5_block_tokens, top5_block_attention_mask = self.retriever.retrieve_evidence_blocks(tokens, attention_mask)
# [batch_size x 5]
fresh_block_logits = self.retriever.ict_model.embed_block(top5_block_tokens, top5_block_attention_mask)
block_probs = F.softmax(fresh_block_logits, axis=1)
# [batch_size x 5 x seq_length]
tokens = torch.stack([tokens.unsqueeze(1)] * 5, dim=1)
attention_mask = torch.stack([attention_mask.unsqueeze(1)] * 5, dim=1)
# [batch_size x 5 x 2 * seq_length]
all_tokens = torch.cat((tokens, top5_block_tokens), axis=2)
all_attention_mask = torch.cat((attention_mask, top5_block_attention_mask), axis=2)
all_token_types = torch.zeros(all_tokens.shape).type(torch.int64).cuda()
# [batch_size x 5 x 2 * seq_length x vocab_size]
lm_logits, _ = self.lm_model.forward(all_tokens, all_attention_mask, all_token_types)
return lm_logits, block_probs
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
"""For easy load when model is combined with other heads,
add an extra key."""
state_dict_ = {}
state_dict_[self._lm_key] = self.lm_model.state_dict_for_save_checkpoint(destination, prefix, keep_vars)
return state_dict_
class REALMRetriever(MegatronModule):
......@@ -296,22 +281,33 @@ class REALMRetriever(MegatronModule):
query_tokens = torch.cuda.LongTensor(np.array(query_tokens).reshape(1, -1))
query_pad_mask = torch.cuda.LongTensor(np.array(query_pad_mask).reshape(1, -1))
query_embed = self.ict_model.module.module.embed_query(query_tokens, query_pad_mask)
query_hash = self.hashed_index.hash_embeds(query_embed)
assert query_hash.size == 1
top5_block_tokens, _ = self.retrieve_evidence_blocks(query_tokens, query_pad_mask)
for i, block in enumerate(top5_block_tokens):
block_text = self.ict_dataset.decode_tokens(block)
print(' > Block {}: {}'.format(i, block_text))
block_bucket = self.hashed_index.get_block_bucket(query_hash[0])
block_embeds = [self.hashed_index.get_block_embed(arr[3]) for arr in block_bucket]
block_embed_tensor = torch.cuda.HalfTensor(np.array(block_embeds))
def retrieve_evidence_blocks(self, query_tokens, query_pad_mask):
query_embeds = self.ict_model.module.module.embed_query(query_tokens, query_pad_mask)
query_hashes = self.hashed_index.hash_embeds(query_embeds)
retrieval_scores = query_embed.matmul(torch.transpose(block_embed_tensor, 0, 1))
top5_vals, top5_indices = torch.topk(retrieval_scores, k=5, sorted=True)
top5_start_end_doc = [block_bucket[idx][:3] for idx in top5_indices.squeeze()]
block_buckets = [self.hashed_index.get_block_bucket(hash) for hash in query_hashes]
block_embeds = [torch.cuda.HalfTensor(np.array([self.hashed_index.get_block_embed(arr[3])
for arr in bucket])) for bucket in block_buckets]
top5_blocks = [(self.ict_dataset.get_block(*indices)) for indices in top5_start_end_doc]
for i, (block, _) in enumerate(top5_blocks):
block_text = self.ict_dataset.decode_tokens(block)
print(' > Block {}: {}'.format(i, block_text))
all_top5_tokens, all_top5_pad_masks = [], []
for query_embed, embed_tensor, bucket in zip(query_embeds, block_embeds, block_buckets):
retrieval_scores = query_embed.matmul(torch.transpose(embed_tensor, 0, 1))
top5_vals, top5_indices = torch.topk(retrieval_scores, k=5, sorted=True)
top5_start_end_doc = [bucket[idx][:3] for idx in top5_indices.squeeze()]
# top_k tuples of (block_tokens, block_pad_mask)
top5_block_data = [(self.ict_dataset.get_block(*indices)) for indices in top5_start_end_doc]
top5_tokens, top5_pad_masks = zip(top5_block_data)
all_top5_tokens.append(np.array(top5_tokens))
all_top5_pad_masks.append(np.array(top5_pad_masks))
return all_top5_tokens, all_top5_pad_masks
class ICTBertModel(MegatronModule):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment