Commit 9b9b8e01 authored by Neel Kant's avatar Neel Kant
Browse files

Minor adjustments to fit QA codebase

parent 6e256445
...@@ -167,10 +167,7 @@ class REALMRetriever(MegatronModule): ...@@ -167,10 +167,7 @@ class REALMRetriever(MegatronModule):
self.hashed_index.reset_index() self.hashed_index.reset_index()
self.hashed_index.add_block_embed_data(self.block_data) self.hashed_index.add_block_embed_data(self.block_data)
def retrieve_evidence_blocks_text(self, query_text): def prep_query_text_for_retrieval(self, query_text):
"""Get the top k evidence blocks for query_text in text form"""
print("-" * 100)
print("Query: ", query_text)
padless_max_len = self.ict_dataset.max_seq_length - 2 padless_max_len = self.ict_dataset.max_seq_length - 2
query_tokens = self.ict_dataset.encode_text(query_text)[:padless_max_len] query_tokens = self.ict_dataset.encode_text(query_text)[:padless_max_len]
...@@ -178,6 +175,13 @@ class REALMRetriever(MegatronModule): ...@@ -178,6 +175,13 @@ class REALMRetriever(MegatronModule):
query_tokens = torch.cuda.LongTensor(np.array(query_tokens).reshape(1, -1)) query_tokens = torch.cuda.LongTensor(np.array(query_tokens).reshape(1, -1))
query_pad_mask = torch.cuda.LongTensor(np.array(query_pad_mask).reshape(1, -1)) query_pad_mask = torch.cuda.LongTensor(np.array(query_pad_mask).reshape(1, -1))
return query_tokens, query_pad_mask
def retrieve_evidence_blocks_text(self, query_text):
"""Get the top k evidence blocks for query_text in text form"""
print("-" * 100)
print("Query: ", query_text)
query_tokens, query_pad_mask = self.prep_query_text_for_retrieval(query_text)
topk_block_tokens, _ = self.retrieve_evidence_blocks(query_tokens, query_pad_mask) topk_block_tokens, _ = self.retrieve_evidence_blocks(query_tokens, query_pad_mask)
for i, block in enumerate(topk_block_tokens[0]): for i, block in enumerate(topk_block_tokens[0]):
block_text = self.ict_dataset.decode_tokens(block) block_text = self.ict_dataset.decode_tokens(block)
...@@ -186,7 +190,10 @@ class REALMRetriever(MegatronModule): ...@@ -186,7 +190,10 @@ class REALMRetriever(MegatronModule):
def retrieve_evidence_blocks(self, query_tokens, query_pad_mask, query_block_indices=None, include_null_doc=False): def retrieve_evidence_blocks(self, query_tokens, query_pad_mask, query_block_indices=None, include_null_doc=False):
"""Embed blocks to be used in a forward pass""" """Embed blocks to be used in a forward pass"""
with torch.no_grad(): with torch.no_grad():
true_model = self.ict_model.module.module if hasattr(self.ict_model, 'module'):
true_model = self.ict_model.module
else:
true_model = self.ict_model
query_embeds = detach(true_model.embed_query(query_tokens, query_pad_mask)) query_embeds = detach(true_model.embed_query(query_tokens, query_pad_mask))
_, block_indices = self.hashed_index.search_mips_index(query_embeds, top_k=self.top_k, reconstruct=False) _, block_indices = self.hashed_index.search_mips_index(query_embeds, top_k=self.top_k, reconstruct=False)
all_topk_tokens, all_topk_pad_masks = [], [] all_topk_tokens, all_topk_pad_masks = [], []
...@@ -195,11 +202,12 @@ class REALMRetriever(MegatronModule): ...@@ -195,11 +202,12 @@ class REALMRetriever(MegatronModule):
if query_block_indices is None: if query_block_indices is None:
query_block_indices = [-1] * len(block_indices) query_block_indices = [-1] * len(block_indices)
top_k_offset = int(include_null_doc)
for query_idx, indices in enumerate(block_indices): for query_idx, indices in enumerate(block_indices):
# [k x meta_dim] # [k x meta_dim]
# exclude trivial candidate if it appears, else just trim the weakest in the top-k # exclude trivial candidate if it appears, else just trim the weakest in the top-k
topk_metas = [self.block_data.meta_data[idx] for idx in indices if idx != query_block_indices[query_idx]] topk_metas = [self.block_data.meta_data[idx] for idx in indices if idx != query_block_indices[query_idx]]
topk_block_data = [self.ict_dataset.get_block(*block_meta) for block_meta in topk_metas[:self.top_k - 1]] topk_block_data = [self.ict_dataset.get_block(*block_meta) for block_meta in topk_metas[:self.top_k - top_k_offset]]
if include_null_doc: if include_null_doc:
topk_block_data.append(self.ict_dataset.get_null_block()) topk_block_data.append(self.ict_dataset.get_null_block())
topk_tokens, topk_pad_masks = zip(*topk_block_data) topk_tokens, topk_pad_masks = zip(*topk_block_data)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment