realm_index.py 13.2 KB
Newer Older
Neel Kant's avatar
Neel Kant committed
1
2
3
4
5
from collections import defaultdict
import os
import pickle
import shutil

Neel Kant's avatar
Neel Kant committed
6
import faiss
Neel Kant's avatar
Neel Kant committed
7
8
9
import numpy as np
import torch

Neel Kant's avatar
Neel Kant committed
10
from megatron import get_args, mpu
Neel Kant's avatar
Neel Kant committed
11
12


13
14
15
16
def detach(tensor):
    return tensor.detach().cpu().numpy()


Neel Kant's avatar
Neel Kant committed
17
18
class BlockData(object):
    def __init__(self):
19
        args = get_args()
Neel Kant's avatar
Neel Kant committed
20
21
        self.embed_data = dict()
        self.meta_data = dict()
22
23
        block_data_path = os.path.splitext(args.block_data_path)[0]
        self.temp_dir_name = block_data_path + '_tmp'
Neel Kant's avatar
Neel Kant committed
24
25
26
27
28
29
30
31
32
33

    def state(self):
        return {
            'embed_data': self.embed_data,
            'meta_data': self.meta_data
        }

    def clear(self):
        """Clear the data structures to save memory"""
        self.embed_data = dict()
Neel Kant's avatar
Neel Kant committed
34
        # self.meta_data = dict()
Neel Kant's avatar
Neel Kant committed
35
36
37

    @classmethod
    def load_from_file(cls, fname):
38
        print("\n> Unpickling block data", flush=True)
Neel Kant's avatar
Neel Kant committed
39
        state_dict = pickle.load(open(fname, 'rb'))
40
        print(">> Finished unpickling block data\n", flush=True)
Neel Kant's avatar
Neel Kant committed
41
42
43
44
45
46
47
48
49
50
51

        new_index = cls()
        new_index.embed_data = state_dict['embed_data']
        new_index.meta_data = state_dict['meta_data']
        return new_index

    def add_block_data(self, block_indices, block_embeds, block_metas, allow_overwrite=False):
        for idx, embed, meta in zip(block_indices, block_embeds, block_metas):
            if not allow_overwrite and idx in self.embed_data:
                raise ValueError("Unexpectedly tried to overwrite block data")

52
            self.embed_data[idx] = np.float16(embed)
Neel Kant's avatar
Neel Kant committed
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
            self.meta_data[idx] = meta

    def save_shard(self, rank):
        if not os.path.isdir(self.temp_dir_name):
            os.mkdir(self.temp_dir_name)

        # save the data for each shard
        with open('{}/{}.pkl'.format(self.temp_dir_name, rank), 'wb') as data_file:
            pickle.dump(self.state(), data_file)

    def consolidate_shards_and_save(self, ignore_shard=0):
        """Combine all the shards made using self.save_shard()"""
        fnames = os.listdir(self.temp_dir_name)
        for fname in fnames:
            with open('{}/{}'.format(self.temp_dir_name, fname), 'rb') as f:
                data = pickle.load(f)

                old_size = len(self.embed_data)
                shard_size = len(data['embed_data'])
                self.embed_data.update(data['embed_data'])
                self.meta_data.update(data['meta_data'])
74
                # assert (len(self.embed_data) == old_size + shard_size) or (str(ignore_shard) in fname)
Neel Kant's avatar
Neel Kant committed
75
76
77
78
79
80
81
82

        args = get_args()
        with open(args.block_data_path, 'wb') as final_file:
            pickle.dump(self.state(), final_file)
        shutil.rmtree(self.temp_dir_name, ignore_errors=True)


class FaissMIPSIndex(object):
Neel Kant's avatar
Neel Kant committed
83
    def __init__(self, index_type, embed_size, use_gpu=False):
Neel Kant's avatar
Neel Kant committed
84
85
        self.index_type = index_type
        self.embed_size = embed_size
Neel Kant's avatar
Neel Kant committed
86
        self.use_gpu = use_gpu
87
        self.id_map = dict()
Neel Kant's avatar
Neel Kant committed
88
89
90
91
92

        # alsh
        self.m = 5
        self.u = 0.99
        self.max_norm = None
93
94
        self.block_mips_index = None
        self._set_block_index()
Neel Kant's avatar
Neel Kant committed
95

96
    def _set_block_index(self):
97
        INDEX_TYPES = ['flat_ip']
Neel Kant's avatar
Neel Kant committed
98
99
100
        if self.index_type not in INDEX_TYPES:
            raise ValueError("Invalid index type specified")

101
102
103
104
        print("\n> Building index", flush=True)
        self.block_mips_index = faiss.index_factory(self.embed_size, 'Flat', faiss.METRIC_INNER_PRODUCT)
        if not self.use_gpu:
            self.block_mips_index = faiss.IndexIDMap(self.block_mips_index)
Neel Kant's avatar
Neel Kant committed
105
            print(">> Finished building index\n", flush=True)
106

Neel Kant's avatar
Neel Kant committed
107
108
        if self.use_gpu:
            res = faiss.StandardGpuResources()
109
110
111
112
113
            # self.block_mips_index = faiss.index_cpu_to_gpu(res, device, self.block_mips_index)
            config = faiss.GpuIndexFlatConfig()
            config.device = torch.cuda.current_device()
            config.useFloat16 = True
            self.block_mips_index = faiss.GpuIndexFlat(res, self.block_mips_index, config)
Neel Kant's avatar
Neel Kant committed
114
            print(">>> Finished building index on GPU {}\n".format(self.block_mips_index.getDevice()), flush=True)
115
116

    def reset_index(self):
Neel Kant's avatar
Neel Kant committed
117
        del self.block_mips_index
118
        self._set_block_index()
Neel Kant's avatar
Neel Kant committed
119
120
121
122

    def add_block_embed_data(self, all_block_data, clear_block_data=False):
        """Add the embedding of each block to the underlying FAISS index"""
        block_indices, block_embeds = zip(*all_block_data.embed_data.items())
123
124
125
        if self.use_gpu:
            for i, idx in enumerate(block_indices):
                self.id_map[i] = idx
Neel Kant's avatar
Neel Kant committed
126
        if True:
Neel Kant's avatar
Neel Kant committed
127
128
            all_block_data.clear()

129
130
131
132
        if self.use_gpu:
            self.block_mips_index.add(np.float32(np.array(block_embeds)))
        else:
            self.block_mips_index.add_with_ids(np.float32(np.array(block_embeds)), np.array(block_indices))
Neel Kant's avatar
Neel Kant committed
133
134
135
136
137
138
139

    def search_mips_index(self, query_embeds, top_k, reconstruct=True):
        """Get the top-k blocks by the index distance metric.

        :param reconstruct: if True: return a [num_queries x k x embed_dim] array of blocks
                            if False: return [num_queries x k] array of distances, and another for indices
        """
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
        query_embeds = np.float32(detach(query_embeds))
        # query_embeds = query_embeds.float()

        with torch.no_grad():
            if reconstruct:
                top_k_block_embeds = self.block_mips_index.search_and_reconstruct(query_embeds, top_k)
                return top_k_block_embeds
            else:
                distances, block_indices = self.block_mips_index.search(query_embeds, top_k)
                if self.use_gpu:
                    fresh_indices = np.zeros(block_indices.shape)
                    for i in range(block_indices.shape[0]):
                        for j in range(block_indices.shape[1]):
                            fresh_indices[i, j] = self.id_map[block_indices[i, j]]
                    block_indices = fresh_indices
155
156
157
158
159
160
                    # args = get_args()
                    # if args.rank == 0:
                    #     torch.save({'query_embeds': query_embeds,
                    #                 'id_map': self.id_map,
                    #                 'block_indices': block_indices,
                    #                 'distances': distances}, 'search.data')
161
                return distances, block_indices
Neel Kant's avatar
Neel Kant committed
162

163
164
    # functions below are for ALSH, which currently isn't being used

Neel Kant's avatar
Neel Kant committed
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
    def get_norm_powers_and_halves_array(self, embeds):
        norm = np.linalg.norm(embeds, axis=1)
        norm_powers = [np.multiply(norm, norm)]  # squared L2 norms of all
        for i in range(self.m - 1):
            norm_powers.append(np.multiply(norm_powers[-1], norm_powers[-1]))
        # [num_blocks x self.m]
        norm_powers = np.transpose(np.array(norm_powers))
        halves_array = 0.5 * np.ones(norm_powers.shape)

        return norm_powers, halves_array

    def alsh_block_preprocess_fn(self, block_embeds):
        block_embeds = np.array(block_embeds)
        if self.max_norm is None:
            self.max_norm = max(np.linalg.norm(block_embeds, axis=1))
        if self.max_norm > 1:
            block_embeds = self.u / self.max_norm * block_embeds
        norm_powers, halves_array = self.get_norm_powers_and_halves_array(block_embeds)

        # P'(S(x)) for all x in block_embeds
        return np.float32(np.concatenate((block_embeds, norm_powers, halves_array), axis=1))

    def alsh_query_preprocess_fn(self, query_embeds):
        max_norm = max(np.linalg.norm(query_embeds, axis=1))
        if max_norm > 1:
            query_embeds = self.u / max_norm * query_embeds
        norm_powers, halves_array = self.get_norm_powers_and_halves_array(query_embeds)

        # Q'(S(x)) for all x in query_embeds
        return np.float32(np.concatenate((query_embeds, halves_array, norm_powers), axis=1))


Neel Kant's avatar
Neel Kant committed
197
198
# This was the original hashing scheme, not used anymore

Neel Kant's avatar
Neel Kant committed
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
class RandProjectionLSHIndex(object):
    """Class for holding hashed data"""
    def __init__(self, embed_size, num_buckets, whiten=True, seed=0):
        np.random.seed(seed)
        self.hash_data = defaultdict(list)
        hash_matrix = 2 * np.random.rand(embed_size, int(num_buckets / 2)) - 1
        self.hash_matrix = hash_matrix / np.linalg.norm(hash_matrix, axis=0).reshape(1, -1)
        self.embed_mean = None
        self.embed_whitener = None
        self.whiten = whiten

    def state(self):
        state = {
            'hash_data': self.hash_data,
            'hash_matrix': self.hash_matrix,
            'embed_mean': self.embed_mean,
            'embed_whitener': self.embed_whitener,
        }
        return state

    def save_to_file(self):
        args = get_args()
        with open(args.block_index_path, 'wb') as index_file:
            pickle.dump(self.state(), index_file)

    @classmethod
    def load_from_file(cls, fname):
        print(" > Unpickling block hash data")
        state_dict = pickle.load(open(fname, 'rb'))
        print(" > Finished unpickling")
        hash_matrix = state_dict['hash_matrix']

        new_index = cls(hash_matrix.shape[0], hash_matrix.shape[1] * 2)
        new_index.hash_data = state_dict['hash_data']
        new_index.embed_mean = state_dict.get('embed_mean')
        new_index.embed_whitener = state_dict.get('embed_whitener')
        new_index.hash_matrix = hash_matrix

        return new_index

    def get_block_bucket(self, hash):
        return self.hash_data[hash]

    def hash_embeds(self, embeds, write_block_data=None):
        """Hash a tensor of embeddings using a random projection matrix"""
Neel Kant's avatar
Neel Kant committed
244
        embed_scores_pos = torch.matmul(embeds, torch.cuda.FloatTensor(self.hash_matrix).type(embeds.dtype))
Neel Kant's avatar
Neel Kant committed
245
246
247
248
249
250
251
252
253
254
255
256
        embed_scores = torch.cat((embed_scores_pos, -embed_scores_pos), axis=1)
        embed_hashes = detach(torch.argmax(embed_scores, axis=1))

        if write_block_data is not None:
            for hash, indices in zip(embed_hashes, write_block_data):
                self.hash_data[hash].append(indices)

        return embed_hashes

    def hash_whitened_block_embeds(self, block_data):
        """Transform all block embeds to have zero mean and unit covariance
        when treated as samples from a distribution"""
257
        block_idx, all_embeds = zip(*block_data.embed_data.items())
Neel Kant's avatar
Neel Kant committed
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
        arr_embeds = np.transpose(np.array(all_embeds))

        mean = np.mean(arr_embeds, axis=1).reshape(-1, 1)
        centered = arr_embeds - mean
        inv_cov = np.linalg.inv(np.cov(arr_embeds))
        whitener = np.transpose(np.linalg.cholesky(inv_cov))
        whitened = np.float16(np.transpose(whitener.dot(centered)))

        self.embed_mean = mean.reshape(-1)
        self.embed_whitener = whitener
        self.hash_data = defaultdict(list)
        batch_size = 16384
        i = 0

        args = get_args()
        with torch.no_grad():
            while True:
                if args.debug:
                    print(i, flush=True)
                batch_slice = slice(i * batch_size, (i + 1) * batch_size)
                batch_embed = torch.cuda.HalfTensor(whitened[batch_slice])
                batch_meta = [block_data.meta_data[idx] for idx in block_idx[batch_slice]]
                if len(batch_meta) == 0:
                    break

                self.hash_embeds(batch_embed, batch_meta)
                i += 1

    def exact_mips_equals(self, query_embeds, all_block_data, norm_blocks):
        """For each query, determine whether the mips block is in the correct hash bucket"""
        shuffled_block_idx, block_embeds = zip(*all_block_data.items())
        if norm_blocks:
            block_embeds = block_embeds / np.linalg.norm(block_embeds, axis=1).reshape(-1, 1)
        with torch.no_grad():
            query_hashes = self.hash_embeds(query_embeds)

            # [num_query x num_blocks]
            inner_products = torch.matmul(torch.cuda.HalfTensor(query_embeds),
                                          torch.cuda.HalfTensor(np.transpose(np.array(block_embeds))))
            max_inner_product_idxes = detach(torch.argmax(inner_products, axis=1))
            best_blocks = np.array([all_block_data[shuffled_block_idx[idx]] for idx in max_inner_product_idxes])
            best_block_hashes = self.hash_embeds(best_blocks)

            print('Query hashes: ', query_hashes)
            print('Block hashes: ', best_block_hashes)
            equal_arr = np.equal(query_hashes, best_block_hashes).astype(int)

            # array of zeros and ones which can be used for counting success
            return equal_arr

    def exact_mips_test(self, num_queries, all_block_data, norm_blocks):
        if self.whiten:
            if self.embed_mean is None:
                self.hash_whitened_block_embeds(all_block_data)
            embed_size = self.hash_matrix.shape[0]
            query_embeds = np.random.multivariate_normal(np.zeros(embed_size), np.eye(embed_size), num_queries)
            query_embeds = query_embeds / np.linalg.norm(query_embeds, axis=1).reshape(-1, 1)
        else:
            block_idx, all_embeds = zip(*all_block_data.items())
            arr_embeds = np.transpose(np.array(all_embeds))

            mean = np.mean(arr_embeds, axis=1).reshape(-1, 1)
            cov = np.cov(arr_embeds)
            query_embeds = np.random.multivariate_normal(mean, cov, num_queries)

        equal_arr = self.exact_mips_equals(query_embeds, all_block_data, norm_blocks)
        print("Num correct: ", sum(equal_arr), " Fraction correct: ", sum(equal_arr) / equal_arr.size)
        print(equal_arr)