faiss_test.py 5.75 KB
Newer Older
Neel Kant's avatar
Neel Kant committed
1
2
from collections import defaultdict
import time
Neel Kant's avatar
Neel Kant committed
3
import pickle
Neel Kant's avatar
Neel Kant committed
4
5

import faiss
Neel Kant's avatar
Neel Kant committed
6
from faiss import index_factory, index_cpu_to_gpu
Neel Kant's avatar
Neel Kant committed
7
8
9
10
11
12
13
14
15
16
17
import numpy as np

from megatron import get_args


PCAS = [
    'PCA', 'PCAR', 'PCAW', 'PCAWR'
]

# PCA to 64 dim gets "first missing" ~ 95% and "mixed" ~ 5% for all
# however, this is pretty hard since the embeds and queries are totally random, would be better to test according to a distribution
Neel Kant's avatar
Neel Kant committed
18
19
# update: Using realisitc mean and covariance helps, but then adjusting for inner product makes it unusable again
# CONCLUSION: PCA should not be used for MIPS
Neel Kant's avatar
Neel Kant committed
20
21
22


QUANTIZERS = [
Neel Kant's avatar
Neel Kant committed
23
24
    'IVF4096_SQ16', # 'IMI2x9',
    'HNSW32_SQ16', # 'IVF4096_HNSW32'
Neel Kant's avatar
Neel Kant committed
25
26
]

Neel Kant's avatar
Neel Kant committed
27
28
29
30
# IMI2x9 or any other MultiIndex doesn't support inner product so it's unusable
# IVF4096_HNSW32 doesn't support inner product either


Neel Kant's avatar
Neel Kant committed
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51

ENCODINGS = [
    'Flat',
    'PQ16np', # PQ16, PQ16x12(np)
    'SQ4', 'SQ8', 'SQ6', 'SQfp16',
  # 'LSH', 'LSHrt', 'LSHr', 'LSHt'
]

# PQ16 is pretty slow for creating and adding - ~96s for 1e5, 105s for 1e6
# PQ16np is a bit faster but is pretty inaccurate - misses top-1 result 2/3 of time (1e6 embeds)
# PQ16x12(np) gets real slow. Uses 4096 centroids.

# SQfp16 is solid.

# LSH is inaccurate - pretty much always missing the top-1 result (1e6 embeds)


def latest(times):
    return times[-1] - times[-2]


Neel Kant's avatar
Neel Kant committed
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
def get_embed_mean_and_cov():
    embed_data = pickle.load(open('/home/dcg-adlr-nkant-data.cosmos1202/hash_data/normed4096_whitened.pkl', 'rb'))
    embed_mean = embed_data['embed_mean']
    whitener = embed_data['embed_whitener']
    embed_cov = whitener.dot(whitener.transpose())

    return embed_mean, embed_cov


def get_embeds_and_queries(mean, cov, num_embeds, num_queries):
    embeds = np.random.multivariate_normal(mean, cov, num_embeds).astype('float32')
    queries = np.random.multivariate_normal(mean, cov, num_queries).astype('float32')
    return embeds, queries


def get_random_embeds_and_queries(d, num_embeds, num_queries):
Neel Kant's avatar
Neel Kant committed
68
69
70
71
72
    embeds = np.random.rand(num_embeds, d).astype('float32')
    queries = np.random.rand(num_queries, d).astype('float32')
    return embeds, queries


Neel Kant's avatar
Neel Kant committed
73

Neel Kant's avatar
Neel Kant committed
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
def print_timing_stats(name, create_and_add, search):
    print('{:20s} Create and add embeds: {:10.4f}s  |  Search embeds: {:10.4f}s'.format(name, create_and_add, search))


def print_accuracy_stats(name, gold_indices, estimated_indices):
    gold_indices, estimated_indices = list(gold_indices), list(estimated_indices)
    results = defaultdict(int)

    for gold, estimated in zip(gold_indices, estimated_indices):
        if gold[0] not in estimated:
            results['first_missing'] += 1
        elif np.array_equal(gold, estimated):
            results['all_equal'] += 1
        else:
            results['mixed'] += 1
    result_strs = ['first_missing', 'all_equal', 'mixed']
    print('{:20s} First missing: {:4d}  |  All equal: {:4d}  |  Mixed: {:4d}'.format(name, *[results[s] for s in result_strs]))


def create_and_test_gold(d, k, embeds, queries):
    times = [time.time()]
Neel Kant's avatar
Neel Kant committed
95
96
    res = faiss.StandardGpuResources()
    gold_idx = index_cpu_to_gpu(res, 0, index_factory(d, 'Flat'))
Neel Kant's avatar
Neel Kant committed
97
98
99
100
101
102
103
104
105
106
107
    gold_idx.add(embeds)
    times.append(time.time())
    create_and_add = latest(times)

    distances, indices = gold_idx.search(queries, k)
    times.append(time.time())
    print_timing_stats('Flat', create_and_add, latest(times))
    print('-' * 100)
    return distances, indices


Neel Kant's avatar
Neel Kant committed
108
def test_pca(d, k, embeds, queries, pca_dim):
Neel Kant's avatar
Neel Kant committed
109
110
111
112
113
114

    distances, indices = create_and_test_gold(d, k, embeds, queries)

    times = [time.time()]
    all_pca_indices = []
    for s in PCAS:
Neel Kant's avatar
Neel Kant committed
115
        pca_idx = index_factory(d, s + "{},Flat".format(pca_dim), faiss.METRIC_INNER_PRODUCT)
Neel Kant's avatar
Neel Kant committed
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
        pca_idx.train(embeds)
        pca_idx.add(embeds)
        times.append(time.time())
        create_and_add = latest(times)

        pca_distances, pca_indices = pca_idx.search(queries, k)
        all_pca_indices.append(pca_indices)
        times.append(time.time())
        print_timing_stats(s, create_and_add, latest(times))

    print('\n')
    for s, pca_indices in zip(PCAS, all_pca_indices):
        print_accuracy_stats(s, indices, pca_indices)


Neel Kant's avatar
Neel Kant committed
131
def test_quantizers(d, k, embeds, queries):
Neel Kant's avatar
Neel Kant committed
132
133
134
135
136

    distances, indices = create_and_test_gold(d, k, embeds, queries)

    times = [time.time()]
    for s in QUANTIZERS:
Neel Kant's avatar
Neel Kant committed
137
138
        if 'HNSW' in s:
            quant_idx = index_factory(d, s, faiss.METRIC_INNER_PRODUCT)
Neel Kant's avatar
Neel Kant committed
139
        else:
Neel Kant's avatar
Neel Kant committed
140
            quant_idx = index_factory(d, "Flat," + s, faiss.METRIC_INNER_PRODUCT)
Neel Kant's avatar
Neel Kant committed
141
142
143
144
145
146
147
148
149
150
151

        quant_idx.train(embeds)
        quant_idx.add(embeds)
        times.append(time.time())
        create_and_add = latest(times)

        quant_distances, quant_indices = quant_idx.search(queries, k)
        times.append(time.time())
        print_timing_stats(s, create_and_add, latest(times))


Neel Kant's avatar
Neel Kant committed
152
def test_encodings(d, k, embeds, queries):
Neel Kant's avatar
Neel Kant committed
153
154
155
156
157
158

    distances, indices = create_and_test_gold(d, k, embeds, queries)

    times = [time.time()]
    all_encode_indices = []
    for s in ENCODINGS:
Neel Kant's avatar
Neel Kant committed
159
        encode_idx = index_factory(d, s, faiss.METRIC_INNER_PRODUCT)
Neel Kant's avatar
Neel Kant committed
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175

        encode_idx.train(embeds)
        encode_idx.add(embeds)
        times.append(time.time())
        create_and_add = latest(times)

        _, encode_indices = encode_idx.search(queries, k)
        all_encode_indices.append(encode_indices)
        times.append(time.time())
        print_timing_stats(s, create_and_add, latest(times))

    print('\n')
    for s, encode_indices in zip(ENCODINGS, all_encode_indices):
        print_accuracy_stats(s, indices, encode_indices)


Neel Kant's avatar
Neel Kant committed
176
177
178
179
180
181
182
183
184
185
186
187
188
189
def run_all_tests():
    mean, cov = get_embed_mean_and_cov()
    embeds, queries = get_embeds_and_queries(mean, cov, int(1e6), 256)
    d = 128
    k = 10
    test_pca(d, k, embeds, queries, 96)
    test_quantizers(d, k, embeds, queries)
    test_encodings(d, k, embeds, queries)


if __name__ == "__main__":
    run_all_tests()


Neel Kant's avatar
Neel Kant committed
190
191
192