faiss_search.py 3.44 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
"""
This file re-uses implementation from https://github.com/yl-1993/learn-to-cluster
"""
import gc
from tqdm import tqdm

from .faiss_gpu import faiss_search_approx_knn

__all__ = ['faiss_search_knn']

def precise_dist(feat, nbrs, num_process=4, sort=True, verbose=False):
    import torch
    feat_share = torch.from_numpy(feat).share_memory_()
    nbrs_share = torch.from_numpy(nbrs).share_memory_()
    dist_share = torch.zeros_like(nbrs_share).float().share_memory_()

    precise_dist_share_mem(feat_share,
                           nbrs_share,
                           dist_share,
                           num_process=num_process,
                           sort=sort,
                           verbose=verbose)

    del feat_share
    gc.collect()
    return dist_share.numpy(), nbrs_share.numpy()

def precise_dist_share_mem(feat,
                           nbrs,
                           dist,
                           num_process=16,
                           sort=True,
                           process_unit=4000,
                           verbose=False):
    from torch import multiprocessing as mp
    num, _ = feat.shape
    num_per_proc = int(num / num_process) + 1
    
    for pi in range(num_process):
        sid = pi * num_per_proc
        eid = min(sid + num_per_proc, num)
        
        kwargs={'feat': feat,
                'nbrs': nbrs,
                'dist': dist,
                'sid': sid,
                'eid': eid,
                'sort': sort,
                'process_unit': process_unit,
                'verbose': verbose,
                }
        bmm(**kwargs)

def bmm(feat,
        nbrs,
        dist,
        sid,
        eid,
        sort=True,
        process_unit=4000,
        verbose=False):
    import torch
    _, cols = dist.shape
    batch_sim = torch.zeros((eid - sid, cols), dtype=torch.float32)
    for s in tqdm(range(sid, eid, process_unit),
                  desc='bmm',
                  disable=not verbose):
        e = min(eid, s + process_unit)
        query = feat[s:e].unsqueeze(1)
        gallery = feat[nbrs[s:e]].permute(0, 2, 1)
        batch_sim[s - sid:e - sid] = torch.clamp(torch.bmm(query, gallery).view(-1, cols), 0.0, 1.0)

    if sort:
        sort_unit = int(1e6)
        batch_nbr = nbrs[sid:eid]
        for s in range(0, batch_sim.shape[0], sort_unit):
            e = min(s + sort_unit, eid)
            batch_sim[s:e], indices = torch.sort(batch_sim[s:e],
                                                 descending=True)
            batch_nbr[s:e] = torch.gather(batch_nbr[s:e], 1, indices)
        nbrs[sid:eid] = batch_nbr
    dist[sid:eid] = 1. - batch_sim

def faiss_search_knn(feat,
                     k,
                     nprobe=128,
                     num_process=4,
                     is_precise=True,
                     sort=True,
                     verbose=False):

    dists, nbrs = faiss_search_approx_knn(query=feat,
                                          target=feat,
                                          k=k,
                                          nprobe=nprobe,
                                          verbose=verbose)

    if is_precise:
        print('compute precise dist among k={} nearest neighbors'.format(k))
        dists, nbrs = precise_dist(feat,
                                   nbrs,
                                   num_process=num_process,
                                   sort=sort,
                                   verbose=verbose)

    return dists, nbrs