faiss_search.py 2.79 KB
Newer Older
1
2
3
4
"""
This file re-uses implementation from https://github.com/yl-1993/learn-to-cluster
"""
import gc
5

6
7
8
9
from tqdm import tqdm

from .faiss_gpu import faiss_search_approx_knn

10
11
__all__ = ["faiss_search_knn"]

12
13
14

def precise_dist(feat, nbrs, num_process=4, sort=True, verbose=False):
    import torch
15

16
17
18
19
    feat_share = torch.from_numpy(feat).share_memory_()
    nbrs_share = torch.from_numpy(nbrs).share_memory_()
    dist_share = torch.zeros_like(nbrs_share).float().share_memory_()

20
21
22
23
24
25
26
27
    precise_dist_share_mem(
        feat_share,
        nbrs_share,
        dist_share,
        num_process=num_process,
        sort=sort,
        verbose=verbose,
    )
28
29
30
31
32

    del feat_share
    gc.collect()
    return dist_share.numpy(), nbrs_share.numpy()

33
34
35
36
37
38
39
40
41
42

def precise_dist_share_mem(
    feat,
    nbrs,
    dist,
    num_process=16,
    sort=True,
    process_unit=4000,
    verbose=False,
):
43
    from torch import multiprocessing as mp
44

45
46
    num, _ = feat.shape
    num_per_proc = int(num / num_process) + 1
47

48
49
50
    for pi in range(num_process):
        sid = pi * num_per_proc
        eid = min(sid + num_per_proc, num)
51
52
53
54
55
56
57
58
59
60
61

        kwargs = {
            "feat": feat,
            "nbrs": nbrs,
            "dist": dist,
            "sid": sid,
            "eid": eid,
            "sort": sort,
            "process_unit": process_unit,
            "verbose": verbose,
        }
62
63
        bmm(**kwargs)

64
65
66
67

def bmm(
    feat, nbrs, dist, sid, eid, sort=True, process_unit=4000, verbose=False
):
68
    import torch
69

70
71
    _, cols = dist.shape
    batch_sim = torch.zeros((eid - sid, cols), dtype=torch.float32)
72
73
74
    for s in tqdm(
        range(sid, eid, process_unit), desc="bmm", disable=not verbose
    ):
75
76
77
        e = min(eid, s + process_unit)
        query = feat[s:e].unsqueeze(1)
        gallery = feat[nbrs[s:e]].permute(0, 2, 1)
78
79
80
        batch_sim[s - sid : e - sid] = torch.clamp(
            torch.bmm(query, gallery).view(-1, cols), 0.0, 1.0
        )
81
82
83
84
85
86

    if sort:
        sort_unit = int(1e6)
        batch_nbr = nbrs[sid:eid]
        for s in range(0, batch_sim.shape[0], sort_unit):
            e = min(s + sort_unit, eid)
87
88
89
            batch_sim[s:e], indices = torch.sort(
                batch_sim[s:e], descending=True
            )
90
91
            batch_nbr[s:e] = torch.gather(batch_nbr[s:e], 1, indices)
        nbrs[sid:eid] = batch_nbr
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
    dist[sid:eid] = 1.0 - batch_sim


def faiss_search_knn(
    feat,
    k,
    nprobe=128,
    num_process=4,
    is_precise=True,
    sort=True,
    verbose=False,
):
    dists, nbrs = faiss_search_approx_knn(
        query=feat, target=feat, k=k, nprobe=nprobe, verbose=verbose
    )
107
108

    if is_precise:
109
110
111
112
        print("compute precise dist among k={} nearest neighbors".format(k))
        dists, nbrs = precise_dist(
            feat, nbrs, num_process=num_process, sort=sort, verbose=verbose
        )
113
114

    return dists, nbrs