"torchvision/vscode:/vscode.git/clone" did not exist on "1f085a0e7989d51537c3ae5a0da78892a018482f"
Commit 5952c558 authored by Neel Kant's avatar Neel Kant
Browse files

Misc mips-related improvements

parent 3f122ce9
......@@ -29,7 +29,7 @@ class HashedIndex(object):
np.random.seed(seed)
self.block_data = defaultdict(list)
self.hash_data = defaultdict(list)
hash_matrix = np.random.rand(embed_size, int(num_buckets / 2))
hash_matrix = 2 * np.random.rand(embed_size, int(num_buckets / 2)) - 1
self.hash_matrix = hash_matrix / np.linalg.norm(hash_matrix, axis=0).reshape(1, -1)
self.embed_mean = None
self.embed_whitener = None
......@@ -130,13 +130,16 @@ class HashedIndex(object):
batch_size = 16384
i = 0
args = get_args()
with torch.no_grad():
hashing_tensor = torch.cuda.HalfTensor(self.hash_matrix)
while True:
if args.debug:
print(i, flush=True)
batch_slice = slice(i * batch_size, (i + 1) * batch_size)
batch_embed = torch.cuda.HalfTensor(whitened[batch_slice])
batch_block_idx = block_idx[batch_slice]
if batch_embed.size == 0:
if len(batch_block_idx) == 0:
break
hash_scores_pos = torch.matmul(batch_embed, hashing_tensor)
......@@ -145,6 +148,8 @@ class HashedIndex(object):
for hash, embed in zip(list(embed_hashes), list(detach(batch_embed))):
# [int] instead of [array<int>] since this is just for analysis rn
self.hash_data[hash].append(batch_block_idx)
i += 1
def create_block_data_index(self):
import faiss
......@@ -175,26 +180,30 @@ class HashedIndex(object):
bb_hash_scores_pos = torch.matmul(torch.cuda.HalfTensor(best_blocks_tensor), torch.cuda.HalfTensor(self.hash_matrix))
bb_hash_scores = torch.cat((bb_hash_scores_pos, -bb_hash_scores_pos), axis=1)
best_block_hashes = detach(torch.argmax(bb_hash_scores, axis=1))
print('Query hashes: ', query_hashes)
print('Block hashes: ', best_block_hashes)
equal_arr = np.equal(query_hashes, best_block_hashes).astype(int)
# array of zeros and ones which can be used for counting success
return equal_arr
def exact_mips_test(self, whitened):
def exact_mips_test(self, whitened, num_queries):
if whitened:
if self.embed_mean is None:
self.whiten_block_embeds()
query_embeds = np.random.multivariate_normal(np.zeros(128), np.eye(128), 256)
query_embeds = np.random.multivariate_normal(np.zeros(128), np.eye(128), num_queries)
else:
block_idx, all_embeds = zip(*self.block_data.items())
arr_embeds = np.transpose(np.array(all_embeds))
mean = np.mean(arr_embeds, axis=1).reshape(-1, 1)
cov = np.cov(arr_embeds)
query_embeds = np.random.multivariate_normal(mean, cov, 256)
query_embeds = np.random.multivariate_normal(mean, cov, num_queries)
equal_arr = self.exact_mips_equals(query_embeds)
print("Num correct: ", sum(equal_arr), " Fraction correct: ", sum(equal_arr) / equal_arr.size)
print(equal_arr)
@classmethod
def load_from_file(cls, fname):
......@@ -206,6 +215,8 @@ class HashedIndex(object):
new_index = HashedIndex(hash_matrix.shape[0], hash_matrix.shape[1] * 2)
new_index.block_data = state_dict['block_data']
new_index.hash_data = state_dict['hash_data']
new_index.embed_mean = state_dict.get('embed_mean')
new_index.embed_whitener = state_dict.get('embed_whitener')
new_index.hash_matrix = hash_matrix
return new_index
......@@ -279,7 +290,7 @@ def main():
hashed_index.hash_embeds(block_logits, block_indices)
hashed_index.assign_block_embeds(block_indices[:, 3], detach(block_logits))
total += block_indices.size
total += block_indices.shape[0]
i += 1
if i % 20 == 0:
print('Batch {:10d} | Total {:10d}'.format(i, total), flush=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment