faiss_gpu.py 3.42 KB
Newer Older
1
2
3
4
"""
This file re-uses implementation from https://github.com/yl-1993/learn-to-cluster
"""
import gc
5
6
7
import os

import faiss
8
9
10
import numpy as np
from tqdm import tqdm

11
__all__ = ["faiss_search_approx_knn"]
12
13


14
15
16
17
18
19
20
21
22
23
class faiss_index_wrapper:
    def __init__(
        self,
        target,
        nprobe=128,
        index_factory_str=None,
        verbose=False,
        mode="proxy",
        using_gpu=True,
    ):
24
25
26
        self._res_list = []

        num_gpu = faiss.get_num_gpus()
27
        print("[faiss gpu] #GPU: {}".format(num_gpu))
28
29
30

        size, dim = target.shape
        assert size > 0, "size: {}".format(size)
31
32
33
34
35
        index_factory_str = (
            "IVF{},PQ{}".format(min(8192, 16 * round(np.sqrt(size))), 32)
            if index_factory_str is None
            else index_factory_str
        )
36
37
38
        cpu_index = faiss.index_factory(dim, index_factory_str)
        cpu_index.nprobe = nprobe

39
        if mode == "proxy":
40
41
42
43
44
45
46
47
            co = faiss.GpuClonerOptions()
            co.useFloat16 = True
            co.usePrecomputed = False

            index = faiss.IndexProxy()
            for i in range(num_gpu):
                res = faiss.StandardGpuResources()
                self._res_list.append(res)
48
49
50
51
52
                sub_index = (
                    faiss.index_cpu_to_gpu(res, i, cpu_index, co)
                    if using_gpu
                    else cpu_index
                )
53
                index.addIndex(sub_index)
54
        elif mode == "shard":
55
56
57
58
            co = faiss.GpuMultipleClonerOptions()
            co.useFloat16 = True
            co.usePrecomputed = False
            co.shard = True
59
            index = faiss.index_cpu_to_all_gpus(cpu_index, co, ngpu=num_gpu)
60
61
62
63
64
65
66
        else:
            raise KeyError("Unknown index mode")

        index = faiss.IndexIDMap(index)
        index.verbose = verbose

        # get nlist to decide how many samples used for training
67
68
69
70
71
72
73
74
75
        nlist = int(
            float(
                [
                    item
                    for item in index_factory_str.split(",")
                    if "IVF" in item
                ][0].replace("IVF", "")
            )
        )
76
77
78

        # training
        if not index.is_trained:
79
            indexes_sample_for_train = np.random.randint(0, size, nlist * 256)
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
            index.train(target[indexes_sample_for_train])

        # add with ids
        target_ids = np.arange(0, size)
        index.add_with_ids(target, target_ids)
        self.index = index

    def search(self, *args, **kargs):
        return self.index.search(*args, **kargs)

    def __del__(self):
        self.index.reset()
        del self.index
        for res in self._res_list:
            del res


def batch_search(index, query, k, bs, verbose=False):
    n = len(query)
    dists = np.zeros((n, k), dtype=np.float32)
    nbrs = np.zeros((n, k), dtype=np.int64)

102
103
104
    for sid in tqdm(
        range(0, n, bs), desc="faiss searching...", disable=not verbose
    ):
105
106
107
108
109
        eid = min(n, sid + bs)
        dists[sid:eid], nbrs[sid:eid] = index.search(query[sid:eid], k)
    return dists, nbrs


110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
def faiss_search_approx_knn(
    query,
    target,
    k,
    nprobe=128,
    bs=int(1e6),
    index_factory_str=None,
    verbose=False,
):
    index = faiss_index_wrapper(
        target,
        nprobe=nprobe,
        index_factory_str=index_factory_str,
        verbose=verbose,
    )
125
126
127
128
129
    dists, nbrs = batch_search(index, query, k=k, bs=bs, verbose=verbose)

    del index
    gc.collect()
    return dists, nbrs