query.py 8.62 KB
Newer Older
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
1
2
3
4
# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.

import numpy as np
import os
liangjing's avatar
v1  
liangjing committed
5
import psutil
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
6
7
8
9
import time
import torch
from tqdm import tqdm

liangjing's avatar
v1  
liangjing committed
10
from megatron import get_retro_args, print_rank_0
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
11
12
from tools.bert_embedding import BertEmbedder
from tools.bert_embedding.utils import get_missing_blocks_by_rank
liangjing's avatar
v1  
liangjing committed
13
14
from tools.retro.db.utils import \
    get_merged_train_dataset as get_db_merged_train_dataset
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
15
16
from tools.retro.external_libs import faiss, h5py
from tools.retro.index.factory import IndexFactory
liangjing's avatar
v1  
liangjing committed
17
from tools.retro.index.utils import get_index_dir
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
18
19
from tools.retro.utils import GPTToTextDataset

liangjing's avatar
v1  
liangjing committed
20
from .chunk_dataset import get_chunk_dataset_map as get_query_dataset_map
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
21
22


liangjing's avatar
v1  
liangjing committed
23
def get_index(ondisk=False):
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
    '''Read index from disk.'''

    args = get_retro_args()

    # Load index.
    index_wrapper = IndexFactory.get_index(args.retro_index_type)
    index_dir = get_index_dir()
    added_index_path = index_wrapper.get_added_index_path()
    if ondisk:
        index = faiss.read_index(added_index_path, faiss.IO_FLAG_MMAP)
    else:
        index = faiss.read_index(added_index_path)

    # Search parameters.
    faiss.ParameterSpace().set_index_parameter(index, "efSearch",
liangjing's avatar
v1  
liangjing committed
39
                                               args.retro_query_ef_search)
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
40
    faiss.ParameterSpace().set_index_parameter(index, "nprobe",
liangjing's avatar
v1  
liangjing committed
41
                                               args.retro_query_nprobe)
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
42
43
44
45
46
47
48
49
50
51
52
53
54

    return index


def embed_block(gpt_dataset, block, embedder):
    '''Embed block of chunks.'''
    text_block_dataset = torch.utils.data.Subset(
        GPTToTextDataset(gpt_dataset),
        range(*block["range"]),
    )
    return embedder.embed_text_dataset(text_block_dataset)


liangjing's avatar
v1  
liangjing committed
55
56
57
def query_embeddings(db_dataset, index,
                     embeddings, chunk_id_range,
                     sample_map, n_chunks_per_sample,
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
58
59
60
61
62
63
64
65
66
67
                     verbose=True):
    '''Query neighbors of a block of embeddings.'''

    args = get_retro_args()

    # Query neighbor ids.
    if verbose: print_rank_0("search.")
    t = time.time()
    assert index.ntotal > 0, "check we don't accidentally have an empty index."
    _, query_neighbor_ids = \
liangjing's avatar
v1  
liangjing committed
68
        index.search(embeddings, args.retro_query_num_neighbors_query)
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
69
70
71
72
73
    if verbose: print_rank_0("  time : %.3f sec." % (time.time() - t))

    # Filter banned neighbor ids.
    if verbose: print_rank_0("filter banned neighbor ids.")
    filtered_neighbor_ids = np.full(
liangjing's avatar
v1  
liangjing committed
74
        shape=(len(query_neighbor_ids), args.retro_query_num_neighbors_save),
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
75
76
77
78
79
80
81
        fill_value=-1,
        dtype="int64",
    )
    min_chunk_id, max_chunk_id = chunk_id_range
    for chunk_id in range(min_chunk_id, max_chunk_id):

        sample_id = chunk_id // n_chunks_per_sample
liangjing's avatar
v1  
liangjing committed
82
83
84
85
86
        sample = sample_map[sample_id]
        sample_dataset_idx = sample["dataset_idx"].item()
        sample_doc_ids = sample["doc_ids"].tolist()
        sample_doc_tuples = [(sample_dataset_idx, d) for d in sample_doc_ids]
        
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
87
88
89
90
91
        # Get valid neighbors (!= -1).
        query_row = [ i for i in query_neighbor_ids[chunk_id-min_chunk_id]
                      if i >= 0 ]

        # Filter row.
liangjing's avatar
v1  
liangjing committed
92
93
94
95
        filtered_row = [ i for i in query_row
                         if tuple(db_dataset.doc_tuples[i].tolist())
                         not in sample_doc_tuples ]
        filtered_row = filtered_row[:args.retro_query_num_neighbors_save]
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
96
        filtered_row += \
liangjing's avatar
v1  
liangjing committed
97
            [-1] * (args.retro_query_num_neighbors_save - len(filtered_row))
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
98
99
100
101
102
        filtered_neighbor_ids[chunk_id-min_chunk_id] = filtered_row

    return query_neighbor_ids, filtered_neighbor_ids


liangjing's avatar
v1  
liangjing committed
103
104
105
def query_embedding_block(db_dataset, index,
                          embeddings, chunk_id_range,
                          sample_map, n_chunks_per_sample):
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123

    query_neighbor_ids = []
    filtered_neighbor_ids = []

    # Query in sub-blocks.
    partial_block_size = 1000
    for partial_start_idx in tqdm(
            range(0, len(embeddings), partial_block_size),
            "search",
    ):
        partial_end_idx = min(len(embeddings),
                              partial_start_idx + partial_block_size)
        partial_embeddings = embeddings[partial_start_idx:partial_end_idx]
        partial_chunk_id_range = (
            chunk_id_range[0] + partial_start_idx,
            chunk_id_range[0] + partial_end_idx,
        )
        partial_query_neighbor_ids, partial_filtered_neighbor_ids = \
liangjing's avatar
v1  
liangjing committed
124
125
126
            query_embeddings(db_dataset, index,
                             partial_embeddings, partial_chunk_id_range,
                             sample_map, n_chunks_per_sample,
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
127
128
129
130
131
132
133
134
135
136
137
                             verbose=False)
        query_neighbor_ids.append(partial_query_neighbor_ids)
        filtered_neighbor_ids.append(partial_filtered_neighbor_ids)

    # Concatenate.
    query_neighbor_ids = np.concatenate(query_neighbor_ids, axis=0)
    filtered_neighbor_ids = np.concatenate(filtered_neighbor_ids, axis=0)

    return query_neighbor_ids, filtered_neighbor_ids


liangjing's avatar
v1  
liangjing committed
138
139
140
def query_block_neighbors(db_dataset, query_dataset,
                          index, embedder,
                          block):
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
141
142
143
    '''Query neighbors of a dataset block (i.e., range).'''

    args = get_retro_args()
liangjing's avatar
v1  
liangjing committed
144
    n_chunks_per_sample = query_dataset.n_chunks_per_sample
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
145
146
147
148

    # Sample map.
    sample_ids = sorted(list(set(chunk_id // n_chunks_per_sample
                                 for chunk_id in range(*block["range"]))))
liangjing's avatar
v1  
liangjing committed
149
150
151
152
153
154
155
    sample_map = {}
    for i in sample_ids:
        sample = query_dataset.sample_dataset[i]
        sample_map[i] = {
            "dataset_idx" : sample["dataset_idx"],
            "doc_ids" : sample["doc_ids"],
        }
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
156
157

    # Embed block.
liangjing's avatar
v1  
liangjing committed
158
    embeddings = embed_block(query_dataset, block, embedder)
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
159
160
161

    # Query embeddings.
    _, filtered_neighbor_ids = query_embedding_block(
liangjing's avatar
v1  
liangjing committed
162
163
164
        db_dataset, index,
        embeddings, block["range"],
        sample_map, n_chunks_per_sample)
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
165
166
167
168
169
170
171
172
173

    # Save neighbors.
    print_rank_0("save neighbors.")
    os.makedirs(os.path.dirname(block["path"]), exist_ok=True)
    f = h5py.File(block["path"], "w")
    f.create_dataset("neighbors", data=filtered_neighbor_ids)
    f.close()


liangjing's avatar
v1  
liangjing committed
174
175
176
def query_dataset_neighbors(db_dataset, query_dataset,
                            prefix, neighbor_dir,
                            index, embedder):
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
177
178
179
180
181
    '''Query neighbors of each chunk within a dataset.'''

    args = get_retro_args()

    def validate(f):
liangjing's avatar
v1  
liangjing committed
182
        assert f["neighbors"].shape[1] == args.retro_query_num_neighbors_save, \
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
183
184
185
186
187
188
            "neighbors.shape == %s; num_neighbors_target == %d." % (
                str(f["neighbors"].shape),
                args.retro_num_neighbors_target,
            )
    n_missing_blocks, missing_neighbor_blocks = get_missing_blocks_by_rank(
        neighbor_dir,
liangjing's avatar
v1  
liangjing committed
189
        len(query_dataset),
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
190
191
192
193
194
195
196
197
198
199
        args.retro_block_size,
        validate=validate,
    )

    # Query each block.
    for block_index, block in enumerate(missing_neighbor_blocks):

        if block is not None:

            # Progress.
liangjing's avatar
v1  
liangjing committed
200
            print_rank_0("query '%s' block %d / %d ... %s ... mem %.3f gb, %.1f%%." % (
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
201
202
203
                prefix,
                block_index,
                len(missing_neighbor_blocks),
liangjing's avatar
v1  
liangjing committed
204
205
206
                os.path.basename(block["path"]),
                psutil.virtual_memory()[3] / 1024**3,
                psutil.virtual_memory()[2],
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
207
208
209
            ))

            # Query block neighbors.
liangjing's avatar
v1  
liangjing committed
210
211
212
            query_block_neighbors(db_dataset, query_dataset,
                                  index, embedder,
                                  block)
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228

        # Synchronize progress across all ranks. (for easier observation)
        print_rank_0(" > waiting for other ranks to finish block.")
        torch.distributed.barrier()


def query_pretraining_neighbors():
    '''Query pretraining datasets (train & valid).'''

    args = get_retro_args()

    # Num threads.
    faiss.omp_set_num_threads(64)

    # Load chunk db dataset.
    print_rank_0("load chunk db dataset.")
liangjing's avatar
v1  
liangjing committed
229
230
    db_dataset = get_db_merged_train_dataset()
    db_dataset.load_doc_tuples()
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
231

liangjing's avatar
v1  
liangjing committed
232
    # Load index.
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
233
    print_rank_0(" > get index.")
liangjing's avatar
v1  
liangjing committed
234
    index = get_index()
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
235

liangjing's avatar
v1  
liangjing committed
236
    # Load datasets.
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
237
    print_rank_0(" > get dataset map.")
liangjing's avatar
v1  
liangjing committed
238
    query_dataset_map = get_query_dataset_map()
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
239
240
241
242
243
244
245
246

    # Bert embedder.
    embedder = BertEmbedder(args.retro_bert_batch_size,
                            args.retro_bert_max_chunk_length,
                            args.bert_embedder_type)

    # Query each (i.e., train, valid, test) dataset.
    print_rank_0(" > query.")
liangjing's avatar
v1  
liangjing committed
247
    for prefix, info in query_dataset_map.items():
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
248
249
        print_rank_0(" > query '%s' dataset ... %d samples." %
                     (prefix, len(info["data"])))
liangjing's avatar
v1  
liangjing committed
250
251
252
        query_dataset_neighbors(db_dataset, info["data"],
                                prefix, info["neighbor_dir"],
                                index, embedder)