Commit 6e256445 authored by Neel Kant's avatar Neel Kant
Browse files

faiss use_gpu

parent 0c077a2c
......@@ -45,8 +45,6 @@ ENCODINGS = [
# LSH is inaccurate - pretty much always missing the top-1 result (1e6 embeds)
def latest(times):
return times[-1] - times[-2]
......@@ -185,8 +183,6 @@ def run_all_tests():
test_encodings(d, k, embeds, queries)
if __name__ == "__main__":
run_all_tests()
......
......@@ -43,8 +43,11 @@ def test_retriever():
model = load_ict_checkpoint(only_block_model=True)
model.eval()
dataset = get_ict_dataset()
hashed_index = HashedIndex.load_from_file(args.hash_data_path)
retriever = REALMRetriever(model, dataset, hashed_index)
block_data = BlockData.load_from_file(args.block_data_path)
mips_index = FaissMIPSIndex('flat_ip', 128)
mips_index.add_block_embed_data(block_data)
retriever = REALMRetriever(model, dataset, mips_index, top_k=5)
strs = [
"The last monarch from the house of windsor",
......@@ -58,8 +61,6 @@ def test_retriever():
def main():
initialize_megatron(extra_args_provider=None,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
args = get_args()
......@@ -116,7 +117,6 @@ def main():
set_model_com_file_not_ready()
def load_ict_checkpoint(only_query_model=False, only_block_model=False, no_grad=False, from_realm_chkpt=False):
args = get_args()
model = get_model(lambda: model_provider(only_query_model, only_block_model))
......
......@@ -3,10 +3,11 @@ import os
import pickle
import shutil
import faiss
import numpy as np
import torch
from megatron import get_args
from megatron import get_args, mpu
def detach(tensor):
......@@ -77,10 +78,10 @@ class BlockData(object):
class FaissMIPSIndex(object):
def __init__(self, index_type, embed_size, **index_kwargs):
def __init__(self, index_type, embed_size, use_gpu=False):
self.index_type = index_type
self.embed_size = embed_size
self.index_kwargs = dict(index_kwargs)
self.use_gpu = use_gpu
# alsh
self.m = 5
......@@ -89,27 +90,17 @@ class FaissMIPSIndex(object):
self.block_mips_index = None
self._set_block_index()
@classmethod
def load_from_file(cls, fname):
print(" > Unpickling block index data")
state_dict = pickle.load(open(fname, 'rb'))
print(" > Finished unpickling")
index_type = state_dict['index_type']
index_kwargs = state_dict['index_kwargs']
embed_size = state_dict['embed_size']
new_index = cls(index_type, embed_size, **index_kwargs)
return new_index
def _set_block_index(self):
import faiss
INDEX_TYPES = ['flat_l2', 'flat_ip']
if self.index_type not in INDEX_TYPES:
raise ValueError("Invalid index type specified")
index = faiss.index_factory(self.embed_size, 'Flat', faiss.METRIC_INNER_PRODUCT)
self.block_mips_index = faiss.IndexIDMap(index)
if self.use_gpu:
res = faiss.StandardGpuResources()
device = mpu.get_data_parallel_rank()
self.block_mips_index = faiss.index_cpu_to_gpu(res, device, self.block_mips_index)
def reset_index(self):
self._set_block_index()
......
......@@ -178,8 +178,8 @@ class REALMRetriever(MegatronModule):
query_tokens = torch.cuda.LongTensor(np.array(query_tokens).reshape(1, -1))
query_pad_mask = torch.cuda.LongTensor(np.array(query_pad_mask).reshape(1, -1))
top5_block_tokens, _ = self.retrieve_evidence_blocks(query_tokens, query_pad_mask)
for i, block in enumerate(top5_block_tokens[0]):
topk_block_tokens, _ = self.retrieve_evidence_blocks(query_tokens, query_pad_mask)
for i, block in enumerate(topk_block_tokens[0]):
block_text = self.ict_dataset.decode_tokens(block)
print('\n > Block {}: {}'.format(i, block_text))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment