Commit 1e01b3a2 authored by Neel Kant's avatar Neel Kant
Browse files

Corrected exact_mips_test

parent c1c958fa
...@@ -39,6 +39,7 @@ class HashedIndex(object): ...@@ -39,6 +39,7 @@ class HashedIndex(object):
self.m = 5 self.m = 5
self.u = 0.99 self.u = 0.99
self.max_norm = None self.max_norm = None
self.block_index = None
def state(self): def state(self):
state = { state = {
...@@ -149,9 +150,9 @@ class HashedIndex(object): ...@@ -149,9 +150,9 @@ class HashedIndex(object):
hash_scores_pos = torch.matmul(batch_embed, hashing_tensor) hash_scores_pos = torch.matmul(batch_embed, hashing_tensor)
embed_scores = torch.cat((hash_scores_pos, -hash_scores_pos), axis=1) embed_scores = torch.cat((hash_scores_pos, -hash_scores_pos), axis=1)
embed_hashes = detach(torch.argmax(embed_scores, axis=1)) embed_hashes = detach(torch.argmax(embed_scores, axis=1))
for hash, embed in zip(list(embed_hashes), list(detach(batch_embed))): for idx, hash in zip(batch_block_idx, list(embed_hashes)):
# [int] instead of [array<int>] since this is just for analysis rn # [int] instead of [array<int>] since this is just for analysis rn
self.hash_data[hash].append(batch_block_idx) self.hash_data[hash].append(idx)
i += 1 i += 1
...@@ -190,8 +191,7 @@ class HashedIndex(object): ...@@ -190,8 +191,7 @@ class HashedIndex(object):
return np.float32(np.concatenate((block_embeds, norm_powers, halves_array), axis=1)) return np.float32(np.concatenate((block_embeds, norm_powers, halves_array), axis=1))
def alsh_query_preprocess_fn(self, query_embeds): def alsh_query_preprocess_fn(self, query_embeds):
norm = np.linalg.norm(query_embeds, axis=1) max_norm = max(np.linalg.norm(query_embeds, axis=1))
max_norm = max(norm)
if max_norm > 1: if max_norm > 1:
query_embeds = self.u / max_norm * query_embeds query_embeds = self.u / max_norm * query_embeds
norm_powers, halves_array = self.get_norm_powers_and_halves_array(query_embeds) norm_powers, halves_array = self.get_norm_powers_and_halves_array(query_embeds)
...@@ -199,9 +199,11 @@ class HashedIndex(object): ...@@ -199,9 +199,11 @@ class HashedIndex(object):
# Q'(S(x)) for all x in query_embeds # Q'(S(x)) for all x in query_embeds
return np.float32(np.concatenate((query_embeds, halves_array, norm_powers), axis=1)) return np.float32(np.concatenate((query_embeds, halves_array, norm_powers), axis=1))
def exact_mips_equals(self, query_embeds): def exact_mips_equals(self, query_embeds, norm_blocks):
"""For each query, determine whether the mips block is in the correct hash bucket""" """For each query, determine whether the mips block is in the correct hash bucket"""
_, block_embeds = zip(*self.block_data.items()) shuffled_block_idx, block_embeds = zip(*self.block_data.items())
if norm_blocks:
block_embeds = block_embeds / np.linalg.norm(block_embeds, axis=1).reshape(-1, 1)
with torch.no_grad(): with torch.no_grad():
# get hashes for the queries # get hashes for the queries
hash_scores_pos = torch.matmul(torch.cuda.HalfTensor(query_embeds), torch.cuda.HalfTensor(self.hash_matrix)) hash_scores_pos = torch.matmul(torch.cuda.HalfTensor(query_embeds), torch.cuda.HalfTensor(self.hash_matrix))
...@@ -212,10 +214,10 @@ class HashedIndex(object): ...@@ -212,10 +214,10 @@ class HashedIndex(object):
inner_products = torch.matmul(torch.cuda.HalfTensor(query_embeds), inner_products = torch.matmul(torch.cuda.HalfTensor(query_embeds),
torch.cuda.HalfTensor(np.transpose(np.array(block_embeds)))) torch.cuda.HalfTensor(np.transpose(np.array(block_embeds))))
max_inner_product_idxes = detach(torch.argmax(inner_products, axis=1)) max_inner_product_idxes = detach(torch.argmax(inner_products, axis=1))
best_blocks = [self.block_data[idx] for idx in max_inner_product_idxes] best_blocks = [self.block_data[shuffled_block_idx[idx]] for idx in max_inner_product_idxes]
best_blocks_tensor = torch.cuda.HalfTensor(np.array(best_blocks)) best_blocks_tensor = torch.cuda.HalfTensor(np.array(best_blocks))
# bb = best_blocks # bb = best_blocks
bb_hash_scores_pos = torch.matmul(torch.cuda.HalfTensor(best_blocks_tensor), torch.cuda.HalfTensor(self.hash_matrix)) bb_hash_scores_pos = torch.matmul(best_blocks_tensor, torch.cuda.HalfTensor(self.hash_matrix))
bb_hash_scores = torch.cat((bb_hash_scores_pos, -bb_hash_scores_pos), axis=1) bb_hash_scores = torch.cat((bb_hash_scores_pos, -bb_hash_scores_pos), axis=1)
best_block_hashes = detach(torch.argmax(bb_hash_scores, axis=1)) best_block_hashes = detach(torch.argmax(bb_hash_scores, axis=1))
...@@ -226,13 +228,15 @@ class HashedIndex(object): ...@@ -226,13 +228,15 @@ class HashedIndex(object):
# array of zeros and ones which can be used for counting success # array of zeros and ones which can be used for counting success
return equal_arr return equal_arr
def exact_mips_test(self, whitened, num_queries, alsh): def exact_mips_test(self, num_queries, whitened, norm_blocks, alsh):
if whitened: if whitened:
if self.embed_mean is None: if self.embed_mean is None:
self.whiten_block_embeds() self.whiten_block_embeds()
query_embeds = np.random.multivariate_normal(np.zeros(128), np.eye(128), num_queries) query_embeds = np.random.multivariate_normal(np.zeros(128), np.eye(128), num_queries)
query_embeds = query_embeds / np.linalg.norm(query_embeds, axis=1).reshape(-1, 1)
if alsh: if alsh:
self.create_block_data_index() if self.block_index is None:
self.create_block_data_index()
alsh_queries = self.alsh_query_preprocess_fn(query_embeds) alsh_queries = self.alsh_query_preprocess_fn(query_embeds)
neighbor_ids, distances = self.block_index.search(alsh_queries, 5) neighbor_ids, distances = self.block_index.search(alsh_queries, 5)
print('DONE') print('DONE')
...@@ -245,7 +249,7 @@ class HashedIndex(object): ...@@ -245,7 +249,7 @@ class HashedIndex(object):
cov = np.cov(arr_embeds) cov = np.cov(arr_embeds)
query_embeds = np.random.multivariate_normal(mean, cov, num_queries) query_embeds = np.random.multivariate_normal(mean, cov, num_queries)
equal_arr = self.exact_mips_equals(query_embeds) equal_arr = self.exact_mips_equals(query_embeds, norm_blocks)
print("Num correct: ", sum(equal_arr), " Fraction correct: ", sum(equal_arr) / equal_arr.size) print("Num correct: ", sum(equal_arr), " Fraction correct: ", sum(equal_arr) / equal_arr.size)
print(equal_arr) print(equal_arr)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment