Commit 3f122ce9 authored by Neel Kant's avatar Neel Kant
Browse files

Write MIPS tests in HashedIndex

parent 9d225b44
......@@ -34,6 +34,7 @@ class HashedIndex(object):
self.embed_mean = None
self.embed_whitener = None
self.whiten = whiten
self.m = 5
def state(self):
state = {
......@@ -120,7 +121,7 @@ class HashedIndex(object):
centered = arr_embeds - mean
inv_cov = np.linalg.inv(np.cov(arr_embeds))
whitener = np.transpose(np.linalg.cholesky(inv_cov))
whitened = np.transpose(whitener.dot(centered))
whitened = np.float16(np.transpose(whitener.dot(centered)))
self.embed_mean = mean.reshape(-1)
self.embed_whitener = whitener
......@@ -145,6 +146,56 @@ class HashedIndex(object):
# [int] instead of [array<int>] since this is just for analysis rn
self.hash_data[hash].append(batch_block_idx)
def create_block_data_index(self):
import faiss
self.block_idx, block_embeds = zip(*self.block_data.items())
block_embeds = np.array(block_embeds)
index = faiss.IndexFlatL2(block_embeds.shape[1])
index.add(block_embeds)
print('Total blocks in index: ', index.ntotal)
self.block_index = index
def exact_mips_equals(self, query_embeds):
"""For each query, determine whether the mips block is in the correct hash bucket"""
_, block_embeds = zip(*self.block_data.items())
with torch.no_grad():
# get hashes for the queries
hash_scores_pos = torch.matmul(torch.cuda.HalfTensor(query_embeds), torch.cuda.HalfTensor(self.hash_matrix))
hash_scores = torch.cat((hash_scores_pos, -hash_scores_pos), axis=1)
query_hashes = detach(torch.argmax(hash_scores, axis=1))
# [num_query x num_blocks]
inner_products = torch.matmul(torch.cuda.HalfTensor(query_embeds),
torch.cuda.HalfTensor(np.transpose(np.array(block_embeds))))
max_inner_product_idxes = detach(torch.argmax(inner_products, axis=1))
best_blocks = [self.block_data[idx] for idx in max_inner_product_idxes]
best_blocks_tensor = torch.cuda.HalfTensor(np.array(best_blocks))
# bb = best_blocks
bb_hash_scores_pos = torch.matmul(torch.cuda.HalfTensor(best_blocks_tensor), torch.cuda.HalfTensor(self.hash_matrix))
bb_hash_scores = torch.cat((bb_hash_scores_pos, -bb_hash_scores_pos), axis=1)
best_block_hashes = detach(torch.argmax(bb_hash_scores, axis=1))
equal_arr = np.equal(query_hashes, best_block_hashes).astype(int)
# array of zeros and ones which can be used for counting success
return equal_arr
def exact_mips_test(self, whitened):
if whitened:
if self.embed_mean is None:
self.whiten_block_embeds()
query_embeds = np.random.multivariate_normal(np.zeros(128), np.eye(128), 256)
else:
block_idx, all_embeds = zip(*self.block_data.items())
arr_embeds = np.transpose(np.array(all_embeds))
mean = np.mean(arr_embeds, axis=1).reshape(-1, 1)
cov = np.cov(arr_embeds)
query_embeds = np.random.multivariate_normal(mean, cov, 256)
equal_arr = self.exact_mips_equals(query_embeds)
print("Num correct: ", sum(equal_arr), " Fraction correct: ", sum(equal_arr) / equal_arr.size)
@classmethod
def load_from_file(cls, fname):
print(" > Unpickling block hash data")
......@@ -159,23 +210,6 @@ class HashedIndex(object):
return new_index
@classmethod
def whiten_and_rehash(cls, fname):
"""Load up a HashedIndex, whiten it and rehash"""
index = cls.load_from_file(fname)
all_vectors = []
for block_embed in index.block_data.values():
all_vectors.append(block_embed)
arr_vectors = np.transpose(np.array(all_vectors))
mean = np.mean(arr_vectors, axis=1)
cov = np.cov(arr_vectors)
inv_cov = np.linalg.inv(cov)
def test_retriever():
initialize_megatron(extra_args_provider=None,
......@@ -239,7 +273,7 @@ def main():
block_indices = detach(block_indices)
block_logits = model(None, None, block_tokens, block_pad_mask, only_block=True)
# If whiten, then hashing needs to be done after whitening the block embeds
# If whitened, then hashing needs to be done after whitening the block embeds
# which is done in consolidate_shards_and_save()
if not whiten:
hashed_index.hash_embeds(block_logits, block_indices)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment