"examples/2_topo/0_demo_topo/test_xml.cpp" did not exist on "57df3737d9a952055d3df66e188cfbc616b4053c"
realm_index.py 11.8 KB
Newer Older
Neel Kant's avatar
Neel Kant committed
1
2
3
4
5
from collections import defaultdict
import os
import pickle
import shutil

Neel Kant's avatar
Neel Kant committed
6
import faiss
Neel Kant's avatar
Neel Kant committed
7
8
9
import numpy as np
import torch

Neel Kant's avatar
Neel Kant committed
10
from megatron import get_args, mpu
Neel Kant's avatar
Neel Kant committed
11
12


13
14
15
16
def detach(tensor):
    return tensor.detach().cpu().numpy()


Neel Kant's avatar
Neel Kant committed
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
class BlockData(object):
    def __init__(self):
        self.embed_data = dict()
        self.meta_data = dict()
        self.temp_dir_name = 'temp_block_data'

    def state(self):
        return {
            'embed_data': self.embed_data,
            'meta_data': self.meta_data
        }

    def clear(self):
        """Clear the data structures to save memory"""
        self.embed_data = dict()
        self.meta_data = dict()

    @classmethod
    def load_from_file(cls, fname):
        print(" > Unpickling block data")
        state_dict = pickle.load(open(fname, 'rb'))
        print(" > Finished unpickling")

        new_index = cls()
        new_index.embed_data = state_dict['embed_data']
        new_index.meta_data = state_dict['meta_data']
        return new_index

    def add_block_data(self, block_indices, block_embeds, block_metas, allow_overwrite=False):
        for idx, embed, meta in zip(block_indices, block_embeds, block_metas):
            if not allow_overwrite and idx in self.embed_data:
                raise ValueError("Unexpectedly tried to overwrite block data")

50
            self.embed_data[idx] = np.float16(embed)
Neel Kant's avatar
Neel Kant committed
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
            self.meta_data[idx] = meta

    def save_shard(self, rank):
        if not os.path.isdir(self.temp_dir_name):
            os.mkdir(self.temp_dir_name)

        # save the data for each shard
        with open('{}/{}.pkl'.format(self.temp_dir_name, rank), 'wb') as data_file:
            pickle.dump(self.state(), data_file)

    def consolidate_shards_and_save(self, ignore_shard=0):
        """Combine all the shards made using self.save_shard()"""
        fnames = os.listdir(self.temp_dir_name)
        for fname in fnames:
            with open('{}/{}'.format(self.temp_dir_name, fname), 'rb') as f:
                data = pickle.load(f)

                old_size = len(self.embed_data)
                shard_size = len(data['embed_data'])
                self.embed_data.update(data['embed_data'])
                self.meta_data.update(data['meta_data'])
                assert (len(self.embed_data) == old_size + shard_size) or (str(ignore_shard) in fname)

        args = get_args()
        with open(args.block_data_path, 'wb') as final_file:
            pickle.dump(self.state(), final_file)
        shutil.rmtree(self.temp_dir_name, ignore_errors=True)


class FaissMIPSIndex(object):
Neel Kant's avatar
Neel Kant committed
81
    def __init__(self, index_type, embed_size, use_gpu=False):
Neel Kant's avatar
Neel Kant committed
82
83
        self.index_type = index_type
        self.embed_size = embed_size
Neel Kant's avatar
Neel Kant committed
84
        self.use_gpu = use_gpu
Neel Kant's avatar
Neel Kant committed
85
86
87
88
89

        # alsh
        self.m = 5
        self.u = 0.99
        self.max_norm = None
90
91
        self.block_mips_index = None
        self._set_block_index()
Neel Kant's avatar
Neel Kant committed
92

93
    def _set_block_index(self):
94
        INDEX_TYPES = ['flat_ip']
Neel Kant's avatar
Neel Kant committed
95
96
97
        if self.index_type not in INDEX_TYPES:
            raise ValueError("Invalid index type specified")

Neel Kant's avatar
Neel Kant committed
98
99
        index = faiss.index_factory(self.embed_size, 'Flat', faiss.METRIC_INNER_PRODUCT)
        self.block_mips_index = faiss.IndexIDMap(index)
Neel Kant's avatar
Neel Kant committed
100
101
102
103
        if self.use_gpu:
            res = faiss.StandardGpuResources()
            device = mpu.get_data_parallel_rank()
            self.block_mips_index = faiss.index_cpu_to_gpu(res, device, self.block_mips_index)
104
105
106

    def reset_index(self):
        self._set_block_index()
Neel Kant's avatar
Neel Kant committed
107
108
109
110
111
112
113
114
115

    def add_block_embed_data(self, all_block_data, clear_block_data=False):
        """Add the embedding of each block to the underlying FAISS index"""
        block_indices, block_embeds = zip(*all_block_data.embed_data.items())
        if clear_block_data:
            all_block_data.clear()

        if self.index_type == 'flat_l2':
            block_embeds = self.alsh_block_preprocess_fn(block_embeds)
Neel Kant's avatar
Neel Kant committed
116
        self.block_mips_index.add_with_ids(np.float32(np.array(block_embeds)), np.array(block_indices))
Neel Kant's avatar
Neel Kant committed
117
118
119
120
121
122
123
124
125

    def search_mips_index(self, query_embeds, top_k, reconstruct=True):
        """Get the top-k blocks by the index distance metric.

        :param reconstruct: if True: return a [num_queries x k x embed_dim] array of blocks
                            if False: return [num_queries x k] array of distances, and another for indices
        """
        if self.index_type == 'flat_l2':
            query_embeds = self.alsh_query_preprocess_fn(query_embeds)
126
        query_embeds = np.float32(query_embeds)
Neel Kant's avatar
Neel Kant committed
127
128

        if reconstruct:
129
            top_k_block_embeds = self.block_mips_index.search_and_reconstruct(query_embeds, top_k)
Neel Kant's avatar
Neel Kant committed
130
131
            return top_k_block_embeds
        else:
132
            distances, block_indices = self.block_mips_index.search(query_embeds, top_k)
Neel Kant's avatar
Neel Kant committed
133
134
            return distances, block_indices

135
136
    # functions below are for ALSH, which currently isn't being used

Neel Kant's avatar
Neel Kant committed
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
    def get_norm_powers_and_halves_array(self, embeds):
        norm = np.linalg.norm(embeds, axis=1)
        norm_powers = [np.multiply(norm, norm)]  # squared L2 norms of all
        for i in range(self.m - 1):
            norm_powers.append(np.multiply(norm_powers[-1], norm_powers[-1]))
        # [num_blocks x self.m]
        norm_powers = np.transpose(np.array(norm_powers))
        halves_array = 0.5 * np.ones(norm_powers.shape)

        return norm_powers, halves_array

    def alsh_block_preprocess_fn(self, block_embeds):
        block_embeds = np.array(block_embeds)
        if self.max_norm is None:
            self.max_norm = max(np.linalg.norm(block_embeds, axis=1))
        if self.max_norm > 1:
            block_embeds = self.u / self.max_norm * block_embeds
        norm_powers, halves_array = self.get_norm_powers_and_halves_array(block_embeds)

        # P'(S(x)) for all x in block_embeds
        return np.float32(np.concatenate((block_embeds, norm_powers, halves_array), axis=1))

    def alsh_query_preprocess_fn(self, query_embeds):
        max_norm = max(np.linalg.norm(query_embeds, axis=1))
        if max_norm > 1:
            query_embeds = self.u / max_norm * query_embeds
        norm_powers, halves_array = self.get_norm_powers_and_halves_array(query_embeds)

        # Q'(S(x)) for all x in query_embeds
        return np.float32(np.concatenate((query_embeds, halves_array, norm_powers), axis=1))


Neel Kant's avatar
Neel Kant committed
169
170
# This was the original hashing scheme, not used anymore

Neel Kant's avatar
Neel Kant committed
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
class RandProjectionLSHIndex(object):
    """Class for holding hashed data"""
    def __init__(self, embed_size, num_buckets, whiten=True, seed=0):
        np.random.seed(seed)
        self.hash_data = defaultdict(list)
        hash_matrix = 2 * np.random.rand(embed_size, int(num_buckets / 2)) - 1
        self.hash_matrix = hash_matrix / np.linalg.norm(hash_matrix, axis=0).reshape(1, -1)
        self.embed_mean = None
        self.embed_whitener = None
        self.whiten = whiten

    def state(self):
        state = {
            'hash_data': self.hash_data,
            'hash_matrix': self.hash_matrix,
            'embed_mean': self.embed_mean,
            'embed_whitener': self.embed_whitener,
        }
        return state

    def save_to_file(self):
        args = get_args()
        with open(args.block_index_path, 'wb') as index_file:
            pickle.dump(self.state(), index_file)

    @classmethod
    def load_from_file(cls, fname):
        print(" > Unpickling block hash data")
        state_dict = pickle.load(open(fname, 'rb'))
        print(" > Finished unpickling")
        hash_matrix = state_dict['hash_matrix']

        new_index = cls(hash_matrix.shape[0], hash_matrix.shape[1] * 2)
        new_index.hash_data = state_dict['hash_data']
        new_index.embed_mean = state_dict.get('embed_mean')
        new_index.embed_whitener = state_dict.get('embed_whitener')
        new_index.hash_matrix = hash_matrix

        return new_index

    def get_block_bucket(self, hash):
        return self.hash_data[hash]

    def hash_embeds(self, embeds, write_block_data=None):
        """Hash a tensor of embeddings using a random projection matrix"""
Neel Kant's avatar
Neel Kant committed
216
        embed_scores_pos = torch.matmul(embeds, torch.cuda.FloatTensor(self.hash_matrix).type(embeds.dtype))
Neel Kant's avatar
Neel Kant committed
217
218
219
220
221
222
223
224
225
226
227
228
        embed_scores = torch.cat((embed_scores_pos, -embed_scores_pos), axis=1)
        embed_hashes = detach(torch.argmax(embed_scores, axis=1))

        if write_block_data is not None:
            for hash, indices in zip(embed_hashes, write_block_data):
                self.hash_data[hash].append(indices)

        return embed_hashes

    def hash_whitened_block_embeds(self, block_data):
        """Transform all block embeds to have zero mean and unit covariance
        when treated as samples from a distribution"""
229
        block_idx, all_embeds = zip(*block_data.embed_data.items())
Neel Kant's avatar
Neel Kant committed
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
        arr_embeds = np.transpose(np.array(all_embeds))

        mean = np.mean(arr_embeds, axis=1).reshape(-1, 1)
        centered = arr_embeds - mean
        inv_cov = np.linalg.inv(np.cov(arr_embeds))
        whitener = np.transpose(np.linalg.cholesky(inv_cov))
        whitened = np.float16(np.transpose(whitener.dot(centered)))

        self.embed_mean = mean.reshape(-1)
        self.embed_whitener = whitener
        self.hash_data = defaultdict(list)
        batch_size = 16384
        i = 0

        args = get_args()
        with torch.no_grad():
            while True:
                if args.debug:
                    print(i, flush=True)
                batch_slice = slice(i * batch_size, (i + 1) * batch_size)
                batch_embed = torch.cuda.HalfTensor(whitened[batch_slice])
                batch_meta = [block_data.meta_data[idx] for idx in block_idx[batch_slice]]
                if len(batch_meta) == 0:
                    break

                self.hash_embeds(batch_embed, batch_meta)
                i += 1

    def exact_mips_equals(self, query_embeds, all_block_data, norm_blocks):
        """For each query, determine whether the mips block is in the correct hash bucket"""
        shuffled_block_idx, block_embeds = zip(*all_block_data.items())
        if norm_blocks:
            block_embeds = block_embeds / np.linalg.norm(block_embeds, axis=1).reshape(-1, 1)
        with torch.no_grad():
            query_hashes = self.hash_embeds(query_embeds)

            # [num_query x num_blocks]
            inner_products = torch.matmul(torch.cuda.HalfTensor(query_embeds),
                                          torch.cuda.HalfTensor(np.transpose(np.array(block_embeds))))
            max_inner_product_idxes = detach(torch.argmax(inner_products, axis=1))
            best_blocks = np.array([all_block_data[shuffled_block_idx[idx]] for idx in max_inner_product_idxes])
            best_block_hashes = self.hash_embeds(best_blocks)

            print('Query hashes: ', query_hashes)
            print('Block hashes: ', best_block_hashes)
            equal_arr = np.equal(query_hashes, best_block_hashes).astype(int)

            # array of zeros and ones which can be used for counting success
            return equal_arr

    def exact_mips_test(self, num_queries, all_block_data, norm_blocks):
        if self.whiten:
            if self.embed_mean is None:
                self.hash_whitened_block_embeds(all_block_data)
            embed_size = self.hash_matrix.shape[0]
            query_embeds = np.random.multivariate_normal(np.zeros(embed_size), np.eye(embed_size), num_queries)
            query_embeds = query_embeds / np.linalg.norm(query_embeds, axis=1).reshape(-1, 1)
        else:
            block_idx, all_embeds = zip(*all_block_data.items())
            arr_embeds = np.transpose(np.array(all_embeds))

            mean = np.mean(arr_embeds, axis=1).reshape(-1, 1)
            cov = np.cov(arr_embeds)
            query_embeds = np.random.multivariate_normal(mean, cov, num_queries)

        equal_arr = self.exact_mips_equals(query_embeds, all_block_data, norm_blocks)
        print("Num correct: ", sum(equal_arr), " Fraction correct: ", sum(equal_arr) / equal_arr.size)
        print(equal_arr)