indexer.py 10.1 KB
Newer Older
1
import os
2
import sys
3
4
import time

Neel Kant's avatar
Neel Kant committed
5
import torch
6
import torch.distributed as dist
Neel Kant's avatar
Neel Kant committed
7
8
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP

9
from megatron import get_args, get_adlr_autoresume, print_rank_0
Neel Kant's avatar
Neel Kant committed
10
11
12
from megatron import mpu
from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoint_name
from megatron.data.bert_dataset import get_indexed_dataset_
13
from megatron.data.realm_dataset import ICTDataset
14
from megatron.data.realm_index import detach, BlockData, FaissMIPSIndex
Neel Kant's avatar
Neel Kant committed
15
16
from megatron.data.samplers import DistributedBatchSampler
from megatron.initialize import initialize_megatron
Neel Kant's avatar
Neel Kant committed
17
from megatron.model import REALMRetriever
18
from megatron.global_vars import set_global_variables
19
from megatron.mpu.initialize import get_index_ready, get_index_group, get_train_group, get_data_parallel_group, get_gloo_comm_group
20
21
from megatron.mpu.initialize import set_data_parallel_group, set_model_parallel_group, init_realm_groups
from megatron.initialize import init_distributed, _init_autoresume, _set_random_seed, _write_args_to_tensorboard
Neel Kant's avatar
Neel Kant committed
22
from megatron.training import get_model
23
from megatron.utils import check_adlr_autoresume_termination
Neel Kant's avatar
Neel Kant committed
24
25
26
from pretrain_bert_ict import get_batch, model_provider


27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
INDEX_READY = None


def pprint(*args):
    print(*args, flush=True)


def initialize_and_run_async_megatron(extra_args_provider=None, args_defaults={},
                                      ignore_unknown_args=False, allow_no_cuda=False):
    if not allow_no_cuda:
        # Make sure cuda is available.
        assert torch.cuda.is_available(), 'Megatron requires CUDA.'

    # Parse args, build tokenizer, and set adlr-autoresume,
    # tensorboard-writer, and timers.
    set_global_variables(extra_args_provider=extra_args_provider,
                         args_defaults=args_defaults,
                         ignore_unknown_args=ignore_unknown_args)

    # instead of _initialize_distributed()
    init_distributed()
    setup_realm_groups_and_vars()
    global INDEX_READY
    INDEX_READY = get_index_ready()
    pprint('finished setting up groups')

    # Autoresume
    _init_autoresume()
    pprint('finished setting up autoresume')

    # Random seeds for reproducibility.
Neel Kant's avatar
Neel Kant committed
58
    args = get_args()
59
60
61
    if args.rank == 0:
        pprint('> setting random seeds to {} ...'.format(args.seed))
    _set_random_seed(args.seed)
Neel Kant's avatar
Neel Kant committed
62

63
64
65
    # Write arguments to tensorboard.
    _write_args_to_tensorboard()
    pprint('finished writing args to tensorboard')
Neel Kant's avatar
Neel Kant committed
66

67
    torch.distributed.barrier()
Neel Kant's avatar
Neel Kant committed
68

69
    if args.rank < args.max_training_rank:
70
        torch.distributed.barrier(get_data_parallel_group())
71
72
73
74
        pprint("All trainers ready.")
        return
    else:
        runner = AsyncIndexBuilder(args.rank)
75
        torch.distributed.barrier(get_data_parallel_group())
76
77
        pprint("All indexers ready.")
        runner.run_async()
Neel Kant's avatar
Neel Kant committed
78
79


80
def setup_realm_groups_and_vars():
Neel Kant's avatar
Neel Kant committed
81
    args = get_args()
82
83
84
85
86
87
88
89
90
91
92
93
94
    world_size = dist.get_world_size()
    max_training_rank = args.max_training_rank

    # assuming no model parallelism right now
    set_model_parallel_group(dist.new_group([args.rank]))
    init_realm_groups(max_training_rank, world_size)

    if args.rank < max_training_rank:
        set_data_parallel_group(get_train_group())
    else:
        set_data_parallel_group(get_index_group())


Neel Kant's avatar
Neel Kant committed
95
96
class IndexBuilder(object):
    def __init__(self):
97
        args = get_args()
Neel Kant's avatar
Neel Kant committed
98
        self.rank = args.rank
99
100
101
102
        self.model = None
        self.dataloader = None
        self.block_data = None
        self.load_attributes()
Neel Kant's avatar
Neel Kant committed
103
        self.is_main_builder = args.rank == 0
104
105

    def load_attributes(self):
Neel Kant's avatar
Neel Kant committed
106
        self.model = load_ict_checkpoint(only_block_model=True, no_grad=True, from_realm_chkpt=False)
107
108
109
        self.model.eval()
        self.dataloader = iter(get_one_epoch_dataloader(get_ict_dataset()))
        self.block_data = BlockData()
110

111
    def build_and_save_index(self):
112
113
114
115
116
117
        i = 1
        total = 0
        while True:
            with torch.no_grad():
                try:
                    query_tokens, query_pad_mask, \
118
                    block_tokens, block_pad_mask, block_index_data = get_batch(self.dataloader)
119
                except:
120
121
                    break

122
123
124
125
                block_index_data = detach(block_index_data)
                block_indices = block_index_data[:, 3]
                block_meta = block_index_data[:, :3]

126
127
                block_logits = detach(self.model(None, None, block_tokens, block_pad_mask, only_block=True))
                self.block_data.add_block_data(block_indices, block_logits, block_meta)
128
129
130

                total += block_indices.size
                i += 1
Neel Kant's avatar
Neel Kant committed
131
                if i % 1000 == 0:
132
                    print('Batch {:10d} | Total {:10d}'.format(i, total), flush=True)
133
                    if self.debug:
134
135
                        break

136
        self.block_data.save_shard(self.rank)
137
        torch.distributed.barrier(get_data_parallel_group())
138
139
140
141
        del self.model

        if self.is_main_builder:
            self.block_data.consolidate_shards_and_save(ignore_shard=self.rank)
142
        self.block_data.clear()
143

Neel Kant's avatar
Neel Kant committed
144
145
146
147
148
149
150
151
152
153
154
155
156
157

class AsyncIndexBuilder(IndexBuilder):
    def __init__(self, rank):
        self.rank = rank
        args = get_args()
        self.is_main_builder = self.rank == args.max_training_rank
        self.main_builder_idx = args.max_training_rank
        self.debug = args.debug

        self.model = None
        self.dataloader = None
        self.block_data = None
        self.load_attributes()

158
        global INDEX_READY
Neel Kant's avatar
Neel Kant committed
159
        INDEX_READY = get_index_ready()
160

Neel Kant's avatar
Neel Kant committed
161
162
163
    def run_async(self):
        global INDEX_READY
        # synchronize for start
164
        dist.broadcast(INDEX_READY, 0, group=get_gloo_comm_group())
Neel Kant's avatar
Neel Kant committed
165
166
167
168
169
170
171
        while True:
            print("Starting (again!)", flush=True)
            self.build_and_save_index()
            self.send_index_ready_signal()
            while INDEX_READY == 1:
                print("Waiting for new model checkpoint.", flush=True)
                time.sleep(5)
172

Neel Kant's avatar
Neel Kant committed
173
            self.load_attributes()
174

Neel Kant's avatar
Neel Kant committed
175
176
177
178
179
180
    def load_attributes(self):
        try:
            self.model = load_ict_checkpoint(only_block_model=True, no_grad=True, from_realm_chkpt=True)
        except:
            print(">>>>> No realm chkpt available", flush=True)
            self.model = load_ict_checkpoint(only_block_model=True, no_grad=True, from_realm_chkpt=False)
181
182
183
184
        self.model.eval()
        self.dataloader = iter(get_one_epoch_dataloader(get_ict_dataset()))
        self.block_data = BlockData()

Neel Kant's avatar
Neel Kant committed
185
186
187
188
189
190
    def send_index_ready_signal(self):
        global INDEX_READY
        if self.is_main_builder:
            INDEX_READY = 1 - INDEX_READY
            print("Switched INDEX_READY", flush=True)
        torch.cuda.synchronize()
191

Neel Kant's avatar
Neel Kant committed
192
193
        # send handle
        dist.broadcast(INDEX_READY, self.main_builder_idx, group=get_gloo_comm_group(), async_op=True)
194

Neel Kant's avatar
Neel Kant committed
195
196
197
        # recv handle
        dist.broadcast(INDEX_READY, 0, group=get_gloo_comm_group())
        torch.distributed.barrier(get_data_parallel_group())
198
199


200
def load_ict_checkpoint(only_query_model=False, only_block_model=False, no_grad=False, from_realm_chkpt=False):
Neel Kant's avatar
Neel Kant committed
201
    args = get_args()
Neel Kant's avatar
Neel Kant committed
202
    model = get_model(lambda: model_provider(only_query_model, only_block_model))
Neel Kant's avatar
Neel Kant committed
203
204
205

    if isinstance(model, torchDDP):
        model = model.module
206
207
208

    load_path = args.load if from_realm_chkpt else args.ict_load

209
    tracker_filename = get_checkpoint_tracker_filename(load_path)
Neel Kant's avatar
Neel Kant committed
210
211
212
    with open(tracker_filename, 'r') as f:
        iteration = int(f.read().strip())

213
    # assert iteration > 0
214
    checkpoint_name = get_checkpoint_name(load_path, iteration, False)
Neel Kant's avatar
Neel Kant committed
215
216
217
218
219
    if mpu.get_data_parallel_rank() == 0:
        print('global rank {} is loading checkpoint {}'.format(
            torch.distributed.get_rank(), checkpoint_name))

    state_dict = torch.load(checkpoint_name, map_location='cpu')
220
221
    ict_state_dict = state_dict['model']
    if from_realm_chkpt:
222
        print(">>>> Attempting to get ict state dict from realm", flush=True)
223
224
        ict_state_dict = ict_state_dict['retriever']['ict_model']

Neel Kant's avatar
Neel Kant committed
225
    if only_query_model:
226
        ict_state_dict.pop('context_model')
Neel Kant's avatar
Neel Kant committed
227
    if only_block_model:
228
        ict_state_dict.pop('question_model')
Neel Kant's avatar
Neel Kant committed
229
230
    if no_grad:
        with torch.no_grad():
231
            model.load_state_dict(ict_state_dict)
Neel Kant's avatar
Neel Kant committed
232
    else:
233
        model.load_state_dict(ict_state_dict)
234
    torch.distributed.barrier(get_data_parallel_group())
Neel Kant's avatar
Neel Kant committed
235
236
237
238
239
240
241

    if mpu.get_data_parallel_rank() == 0:
        print(' successfully loaded {}'.format(checkpoint_name))

    return model


242
def get_ict_dataset(use_titles=True):
Neel Kant's avatar
Neel Kant committed
243
    args = get_args()
Neel Kant's avatar
Neel Kant committed
244
    block_dataset = get_indexed_dataset_(args.data_path, 'mmap', True)
Neel Kant's avatar
Neel Kant committed
245
    titles_dataset = get_indexed_dataset_(args.titles_data_path, 'mmap', True)
Neel Kant's avatar
Neel Kant committed
246
247
248

    kwargs = dict(
        name='full',
Neel Kant's avatar
Neel Kant committed
249
250
        block_dataset=block_dataset,
        title_dataset=titles_dataset,
Neel Kant's avatar
Neel Kant committed
251
        data_prefix=args.data_path,
Neel Kant's avatar
Neel Kant committed
252
253
        num_epochs=1,
        max_num_samples=None,
Neel Kant's avatar
Neel Kant committed
254
255
        max_seq_length=288,  # doesn't matter
        short_seq_prob=0.0001,  # doesn't matter
256
        seed=1,
Neel Kant's avatar
Neel Kant committed
257
        query_in_block_prob=1,
258
        use_titles=use_titles
Neel Kant's avatar
Neel Kant committed
259
    )
260
    dataset = ICTDataset(**kwargs)
Neel Kant's avatar
Neel Kant committed
261
262
263
    return dataset


264
def get_one_epoch_dataloader(dataset, batch_size=None):
Neel Kant's avatar
Neel Kant committed
265
266
267
268
    args = get_args()

    world_size = mpu.get_data_parallel_world_size()
    rank = mpu.get_data_parallel_rank()
269
270
271
    if batch_size is None:
        batch_size = args.batch_size
    global_batch_size = batch_size * world_size
Neel Kant's avatar
Neel Kant committed
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
    num_workers = args.num_workers

    sampler = torch.utils.data.SequentialSampler(dataset)
    batch_sampler = DistributedBatchSampler(sampler,
                                            batch_size=global_batch_size,
                                            drop_last=True,
                                            rank=rank,
                                            world_size=world_size)

    return torch.utils.data.DataLoader(dataset,
                                       batch_sampler=batch_sampler,
                                       num_workers=num_workers,
                                       pin_memory=True)


if __name__ == "__main__":
288
289
290
291
292
    initialize_megatron(extra_args_provider=None,
                        args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
    index_builder = BasicIndexBuilder()
    index_builder.build_and_save_index()