train_searcher.py 5.41 KB
Newer Older
1
2
import argparse
import glob
3
4
import os
import sys
5
6
from multiprocessing import cpu_count

7
8
import numpy as np
import scann
9
from ldm.util import parallel_data_prefetch
10
from tqdm import tqdm
11
12
13
14
15
16


def search_bruteforce(searcher):
    return searcher.score_brute_force().build()


17
18
19
20
21
22
23
24
25
26
27
def search_partioned_ah(
    searcher, dims_per_block, aiq_threshold, reorder_k, partioning_trainsize, num_leaves, num_leaves_to_search
):
    return (
        searcher.tree(
            num_leaves=num_leaves, num_leaves_to_search=num_leaves_to_search, training_sample_size=partioning_trainsize
        )
        .score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold)
        .reorder(reorder_k)
        .build()
    )
28
29
30


def search_ah(searcher, dims_per_block, aiq_threshold, reorder_k):
31
32
33
    return (
        searcher.score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder(reorder_k).build()
    )
34
35


36
def load_datapool(dpath):
37
38
39
40
41
42
43
    def load_single_file(saved_embeddings):
        compressed = np.load(saved_embeddings)
        database = {key: compressed[key] for key in compressed.files}
        return database

    def load_multi_files(data_archive):
        database = {key: [] for key in data_archive[0].files}
44
        for d in tqdm(data_archive, desc=f"Loading datapool from {len(data_archive)} individual files."):
45
46
47
48
49
50
            for key in d.files:
                database[key].append(d[key])

        return database

    print(f'Load saved patch embedding from "{dpath}"')
51
    file_content = glob.glob(os.path.join(dpath, "*.npz"))
52
53
54
55
56

    if len(file_content) == 1:
        data_pool = load_single_file(file_content[0])
    elif len(file_content) > 1:
        data = [np.load(f) for f in file_content]
57
58
59
        prefetched_data = parallel_data_prefetch(
            load_multi_files, data, n_proc=min(len(data), cpu_count()), target_data_type="dict"
        )
60

61
62
63
        data_pool = {
            key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in prefetched_data[0].keys()
        }
64
65
66
67
68
69
70
    else:
        raise ValueError(f'No npz-files in specified path "{dpath}" is this directory existing?')

    print(f'Finished loading of retrieval database of length {data_pool["embedding"].shape[0]}.')
    return data_pool


71
72
73
74
75
76
77
78
79
80
81
def train_searcher(
    opt,
    metric="dot_product",
    partioning_trainsize=None,
    reorder_k=None,
    # todo tune
    aiq_thld=0.2,
    dims_per_block=2,
    num_leaves=None,
    num_leaves_to_search=None,
):
82
83
84
85
86
87
88
89
    data_pool = load_datapool(opt.database)
    k = opt.knn

    if not reorder_k:
        reorder_k = 2 * k

    # normalize
    # embeddings =
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
    searcher = scann.scann_ops_pybind.builder(
        data_pool["embedding"] / np.linalg.norm(data_pool["embedding"], axis=1)[:, np.newaxis], k, metric
    )
    pool_size = data_pool["embedding"].shape[0]

    print(*(["#"] * 100))
    print("Initializing scaNN searcher with the following values:")
    print(f"k: {k}")
    print(f"metric: {metric}")
    print(f"reorder_k: {reorder_k}")
    print(f"anisotropic_quantization_threshold: {aiq_thld}")
    print(f"dims_per_block: {dims_per_block}")
    print(*(["#"] * 100))
    print("Start training searcher....")
    print(f"N samples in pool is {pool_size}")
105
106
107
108

    # this reflects the recommended design choices proposed at
    # https://github.com/google-research/google-research/blob/aca5f2e44e301af172590bb8e65711f0c9ee0cfd/scann/docs/algorithms.md
    if pool_size < 2e4:
109
        print("Using brute force search.")
110
111
        searcher = search_bruteforce(searcher)
    elif 2e4 <= pool_size and pool_size < 1e5:
112
        print("Using asymmetric hashing search and reordering.")
113
114
        searcher = search_ah(searcher, dims_per_block, aiq_thld, reorder_k)
    else:
115
        print("Using using partioning, asymmetric hashing search and reordering.")
116
117

        if not partioning_trainsize:
118
            partioning_trainsize = data_pool["embedding"].shape[0] // 10
119
120
121
122
123
124
        if not num_leaves:
            num_leaves = int(np.sqrt(pool_size))

        if not num_leaves_to_search:
            num_leaves_to_search = max(num_leaves // 20, 1)

125
126
127
        print("Partitioning params:")
        print(f"num_leaves: {num_leaves}")
        print(f"num_leaves_to_search: {num_leaves_to_search}")
128
        # self.searcher = self.search_ah(searcher, dims_per_block, aiq_thld, reorder_k)
129
130
131
        searcher = search_partioned_ah(
            searcher, dims_per_block, aiq_thld, reorder_k, partioning_trainsize, num_leaves, num_leaves_to_search
        )
132

133
    print("Finish training searcher")
134
135
136
137
138
    searcher_savedir = opt.target_path
    os.makedirs(searcher_savedir, exist_ok=True)
    searcher.serialize(searcher_savedir)
    print(f'Saved trained searcher under "{searcher_savedir}"')

139
140

if __name__ == "__main__":
141
142
    sys.path.append(os.getcwd())
    parser = argparse.ArgumentParser()
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
    parser.add_argument(
        "--database",
        "-d",
        default="data/rdm/retrieval_databases/openimages",
        type=str,
        help="path to folder containing the clip feature of the database",
    )
    parser.add_argument(
        "--target_path",
        "-t",
        default="data/rdm/searchers/openimages",
        type=str,
        help="path to the target folder where the searcher shall be stored.",
    )
    parser.add_argument(
        "--knn",
        "-k",
        default=20,
        type=int,
        help="number of nearest neighbors, for which the searcher shall be optimized",
    )

    opt, _ = parser.parse_known_args()

    train_searcher(
        opt,
    )