""" This file re-uses implementation from https://github.com/yl-1993/learn-to-cluster """ import gc from tqdm import tqdm from .faiss_gpu import faiss_search_approx_knn __all__ = ['faiss_search_knn'] def precise_dist(feat, nbrs, num_process=4, sort=True, verbose=False): import torch feat_share = torch.from_numpy(feat).share_memory_() nbrs_share = torch.from_numpy(nbrs).share_memory_() dist_share = torch.zeros_like(nbrs_share).float().share_memory_() precise_dist_share_mem(feat_share, nbrs_share, dist_share, num_process=num_process, sort=sort, verbose=verbose) del feat_share gc.collect() return dist_share.numpy(), nbrs_share.numpy() def precise_dist_share_mem(feat, nbrs, dist, num_process=16, sort=True, process_unit=4000, verbose=False): from torch import multiprocessing as mp num, _ = feat.shape num_per_proc = int(num / num_process) + 1 for pi in range(num_process): sid = pi * num_per_proc eid = min(sid + num_per_proc, num) kwargs={'feat': feat, 'nbrs': nbrs, 'dist': dist, 'sid': sid, 'eid': eid, 'sort': sort, 'process_unit': process_unit, 'verbose': verbose, } bmm(**kwargs) def bmm(feat, nbrs, dist, sid, eid, sort=True, process_unit=4000, verbose=False): import torch _, cols = dist.shape batch_sim = torch.zeros((eid - sid, cols), dtype=torch.float32) for s in tqdm(range(sid, eid, process_unit), desc='bmm', disable=not verbose): e = min(eid, s + process_unit) query = feat[s:e].unsqueeze(1) gallery = feat[nbrs[s:e]].permute(0, 2, 1) batch_sim[s - sid:e - sid] = torch.clamp(torch.bmm(query, gallery).view(-1, cols), 0.0, 1.0) if sort: sort_unit = int(1e6) batch_nbr = nbrs[sid:eid] for s in range(0, batch_sim.shape[0], sort_unit): e = min(s + sort_unit, eid) batch_sim[s:e], indices = torch.sort(batch_sim[s:e], descending=True) batch_nbr[s:e] = torch.gather(batch_nbr[s:e], 1, indices) nbrs[sid:eid] = batch_nbr dist[sid:eid] = 1. - batch_sim def faiss_search_knn(feat, k, nprobe=128, num_process=4, is_precise=True, sort=True, verbose=False): dists, nbrs = faiss_search_approx_knn(query=feat, target=feat, k=k, nprobe=nprobe, verbose=verbose) if is_precise: print('compute precise dist among k={} nearest neighbors'.format(k)) dists, nbrs = precise_dist(feat, nbrs, num_process=num_process, sort=sort, verbose=verbose) return dists, nbrs