combined_index.py 6.22 KB
Newer Older
huchen's avatar
huchen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import os
import faiss
import numpy as np


class CombinedIndex:
    """
    combines a set of inverted lists into a hstack
    masks part of those lists
    adds these inverted lists to an empty index that contains
    the info on how to perform searches
    """

    def __init__(self, invlist_fnames, empty_index_fname,
                 masked_index_fname=None):

        self.indexes = indexes = []
        ilv = faiss.InvertedListsPtrVector()

        for fname in invlist_fnames:
            if os.path.exists(fname):
                print('reading', fname, end='\r', flush=True)
                index = faiss.read_index(fname)
                indexes.append(index)
                il = faiss.extract_index_ivf(index).invlists
            else:
                raise AssertionError
            ilv.push_back(il)
        print()

        self.big_il = faiss.VStackInvertedLists(ilv.size(), ilv.data())
        if masked_index_fname:
            self.big_il_base = self.big_il
            print('loading', masked_index_fname)
            self.masked_index = faiss.read_index(
                masked_index_fname,
                faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY)
            self.big_il = faiss.MaskedInvertedLists(
                faiss.extract_index_ivf(self.masked_index).invlists,
                self.big_il_base)

        print('loading empty index', empty_index_fname)
        self.index = faiss.read_index(empty_index_fname)
        ntotal = self.big_il.compute_ntotal()

        print('replace invlists')
        index_ivf = faiss.extract_index_ivf(self.index)
        index_ivf.replace_invlists(self.big_il, False)
        index_ivf.ntotal = self.index.ntotal = ntotal
        index_ivf.parallel_mode = 1   # seems reasonable to do this all the time

        quantizer = faiss.downcast_index(index_ivf.quantizer)
        quantizer.hnsw.efSearch = 1024

    ############################################################
    # Expose fields and functions of the index as methods so that they
    # can be called by RPC

    def search(self, x, k):
        return self.index.search(x, k)

    def range_search(self, x, radius):
        return self.index.range_search(x, radius)

    def transform_and_assign(self, xq):
        index = self.index

        if isinstance(index, faiss.IndexPreTransform):
            assert index.chain.size() == 1
            vt = index.chain.at(0)
            xq = vt.apply_py(xq)

        # perform quantization
        index_ivf = faiss.extract_index_ivf(index)
        quantizer = index_ivf.quantizer
        coarse_dis, list_nos = quantizer.search(xq, index_ivf.nprobe)
        return xq, list_nos, coarse_dis


    def ivf_search_preassigned(self, xq, list_nos, coarse_dis, k):
        index_ivf = faiss.extract_index_ivf(self.index)
        n, d = xq.shape
        assert d == index_ivf.d
        n2, d2 = list_nos.shape
        assert list_nos.shape == coarse_dis.shape
        assert n2 == n
        assert d2 == index_ivf.nprobe
        D = np.empty((n, k), dtype='float32')
        I = np.empty((n, k), dtype='int64')
        index_ivf.search_preassigned(
            n, faiss.swig_ptr(xq), k,
            faiss.swig_ptr(list_nos), faiss.swig_ptr(coarse_dis),
            faiss.swig_ptr(D), faiss.swig_ptr(I), False)
        return D, I


    def ivf_range_search_preassigned(self, xq, list_nos, coarse_dis, radius):
        index_ivf = faiss.extract_index_ivf(self.index)
        n, d = xq.shape
        assert d == index_ivf.d
        n2, d2 = list_nos.shape
        assert list_nos.shape == coarse_dis.shape
        assert n2 == n
        assert d2 == index_ivf.nprobe
        res = faiss.RangeSearchResult(n)

        index_ivf.range_search_preassigned(
            n, faiss.swig_ptr(xq), radius,
            faiss.swig_ptr(list_nos), faiss.swig_ptr(coarse_dis),
            res)

        lims = faiss.rev_swig_ptr(res.lims, n + 1).copy()
        nd = int(lims[-1])
        D = faiss.rev_swig_ptr(res.distances, nd).copy()
        I = faiss.rev_swig_ptr(res.labels, nd).copy()
        return lims, D, I

    def set_nprobe(self, nprobe):
        index_ivf = faiss.extract_index_ivf(self.index)
        index_ivf.nprobe = nprobe

    def set_parallel_mode(self, pm):
        index_ivf = faiss.extract_index_ivf(self.index)
        index_ivf.parallel_mode = pm

    def get_ntotal(self):
        return self.index.ntotal

    def set_prefetch_nthread(self, nt):
        for idx in self.indexes:
            il = faiss.downcast_InvertedLists(
                faiss.extract_index_ivf(idx).invlists)
            il.prefetch_nthread
            il.prefetch_nthread = nt

    def set_omp_num_threads(self, nt):
        faiss.omp_set_num_threads(nt)

class CombinedIndexDeep1B(CombinedIndex):
    """ loads a CombinedIndex with the data from the big photodna index """

    def __init__(self):
        # set some paths
        workdir = "/checkpoint/matthijs/ondisk_distributed/"

        # empty index with the proper quantizer
        indexfname = workdir + 'trained.faissindex'

        # index that has some invlists that override the big one
        masked_index_fname = None
        invlist_fnames = [
            '%s/hslices/slice%d.faissindex' % (workdir, i)
            for i in range(50)
        ]
        CombinedIndex.__init__(self, invlist_fnames, indexfname, masked_index_fname)


def ivecs_read(fname):
    a = np.fromfile(fname, dtype='int32')
    d = a[0]
    return a.reshape(-1, d + 1)[:, 1:].copy()


def fvecs_read(fname):
    return ivecs_read(fname).view('float32')


if __name__ == '__main__':
    import time
    ci = CombinedIndexDeep1B()
    print('loaded index of size ', ci.index.ntotal)

    deep1bdir = "/datasets01_101/simsearch/041218/deep1b/"

    xq = fvecs_read(deep1bdir + "deep1B_queries.fvecs")
    gt_fname = deep1bdir + "deep1B_groundtruth.ivecs"
    gt = ivecs_read(gt_fname)

    for nprobe in 1, 10, 100, 1000:
        ci.set_nprobe(nprobe)
        t0 = time.time()
        D, I = ci.search(xq, 100)
        t1 = time.time()
        print('nprobe=%d 1-recall@1=%.4f t=%.2fs' % (
            nprobe, (I[:, 0] == gt[:, 0]).sum() / len(xq),
            t1 - t0
        ))