knn.py 5.5 KB
Newer Older
1
2
3
4
5
6
7
8
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
This file re-uses implementation from https://github.com/yl-1993/learn-to-cluster
"""

import math
import multiprocessing as mp
9
import os
10

11
12
import numpy as np
from tqdm import tqdm
13
from utils import Timer
14

15
16
17
from .faiss_search import faiss_search_knn

__all__ = [
18
19
20
21
22
    "knn_faiss",
    "knn_faiss_gpu",
    "fast_knns2spmat",
    "build_knns",
    "knns2ordered_nbrs",
23
24
]

25

26
27
28
29
30
31
32
33
34
35
36
37
38
def knns2ordered_nbrs(knns, sort=True):
    if isinstance(knns, list):
        knns = np.array(knns)
    nbrs = knns[:, 0, :].astype(np.int32)
    dists = knns[:, 1, :]
    if sort:
        # sort dists from low to high
        nb_idx = np.argsort(dists, axis=1)
        idxs = np.arange(nb_idx.shape[0]).reshape(-1, 1)
        dists = dists[idxs, nb_idx]
        nbrs = nbrs[idxs, nb_idx]
    return dists, nbrs

39

40
41
42
def fast_knns2spmat(knns, k, th_sim=0, use_sim=True, fill_value=None):
    # convert knns to symmetric sparse matrix
    from scipy.sparse import csr_matrix
43

44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
    eps = 1e-5
    n = len(knns)
    if isinstance(knns, list):
        knns = np.array(knns)
    if len(knns.shape) == 2:
        # knns saved by hnsw has different shape
        n = len(knns)
        ndarr = np.ones([n, 2, k])
        ndarr[:, 0, :] = -1  # assign unknown dist to 1 and nbr to -1
        for i, (nbr, dist) in enumerate(knns):
            size = len(nbr)
            assert size == len(dist)
            ndarr[i, 0, :size] = nbr[:size]
            ndarr[i, 1, :size] = dist[:size]
        knns = ndarr
    nbrs = knns[:, 0, :]
    dists = knns[:, 1, :]
61
62
63
    assert (
        -eps <= dists.min() <= dists.max() <= 1 + eps
    ), "min: {}, max: {}".format(dists.min(), dists.max())
64
    if use_sim:
65
        sims = 1.0 - dists
66
67
68
    else:
        sims = dists
    if fill_value is not None:
69
        print("[fast_knns2spmat] edge fill value:", fill_value)
70
71
72
73
74
75
76
77
78
79
80
81
        sims.fill(fill_value)
    row, col = np.where(sims >= th_sim)
    # remove the self-loop
    idxs = np.where(row != nbrs[row, col])
    row = row[idxs]
    col = col[idxs]
    data = sims[row, col]
    col = nbrs[row, col]  # convert to absolute column
    assert len(row) == len(col) == len(data)
    spmat = csr_matrix((data, (row, col)), shape=(n, n))
    return spmat

82
83
84
85

def build_knns(feats, k, knn_method, dump=True):
    with Timer("build index"):
        if knn_method == "faiss":
86
            index = knn_faiss(feats, k, omp_num_threads=None)
87
        elif knn_method == "faiss_gpu":
88
89
90
            index = knn_faiss_gpu(feats, k)
        else:
            raise KeyError(
91
92
93
94
                "Only support faiss and faiss_gpu currently ({}).".format(
                    knn_method
                )
            )
95
96
97
98
        knns = index.get_knns()
    return knns


99
100
class knn:
    def __init__(self, feats, k, index_path="", verbose=True):
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
        pass

    def filter_by_th(self, i):
        th_nbrs = []
        th_dists = []
        nbrs, dists = self.knns[i]
        for n, dist in zip(nbrs, dists):
            if 1 - dist < self.th:
                continue
            th_nbrs.append(n)
            th_dists.append(dist)
        th_nbrs = np.array(th_nbrs)
        th_dists = np.array(th_dists)
        return (th_nbrs, th_dists)

    def get_knns(self, th=None):
117
        if th is None or th <= 0.0:
118
119
120
121
            return self.knns
        # TODO: optimize the filtering process by numpy
        # nproc = mp.cpu_count()
        nproc = 1
122
123
124
        with Timer(
            "filter edges by th {} (CPU={})".format(th, nproc), self.verbose
        ):
125
126
127
128
129
130
            self.th = th
            self.th_knns = []
            tot = len(self.knns)
            if nproc > 1:
                pool = mp.Pool(nproc)
                th_knns = list(
131
132
                    tqdm(pool.imap(self.filter_by_th, range(tot)), total=tot)
                )
133
134
135
136
137
                pool.close()
            else:
                th_knns = [self.filter_by_th(i) for i in range(tot)]
            return th_knns

138

139
class knn_faiss(knn):
140
141
142
143
144
145
146
147
148
149
    def __init__(
        self,
        feats,
        k,
        nprobe=128,
        omp_num_threads=None,
        rebuild_index=True,
        verbose=True,
        **kwargs
    ):
150
        import faiss
151

152
153
154
        if omp_num_threads is not None:
            faiss.omp_set_num_threads(omp_num_threads)
        self.verbose = verbose
155
156
        with Timer("[faiss] build index", verbose):
            feats = feats.astype("float32")
157
158
159
            size, dim = feats.shape
            index = faiss.IndexFlatIP(dim)
            index.add(feats)
160
        with Timer("[faiss] query topk {}".format(k), verbose):
161
            sims, nbrs = index.search(feats, k=k)
162
163
164
165
166
167
168
169
            self.knns = [
                (
                    np.array(nbr, dtype=np.int32),
                    1 - np.array(sim, dtype=np.float32),
                )
                for nbr, sim in zip(nbrs, sims)
            ]

170
171

class knn_faiss_gpu(knn):
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
    def __init__(
        self,
        feats,
        k,
        nprobe=128,
        num_process=4,
        is_precise=True,
        sort=True,
        verbose=True,
        **kwargs
    ):
        with Timer("[faiss_gpu] query topk {}".format(k), verbose):
            dists, nbrs = faiss_search_knn(
                feats,
                k=k,
                nprobe=nprobe,
                num_process=num_process,
                is_precise=is_precise,
                sort=sort,
                verbose=verbose,
            )

            self.knns = [
                (
                    np.array(nbr, dtype=np.int32),
                    np.array(dist, dtype=np.float32),
                )
                for nbr, dist in zip(nbrs, dists)
            ]