realm_index.py 13.4 KB
Newer Older
Neel Kant's avatar
Neel Kant committed
1
2
3
4
5
from collections import defaultdict
import os
import pickle
import shutil

Neel Kant's avatar
Neel Kant committed
6
import faiss
Neel Kant's avatar
Neel Kant committed
7
8
9
import numpy as np
import torch

Neel Kant's avatar
Neel Kant committed
10
from megatron import get_args, mpu
Neel Kant's avatar
Neel Kant committed
11
12


13
14
15
16
def detach(tensor):
    return tensor.detach().cpu().numpy()


Neel Kant's avatar
Neel Kant committed
17
class BlockData(object):
Mohammad's avatar
Mohammad committed
18
    def __init__(self, block_data_path=None):
Neel Kant's avatar
Neel Kant committed
19
20
        self.embed_data = dict()
        self.meta_data = dict()
Mohammad's avatar
Mohammad committed
21
22
23
24
25
26
27
        if block_data_path is None:
            args = get_args()
            block_data_path = args.block_data_path
        self.block_data_path = block_data_path

        block_data_name = os.path.splitext(self.block_data_path)[0]
        self.temp_dir_name = block_data_name + '_tmp'
Neel Kant's avatar
Neel Kant committed
28
29
30
31
32
33
34
35
36
37

    def state(self):
        return {
            'embed_data': self.embed_data,
            'meta_data': self.meta_data
        }

    def clear(self):
        """Clear the data structures to save memory"""
        self.embed_data = dict()
Neel Kant's avatar
Neel Kant committed
38
        # self.meta_data = dict()
Neel Kant's avatar
Neel Kant committed
39
40
41

    @classmethod
    def load_from_file(cls, fname):
42
        print("\n> Unpickling block data", flush=True)
Neel Kant's avatar
Neel Kant committed
43
        state_dict = pickle.load(open(fname, 'rb'))
44
        print(">> Finished unpickling block data\n", flush=True)
Neel Kant's avatar
Neel Kant committed
45
46
47
48
49
50
51
52
53
54
55

        new_index = cls()
        new_index.embed_data = state_dict['embed_data']
        new_index.meta_data = state_dict['meta_data']
        return new_index

    def add_block_data(self, block_indices, block_embeds, block_metas, allow_overwrite=False):
        for idx, embed, meta in zip(block_indices, block_embeds, block_metas):
            if not allow_overwrite and idx in self.embed_data:
                raise ValueError("Unexpectedly tried to overwrite block data")

56
            self.embed_data[idx] = np.float16(embed)
Neel Kant's avatar
Neel Kant committed
57
58
59
60
            self.meta_data[idx] = meta

    def save_shard(self, rank):
        if not os.path.isdir(self.temp_dir_name):
Mohammad's avatar
Mohammad committed
61
            os.makedirs(self.temp_dir_name, exist_ok=True)
Neel Kant's avatar
Neel Kant committed
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77

        # save the data for each shard
        with open('{}/{}.pkl'.format(self.temp_dir_name, rank), 'wb') as data_file:
            pickle.dump(self.state(), data_file)

    def consolidate_shards_and_save(self, ignore_shard=0):
        """Combine all the shards made using self.save_shard()"""
        fnames = os.listdir(self.temp_dir_name)
        for fname in fnames:
            with open('{}/{}'.format(self.temp_dir_name, fname), 'rb') as f:
                data = pickle.load(f)

                old_size = len(self.embed_data)
                shard_size = len(data['embed_data'])
                self.embed_data.update(data['embed_data'])
                self.meta_data.update(data['meta_data'])
78
                # assert (len(self.embed_data) == old_size + shard_size) or (str(ignore_shard) in fname)
Neel Kant's avatar
Neel Kant committed
79

Mohammad's avatar
Mohammad committed
80
        with open(self.block_data_path, 'wb') as final_file:
Neel Kant's avatar
Neel Kant committed
81
82
83
84
85
            pickle.dump(self.state(), final_file)
        shutil.rmtree(self.temp_dir_name, ignore_errors=True)


class FaissMIPSIndex(object):
Neel Kant's avatar
Neel Kant committed
86
    def __init__(self, index_type, embed_size, use_gpu=False):
Neel Kant's avatar
Neel Kant committed
87
88
        self.index_type = index_type
        self.embed_size = embed_size
Neel Kant's avatar
Neel Kant committed
89
        self.use_gpu = use_gpu
90
        self.id_map = dict()
Neel Kant's avatar
Neel Kant committed
91
92
93
94
95

        # alsh
        self.m = 5
        self.u = 0.99
        self.max_norm = None
96
97
        self.block_mips_index = None
        self._set_block_index()
Neel Kant's avatar
Neel Kant committed
98

99
    def _set_block_index(self):
100
        INDEX_TYPES = ['flat_ip']
Neel Kant's avatar
Neel Kant committed
101
102
103
        if self.index_type not in INDEX_TYPES:
            raise ValueError("Invalid index type specified")

104
105
106
107
        print("\n> Building index", flush=True)
        self.block_mips_index = faiss.index_factory(self.embed_size, 'Flat', faiss.METRIC_INNER_PRODUCT)
        if not self.use_gpu:
            self.block_mips_index = faiss.IndexIDMap(self.block_mips_index)
Neel Kant's avatar
Neel Kant committed
108
            print(">> Finished building index\n", flush=True)
109

Neel Kant's avatar
Neel Kant committed
110
111
        if self.use_gpu:
            res = faiss.StandardGpuResources()
112
113
114
115
116
            # self.block_mips_index = faiss.index_cpu_to_gpu(res, device, self.block_mips_index)
            config = faiss.GpuIndexFlatConfig()
            config.device = torch.cuda.current_device()
            config.useFloat16 = True
            self.block_mips_index = faiss.GpuIndexFlat(res, self.block_mips_index, config)
Neel Kant's avatar
Neel Kant committed
117
            print(">>> Finished building index on GPU {}\n".format(self.block_mips_index.getDevice()), flush=True)
118
119

    def reset_index(self):
Neel Kant's avatar
Neel Kant committed
120
        del self.block_mips_index
121
        self._set_block_index()
Neel Kant's avatar
Neel Kant committed
122
123
124
125

    def add_block_embed_data(self, all_block_data, clear_block_data=False):
        """Add the embedding of each block to the underlying FAISS index"""
        block_indices, block_embeds = zip(*all_block_data.embed_data.items())
126
127
128
        if self.use_gpu:
            for i, idx in enumerate(block_indices):
                self.id_map[i] = idx
Neel Kant's avatar
Neel Kant committed
129
        if True:
Neel Kant's avatar
Neel Kant committed
130
131
            all_block_data.clear()

132
133
134
135
        if self.use_gpu:
            self.block_mips_index.add(np.float32(np.array(block_embeds)))
        else:
            self.block_mips_index.add_with_ids(np.float32(np.array(block_embeds)), np.array(block_indices))
Neel Kant's avatar
Neel Kant committed
136
137
138
139
140
141
142

    def search_mips_index(self, query_embeds, top_k, reconstruct=True):
        """Get the top-k blocks by the index distance metric.

        :param reconstruct: if True: return a [num_queries x k x embed_dim] array of blocks
                            if False: return [num_queries x k] array of distances, and another for indices
        """
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
        query_embeds = np.float32(detach(query_embeds))
        # query_embeds = query_embeds.float()

        with torch.no_grad():
            if reconstruct:
                top_k_block_embeds = self.block_mips_index.search_and_reconstruct(query_embeds, top_k)
                return top_k_block_embeds
            else:
                distances, block_indices = self.block_mips_index.search(query_embeds, top_k)
                if self.use_gpu:
                    fresh_indices = np.zeros(block_indices.shape)
                    for i in range(block_indices.shape[0]):
                        for j in range(block_indices.shape[1]):
                            fresh_indices[i, j] = self.id_map[block_indices[i, j]]
                    block_indices = fresh_indices
158
159
160
161
162
163
                    # args = get_args()
                    # if args.rank == 0:
                    #     torch.save({'query_embeds': query_embeds,
                    #                 'id_map': self.id_map,
                    #                 'block_indices': block_indices,
                    #                 'distances': distances}, 'search.data')
164
                return distances, block_indices
Neel Kant's avatar
Neel Kant committed
165

166
167
    # functions below are for ALSH, which currently isn't being used

Neel Kant's avatar
Neel Kant committed
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
    def get_norm_powers_and_halves_array(self, embeds):
        norm = np.linalg.norm(embeds, axis=1)
        norm_powers = [np.multiply(norm, norm)]  # squared L2 norms of all
        for i in range(self.m - 1):
            norm_powers.append(np.multiply(norm_powers[-1], norm_powers[-1]))
        # [num_blocks x self.m]
        norm_powers = np.transpose(np.array(norm_powers))
        halves_array = 0.5 * np.ones(norm_powers.shape)

        return norm_powers, halves_array

    def alsh_block_preprocess_fn(self, block_embeds):
        block_embeds = np.array(block_embeds)
        if self.max_norm is None:
            self.max_norm = max(np.linalg.norm(block_embeds, axis=1))
        if self.max_norm > 1:
            block_embeds = self.u / self.max_norm * block_embeds
        norm_powers, halves_array = self.get_norm_powers_and_halves_array(block_embeds)

        # P'(S(x)) for all x in block_embeds
        return np.float32(np.concatenate((block_embeds, norm_powers, halves_array), axis=1))

    def alsh_query_preprocess_fn(self, query_embeds):
        max_norm = max(np.linalg.norm(query_embeds, axis=1))
        if max_norm > 1:
            query_embeds = self.u / max_norm * query_embeds
        norm_powers, halves_array = self.get_norm_powers_and_halves_array(query_embeds)

        # Q'(S(x)) for all x in query_embeds
        return np.float32(np.concatenate((query_embeds, halves_array, norm_powers), axis=1))


Neel Kant's avatar
Neel Kant committed
200
201
# This was the original hashing scheme, not used anymore

Neel Kant's avatar
Neel Kant committed
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
class RandProjectionLSHIndex(object):
    """Class for holding hashed data"""
    def __init__(self, embed_size, num_buckets, whiten=True, seed=0):
        np.random.seed(seed)
        self.hash_data = defaultdict(list)
        hash_matrix = 2 * np.random.rand(embed_size, int(num_buckets / 2)) - 1
        self.hash_matrix = hash_matrix / np.linalg.norm(hash_matrix, axis=0).reshape(1, -1)
        self.embed_mean = None
        self.embed_whitener = None
        self.whiten = whiten

    def state(self):
        state = {
            'hash_data': self.hash_data,
            'hash_matrix': self.hash_matrix,
            'embed_mean': self.embed_mean,
            'embed_whitener': self.embed_whitener,
        }
        return state

    def save_to_file(self):
        args = get_args()
        with open(args.block_index_path, 'wb') as index_file:
            pickle.dump(self.state(), index_file)

    @classmethod
    def load_from_file(cls, fname):
        print(" > Unpickling block hash data")
        state_dict = pickle.load(open(fname, 'rb'))
        print(" > Finished unpickling")
        hash_matrix = state_dict['hash_matrix']

        new_index = cls(hash_matrix.shape[0], hash_matrix.shape[1] * 2)
        new_index.hash_data = state_dict['hash_data']
        new_index.embed_mean = state_dict.get('embed_mean')
        new_index.embed_whitener = state_dict.get('embed_whitener')
        new_index.hash_matrix = hash_matrix

        return new_index

    def get_block_bucket(self, hash):
        return self.hash_data[hash]

    def hash_embeds(self, embeds, write_block_data=None):
        """Hash a tensor of embeddings using a random projection matrix"""
Neel Kant's avatar
Neel Kant committed
247
        embed_scores_pos = torch.matmul(embeds, torch.cuda.FloatTensor(self.hash_matrix).type(embeds.dtype))
Neel Kant's avatar
Neel Kant committed
248
249
250
251
252
253
254
255
256
257
258
259
        embed_scores = torch.cat((embed_scores_pos, -embed_scores_pos), axis=1)
        embed_hashes = detach(torch.argmax(embed_scores, axis=1))

        if write_block_data is not None:
            for hash, indices in zip(embed_hashes, write_block_data):
                self.hash_data[hash].append(indices)

        return embed_hashes

    def hash_whitened_block_embeds(self, block_data):
        """Transform all block embeds to have zero mean and unit covariance
        when treated as samples from a distribution"""
260
        block_idx, all_embeds = zip(*block_data.embed_data.items())
Neel Kant's avatar
Neel Kant committed
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
        arr_embeds = np.transpose(np.array(all_embeds))

        mean = np.mean(arr_embeds, axis=1).reshape(-1, 1)
        centered = arr_embeds - mean
        inv_cov = np.linalg.inv(np.cov(arr_embeds))
        whitener = np.transpose(np.linalg.cholesky(inv_cov))
        whitened = np.float16(np.transpose(whitener.dot(centered)))

        self.embed_mean = mean.reshape(-1)
        self.embed_whitener = whitener
        self.hash_data = defaultdict(list)
        batch_size = 16384
        i = 0

        args = get_args()
        with torch.no_grad():
            while True:
                if args.debug:
                    print(i, flush=True)
                batch_slice = slice(i * batch_size, (i + 1) * batch_size)
                batch_embed = torch.cuda.HalfTensor(whitened[batch_slice])
                batch_meta = [block_data.meta_data[idx] for idx in block_idx[batch_slice]]
                if len(batch_meta) == 0:
                    break

                self.hash_embeds(batch_embed, batch_meta)
                i += 1

    def exact_mips_equals(self, query_embeds, all_block_data, norm_blocks):
        """For each query, determine whether the mips block is in the correct hash bucket"""
        shuffled_block_idx, block_embeds = zip(*all_block_data.items())
        if norm_blocks:
            block_embeds = block_embeds / np.linalg.norm(block_embeds, axis=1).reshape(-1, 1)
        with torch.no_grad():
            query_hashes = self.hash_embeds(query_embeds)

            # [num_query x num_blocks]
            inner_products = torch.matmul(torch.cuda.HalfTensor(query_embeds),
                                          torch.cuda.HalfTensor(np.transpose(np.array(block_embeds))))
            max_inner_product_idxes = detach(torch.argmax(inner_products, axis=1))
            best_blocks = np.array([all_block_data[shuffled_block_idx[idx]] for idx in max_inner_product_idxes])
            best_block_hashes = self.hash_embeds(best_blocks)

            print('Query hashes: ', query_hashes)
            print('Block hashes: ', best_block_hashes)
            equal_arr = np.equal(query_hashes, best_block_hashes).astype(int)

            # array of zeros and ones which can be used for counting success
            return equal_arr

    def exact_mips_test(self, num_queries, all_block_data, norm_blocks):
        if self.whiten:
            if self.embed_mean is None:
                self.hash_whitened_block_embeds(all_block_data)
            embed_size = self.hash_matrix.shape[0]
            query_embeds = np.random.multivariate_normal(np.zeros(embed_size), np.eye(embed_size), num_queries)
            query_embeds = query_embeds / np.linalg.norm(query_embeds, axis=1).reshape(-1, 1)
        else:
            block_idx, all_embeds = zip(*all_block_data.items())
            arr_embeds = np.transpose(np.array(all_embeds))

            mean = np.mean(arr_embeds, axis=1).reshape(-1, 1)
            cov = np.cov(arr_embeds)
            query_embeds = np.random.multivariate_normal(mean, cov, num_queries)

        equal_arr = self.exact_mips_equals(query_embeds, all_block_data, norm_blocks)
        print("Num correct: ", sum(equal_arr), " Fraction correct: ", sum(equal_arr) / equal_arr.size)
        print(equal_arr)