hashed_index.py 9.05 KB
Newer Older
Neel Kant's avatar
Neel Kant committed
1
from collections import defaultdict
Neel Kant's avatar
Neel Kant committed
2
import os
Neel Kant's avatar
Neel Kant committed
3
import pickle
Neel Kant's avatar
Neel Kant committed
4
import shutil
Neel Kant's avatar
Neel Kant committed
5

Neel Kant's avatar
Neel Kant committed
6
7
8
9
10
11
12
13
14
15
16
import numpy as np
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP

from megatron import get_args
from megatron import mpu
from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoint_name
from megatron.data.bert_dataset import get_indexed_dataset_
from megatron.data.ict_dataset import InverseClozeDataset
from megatron.data.samplers import DistributedBatchSampler
from megatron.initialize import initialize_megatron
Neel Kant's avatar
Neel Kant committed
17
from megatron.model import REALMRetriever
Neel Kant's avatar
Neel Kant committed
18
19
20
21
from megatron.training import get_model
from pretrain_bert_ict import get_batch, model_provider


22
23
24
25
def detach(tensor):
    return tensor.detach().cpu().numpy()


Neel Kant's avatar
Neel Kant committed
26
27
28
29
30
31
class HashedIndex(object):
    """Class for holding hashed data"""
    def __init__(self, embed_size, num_buckets, seed=0):
        np.random.seed(seed)
        self.block_data = defaultdict(list)
        self.hash_data = defaultdict(list)
Neel Kant's avatar
Neel Kant committed
32
        self.hash_matrix = np.random.rand(embed_size, int(num_buckets / 2))
Neel Kant's avatar
Neel Kant committed
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75

    def state(self):
        state = {
            'block_data': self.block_data,
            'hash_data': self.hash_data,
            'hash_matrix': self.hash_matrix
        }
        return state

    def get_block_bucket(self, hash):
        return self.hash_data[hash]

    def get_block_embed(self, block_idx):
        return self.block_data[block_idx]

    def hash_embeds(self, embeds, block_data=None):
        """Hash a tensor of embeddings using a random projection matrix"""
        embed_scores_pos = torch.matmul(embeds, torch.cuda.HalfTensor(self.hash_matrix))
        embed_scores = torch.cat((embed_scores_pos, -embed_scores_pos), axis=1)
        embed_hashes = detach(torch.argmax(embed_scores, axis=1))

        if block_data is not None:
            for hash, indices in zip(embed_hashes, block_data):
                self.hash_data[hash].append(indices)

        return embed_hashes

    def assign_block_embeds(self, block_indices, block_embeds, allow_overwrite=False):
        """Assign the embeddings for each block index into a hash map"""
        for idx, embed in zip(block_indices, block_embeds):
            if not allow_overwrite and int(idx) in self.block_data:
                raise ValueError("Attempted to overwrite a read-only HashedIndex")
            self.block_data[int(idx)] = embed

    def save_shard(self, rank):
        dir_name = 'block_hash_data'
        if not os.path.isdir(dir_name):
            os.mkdir(dir_name)

        # save the data for each shard
        with open('{}/{}.pkl'.format(dir_name, rank), 'wb') as data_file:
            pickle.dump(self.state(), data_file)

Neel Kant's avatar
Neel Kant committed
76
    def consolidate_shards_and_save(self, ignore_shard=0):
Neel Kant's avatar
Neel Kant committed
77
78
79
80
        """Combine all the shards made using self.save_shard()"""
        dir_name = 'block_hash_data'
        fnames = os.listdir(dir_name)
        for fname in fnames:
Neel Kant's avatar
Neel Kant committed
81
82
            if str(ignore_shard) in fname:
                continue
Neel Kant's avatar
Neel Kant committed
83
84
            with open('{}/{}'.format(dir_name, fname), 'rb') as f:
                data = pickle.load(f)
Neel Kant's avatar
Neel Kant committed
85
                assert np.array_equal(data['hash_matrix'], self.hash_matrix)
Neel Kant's avatar
Neel Kant committed
86
87
88
89

                old_size = len(self.block_data)
                shard_size = len(data['block_data'])
                self.block_data.update(data['block_data'])
Neel Kant's avatar
Neel Kant committed
90
                assert len(self.block_data) == old_size + shard_size, (old_size, shard_size, len(self.block_data))
Neel Kant's avatar
Neel Kant committed
91
92
93
94
95
96
97
98
99
100
101
102
103

                for bucket, items in data['hash_data'].items():
                    self.hash_data[bucket].extend(items)

        with open('block_hash_data.pkl', 'wb') as final_file:
            pickle.dump(self.state(), final_file)
        shutil.rmtree(dir_name, ignore_errors=True)

    def clear(self):
        """Clear the data structures to save memory"""
        self.block_data = defaultdict(list)
        self.hash_data = defaultdict(list)

Neel Kant's avatar
Neel Kant committed
104
105
106
107
108
109
110
111
112
113
114
    @classmethod
    def load_from_file(cls, fname):
        state_dict = pickle.load(open(fname, 'rb'))
        hash_matrix = state_dict['hash_matrix']

        new_index = HashedIndex(hash_matrix.shape[0], hash_matrix.shape[1] * 2)
        new_index.block_data = state_dict['block_data']
        new_index.hash_data = state_dict['hash_data']
        new_index.hash_matrix = hash_matrix
        return new_index

Neel Kant's avatar
Neel Kant committed
115

Neel Kant's avatar
Neel Kant committed
116
117
118
119
120
121
def test_retriever():
    initialize_megatron(extra_args_provider=None,
                        args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
    model = load_checkpoint()
    model.eval()
    dataset = get_dataset()
Neel Kant's avatar
Neel Kant committed
122
    hashed_index = HashedIndex.load_from_file('block_hash_data.pkl')
Neel Kant's avatar
Neel Kant committed
123
124
125
126
    retriever = REALMRetriever(model, dataset, hashed_index)
    retriever.retrieve_evidence_blocks_text("The last monarch from the house of windsor")


Neel Kant's avatar
Neel Kant committed
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
def main():

    # TODO
    # consider broadcasting/all-reducing all in memory rather than using the filesystem
    # create a different process group in the same nccl world - don't have to use chkpts on disc or transfer things on disc
    # torch distributed new group, constains a list of rank, gives back a group which I can hand to the collective operations
    # create a training process group, indexing process group
    # pass the training group to the distributed DDP, instead of the large world process group
    # use indexing process group for the shard-combining
    # communication group between process "8" and process "0" which tells training group that there's a new index
    # also, process 0 sends process 8 the new model

    # if i want to launch a separate process for indexing, may have to work with environment variables to
    # allocate the resources well. Have to subsequently assign the correct gpus to the indexing job
    # consider initializing everything in a single group and break off processes based on the ranks

Neel Kant's avatar
Neel Kant committed
143
144
145
146
147
148
149
    initialize_megatron(extra_args_provider=None,
                        args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
    args = get_args()
    model = load_checkpoint()
    model.eval()
    dataset = get_dataset()
    data_iter = iter(get_dataloader(dataset))
Neel Kant's avatar
Neel Kant committed
150
    hashed_index = HashedIndex(embed_size=128, num_buckets=2048)
Neel Kant's avatar
Neel Kant committed
151

Neel Kant's avatar
Neel Kant committed
152
    i = 0
Neel Kant's avatar
Neel Kant committed
153
154
    while True:
        try:
Neel Kant's avatar
Neel Kant committed
155
156
            query_tokens, query_pad_mask, \
            block_tokens, block_pad_mask, block_indices = get_batch(data_iter)
157
        except:
Neel Kant's avatar
Neel Kant committed
158
            break
159

Neel Kant's avatar
Neel Kant committed
160
161
        actual_model = model.module.module
        block_indices = detach(block_indices)
Neel Kant's avatar
Neel Kant committed
162

Neel Kant's avatar
Neel Kant committed
163
164
        block_logits = actual_model.embed_block(block_tokens, block_pad_mask)
        hashed_index.hash_embeds(block_logits, block_indices)
Neel Kant's avatar
Neel Kant committed
165
        hashed_index.assign_block_embeds(block_indices[:,3], detach(block_logits))
Neel Kant's avatar
Neel Kant committed
166

Neel Kant's avatar
Neel Kant committed
167
168
        if i % 100 == 0:
            print(i, flush=True)
169
170
        i += 1

Neel Kant's avatar
Neel Kant committed
171
    hashed_index.save_shard(args.rank)
Neel Kant's avatar
Neel Kant committed
172
    torch.distributed.barrier()
173
174
    del model

Neel Kant's avatar
Neel Kant committed
175
    if mpu.get_data_parallel_rank() == 0:
Neel Kant's avatar
Neel Kant committed
176
177
178
        hashed_index.consolidate_shards_and_save()
    else:
        hashed_index.clear()
Neel Kant's avatar
Neel Kant committed
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208


def load_checkpoint():
    args = get_args()
    model = get_model(model_provider)

    if isinstance(model, torchDDP):
        model = model.module
    tracker_filename = get_checkpoint_tracker_filename(args.load)
    with open(tracker_filename, 'r') as f:
        iteration = int(f.read().strip())

    assert iteration > 0
    checkpoint_name = get_checkpoint_name(args.load, iteration, False)
    if mpu.get_data_parallel_rank() == 0:
        print('global rank {} is loading checkpoint {}'.format(
            torch.distributed.get_rank(), checkpoint_name))

    state_dict = torch.load(checkpoint_name, map_location='cpu')
    model.load_state_dict(state_dict['model'])
    torch.distributed.barrier()

    if mpu.get_data_parallel_rank() == 0:
        print(' successfully loaded {}'.format(checkpoint_name))

    return model


def get_dataset():
    args = get_args()
Neel Kant's avatar
Neel Kant committed
209
210
    block_dataset = get_indexed_dataset_(args.data_path, 'mmap', True)
    titles_dataset = get_indexed_dataset_(args.data_path + '-titles', 'mmap', True)
Neel Kant's avatar
Neel Kant committed
211
212
213

    kwargs = dict(
        name='full',
Neel Kant's avatar
Neel Kant committed
214
215
        block_dataset=block_dataset,
        title_dataset=titles_dataset,
Neel Kant's avatar
Neel Kant committed
216
        data_prefix=args.data_path,
Neel Kant's avatar
Neel Kant committed
217
218
        num_epochs=1,
        max_num_samples=None,
Neel Kant's avatar
Neel Kant committed
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
        max_seq_length=288,  # doesn't matter
        short_seq_prob=0.0001,  # doesn't matter
        seed=1
    )
    dataset = InverseClozeDataset(**kwargs)
    return dataset


def get_dataloader(dataset):
    args = get_args()

    world_size = mpu.get_data_parallel_world_size()
    rank = mpu.get_data_parallel_rank()
    global_batch_size = args.batch_size * world_size
    num_workers = args.num_workers

    sampler = torch.utils.data.SequentialSampler(dataset)
    batch_sampler = DistributedBatchSampler(sampler,
                                            batch_size=global_batch_size,
                                            drop_last=True,
                                            rank=rank,
                                            world_size=world_size)

    return torch.utils.data.DataLoader(dataset,
                                       batch_sampler=batch_sampler,
                                       num_workers=num_workers,
                                       pin_memory=True)


if __name__ == "__main__":
Neel Kant's avatar
Neel Kant committed
249
    main()