knn.py 5.58 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
This file re-uses implementation from https://github.com/yl-1993/learn-to-cluster
"""

import os
import math
import numpy as np
import multiprocessing as mp
from tqdm import tqdm

from utils import Timer
from .faiss_search import faiss_search_knn

__all__ = [
    'knn_faiss', 'knn_faiss_gpu',
    'fast_knns2spmat', 'build_knns',
    'knns2ordered_nbrs'
]

def knns2ordered_nbrs(knns, sort=True):
    if isinstance(knns, list):
        knns = np.array(knns)
    nbrs = knns[:, 0, :].astype(np.int32)
    dists = knns[:, 1, :]
    if sort:
        # sort dists from low to high
        nb_idx = np.argsort(dists, axis=1)
        idxs = np.arange(nb_idx.shape[0]).reshape(-1, 1)
        dists = dists[idxs, nb_idx]
        nbrs = nbrs[idxs, nb_idx]
    return dists, nbrs

def fast_knns2spmat(knns, k, th_sim=0, use_sim=True, fill_value=None):
    # convert knns to symmetric sparse matrix
    from scipy.sparse import csr_matrix
    eps = 1e-5
    n = len(knns)
    if isinstance(knns, list):
        knns = np.array(knns)
    if len(knns.shape) == 2:
        # knns saved by hnsw has different shape
        n = len(knns)
        ndarr = np.ones([n, 2, k])
        ndarr[:, 0, :] = -1  # assign unknown dist to 1 and nbr to -1
        for i, (nbr, dist) in enumerate(knns):
            size = len(nbr)
            assert size == len(dist)
            ndarr[i, 0, :size] = nbr[:size]
            ndarr[i, 1, :size] = dist[:size]
        knns = ndarr
    nbrs = knns[:, 0, :]
    dists = knns[:, 1, :]
    assert -eps <= dists.min() <= dists.max(
    ) <= 1 + eps, "min: {}, max: {}".format(dists.min(), dists.max())
    if use_sim:
        sims = 1. - dists
    else:
        sims = dists
    if fill_value is not None:
        print('[fast_knns2spmat] edge fill value:', fill_value)
        sims.fill(fill_value)
    row, col = np.where(sims >= th_sim)
    # remove the self-loop
    idxs = np.where(row != nbrs[row, col])
    row = row[idxs]
    col = col[idxs]
    data = sims[row, col]
    col = nbrs[row, col]  # convert to absolute column
    assert len(row) == len(col) == len(data)
    spmat = csr_matrix((data, (row, col)), shape=(n, n))
    return spmat

def build_knns(feats,
               k,
               knn_method,
               dump=True):
    with Timer('build index'):
        if knn_method == 'faiss':
            index = knn_faiss(feats, k, omp_num_threads=None)
        elif knn_method == 'faiss_gpu':
            index = knn_faiss_gpu(feats, k)
        else:
            raise KeyError(
                'Only support faiss and faiss_gpu currently ({}).'.format(knn_method))
        knns = index.get_knns()
    return knns


class knn():
    def __init__(self, feats, k, index_path='', verbose=True):
        pass

    def filter_by_th(self, i):
        th_nbrs = []
        th_dists = []
        nbrs, dists = self.knns[i]
        for n, dist in zip(nbrs, dists):
            if 1 - dist < self.th:
                continue
            th_nbrs.append(n)
            th_dists.append(dist)
        th_nbrs = np.array(th_nbrs)
        th_dists = np.array(th_dists)
        return (th_nbrs, th_dists)

    def get_knns(self, th=None):
        if th is None or th <= 0.:
            return self.knns
        # TODO: optimize the filtering process by numpy
        # nproc = mp.cpu_count()
        nproc = 1
        with Timer('filter edges by th {} (CPU={})'.format(th, nproc),
                   self.verbose):
            self.th = th
            self.th_knns = []
            tot = len(self.knns)
            if nproc > 1:
                pool = mp.Pool(nproc)
                th_knns = list(
                    tqdm(pool.imap(self.filter_by_th, range(tot)), total=tot))
                pool.close()
            else:
                th_knns = [self.filter_by_th(i) for i in range(tot)]
            return th_knns

class knn_faiss(knn):
    def __init__(self,
                 feats,
                 k,
                 nprobe=128,
                 omp_num_threads=None,
                 rebuild_index=True,
                 verbose=True,
                 **kwargs):
        import faiss
        if omp_num_threads is not None:
            faiss.omp_set_num_threads(omp_num_threads)
        self.verbose = verbose
        with Timer('[faiss] build index', verbose):
            feats = feats.astype('float32')
            size, dim = feats.shape
            index = faiss.IndexFlatIP(dim)
            index.add(feats)
        with Timer('[faiss] query topk {}'.format(k), verbose):
            sims, nbrs = index.search(feats, k=k)
            self.knns = [(np.array(nbr, dtype=np.int32),
                          1 - np.array(sim, dtype=np.float32))
                         for nbr, sim in zip(nbrs, sims)]

class knn_faiss_gpu(knn):
    def __init__(self,
                 feats,
                 k,
                 nprobe=128,
                 num_process=4,
                 is_precise=True,
                 sort=True,
                 verbose=True,
                 **kwargs):
        with Timer('[faiss_gpu] query topk {}'.format(k), verbose):
            dists, nbrs = faiss_search_knn(feats,
                                           k=k,
                                           nprobe=nprobe,
                                           num_process=num_process,
                                           is_precise=is_precise,
                                           sort=sort,
                                           verbose=verbose)

            self.knns = [(np.array(nbr, dtype=np.int32),
                          np.array(dist, dtype=np.float32))
                         for nbr, dist in zip(nbrs, dists)]