faiss_gpu.py 3.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
"""
This file re-uses implementation from https://github.com/yl-1993/learn-to-cluster
"""
import os
import gc
import numpy as np
from tqdm import tqdm

import faiss

__all__ = ['faiss_search_approx_knn']

class faiss_index_wrapper():
    def __init__(self,
                 target,
                 nprobe=128,
                 index_factory_str=None,
                 verbose=False,
                 mode='proxy',
                 using_gpu=True):
        self._res_list = []

        num_gpu = faiss.get_num_gpus()
        print('[faiss gpu] #GPU: {}'.format(num_gpu))

        size, dim = target.shape
        assert size > 0, "size: {}".format(size)
        index_factory_str = "IVF{},PQ{}".format(
            min(8192, 16 * round(np.sqrt(size))),
            32) if index_factory_str is None else index_factory_str
        cpu_index = faiss.index_factory(dim, index_factory_str)
        cpu_index.nprobe = nprobe

        if mode == 'proxy':
            co = faiss.GpuClonerOptions()
            co.useFloat16 = True
            co.usePrecomputed = False

            index = faiss.IndexProxy()
            for i in range(num_gpu):
                res = faiss.StandardGpuResources()
                self._res_list.append(res)
                sub_index = faiss.index_cpu_to_gpu(
                    res, i, cpu_index, co) if using_gpu else cpu_index
                index.addIndex(sub_index)
        elif mode == 'shard':
            co = faiss.GpuMultipleClonerOptions()
            co.useFloat16 = True
            co.usePrecomputed = False
            co.shard = True
            index = faiss.index_cpu_to_all_gpus(cpu_index,
                                                co,
                                                ngpu=num_gpu)
        else:
            raise KeyError("Unknown index mode")

        index = faiss.IndexIDMap(index)
        index.verbose = verbose

        # get nlist to decide how many samples used for training
        nlist = int(float([
            item for item in index_factory_str.split(",") if 'IVF' in item
        ][0].replace("IVF", "")))

        # training
        if not index.is_trained:
            indexes_sample_for_train = np.random.randint(
                0, size, nlist * 256)
            index.train(target[indexes_sample_for_train])

        # add with ids
        target_ids = np.arange(0, size)
        index.add_with_ids(target, target_ids)
        self.index = index

    def search(self, *args, **kargs):
        return self.index.search(*args, **kargs)

    def __del__(self):
        self.index.reset()
        del self.index
        for res in self._res_list:
            del res


def batch_search(index, query, k, bs, verbose=False):
    n = len(query)
    dists = np.zeros((n, k), dtype=np.float32)
    nbrs = np.zeros((n, k), dtype=np.int64)

    for sid in tqdm(range(0, n, bs),
                    desc="faiss searching...",
                    disable=not verbose):
        eid = min(n, sid + bs)
        dists[sid:eid], nbrs[sid:eid] = index.search(query[sid:eid], k)
    return dists, nbrs


def faiss_search_approx_knn(query,
                            target,
                            k,
                            nprobe=128,
                            bs=int(1e6),
                            index_factory_str=None,
                            verbose=False):
    index = faiss_index_wrapper(target,
                                nprobe=nprobe,
                                index_factory_str=index_factory_str,
                                verbose=verbose)
    dists, nbrs = batch_search(index, query, k=k, bs=bs, verbose=verbose)

    del index
    gc.collect()
    return dists, nbrs