Commit 56bd4804 authored by Neel Kant's avatar Neel Kant
Browse files

Reconcile changes with head node

parent 5952c558
......@@ -34,7 +34,11 @@ class HashedIndex(object):
self.embed_mean = None
self.embed_whitener = None
self.whiten = whiten
# alsh
self.m = 5
self.u = 0.99
self.max_norm = None
def state(self):
state = {
......@@ -157,10 +161,44 @@ class HashedIndex(object):
block_embeds = np.array(block_embeds)
index = faiss.IndexFlatL2(block_embeds.shape[1])
index.add(block_embeds)
alsh_preprocessed_blocks = self.alsh_block_preprocess_fn()
index.add(alsh_preprocessed_blocks)
print('Total blocks in index: ', index.ntotal)
self.block_index = index
def get_norm_powers_and_halves_array(self, embeds):
norm = np.linalg.norm(embeds, axis=1)
norm_powers = [np.multiply(norm, norm)] # squared L2 norms of all
for i in range(self.m - 1):
norm_powers.append(np.multiply(norm_powers[-1], norm_powers[-1]))
# [num_blocks x self.m]
norm_powers = np.transpose(np.array(norm_powers))
halves_array = 0.5 * np.ones(norm_powers.shape)
return norm_powers, halves_array
def alsh_block_preprocess_fn(self):
block_idx, block_embeds = zip(*self.block_data.items())
block_embeds = np.array(block_embeds)
if self.max_norm is None:
self.max_norm = max(np.linalg.norm(block_embeds, axis=1))
if self.max_norm > 1:
block_embeds = self.u / self.max_norm * block_embeds
norm_powers, halves_array = self.get_norm_powers_and_halves_array(block_embeds)
# P'(S(x)) for all x in block_embeds
return np.concatenate((block_embeds, norm_powers, halves_array), axis=1)
def alsh_query_preprocess_fn(self, query_embeds):
norm = np.linalg.norm(query_embeds, axis=1)
max_norm = max(norm)
if max_norm > 1:
query_embeds = self.u / max_norm * query_embeds
norm_powers, halves_array = self.get_norm_powers_and_halves_array(query_embeds)
# Q'(S(x)) for all x in query_embeds
return np.concatenate((query_embeds, halves_array, norm_powers), axis=1)
def exact_mips_equals(self, query_embeds):
"""For each query, determine whether the mips block is in the correct hash bucket"""
_, block_embeds = zip(*self.block_data.items())
......@@ -188,11 +226,17 @@ class HashedIndex(object):
# array of zeros and ones which can be used for counting success
return equal_arr
def exact_mips_test(self, whitened, num_queries):
def exact_mips_test(self, whitened, num_queries, alsh):
if whitened:
if self.embed_mean is None:
self.whiten_block_embeds()
query_embeds = np.random.multivariate_normal(np.zeros(128), np.eye(128), num_queries)
if alsh:
self.create_block_data_index()
alsh_queries = self.alsh_query_preprocess_fn(query_embeds)
neighbor_ids, distances = self.block_idx.search(alsh_queries, 5)
print('DONE')
return
else:
block_idx, all_embeds = zip(*self.block_data.items())
arr_embeds = np.transpose(np.array(all_embeds))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment