faiss_test.py 5.76 KB
Newer Older
Neel Kant's avatar
Neel Kant committed
1
2
from collections import defaultdict
import time
Neel Kant's avatar
Neel Kant committed
3
import pickle
Neel Kant's avatar
Neel Kant committed
4
5

import faiss
Neel Kant's avatar
Neel Kant committed
6
from faiss import index_factory, index_cpu_to_gpu
Neel Kant's avatar
Neel Kant committed
7
8
9
10
11
12
13
14
15
16
17
import numpy as np

from megatron import get_args


PCAS = [
    'PCA', 'PCAR', 'PCAW', 'PCAWR'
]

# PCA to 64 dim gets "first missing" ~ 95% and "mixed" ~ 5% for all
# however, this is pretty hard since the embeds and queries are totally random, would be better to test according to a distribution
Neel Kant's avatar
Neel Kant committed
18
19
# update: Using realisitc mean and covariance helps, but then adjusting for inner product makes it unusable again
# CONCLUSION: PCA should not be used for MIPS
Neel Kant's avatar
Neel Kant committed
20
21
22


QUANTIZERS = [
Neel Kant's avatar
Neel Kant committed
23
24
    'IVF4096_SQ16', # 'IMI2x9',
    'HNSW32_SQ16', # 'IVF4096_HNSW32'
Neel Kant's avatar
Neel Kant committed
25
26
]

Neel Kant's avatar
Neel Kant committed
27
28
29
30
# IMI2x9 or any other MultiIndex doesn't support inner product so it's unusable
# IVF4096_HNSW32 doesn't support inner product either


Neel Kant's avatar
Neel Kant committed
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47

ENCODINGS = [
    'Flat',
    'PQ16np', # PQ16, PQ16x12(np)
    'SQ4', 'SQ8', 'SQ6', 'SQfp16',
  # 'LSH', 'LSHrt', 'LSHr', 'LSHt'
]

# PQ16 is pretty slow for creating and adding - ~96s for 1e5, 105s for 1e6
# PQ16np is a bit faster but is pretty inaccurate - misses top-1 result 2/3 of time (1e6 embeds)
# PQ16x12(np) gets real slow. Uses 4096 centroids.

# SQfp16 is solid.

# LSH is inaccurate - pretty much always missing the top-1 result (1e6 embeds)


Neel Kant's avatar
Neel Kant committed
48
49


Neel Kant's avatar
Neel Kant committed
50
51
52
53
def latest(times):
    return times[-1] - times[-2]


Neel Kant's avatar
Neel Kant committed
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def get_embed_mean_and_cov():
    embed_data = pickle.load(open('/home/dcg-adlr-nkant-data.cosmos1202/hash_data/normed4096_whitened.pkl', 'rb'))
    embed_mean = embed_data['embed_mean']
    whitener = embed_data['embed_whitener']
    embed_cov = whitener.dot(whitener.transpose())

    return embed_mean, embed_cov


def get_embeds_and_queries(mean, cov, num_embeds, num_queries):
    embeds = np.random.multivariate_normal(mean, cov, num_embeds).astype('float32')
    queries = np.random.multivariate_normal(mean, cov, num_queries).astype('float32')
    return embeds, queries


def get_random_embeds_and_queries(d, num_embeds, num_queries):
Neel Kant's avatar
Neel Kant committed
70
71
72
73
74
    embeds = np.random.rand(num_embeds, d).astype('float32')
    queries = np.random.rand(num_queries, d).astype('float32')
    return embeds, queries


Neel Kant's avatar
Neel Kant committed
75

Neel Kant's avatar
Neel Kant committed
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
def print_timing_stats(name, create_and_add, search):
    print('{:20s} Create and add embeds: {:10.4f}s  |  Search embeds: {:10.4f}s'.format(name, create_and_add, search))


def print_accuracy_stats(name, gold_indices, estimated_indices):
    gold_indices, estimated_indices = list(gold_indices), list(estimated_indices)
    results = defaultdict(int)

    for gold, estimated in zip(gold_indices, estimated_indices):
        if gold[0] not in estimated:
            results['first_missing'] += 1
        elif np.array_equal(gold, estimated):
            results['all_equal'] += 1
        else:
            results['mixed'] += 1
    result_strs = ['first_missing', 'all_equal', 'mixed']
    print('{:20s} First missing: {:4d}  |  All equal: {:4d}  |  Mixed: {:4d}'.format(name, *[results[s] for s in result_strs]))


def create_and_test_gold(d, k, embeds, queries):
    times = [time.time()]
Neel Kant's avatar
Neel Kant committed
97
98
    res = faiss.StandardGpuResources()
    gold_idx = index_cpu_to_gpu(res, 0, index_factory(d, 'Flat'))
Neel Kant's avatar
Neel Kant committed
99
100
101
102
103
104
105
106
107
108
109
    gold_idx.add(embeds)
    times.append(time.time())
    create_and_add = latest(times)

    distances, indices = gold_idx.search(queries, k)
    times.append(time.time())
    print_timing_stats('Flat', create_and_add, latest(times))
    print('-' * 100)
    return distances, indices


Neel Kant's avatar
Neel Kant committed
110
def test_pca(d, k, embeds, queries, pca_dim):
Neel Kant's avatar
Neel Kant committed
111
112
113
114
115
116

    distances, indices = create_and_test_gold(d, k, embeds, queries)

    times = [time.time()]
    all_pca_indices = []
    for s in PCAS:
Neel Kant's avatar
Neel Kant committed
117
        pca_idx = index_factory(d, s + "{},Flat".format(pca_dim), faiss.METRIC_INNER_PRODUCT)
Neel Kant's avatar
Neel Kant committed
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
        pca_idx.train(embeds)
        pca_idx.add(embeds)
        times.append(time.time())
        create_and_add = latest(times)

        pca_distances, pca_indices = pca_idx.search(queries, k)
        all_pca_indices.append(pca_indices)
        times.append(time.time())
        print_timing_stats(s, create_and_add, latest(times))

    print('\n')
    for s, pca_indices in zip(PCAS, all_pca_indices):
        print_accuracy_stats(s, indices, pca_indices)


Neel Kant's avatar
Neel Kant committed
133
def test_quantizers(d, k, embeds, queries):
Neel Kant's avatar
Neel Kant committed
134
135
136
137
138

    distances, indices = create_and_test_gold(d, k, embeds, queries)

    times = [time.time()]
    for s in QUANTIZERS:
Neel Kant's avatar
Neel Kant committed
139
140
        if 'HNSW' in s:
            quant_idx = index_factory(d, s, faiss.METRIC_INNER_PRODUCT)
Neel Kant's avatar
Neel Kant committed
141
        else:
Neel Kant's avatar
Neel Kant committed
142
            quant_idx = index_factory(d, "Flat," + s, faiss.METRIC_INNER_PRODUCT)
Neel Kant's avatar
Neel Kant committed
143
144
145
146
147
148
149
150
151
152
153

        quant_idx.train(embeds)
        quant_idx.add(embeds)
        times.append(time.time())
        create_and_add = latest(times)

        quant_distances, quant_indices = quant_idx.search(queries, k)
        times.append(time.time())
        print_timing_stats(s, create_and_add, latest(times))


Neel Kant's avatar
Neel Kant committed
154
def test_encodings(d, k, embeds, queries):
Neel Kant's avatar
Neel Kant committed
155
156
157
158
159
160

    distances, indices = create_and_test_gold(d, k, embeds, queries)

    times = [time.time()]
    all_encode_indices = []
    for s in ENCODINGS:
Neel Kant's avatar
Neel Kant committed
161
        encode_idx = index_factory(d, s, faiss.METRIC_INNER_PRODUCT)
Neel Kant's avatar
Neel Kant committed
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177

        encode_idx.train(embeds)
        encode_idx.add(embeds)
        times.append(time.time())
        create_and_add = latest(times)

        _, encode_indices = encode_idx.search(queries, k)
        all_encode_indices.append(encode_indices)
        times.append(time.time())
        print_timing_stats(s, create_and_add, latest(times))

    print('\n')
    for s, encode_indices in zip(ENCODINGS, all_encode_indices):
        print_accuracy_stats(s, indices, encode_indices)


Neel Kant's avatar
Neel Kant committed
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
def run_all_tests():
    mean, cov = get_embed_mean_and_cov()
    embeds, queries = get_embeds_and_queries(mean, cov, int(1e6), 256)
    d = 128
    k = 10
    test_pca(d, k, embeds, queries, 96)
    test_quantizers(d, k, embeds, queries)
    test_encodings(d, k, embeds, queries)




if __name__ == "__main__":
    run_all_tests()


Neel Kant's avatar
Neel Kant committed
194
195
196