indexer.py 9.95 KB
Newer Older
1
import os
2
import sys
3
4
import time

Neel Kant's avatar
Neel Kant committed
5
import torch
6
import torch.distributed as dist
Neel Kant's avatar
Neel Kant committed
7
8
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP

9
from megatron import get_args, get_adlr_autoresume, print_rank_0
Neel Kant's avatar
Neel Kant committed
10
11
12
from megatron import mpu
from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoint_name
from megatron.data.bert_dataset import get_indexed_dataset_
13
from megatron.data.realm_dataset import ICTDataset
14
from megatron.data.realm_index import detach, BlockData, FaissMIPSIndex
Neel Kant's avatar
Neel Kant committed
15
16
from megatron.data.samplers import DistributedBatchSampler
from megatron.initialize import initialize_megatron
Neel Kant's avatar
Neel Kant committed
17
from megatron.model import REALMRetriever
18
from megatron.global_vars import set_global_variables
19
from megatron.mpu.initialize import get_index_ready, get_index_group, get_train_group, get_data_parallel_group, get_gloo_comm_group
20
21
from megatron.mpu.initialize import set_data_parallel_group, set_model_parallel_group, init_realm_groups
from megatron.initialize import init_distributed, _init_autoresume, _set_random_seed, _write_args_to_tensorboard
Neel Kant's avatar
Neel Kant committed
22
from megatron.training import get_model
23
from megatron.utils import check_adlr_autoresume_termination
Neel Kant's avatar
Neel Kant committed
24
25
26
from pretrain_bert_ict import get_batch, model_provider


27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
INDEX_READY = None


def pprint(*args):
    print(*args, flush=True)


def initialize_and_run_async_megatron(extra_args_provider=None, args_defaults={},
                                      ignore_unknown_args=False, allow_no_cuda=False):
    if not allow_no_cuda:
        # Make sure cuda is available.
        assert torch.cuda.is_available(), 'Megatron requires CUDA.'

    # Parse args, build tokenizer, and set adlr-autoresume,
    # tensorboard-writer, and timers.
    set_global_variables(extra_args_provider=extra_args_provider,
                         args_defaults=args_defaults,
                         ignore_unknown_args=ignore_unknown_args)

    # instead of _initialize_distributed()
    init_distributed()
    setup_realm_groups_and_vars()
    global INDEX_READY
    INDEX_READY = get_index_ready()
    pprint('finished setting up groups')

    # Autoresume
    _init_autoresume()
    pprint('finished setting up autoresume')

    # Random seeds for reproducibility.
Neel Kant's avatar
Neel Kant committed
58
    args = get_args()
59
60
61
    if args.rank == 0:
        pprint('> setting random seeds to {} ...'.format(args.seed))
    _set_random_seed(args.seed)
Neel Kant's avatar
Neel Kant committed
62

63
64
65
    # Write arguments to tensorboard.
    _write_args_to_tensorboard()
    pprint('finished writing args to tensorboard')
Neel Kant's avatar
Neel Kant committed
66

67
    torch.distributed.barrier()
Neel Kant's avatar
Neel Kant committed
68

69
    if args.rank < args.max_training_rank:
70
        torch.distributed.barrier(get_data_parallel_group())
71
72
73
74
        pprint("All trainers ready.")
        return
    else:
        runner = AsyncIndexBuilder(args.rank)
75
        torch.distributed.barrier(get_data_parallel_group())
76
77
        pprint("All indexers ready.")
        runner.run_async()
Neel Kant's avatar
Neel Kant committed
78
79


80
def setup_realm_groups_and_vars():
Neel Kant's avatar
Neel Kant committed
81
    args = get_args()
82
83
84
85
86
87
88
89
90
91
92
93
94
    world_size = dist.get_world_size()
    max_training_rank = args.max_training_rank

    # assuming no model parallelism right now
    set_model_parallel_group(dist.new_group([args.rank]))
    init_realm_groups(max_training_rank, world_size)

    if args.rank < max_training_rank:
        set_data_parallel_group(get_train_group())
    else:
        set_data_parallel_group(get_index_group())


Neel Kant's avatar
Neel Kant committed
95
96
class IndexBuilder(object):
    def __init__(self):
97
        args = get_args()
Neel Kant's avatar
Neel Kant committed
98
        self.debug = args.debug
Neel Kant's avatar
Neel Kant committed
99
        self.rank = args.rank
100
101
102
103
        self.model = None
        self.dataloader = None
        self.block_data = None
        self.load_attributes()
Neel Kant's avatar
Neel Kant committed
104
        self.is_main_builder = args.rank == 0
105
106

    def load_attributes(self):
Neel Kant's avatar
Neel Kant committed
107
        self.model = load_ict_checkpoint(only_block_model=True, no_grad=True, from_realm_chkpt=False)
108
109
110
        self.model.eval()
        self.dataloader = iter(get_one_epoch_dataloader(get_ict_dataset()))
        self.block_data = BlockData()
111

112
    def build_and_save_index(self):
113
114
115
116
117
118
        i = 1
        total = 0
        while True:
            with torch.no_grad():
                try:
                    query_tokens, query_pad_mask, \
119
                    block_tokens, block_pad_mask, block_index_data = get_batch(self.dataloader)
120
                except:
121
122
                    break

123
124
125
126
                block_index_data = detach(block_index_data)
                block_indices = block_index_data[:, 3]
                block_meta = block_index_data[:, :3]

127
128
                block_logits = detach(self.model(None, None, block_tokens, block_pad_mask, only_block=True))
                self.block_data.add_block_data(block_indices, block_logits, block_meta)
129
130
131

                total += block_indices.size
                i += 1
Neel Kant's avatar
Neel Kant committed
132
                if i % 1000 == 0:
133
                    print('Batch {:10d} | Total {:10d}'.format(i, total), flush=True)
134
                    if self.debug:
135
136
                        break

137
        self.block_data.save_shard(self.rank)
138
        torch.distributed.barrier(get_data_parallel_group())
139
140
141
142
        del self.model

        if self.is_main_builder:
            self.block_data.consolidate_shards_and_save(ignore_shard=self.rank)
143
        self.block_data.clear()
144

Neel Kant's avatar
Neel Kant committed
145
146
147
148
149
150
151
152
153
154
155
156
157
158

class AsyncIndexBuilder(IndexBuilder):
    def __init__(self, rank):
        self.rank = rank
        args = get_args()
        self.is_main_builder = self.rank == args.max_training_rank
        self.main_builder_idx = args.max_training_rank
        self.debug = args.debug

        self.model = None
        self.dataloader = None
        self.block_data = None
        self.load_attributes()

159
        global INDEX_READY
Neel Kant's avatar
Neel Kant committed
160
        INDEX_READY = get_index_ready()
161

Neel Kant's avatar
Neel Kant committed
162
163
164
    def run_async(self):
        global INDEX_READY
        # synchronize for start
165
        dist.broadcast(INDEX_READY, 0, group=get_gloo_comm_group())
Neel Kant's avatar
Neel Kant committed
166
167
168
169
170
        while True:
            print("Starting (again!)", flush=True)
            self.build_and_save_index()
            self.send_index_ready_signal()
            self.load_attributes()
171

Neel Kant's avatar
Neel Kant committed
172
173
174
175
176
177
    def load_attributes(self):
        try:
            self.model = load_ict_checkpoint(only_block_model=True, no_grad=True, from_realm_chkpt=True)
        except:
            print(">>>>> No realm chkpt available", flush=True)
            self.model = load_ict_checkpoint(only_block_model=True, no_grad=True, from_realm_chkpt=False)
178
179
180
181
        self.model.eval()
        self.dataloader = iter(get_one_epoch_dataloader(get_ict_dataset()))
        self.block_data = BlockData()

Neel Kant's avatar
Neel Kant committed
182
183
184
185
186
187
    def send_index_ready_signal(self):
        global INDEX_READY
        if self.is_main_builder:
            INDEX_READY = 1 - INDEX_READY
            print("Switched INDEX_READY", flush=True)
        torch.cuda.synchronize()
188

Neel Kant's avatar
Neel Kant committed
189
190
        # send handle
        dist.broadcast(INDEX_READY, self.main_builder_idx, group=get_gloo_comm_group(), async_op=True)
191

Neel Kant's avatar
Neel Kant committed
192
193
        # recv handle
        dist.broadcast(INDEX_READY, 0, group=get_gloo_comm_group())
194
195


196
def load_ict_checkpoint(only_query_model=False, only_block_model=False, no_grad=False, from_realm_chkpt=False):
Neel Kant's avatar
Neel Kant committed
197
    args = get_args()
Neel Kant's avatar
Neel Kant committed
198
    model = get_model(lambda: model_provider(only_query_model, only_block_model))
Neel Kant's avatar
Neel Kant committed
199
200
201

    if isinstance(model, torchDDP):
        model = model.module
202
203
204

    load_path = args.load if from_realm_chkpt else args.ict_load

205
    tracker_filename = get_checkpoint_tracker_filename(load_path)
Neel Kant's avatar
Neel Kant committed
206
207
208
    with open(tracker_filename, 'r') as f:
        iteration = int(f.read().strip())

209
    # assert iteration > 0
210
    checkpoint_name = get_checkpoint_name(load_path, iteration, False)
Neel Kant's avatar
Neel Kant committed
211
212
213
214
215
    if mpu.get_data_parallel_rank() == 0:
        print('global rank {} is loading checkpoint {}'.format(
            torch.distributed.get_rank(), checkpoint_name))

    state_dict = torch.load(checkpoint_name, map_location='cpu')
216
217
    ict_state_dict = state_dict['model']
    if from_realm_chkpt:
218
        print(">>>> Attempting to get ict state dict from realm", flush=True)
219
220
        ict_state_dict = ict_state_dict['retriever']['ict_model']

Neel Kant's avatar
Neel Kant committed
221
    if only_query_model:
222
        ict_state_dict.pop('context_model')
Neel Kant's avatar
Neel Kant committed
223
    if only_block_model:
224
        ict_state_dict.pop('question_model')
Neel Kant's avatar
Neel Kant committed
225
226
    if no_grad:
        with torch.no_grad():
227
            model.load_state_dict(ict_state_dict)
Neel Kant's avatar
Neel Kant committed
228
    else:
229
        model.load_state_dict(ict_state_dict)
230
    torch.distributed.barrier(get_data_parallel_group())
Neel Kant's avatar
Neel Kant committed
231
232
233
234
235
236
237

    if mpu.get_data_parallel_rank() == 0:
        print(' successfully loaded {}'.format(checkpoint_name))

    return model


238
def get_ict_dataset(use_titles=True):
Neel Kant's avatar
Neel Kant committed
239
    args = get_args()
Neel Kant's avatar
Neel Kant committed
240
    block_dataset = get_indexed_dataset_(args.data_path, 'mmap', True)
Neel Kant's avatar
Neel Kant committed
241
    titles_dataset = get_indexed_dataset_(args.titles_data_path, 'mmap', True)
Neel Kant's avatar
Neel Kant committed
242
243
244

    kwargs = dict(
        name='full',
Neel Kant's avatar
Neel Kant committed
245
246
        block_dataset=block_dataset,
        title_dataset=titles_dataset,
Neel Kant's avatar
Neel Kant committed
247
        data_prefix=args.data_path,
Neel Kant's avatar
Neel Kant committed
248
249
        num_epochs=1,
        max_num_samples=None,
Neel Kant's avatar
Neel Kant committed
250
        max_seq_length=args.seq_length,
Neel Kant's avatar
Neel Kant committed
251
        short_seq_prob=0.0001,  # doesn't matter
252
        seed=1,
Neel Kant's avatar
Neel Kant committed
253
        query_in_block_prob=1,
254
        use_titles=use_titles
Neel Kant's avatar
Neel Kant committed
255
    )
256
    dataset = ICTDataset(**kwargs)
Neel Kant's avatar
Neel Kant committed
257
258
259
    return dataset


260
def get_one_epoch_dataloader(dataset, batch_size=None):
Neel Kant's avatar
Neel Kant committed
261
262
263
264
    args = get_args()

    world_size = mpu.get_data_parallel_world_size()
    rank = mpu.get_data_parallel_rank()
265
266
267
    if batch_size is None:
        batch_size = args.batch_size
    global_batch_size = batch_size * world_size
Neel Kant's avatar
Neel Kant committed
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
    num_workers = args.num_workers

    sampler = torch.utils.data.SequentialSampler(dataset)
    batch_sampler = DistributedBatchSampler(sampler,
                                            batch_size=global_batch_size,
                                            drop_last=True,
                                            rank=rank,
                                            world_size=world_size)

    return torch.utils.data.DataLoader(dataset,
                                       batch_sampler=batch_sampler,
                                       num_workers=num_workers,
                                       pin_memory=True)


if __name__ == "__main__":
284
285
    initialize_megatron(extra_args_provider=None,
                        args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
Neel Kant's avatar
Neel Kant committed
286
    index_builder = IndexBuilder()
287
288
    index_builder.build_and_save_index()