faiss_test.py 4.6 KB
Newer Older
Neel Kant's avatar
Neel Kant committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
from collections import defaultdict
import time

import faiss
from faiss import index_factory
import numpy as np

from megatron import get_args


PCAS = [
    'PCA', 'PCAR', 'PCAW', 'PCAWR'
]

# PCA to 64 dim gets "first missing" ~ 95% and "mixed" ~ 5% for all
# however, this is pretty hard since the embeds and queries are totally random, would be better to test according to a distribution


QUANTIZERS = [
    'IVF4096', 'IMI2x9',
    'HNSW32', 'IVF4096_HNSW32'
]


ENCODINGS = [
    'Flat',
    'PQ16np', # PQ16, PQ16x12(np)
    'SQ4', 'SQ8', 'SQ6', 'SQfp16',
  # 'LSH', 'LSHrt', 'LSHr', 'LSHt'
]

# PQ16 is pretty slow for creating and adding - ~96s for 1e5, 105s for 1e6
# PQ16np is a bit faster but is pretty inaccurate - misses top-1 result 2/3 of time (1e6 embeds)
# PQ16x12(np) gets real slow. Uses 4096 centroids.

# SQfp16 is solid.

# LSH is inaccurate - pretty much always missing the top-1 result (1e6 embeds)


def latest(times):
    return times[-1] - times[-2]


def get_embeds_and_queries(d, num_embeds, num_queries):
    embeds = np.random.rand(num_embeds, d).astype('float32')
    queries = np.random.rand(num_queries, d).astype('float32')
    return embeds, queries


def print_timing_stats(name, create_and_add, search):
    print('{:20s} Create and add embeds: {:10.4f}s  |  Search embeds: {:10.4f}s'.format(name, create_and_add, search))


def print_accuracy_stats(name, gold_indices, estimated_indices):
    gold_indices, estimated_indices = list(gold_indices), list(estimated_indices)
    results = defaultdict(int)

    for gold, estimated in zip(gold_indices, estimated_indices):
        if gold[0] not in estimated:
            results['first_missing'] += 1
        elif np.array_equal(gold, estimated):
            results['all_equal'] += 1
        else:
            results['mixed'] += 1
    result_strs = ['first_missing', 'all_equal', 'mixed']
    print('{:20s} First missing: {:4d}  |  All equal: {:4d}  |  Mixed: {:4d}'.format(name, *[results[s] for s in result_strs]))


def create_and_test_gold(d, k, embeds, queries):
    times = [time.time()]
    gold_idx = index_factory(d, 'Flat')
    gold_idx.add(embeds)
    times.append(time.time())
    create_and_add = latest(times)

    distances, indices = gold_idx.search(queries, k)
    times.append(time.time())
    print_timing_stats('Flat', create_and_add, latest(times))
    print('-' * 100)
    return distances, indices


def test_pca(d, k, num_embeds, num_queries, pca_dim):

    embeds, queries = get_embeds_and_queries(d, num_embeds, num_queries)
    distances, indices = create_and_test_gold(d, k, embeds, queries)

    times = [time.time()]
    all_pca_indices = []
    for s in PCAS:
        pca_idx = index_factory(d, s + "{},Flat".format(pca_dim))
        pca_idx.train(embeds)
        pca_idx.add(embeds)
        times.append(time.time())
        create_and_add = latest(times)

        pca_distances, pca_indices = pca_idx.search(queries, k)
        all_pca_indices.append(pca_indices)
        times.append(time.time())
        print_timing_stats(s, create_and_add, latest(times))

    print('\n')
    for s, pca_indices in zip(PCAS, all_pca_indices):
        print_accuracy_stats(s, indices, pca_indices)


def test_quantizers(d, k, num_embeds, num_queries):

    embeds, queries = get_embeds_and_queries(d, num_embeds, num_queries)
    distances, indices = create_and_test_gold(d, k, embeds, queries)

    times = [time.time()]
    for s in QUANTIZERS:
        if 'HNSW' in s and '_' not in s:
            quant_idx = index_factory(d, s)
        else:
            quant_idx = index_factory(d, "Flat," + s)

        quant_idx.train(embeds)
        quant_idx.add(embeds)
        times.append(time.time())
        create_and_add = latest(times)

        quant_distances, quant_indices = quant_idx.search(queries, k)
        times.append(time.time())
        print_timing_stats(s, create_and_add, latest(times))


def test_encodings(d, k, num_embeds, num_queries):

    embeds, queries = get_embeds_and_queries(d, num_embeds, num_queries)
    distances, indices = create_and_test_gold(d, k, embeds, queries)

    times = [time.time()]
    all_encode_indices = []
    for s in ENCODINGS:
        encode_idx = index_factory(d, s)

        encode_idx.train(embeds)
        encode_idx.add(embeds)
        times.append(time.time())
        create_and_add = latest(times)

        _, encode_indices = encode_idx.search(queries, k)
        all_encode_indices.append(encode_indices)
        times.append(time.time())
        print_timing_stats(s, create_and_add, latest(times))

    print('\n')
    for s, encode_indices in zip(ENCODINGS, all_encode_indices):
        print_accuracy_stats(s, indices, encode_indices)