Commit 150f2384 authored by Neel Kant's avatar Neel Kant
Browse files

Update faiss_test

parent c9ca82bd
from collections import defaultdict from collections import defaultdict
import time import time
import pickle
import faiss import faiss
from faiss import index_factory from faiss import index_factory, index_cpu_to_gpu
import numpy as np import numpy as np
from megatron import get_args from megatron import get_args
...@@ -14,13 +15,19 @@ PCAS = [ ...@@ -14,13 +15,19 @@ PCAS = [
# PCA to 64 dim gets "first missing" ~ 95% and "mixed" ~ 5% for all # PCA to 64 dim gets "first missing" ~ 95% and "mixed" ~ 5% for all
# however, this is pretty hard since the embeds and queries are totally random, would be better to test according to a distribution # however, this is pretty hard since the embeds and queries are totally random, would be better to test according to a distribution
# update: Using realisitc mean and covariance helps, but then adjusting for inner product makes it unusable again
# CONCLUSION: PCA should not be used for MIPS
QUANTIZERS = [ QUANTIZERS = [
'IVF4096', 'IMI2x9', 'IVF4096_SQ16', # 'IMI2x9',
'HNSW32', 'IVF4096_HNSW32' 'HNSW32_SQ16', # 'IVF4096_HNSW32'
] ]
# IMI2x9 or any other MultiIndex doesn't support inner product so it's unusable
# IVF4096_HNSW32 doesn't support inner product either
ENCODINGS = [ ENCODINGS = [
'Flat', 'Flat',
...@@ -38,16 +45,34 @@ ENCODINGS = [ ...@@ -38,16 +45,34 @@ ENCODINGS = [
# LSH is inaccurate - pretty much always missing the top-1 result (1e6 embeds) # LSH is inaccurate - pretty much always missing the top-1 result (1e6 embeds)
def latest(times): def latest(times):
return times[-1] - times[-2] return times[-1] - times[-2]
def get_embeds_and_queries(d, num_embeds, num_queries): def get_embed_mean_and_cov():
embed_data = pickle.load(open('/home/dcg-adlr-nkant-data.cosmos1202/hash_data/normed4096_whitened.pkl', 'rb'))
embed_mean = embed_data['embed_mean']
whitener = embed_data['embed_whitener']
embed_cov = whitener.dot(whitener.transpose())
return embed_mean, embed_cov
def get_embeds_and_queries(mean, cov, num_embeds, num_queries):
embeds = np.random.multivariate_normal(mean, cov, num_embeds).astype('float32')
queries = np.random.multivariate_normal(mean, cov, num_queries).astype('float32')
return embeds, queries
def get_random_embeds_and_queries(d, num_embeds, num_queries):
embeds = np.random.rand(num_embeds, d).astype('float32') embeds = np.random.rand(num_embeds, d).astype('float32')
queries = np.random.rand(num_queries, d).astype('float32') queries = np.random.rand(num_queries, d).astype('float32')
return embeds, queries return embeds, queries
def print_timing_stats(name, create_and_add, search): def print_timing_stats(name, create_and_add, search):
print('{:20s} Create and add embeds: {:10.4f}s | Search embeds: {:10.4f}s'.format(name, create_and_add, search)) print('{:20s} Create and add embeds: {:10.4f}s | Search embeds: {:10.4f}s'.format(name, create_and_add, search))
...@@ -69,7 +94,8 @@ def print_accuracy_stats(name, gold_indices, estimated_indices): ...@@ -69,7 +94,8 @@ def print_accuracy_stats(name, gold_indices, estimated_indices):
def create_and_test_gold(d, k, embeds, queries): def create_and_test_gold(d, k, embeds, queries):
times = [time.time()] times = [time.time()]
gold_idx = index_factory(d, 'Flat') res = faiss.StandardGpuResources()
gold_idx = index_cpu_to_gpu(res, 0, index_factory(d, 'Flat'))
gold_idx.add(embeds) gold_idx.add(embeds)
times.append(time.time()) times.append(time.time())
create_and_add = latest(times) create_and_add = latest(times)
...@@ -81,15 +107,14 @@ def create_and_test_gold(d, k, embeds, queries): ...@@ -81,15 +107,14 @@ def create_and_test_gold(d, k, embeds, queries):
return distances, indices return distances, indices
def test_pca(d, k, num_embeds, num_queries, pca_dim): def test_pca(d, k, embeds, queries, pca_dim):
embeds, queries = get_embeds_and_queries(d, num_embeds, num_queries)
distances, indices = create_and_test_gold(d, k, embeds, queries) distances, indices = create_and_test_gold(d, k, embeds, queries)
times = [time.time()] times = [time.time()]
all_pca_indices = [] all_pca_indices = []
for s in PCAS: for s in PCAS:
pca_idx = index_factory(d, s + "{},Flat".format(pca_dim)) pca_idx = index_factory(d, s + "{},Flat".format(pca_dim), faiss.METRIC_INNER_PRODUCT)
pca_idx.train(embeds) pca_idx.train(embeds)
pca_idx.add(embeds) pca_idx.add(embeds)
times.append(time.time()) times.append(time.time())
...@@ -105,17 +130,16 @@ def test_pca(d, k, num_embeds, num_queries, pca_dim): ...@@ -105,17 +130,16 @@ def test_pca(d, k, num_embeds, num_queries, pca_dim):
print_accuracy_stats(s, indices, pca_indices) print_accuracy_stats(s, indices, pca_indices)
def test_quantizers(d, k, num_embeds, num_queries): def test_quantizers(d, k, embeds, queries):
embeds, queries = get_embeds_and_queries(d, num_embeds, num_queries)
distances, indices = create_and_test_gold(d, k, embeds, queries) distances, indices = create_and_test_gold(d, k, embeds, queries)
times = [time.time()] times = [time.time()]
for s in QUANTIZERS: for s in QUANTIZERS:
if 'HNSW' in s and '_' not in s: if 'HNSW' in s:
quant_idx = index_factory(d, s) quant_idx = index_factory(d, s, faiss.METRIC_INNER_PRODUCT)
else: else:
quant_idx = index_factory(d, "Flat," + s) quant_idx = index_factory(d, "Flat," + s, faiss.METRIC_INNER_PRODUCT)
quant_idx.train(embeds) quant_idx.train(embeds)
quant_idx.add(embeds) quant_idx.add(embeds)
...@@ -127,15 +151,14 @@ def test_quantizers(d, k, num_embeds, num_queries): ...@@ -127,15 +151,14 @@ def test_quantizers(d, k, num_embeds, num_queries):
print_timing_stats(s, create_and_add, latest(times)) print_timing_stats(s, create_and_add, latest(times))
def test_encodings(d, k, num_embeds, num_queries): def test_encodings(d, k, embeds, queries):
embeds, queries = get_embeds_and_queries(d, num_embeds, num_queries)
distances, indices = create_and_test_gold(d, k, embeds, queries) distances, indices = create_and_test_gold(d, k, embeds, queries)
times = [time.time()] times = [time.time()]
all_encode_indices = [] all_encode_indices = []
for s in ENCODINGS: for s in ENCODINGS:
encode_idx = index_factory(d, s) encode_idx = index_factory(d, s, faiss.METRIC_INNER_PRODUCT)
encode_idx.train(embeds) encode_idx.train(embeds)
encode_idx.add(embeds) encode_idx.add(embeds)
...@@ -152,6 +175,22 @@ def test_encodings(d, k, num_embeds, num_queries): ...@@ -152,6 +175,22 @@ def test_encodings(d, k, num_embeds, num_queries):
print_accuracy_stats(s, indices, encode_indices) print_accuracy_stats(s, indices, encode_indices)
def run_all_tests():
mean, cov = get_embed_mean_and_cov()
embeds, queries = get_embeds_and_queries(mean, cov, int(1e6), 256)
d = 128
k = 10
test_pca(d, k, embeds, queries, 96)
test_quantizers(d, k, embeds, queries)
test_encodings(d, k, embeds, queries)
if __name__ == "__main__":
run_all_tests()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment