#!/usr/bin/env python # -*- coding: utf-8 -*- """ This file re-uses implementation from https://github.com/yl-1993/learn-to-cluster """ import os import math import numpy as np import multiprocessing as mp from tqdm import tqdm from utils import Timer from .faiss_search import faiss_search_knn __all__ = [ 'knn_faiss', 'knn_faiss_gpu', 'fast_knns2spmat', 'build_knns', 'knns2ordered_nbrs' ] def knns2ordered_nbrs(knns, sort=True): if isinstance(knns, list): knns = np.array(knns) nbrs = knns[:, 0, :].astype(np.int32) dists = knns[:, 1, :] if sort: # sort dists from low to high nb_idx = np.argsort(dists, axis=1) idxs = np.arange(nb_idx.shape[0]).reshape(-1, 1) dists = dists[idxs, nb_idx] nbrs = nbrs[idxs, nb_idx] return dists, nbrs def fast_knns2spmat(knns, k, th_sim=0, use_sim=True, fill_value=None): # convert knns to symmetric sparse matrix from scipy.sparse import csr_matrix eps = 1e-5 n = len(knns) if isinstance(knns, list): knns = np.array(knns) if len(knns.shape) == 2: # knns saved by hnsw has different shape n = len(knns) ndarr = np.ones([n, 2, k]) ndarr[:, 0, :] = -1 # assign unknown dist to 1 and nbr to -1 for i, (nbr, dist) in enumerate(knns): size = len(nbr) assert size == len(dist) ndarr[i, 0, :size] = nbr[:size] ndarr[i, 1, :size] = dist[:size] knns = ndarr nbrs = knns[:, 0, :] dists = knns[:, 1, :] assert -eps <= dists.min() <= dists.max( ) <= 1 + eps, "min: {}, max: {}".format(dists.min(), dists.max()) if use_sim: sims = 1. - dists else: sims = dists if fill_value is not None: print('[fast_knns2spmat] edge fill value:', fill_value) sims.fill(fill_value) row, col = np.where(sims >= th_sim) # remove the self-loop idxs = np.where(row != nbrs[row, col]) row = row[idxs] col = col[idxs] data = sims[row, col] col = nbrs[row, col] # convert to absolute column assert len(row) == len(col) == len(data) spmat = csr_matrix((data, (row, col)), shape=(n, n)) return spmat def build_knns(feats, k, knn_method, dump=True): with Timer('build index'): if knn_method == 'faiss': index = knn_faiss(feats, k, omp_num_threads=None) elif knn_method == 'faiss_gpu': index = knn_faiss_gpu(feats, k) else: raise KeyError( 'Only support faiss and faiss_gpu currently ({}).'.format(knn_method)) knns = index.get_knns() return knns class knn(): def __init__(self, feats, k, index_path='', verbose=True): pass def filter_by_th(self, i): th_nbrs = [] th_dists = [] nbrs, dists = self.knns[i] for n, dist in zip(nbrs, dists): if 1 - dist < self.th: continue th_nbrs.append(n) th_dists.append(dist) th_nbrs = np.array(th_nbrs) th_dists = np.array(th_dists) return (th_nbrs, th_dists) def get_knns(self, th=None): if th is None or th <= 0.: return self.knns # TODO: optimize the filtering process by numpy # nproc = mp.cpu_count() nproc = 1 with Timer('filter edges by th {} (CPU={})'.format(th, nproc), self.verbose): self.th = th self.th_knns = [] tot = len(self.knns) if nproc > 1: pool = mp.Pool(nproc) th_knns = list( tqdm(pool.imap(self.filter_by_th, range(tot)), total=tot)) pool.close() else: th_knns = [self.filter_by_th(i) for i in range(tot)] return th_knns class knn_faiss(knn): def __init__(self, feats, k, nprobe=128, omp_num_threads=None, rebuild_index=True, verbose=True, **kwargs): import faiss if omp_num_threads is not None: faiss.omp_set_num_threads(omp_num_threads) self.verbose = verbose with Timer('[faiss] build index', verbose): feats = feats.astype('float32') size, dim = feats.shape index = faiss.IndexFlatIP(dim) index.add(feats) with Timer('[faiss] query topk {}'.format(k), verbose): sims, nbrs = index.search(feats, k=k) self.knns = [(np.array(nbr, dtype=np.int32), 1 - np.array(sim, dtype=np.float32)) for nbr, sim in zip(nbrs, sims)] class knn_faiss_gpu(knn): def __init__(self, feats, k, nprobe=128, num_process=4, is_precise=True, sort=True, verbose=True, **kwargs): with Timer('[faiss_gpu] query topk {}'.format(k), verbose): dists, nbrs = faiss_search_knn(feats, k=k, nprobe=nprobe, num_process=num_process, is_precise=is_precise, sort=sort, verbose=verbose) self.knns = [(np.array(nbr, dtype=np.int32), np.array(dist, dtype=np.float32)) for nbr, dist in zip(nbrs, dists)]