Commit c1c958fa authored by Neel Kant's avatar Neel Kant
Browse files

Implement MIPS with FAISS

parent 56bd4804
...@@ -160,8 +160,8 @@ class HashedIndex(object): ...@@ -160,8 +160,8 @@ class HashedIndex(object):
self.block_idx, block_embeds = zip(*self.block_data.items()) self.block_idx, block_embeds = zip(*self.block_data.items())
block_embeds = np.array(block_embeds) block_embeds = np.array(block_embeds)
index = faiss.IndexFlatL2(block_embeds.shape[1])
alsh_preprocessed_blocks = self.alsh_block_preprocess_fn() alsh_preprocessed_blocks = self.alsh_block_preprocess_fn()
index = faiss.IndexFlatL2(alsh_preprocessed_blocks.shape[1])
index.add(alsh_preprocessed_blocks) index.add(alsh_preprocessed_blocks)
print('Total blocks in index: ', index.ntotal) print('Total blocks in index: ', index.ntotal)
self.block_index = index self.block_index = index
...@@ -187,7 +187,7 @@ class HashedIndex(object): ...@@ -187,7 +187,7 @@ class HashedIndex(object):
norm_powers, halves_array = self.get_norm_powers_and_halves_array(block_embeds) norm_powers, halves_array = self.get_norm_powers_and_halves_array(block_embeds)
# P'(S(x)) for all x in block_embeds # P'(S(x)) for all x in block_embeds
return np.concatenate((block_embeds, norm_powers, halves_array), axis=1) return np.float32(np.concatenate((block_embeds, norm_powers, halves_array), axis=1))
def alsh_query_preprocess_fn(self, query_embeds): def alsh_query_preprocess_fn(self, query_embeds):
norm = np.linalg.norm(query_embeds, axis=1) norm = np.linalg.norm(query_embeds, axis=1)
...@@ -197,7 +197,7 @@ class HashedIndex(object): ...@@ -197,7 +197,7 @@ class HashedIndex(object):
norm_powers, halves_array = self.get_norm_powers_and_halves_array(query_embeds) norm_powers, halves_array = self.get_norm_powers_and_halves_array(query_embeds)
# Q'(S(x)) for all x in query_embeds # Q'(S(x)) for all x in query_embeds
return np.concatenate((query_embeds, halves_array, norm_powers), axis=1) return np.float32(np.concatenate((query_embeds, halves_array, norm_powers), axis=1))
def exact_mips_equals(self, query_embeds): def exact_mips_equals(self, query_embeds):
"""For each query, determine whether the mips block is in the correct hash bucket""" """For each query, determine whether the mips block is in the correct hash bucket"""
...@@ -234,7 +234,7 @@ class HashedIndex(object): ...@@ -234,7 +234,7 @@ class HashedIndex(object):
if alsh: if alsh:
self.create_block_data_index() self.create_block_data_index()
alsh_queries = self.alsh_query_preprocess_fn(query_embeds) alsh_queries = self.alsh_query_preprocess_fn(query_embeds)
neighbor_ids, distances = self.block_idx.search(alsh_queries, 5) neighbor_ids, distances = self.block_index.search(alsh_queries, 5)
print('DONE') print('DONE')
return return
else: else:
...@@ -313,7 +313,7 @@ def main(): ...@@ -313,7 +313,7 @@ def main():
model.eval() model.eval()
dataset = get_ict_dataset() dataset = get_ict_dataset()
data_iter = iter(get_one_epoch_dataloader(dataset)) data_iter = iter(get_one_epoch_dataloader(dataset))
hashed_index = HashedIndex(embed_size=128, num_buckets=4096, whiten=True) hashed_index = HashedIndex(embed_size=128, num_buckets=32, whiten=True)
i = 1 i = 1
total = 0 total = 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment