Commit 1e01b3a2 authored by Neel Kant's avatar Neel Kant
Browse files

Corrected exact_mips_test

parent c1c958fa
......@@ -39,6 +39,7 @@ class HashedIndex(object):
self.m = 5
self.u = 0.99
self.max_norm = None
self.block_index = None
def state(self):
state = {
......@@ -149,9 +150,9 @@ class HashedIndex(object):
hash_scores_pos = torch.matmul(batch_embed, hashing_tensor)
embed_scores = torch.cat((hash_scores_pos, -hash_scores_pos), axis=1)
embed_hashes = detach(torch.argmax(embed_scores, axis=1))
for hash, embed in zip(list(embed_hashes), list(detach(batch_embed))):
for idx, hash in zip(batch_block_idx, list(embed_hashes)):
# [int] instead of [array<int>] since this is just for analysis rn
self.hash_data[hash].append(batch_block_idx)
self.hash_data[hash].append(idx)
i += 1
......@@ -190,8 +191,7 @@ class HashedIndex(object):
return np.float32(np.concatenate((block_embeds, norm_powers, halves_array), axis=1))
def alsh_query_preprocess_fn(self, query_embeds):
norm = np.linalg.norm(query_embeds, axis=1)
max_norm = max(norm)
max_norm = max(np.linalg.norm(query_embeds, axis=1))
if max_norm > 1:
query_embeds = self.u / max_norm * query_embeds
norm_powers, halves_array = self.get_norm_powers_and_halves_array(query_embeds)
......@@ -199,9 +199,11 @@ class HashedIndex(object):
# Q'(S(x)) for all x in query_embeds
return np.float32(np.concatenate((query_embeds, halves_array, norm_powers), axis=1))
def exact_mips_equals(self, query_embeds):
def exact_mips_equals(self, query_embeds, norm_blocks):
"""For each query, determine whether the mips block is in the correct hash bucket"""
_, block_embeds = zip(*self.block_data.items())
shuffled_block_idx, block_embeds = zip(*self.block_data.items())
if norm_blocks:
block_embeds = block_embeds / np.linalg.norm(block_embeds, axis=1).reshape(-1, 1)
with torch.no_grad():
# get hashes for the queries
hash_scores_pos = torch.matmul(torch.cuda.HalfTensor(query_embeds), torch.cuda.HalfTensor(self.hash_matrix))
......@@ -212,10 +214,10 @@ class HashedIndex(object):
inner_products = torch.matmul(torch.cuda.HalfTensor(query_embeds),
torch.cuda.HalfTensor(np.transpose(np.array(block_embeds))))
max_inner_product_idxes = detach(torch.argmax(inner_products, axis=1))
best_blocks = [self.block_data[idx] for idx in max_inner_product_idxes]
best_blocks = [self.block_data[shuffled_block_idx[idx]] for idx in max_inner_product_idxes]
best_blocks_tensor = torch.cuda.HalfTensor(np.array(best_blocks))
# bb = best_blocks
bb_hash_scores_pos = torch.matmul(torch.cuda.HalfTensor(best_blocks_tensor), torch.cuda.HalfTensor(self.hash_matrix))
bb_hash_scores_pos = torch.matmul(best_blocks_tensor, torch.cuda.HalfTensor(self.hash_matrix))
bb_hash_scores = torch.cat((bb_hash_scores_pos, -bb_hash_scores_pos), axis=1)
best_block_hashes = detach(torch.argmax(bb_hash_scores, axis=1))
......@@ -226,13 +228,15 @@ class HashedIndex(object):
# array of zeros and ones which can be used for counting success
return equal_arr
def exact_mips_test(self, whitened, num_queries, alsh):
def exact_mips_test(self, num_queries, whitened, norm_blocks, alsh):
if whitened:
if self.embed_mean is None:
self.whiten_block_embeds()
query_embeds = np.random.multivariate_normal(np.zeros(128), np.eye(128), num_queries)
query_embeds = query_embeds / np.linalg.norm(query_embeds, axis=1).reshape(-1, 1)
if alsh:
self.create_block_data_index()
if self.block_index is None:
self.create_block_data_index()
alsh_queries = self.alsh_query_preprocess_fn(query_embeds)
neighbor_ids, distances = self.block_index.search(alsh_queries, 5)
print('DONE')
......@@ -245,7 +249,7 @@ class HashedIndex(object):
cov = np.cov(arr_embeds)
query_embeds = np.random.multivariate_normal(mean, cov, num_queries)
equal_arr = self.exact_mips_equals(query_embeds)
equal_arr = self.exact_mips_equals(query_embeds, norm_blocks)
print("Num correct: ", sum(equal_arr), " Fraction correct: ", sum(equal_arr) / equal_arr.size)
print(equal_arr)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment