Commit ac79d374 authored by Neel Kant's avatar Neel Kant
Browse files

Debug test_retriever

parent 3fb02b8e
......@@ -103,7 +103,9 @@ class HashedIndex(object):
@classmethod
def load_from_file(cls, fname):
print(" > Unpickling block hash data")
state_dict = pickle.load(open(fname, 'rb'))
print(" > Finished unpickling")
hash_matrix = state_dict['hash_matrix']
new_index = HashedIndex(hash_matrix.shape[0], hash_matrix.shape[1] * 2)
......@@ -121,7 +123,16 @@ def test_retriever():
dataset = get_dataset()
hashed_index = HashedIndex.load_from_file('block_hash_data.pkl')
retriever = REALMRetriever(model, dataset, hashed_index)
retriever.retrieve_evidence_blocks_text("The last monarch from the house of windsor")
strs = [
"The last monarch from the house of windsor",
"married to Elvis Presley",
"tallest building in the world today",
"who makes graphics cards"
]
for s in strs:
retriever.retrieve_evidence_blocks_text(s)
def main():
......@@ -246,4 +257,4 @@ def get_dataloader(dataset):
if __name__ == "__main__":
main()
test_retriever()
......@@ -84,10 +84,10 @@ class InverseClozeDataset(Dataset):
def get_block(self, start_idx, end_idx, doc_idx):
"""Get the IDs for an evidence block plus the title of the corresponding document"""
block = [self.context_dataset[i] for i in range(start_idx, end_idx)]
title = list(self.titles_dataset[int(doc_idx)])
block = [list(self.block_dataset[i]) for i in range(start_idx, end_idx)]
title = list(self.title_dataset[int(doc_idx)])
block = list(itertools.chain(*block))[self.max_seq_length - (3 + len(title))]
block = list(itertools.chain(*block))[:self.max_seq_length - (3 + len(title))]
block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title)
return block_tokens, block_pad_mask
......
......@@ -293,20 +293,20 @@ class REALMRetriever(MegatronModule):
query_tokens = self.ict_dataset.encode_text(query_text)[:padless_max_len]
query_tokens, query_pad_mask = self.ict_dataset.concat_and_pad_tokens(query_tokens)
query_tokens = torch.cuda.IntTensor(np.array(query_tokens).reshape(1, -1))
query_pad_mask = torch.cuda.IntTensor(np.array(query_pad_mask).reshape(1, -1))
query_tokens = torch.cuda.LongTensor(np.array(query_tokens).reshape(1, -1))
query_pad_mask = torch.cuda.LongTensor(np.array(query_pad_mask).reshape(1, -1))
query_embed = self.ict_model.embed_query(query_tokens, query_pad_mask)
query_embed = self.ict_model.module.module.embed_query(query_tokens, query_pad_mask)
query_hash = self.hashed_index.hash_embeds(query_embed)
assert query_hash.size == 1
block_bucket = self.hashed_index.get_block_bucket(query_hash[0])
block_embeds = [self.hashed_index.get_block_embed[idx] for idx in block_bucket[:, 3]]
block_embeds = [self.hashed_index.get_block_embed(arr[3]) for arr in block_bucket]
block_embed_tensor = torch.cuda.HalfTensor(np.array(block_embeds))
retrieval_scores = query_embed.matmul(torch.transpose(block_embed_tensor, 0, 1))
top5_vals, top5_indices = torch.topk(retrieval_scores, k=5, sorted=True)
top5_start_end_doc = [block_bucket[idx][:3] for idx in top5_indices]
top5_start_end_doc = [block_bucket[idx][:3] for idx in top5_indices.squeeze()]
top5_blocks = [(self.ict_dataset.get_block(*indices)) for indices in top5_start_end_doc]
for i, (block, _) in enumerate(top5_blocks):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment