indexer.py 11 KB
Newer Older
1
import os
2
import sys
3
4
import time

Neel Kant's avatar
Neel Kant committed
5
import torch
6
import torch.distributed as dist
Neel Kant's avatar
Neel Kant committed
7
8
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP

9
from megatron import get_args, get_adlr_autoresume, print_rank_0
Neel Kant's avatar
Neel Kant committed
10
11
12
from megatron import mpu
from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoint_name
from megatron.data.bert_dataset import get_indexed_dataset_
13
from megatron.data.realm_dataset import ICTDataset
14
from megatron.data.realm_index import detach, BlockData, FaissMIPSIndex
Neel Kant's avatar
Neel Kant committed
15
16
from megatron.data.samplers import DistributedBatchSampler
from megatron.initialize import initialize_megatron
Neel Kant's avatar
Neel Kant committed
17
from megatron.model import REALMRetriever
18
19
20
21
from megatron.global_vars import set_global_variables
from megatron.mpu.initialize import get_index_ready, get_index_group, get_train_group
from megatron.mpu.initialize import set_data_parallel_group, set_model_parallel_group, init_realm_groups
from megatron.initialize import init_distributed, _init_autoresume, _set_random_seed, _write_args_to_tensorboard
Neel Kant's avatar
Neel Kant committed
22
from megatron.training import get_model
23
from megatron.utils import check_adlr_autoresume_termination
Neel Kant's avatar
Neel Kant committed
24
25
26
from pretrain_bert_ict import get_batch, model_provider


27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
INDEX_READY = None


def pprint(*args):
    print(*args, flush=True)


def initialize_and_run_async_megatron(extra_args_provider=None, args_defaults={},
                                      ignore_unknown_args=False, allow_no_cuda=False):
    if not allow_no_cuda:
        # Make sure cuda is available.
        assert torch.cuda.is_available(), 'Megatron requires CUDA.'

    # Parse args, build tokenizer, and set adlr-autoresume,
    # tensorboard-writer, and timers.
    set_global_variables(extra_args_provider=extra_args_provider,
                         args_defaults=args_defaults,
                         ignore_unknown_args=ignore_unknown_args)

    # instead of _initialize_distributed()
    init_distributed()
    setup_realm_groups_and_vars()
    global INDEX_READY
    INDEX_READY = get_index_ready()
    pprint('finished setting up groups')

    # Autoresume
    _init_autoresume()
    pprint('finished setting up autoresume')

    # Random seeds for reproducibility.
Neel Kant's avatar
Neel Kant committed
58
    args = get_args()
59
60
61
    if args.rank == 0:
        pprint('> setting random seeds to {} ...'.format(args.seed))
    _set_random_seed(args.seed)
Neel Kant's avatar
Neel Kant committed
62

63
64
65
    # Write arguments to tensorboard.
    _write_args_to_tensorboard()
    pprint('finished writing args to tensorboard')
Neel Kant's avatar
Neel Kant committed
66

67
    torch.distributed.barrier()
Neel Kant's avatar
Neel Kant committed
68

69
70
71
72
73
74
75
76
77
    if args.rank < args.max_training_rank:
        torch.distributed.barrier(get_train_group())
        pprint("All trainers ready.")
        return
    else:
        runner = AsyncIndexBuilder(args.rank)
        torch.distributed.barrier(get_index_group())
        pprint("All indexers ready.")
        runner.run_async()
Neel Kant's avatar
Neel Kant committed
78
79


80
def setup_realm_groups_and_vars():
Neel Kant's avatar
Neel Kant committed
81
    args = get_args()
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
    world_size = dist.get_world_size()
    max_training_rank = args.max_training_rank

    # assuming no model parallelism right now
    set_model_parallel_group(dist.new_group([args.rank]))
    init_realm_groups(max_training_rank, world_size)

    if args.rank < max_training_rank:
        set_data_parallel_group(get_train_group())
    else:
        set_data_parallel_group(get_index_group())


class AsyncIndexBuilder(object):
    def __init__(self, rank):
        self.rank = rank
        args = get_args()
        self.is_main_builder = self.rank == args.max_training_rank
        self.main_builder_idx = args.max_training_rank
        self.debug = args.debug

        self.model = None
        self.dataloader = None
        self.block_data = None
        self.load_attributes()

        global INDEX_READY
        INDEX_READY = get_index_ready()

    def run_async(self):
        while True:
113
114
            print("Starting (again!)", flush=True)
            self.build_and_save_index()
115
116
            self.send_index_ready_signal()
            while INDEX_READY == 1:
117
118
                print("Waiting for new model checkpoint.", flush=True)
                time.sleep(5)
119

120
            self.load_attributes()
121
122

    def load_attributes(self):
Neel Kant's avatar
Neel Kant committed
123
        try:
124
            self.model = load_ict_checkpoint(only_block_model=True, no_grad=True, from_realm_chkpt=True)
Neel Kant's avatar
Neel Kant committed
125
        except:
126
127
128
129
            self.model = load_ict_checkpoint(only_block_model=True, no_grad=True, from_realm_chkpt=False)
        self.model.eval()
        self.dataloader = iter(get_one_epoch_dataloader(get_ict_dataset()))
        self.block_data = BlockData()
130

131
    def build_and_save_index(self):
132
133
134
135
136
137
        i = 1
        total = 0
        while True:
            with torch.no_grad():
                try:
                    query_tokens, query_pad_mask, \
138
                    block_tokens, block_pad_mask, block_index_data = get_batch(self.dataloader)
139
                except:
140
141
                    break

142
143
144
145
                block_index_data = detach(block_index_data)
                block_indices = block_index_data[:, 3]
                block_meta = block_index_data[:, :3]

146
147
                block_logits = detach(self.model(None, None, block_tokens, block_pad_mask, only_block=True))
                self.block_data.add_block_data(block_indices, block_logits, block_meta)
148
149
150

                total += block_indices.size
                i += 1
151
                if i % 10 == 0:
152
                    print('Batch {:10d} | Total {:10d}'.format(i, total), flush=True)
153
                    if self.debug:
154
155
                        break

156
157
158
159
160
161
162
163
                autoresume = get_adlr_autoresume()
                if autoresume.termination_requested():
                    print_rank_0(">>> autoresume termination request found!")
                    if torch.distributed.get_rank() == 0:
                        autoresume.request_resume()
                    print_rank_0(">>> training terminated. Returning")
                    sys.exit(0)

164
165
166
167
168
169
        self.block_data.save_shard(self.rank)
        torch.distributed.barrier()
        del self.model

        if self.is_main_builder:
            self.block_data.consolidate_shards_and_save(ignore_shard=self.rank)
170
        self.block_data.clear()
171
172
173
174
175
176

    def send_index_ready_signal(self):
        global INDEX_READY
        if self.is_main_builder:
            INDEX_READY = 1 - INDEX_READY
            print("Switched INDEX_READY", flush=True)
177
178
        import time
        print(time.ctime(time.time()), flush=True)
179
180
181
182
        send_handle = dist.broadcast(INDEX_READY, self.main_builder_idx, async_op=True)

        torch.distributed.barrier(get_index_group())
        recv_handle = dist.broadcast(INDEX_READY, 0, async_op=True)
183
184


185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
class BasicIndexBuilder(object):
    def __init__(self):
        args = get_args()
        self.rank = args.rank
        self.model = load_ict_checkpoint(only_block_model=True, no_grad=True, from_realm_chkpt=False)
        self.model.eval()
        self.dataloader = iter(get_one_epoch_dataloader(get_ict_dataset()))
        self.block_data = BlockData()

    def build_and_save_index(self):
        i = 1
        total = 0
        while True:
            with torch.no_grad():
                try:
                    query_tokens, query_pad_mask, \
                    block_tokens, block_pad_mask, block_index_data = get_batch(self.dataloader)
                except:
                    break

                block_index_data = detach(block_index_data)
                block_indices = block_index_data[:, 3]
                block_meta = block_index_data[:, :3]

                block_logits = detach(self.model(None, None, block_tokens, block_pad_mask, only_block=True))
                self.block_data.add_block_data(block_indices, block_logits, block_meta)

                total += block_indices.size
                i += 1
                if i % 2000 == 0:
                    print('Batch {:10d} | Total {:10d}'.format(i, total), flush=True)

        self.block_data.save_shard(self.rank)
        torch.distributed.barrier()
        del self.model

        if self.rank == 0:
            self.block_data.consolidate_shards_and_save(ignore_shard=self.rank)
        self.block_data.clear()


226
def load_ict_checkpoint(only_query_model=False, only_block_model=False, no_grad=False, from_realm_chkpt=False):
Neel Kant's avatar
Neel Kant committed
227
    args = get_args()
Neel Kant's avatar
Neel Kant committed
228
    model = get_model(lambda: model_provider(only_query_model, only_block_model))
Neel Kant's avatar
Neel Kant committed
229
230
231

    if isinstance(model, torchDDP):
        model = model.module
232
233
234

    load_path = args.load if from_realm_chkpt else args.ict_load

235
    tracker_filename = get_checkpoint_tracker_filename(load_path)
Neel Kant's avatar
Neel Kant committed
236
237
238
239
    with open(tracker_filename, 'r') as f:
        iteration = int(f.read().strip())

    assert iteration > 0
240
    checkpoint_name = get_checkpoint_name(load_path, iteration, False)
Neel Kant's avatar
Neel Kant committed
241
242
243
244
245
    if mpu.get_data_parallel_rank() == 0:
        print('global rank {} is loading checkpoint {}'.format(
            torch.distributed.get_rank(), checkpoint_name))

    state_dict = torch.load(checkpoint_name, map_location='cpu')
246
247
248
249
    ict_state_dict = state_dict['model']
    if from_realm_chkpt:
        ict_state_dict = ict_state_dict['retriever']['ict_model']

Neel Kant's avatar
Neel Kant committed
250
    if only_query_model:
251
        ict_state_dict.pop('context_model')
Neel Kant's avatar
Neel Kant committed
252
    if only_block_model:
253
        ict_state_dict.pop('question_model')
Neel Kant's avatar
Neel Kant committed
254
255
    if no_grad:
        with torch.no_grad():
256
            model.load_state_dict(ict_state_dict)
Neel Kant's avatar
Neel Kant committed
257
    else:
258
        model.load_state_dict(ict_state_dict)
Neel Kant's avatar
Neel Kant committed
259
260
261
262
263
264
265
266
    torch.distributed.barrier()

    if mpu.get_data_parallel_rank() == 0:
        print(' successfully loaded {}'.format(checkpoint_name))

    return model


267
def get_ict_dataset(use_titles=True):
Neel Kant's avatar
Neel Kant committed
268
    args = get_args()
Neel Kant's avatar
Neel Kant committed
269
    block_dataset = get_indexed_dataset_(args.data_path, 'mmap', True)
Neel Kant's avatar
Neel Kant committed
270
    titles_dataset = get_indexed_dataset_(args.titles_data_path, 'mmap', True)
Neel Kant's avatar
Neel Kant committed
271
272
273

    kwargs = dict(
        name='full',
Neel Kant's avatar
Neel Kant committed
274
275
        block_dataset=block_dataset,
        title_dataset=titles_dataset,
Neel Kant's avatar
Neel Kant committed
276
        data_prefix=args.data_path,
Neel Kant's avatar
Neel Kant committed
277
278
        num_epochs=1,
        max_num_samples=None,
Neel Kant's avatar
Neel Kant committed
279
280
        max_seq_length=288,  # doesn't matter
        short_seq_prob=0.0001,  # doesn't matter
281
        seed=1,
Neel Kant's avatar
Neel Kant committed
282
        query_in_block_prob=1,
283
        use_titles=use_titles
Neel Kant's avatar
Neel Kant committed
284
    )
285
    dataset = ICTDataset(**kwargs)
Neel Kant's avatar
Neel Kant committed
286
287
288
    return dataset


Neel Kant's avatar
Neel Kant committed
289
def get_one_epoch_dataloader(dataset):
Neel Kant's avatar
Neel Kant committed
290
291
292
    args = get_args()

    world_size = mpu.get_data_parallel_world_size()
293
    print(world_size, flush=True)
Neel Kant's avatar
Neel Kant committed
294
    rank = mpu.get_data_parallel_rank()
295
    print(rank, flush=True)
Neel Kant's avatar
Neel Kant committed
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
    global_batch_size = args.batch_size * world_size
    num_workers = args.num_workers

    sampler = torch.utils.data.SequentialSampler(dataset)
    batch_sampler = DistributedBatchSampler(sampler,
                                            batch_size=global_batch_size,
                                            drop_last=True,
                                            rank=rank,
                                            world_size=world_size)

    return torch.utils.data.DataLoader(dataset,
                                       batch_sampler=batch_sampler,
                                       num_workers=num_workers,
                                       pin_memory=True)


if __name__ == "__main__":
313
314
315
316
317
    initialize_megatron(extra_args_provider=None,
                        args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
    index_builder = BasicIndexBuilder()
    index_builder.build_and_save_index()