training.py 22.9 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Mohammad's avatar
Mohammad committed
16
"""Pretrain utilities."""
17
18
19

from datetime import datetime
import math
Mohammad's avatar
Mohammad committed
20
import sys
21
import time
22
23
24
25
26

import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from apex.optimizers import FusedAdam as Adam

Mohammad's avatar
Mohammad committed
27
28
29
from megatron import get_args
from megatron import get_timers
from megatron import get_tensorboard_writer
30
from megatron import mpu
Mohammad's avatar
Mohammad committed
31
32
33
from megatron import print_rank_0
from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint
34
35
from megatron.fp16 import FP16_Module
from megatron.fp16 import FP16_Optimizer
Mohammad's avatar
Mohammad committed
36
from megatron.initialize import initialize_megatron
37
38
39
from megatron.learning_rates import AnnealingLR
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import get_params_for_weight_decay_optimization
40
from megatron.mpu.initialize import get_index_ready, get_train_group, get_data_parallel_group, get_gloo_comm_group
Neel Kant's avatar
Neel Kant committed
41
from megatron.model.realm_model import ICTBertModel
42
from megatron.utils import check_adlr_autoresume_termination
43
from megatron.utils import make_data_loader
44
from megatron.utils import report_memory
45
46
47


INDEX_READY = None
48
49


50
def pretrain(train_valid_test_dataset_provider, model_provider,
51
52
             forward_step_func, extra_args_provider=None, args_defaults={},
             initializer_func=None):
53
54
55
    """Main training program.

    This function will run the followings in the order provided:
Mohammad's avatar
Mohammad committed
56
57
        1) initialize Megatron.
        2) setup model, optimizer and lr schedule using the model_provider.
58
        3) call train_val_test_data_provider to get train/val/test datasets.
Mohammad's avatar
Mohammad committed
59
        4) train the modle using the forward_step_func.
60
61

    Arguments:
62
63
64
        train_valid_test_dataset_provider: a function that takes the size of
            train/valid/test dataset and returns `train, valid, test` datasets.
        model_provider: a function that returns a vanilla version of the
Mohammad's avatar
Mohammad committed
65
66
67
68
69
70
71
72
73
74
            model. By vanilla we mean a simple model on cpu with no fp16 or ddp.
        forward_step_func: a function that takes a `data iterator` and `model`,
            and returns a `loss` scalar with a dictionary with key:values being
            the info we would like to monitor during training, for example
            `lm-loss: value`. We also require that this function add
            `batch generator` to the timers class.
        extra_args_provider: a function that takes a parser and adds arguments
            to it. It is used for programs to add their own arguments.
        args_defaults: a dictionary from argument-name to argument-value. It
            to set already parse arguments.
75
76
    """

77
    # Initalize and get arguments, timers, and Tensorboard writer.
78
79
80
81
82
83
84
85
86
    if initializer_func is None:
        initialize_megatron(extra_args_provider=extra_args_provider,
                            args_defaults=args_defaults)
    else:
        initializer_func(extra_args_provider=extra_args_provider,
                         args_defaults=args_defaults)
        global INDEX_READY
        INDEX_READY = get_index_ready()

87
    args = get_args()
Mohammad's avatar
Mohammad committed
88
    timers = get_timers()
89

90
91
92
93
94
    if args.rank == 0 and args.cased_data_path is not None:
        import stanza
        stanza.download('en', processors={'ner': 'conll03'}, dir='stanza')


95
    # Model, optimizer, and learning rate.
Mohammad's avatar
Mohammad committed
96
97
98
    timers('model and optimizer').start()
    model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider)
    timers('model and optimizer').stop()
99
100

    # Data stuff.
101
102
103
104
105
    timers('train/valid/test data iterators').start()
    train_data_iterator, valid_data_iterator, test_data_iterator \
        = build_train_valid_test_data_iterators(
            train_valid_test_dataset_provider)
    timers('train/valid/test data iterators').stop()
Mohammad's avatar
Mohammad committed
106
107
108

    # Print setup timing.
    print_rank_0('done with setups ...')
109
    timers.log(['model and optimizer', 'train/valid/test data iterators'])
Mohammad's avatar
Mohammad committed
110
    print_rank_0('training ...')
111
112

    iteration = 0
113
    if args.do_train and args.train_iters > 0:
114
115
        iteration, _ = train(forward_step_func,
                             model, optimizer, lr_scheduler,
Neel Kant's avatar
Neel Kant committed
116
                             train_data_iterator, valid_data_iterator)
Mohammad's avatar
Mohammad committed
117

118
119
120
    if args.do_valid:
        prefix = 'the end of training for val data'
        evaluate_and_print_results(prefix, forward_step_func,
121
                                   valid_data_iterator, model,
Mohammad's avatar
Mohammad committed
122
                                   iteration, False)
123
124

    if args.save and iteration != 0:
125
        save_checkpoint(iteration, model, optimizer, lr_scheduler)
126
127
128
129
130
131

    if args.do_test:
        # Run on test data.
        prefix = 'the end of training for test data'
        evaluate_and_print_results(prefix, forward_step_func,
                                   test_data_iterator, model,
Mohammad's avatar
Mohammad committed
132
                                   0, True)
133
134


Mohammad's avatar
Mohammad committed
135
def get_model(model_provider_func):
136
    """Build the model."""
Mohammad's avatar
Mohammad committed
137
    args = get_args()
138
139

    # Build model on cpu.
Mohammad's avatar
Mohammad committed
140
    model = model_provider_func()
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157

    # Print number of parameters.
    if mpu.get_data_parallel_rank() == 0:
        print(' > number of parameters on model parallel rank {}: {}'.format(
            mpu.get_model_parallel_rank(),
            sum([p.nelement() for p in model.parameters()])), flush=True)

    # GPU allocation.
    model.cuda(torch.cuda.current_device())

    # Fp16 conversion.
    if args.fp16:
        model = FP16_Module(model)

    # Wrap model for distributed training."""
    if args.DDP_impl == 'torch':
        i = torch.cuda.current_device()
Mohammad's avatar
Mohammad committed
158
159
        model = torchDDP(model, device_ids=[i], output_device=i,
                         process_group=mpu.get_data_parallel_group())
160
161
        return model
    if args.DDP_impl == 'local':
Mohammad's avatar
Mohammad committed
162
        model = LocalDDP(model)
163
164
        return model

165
    raise NotImplementedError('Unknown DDP implementation specified: {}. '
166
                              'Exiting.'.format(args.DDP_impl))
167
168


Mohammad's avatar
Mohammad committed
169
def get_optimizer(model):
170
    """Set up the optimizer."""
Mohammad's avatar
Mohammad committed
171
    args = get_args()
172
173

    # Build parameter groups (weight decay and non-decay).
Mohammad's avatar
Mohammad committed
174
    while isinstance(model, (torchDDP, LocalDDP, FP16_Module)):
175
176
177
178
179
180
181
182
183
184
        model = model.module
    param_groups = get_params_for_weight_decay_optimization(model)

    # Add model parallel attribute if it is not set.
    for param_group in param_groups:
        for param in param_group['params']:
            if not hasattr(param, 'model_parallel'):
                param.model_parallel = False

    # Use Adam.
Mohammad's avatar
Mohammad committed
185
    optimizer = Adam(param_groups, lr=args.lr, weight_decay=args.weight_decay)
186
187
188
189
190
191
192
193

    # Wrap into fp16 optimizer.
    if args.fp16:
        optimizer = FP16_Optimizer(optimizer,
                                   static_loss_scale=args.loss_scale,
                                   dynamic_loss_scale=args.dynamic_loss_scale,
                                   dynamic_loss_args={
                                       'scale_window': args.loss_scale_window,
Neel Kant's avatar
Neel Kant committed
194
                                       'min_scale': args.min_scale,
195
196
197
198
199
                                       'delayed_shift': args.hysteresis})

    return optimizer


Mohammad's avatar
Mohammad committed
200
def get_learning_rate_scheduler(optimizer):
201
    """Build the learning rate scheduler."""
Mohammad's avatar
Mohammad committed
202
    args = get_args()
203
204
205
206
207
208
209

    # Add linear learning rate scheduler.
    if args.lr_decay_iters is not None:
        num_iters = args.lr_decay_iters
    else:
        num_iters = args.train_iters
    num_iters = max(1, num_iters)
Mohammad's avatar
Mohammad committed
210
    init_step = 0
211
212
213
214
215
    warmup_iter = args.warmup * num_iters
    lr_scheduler = AnnealingLR(
        optimizer,
        start_lr=args.lr,
        warmup_iter=warmup_iter,
Mohammad's avatar
Mohammad committed
216
        total_iters=num_iters,
217
218
219
220
221
222
223
224
225
        decay_style=args.lr_decay_style,
        last_iter=init_step,
        min_lr=args.min_lr,
        use_checkpoint_lr_scheduler=args.use_checkpoint_lr_scheduler,
        override_lr_scheduler=args.override_lr_scheduler)

    return lr_scheduler


Mohammad's avatar
Mohammad committed
226
def setup_model_and_optimizer(model_provider_func):
227
    """Setup model and optimizer."""
Mohammad's avatar
Mohammad committed
228
    args = get_args()
229

Mohammad's avatar
Mohammad committed
230
231
232
    model = get_model(model_provider_func)
    optimizer = get_optimizer(model)
    lr_scheduler = get_learning_rate_scheduler(optimizer)
233
234

    if args.load is not None:
235
        args.iteration = load_checkpoint(model, optimizer, lr_scheduler)
236
237
238
    else:
        args.iteration = 0

Neel Kant's avatar
Neel Kant committed
239
240
241
242
243
244
    if args.iteration == 0 and isinstance(model.module.module, ICTBertModel):
        print("Yes, located ICT model", flush=True)
        model.module.module.init_state_dict_from_bert()
    elif args.iteration == 0:
        print("Ooops", flush=True)

245
246
247
    return model, optimizer, lr_scheduler


Mohammad's avatar
Mohammad committed
248
def backward_step(optimizer, model, loss):
249
    """Backward step."""
250
251
    # if args.rank == 0:
    #    torch.save(lick)
Mohammad's avatar
Mohammad committed
252
253
    args = get_args()
    timers = get_timers()
Neel Kant's avatar
Neel Kant committed
254
    torch.cuda.synchronize()
255
256

    # Backward pass.
Neel Kant's avatar
Neel Kant committed
257
    # optimizer.zero_grad(set_grads_to_None=True)
258
    if args.fp16:
Neel Kant's avatar
Neel Kant committed
259
        optimizer.zero_grad(set_grads_to_None=True)
260
261
        optimizer.backward(loss, update_master_grads=False)
    else:
Neel Kant's avatar
Neel Kant committed
262
        optimizer.zero_grad()
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
        loss.backward()

    # All-reduce if needed.
    if args.DDP_impl == 'local':
        timers('allreduce').start()
        model.allreduce_params(reduce_after=False,
                               fp32_allreduce=args.fp32_allreduce)
        timers('allreduce').stop()
    # Update master gradients.
    if args.fp16:
        optimizer.update_master_grads()
    # Clipping gradients helps prevent the exploding gradient.
    if args.clip_grad > 0:
        if not args.fp16:
            mpu.clip_grad_norm(model.parameters(), args.clip_grad)
        else:
            optimizer.clip_master_grads(args.clip_grad)


Mohammad's avatar
Mohammad committed
282
283
def train_step(forward_step_func, data_iterator,
               model, optimizer, lr_scheduler):
284
    """Single training step."""
Mohammad's avatar
Mohammad committed
285
286
    args = get_args()
    timers = get_timers()
287
288
289

    # Forward model for one step.
    timers('forward').start()
Mohammad's avatar
Mohammad committed
290
    loss, loss_reduced = forward_step_func(data_iterator, model)
291
292
293
    timers('forward').stop()

    timers('backward').start()
Mohammad's avatar
Mohammad committed
294
    backward_step(optimizer, model, loss)
295
296
    timers('backward').stop()

Neel Kant's avatar
Neel Kant committed
297
298
    # Calculate gradients, reduce across processes, and clip.

299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
    # Update parameters.
    timers('optimizer').start()
    optimizer.step()
    timers('optimizer').stop()

    # Update learning rate.
    skipped_iter = 0
    if not (args.fp16 and optimizer.overflow):
        lr_scheduler.step()
    else:
        skipped_iter = 1

    return loss_reduced, skipped_iter


Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
314
def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
Mohammad's avatar
Mohammad committed
315
316
317
318
319
                 loss_scale, report_memory_flag):
    """Log training information such as losses, timing, ...."""
    args = get_args()
    timers = get_timers()
    writer = get_tensorboard_writer()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
320
321
322
323
324
325
326

    # Update losses.
    for key in loss_dict:
        total_loss_dict[key] = total_loss_dict.get(key, 0.) + loss_dict[key]

    # Logging.
    timers_to_log = []
Neel Kant's avatar
Neel Kant committed
327

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
    def add_to_logging(name):
        if name in timers.timers:
            timers_to_log.append(name)
    add_to_logging('forward')
    add_to_logging('backward')
    add_to_logging('allreduce')
    add_to_logging('optimizer')
    add_to_logging('batch generator')

    # Tensorboard values.
    if writer and torch.distributed.get_rank() == 0:
        writer.add_scalar('learning_rate', learning_rate, iteration)
        for key in loss_dict:
            writer.add_scalar(key, loss_dict[key], iteration)
        if args.fp16:
            writer.add_scalar('loss_scale', loss_scale, iteration)
        normalizer = iteration % args.log_interval
        if normalizer == 0:
            normalizer = args.log_interval
        timers.write(timers_to_log, writer, iteration,
                     normalizer=normalizer)

    if iteration % args.log_interval == 0:
        elapsed_time = timers('interval time').elapsed()
        if writer and torch.distributed.get_rank() == 0:
            writer.add_scalar('iteration_time',
                              elapsed_time / args.log_interval, iteration)
        log_string = ' iteration {:8d}/{:8d} |'.format(iteration,
                                                       args.train_iters)
        log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
            elapsed_time * 1000.0 / args.log_interval)
        log_string += ' learning rate: {:.3E} |'.format(learning_rate)
        for key in total_loss_dict:
            avg = total_loss_dict[key].item() / args.log_interval
            log_string += ' {}: {:.6E} |'.format(key, avg)
            total_loss_dict[key] = 0.0
        if args.fp16:
            log_string += ' loss scale: {:.1f} |'.format(loss_scale)
        print_rank_0(log_string)
        if report_memory_flag:
            report_memory('after {} iterations'.format(iteration))
            report_memory_flag = False
        timers.log(timers_to_log, normalizer=args.log_interval)

    return report_memory_flag


375
def train(forward_step_func, model, optimizer, lr_scheduler,
376
          train_data_iterator, valid_data_iterator):
377
    """Train the model function."""
Mohammad's avatar
Mohammad committed
378
379
    args = get_args()
    timers = get_timers()
380
381
382
383
384
385
386
387
388
389
390
391
392

    # Turn on training mode which enables dropout.
    model.train()

    # Tracking loss.
    total_loss_dict = {}

    # Iterations.
    iteration = args.iteration
    skipped_iters = 0

    timers('interval time').start()
    report_memory_flag = True
393
    global INDEX_READY
394
    print('>>> Starting train()', flush=True)
395
    # start off by posting a receive call which will be answered.
396
    # synchronize for start
Neel Kant's avatar
Neel Kant committed
397
398
399
400
    if args.max_training_rank is not None:
        torch.distributed.broadcast(INDEX_READY, 0, group=get_gloo_comm_group())
        recv_handle = torch.distributed.broadcast(INDEX_READY, args.max_training_rank, group=get_gloo_comm_group(), async_op=True)
        last_reload_iteration = iteration
401
    while iteration < args.train_iters:
402
        if args.max_training_rank is not None and iteration >= last_reload_iteration + args.index_reload_interval:
Neel Kant's avatar
Neel Kant committed
403
404
405
            if recv_handle.is_completed():
                # should add check that INDEX_READY == 1 but what else could be happening
                true_model = model
406
407
                if hasattr(true_model, 'module'):
                    true_model = true_model.module
Neel Kant's avatar
Neel Kant committed
408
409
                    if hasattr(true_model, 'module'):
                        true_model = true_model.module
410
411


Neel Kant's avatar
Neel Kant committed
412
                print("> Saving model and reloading index", flush=True)
413
                save_checkpoint(iteration, model, optimizer, lr_scheduler)
Neel Kant's avatar
Neel Kant committed
414
415
416
417
418
419
420
                if args.rank == 0:
                    INDEX_READY = 1 - INDEX_READY
                # send handle
                torch.distributed.broadcast(INDEX_READY, 0, group=get_gloo_comm_group())
                true_model.retriever.reload_index()

                torch.cuda.synchronize()
421

Neel Kant's avatar
Neel Kant committed
422
423
424
425
426
                recv_handle = torch.distributed.broadcast(INDEX_READY, args.max_training_rank, group=get_gloo_comm_group(), async_op=True)
                last_reload_iteration = iteration
            else:
                time.sleep(5)
                continue
427
428
429


        elif iteration < 20:
Mohammad's avatar
Mohammad committed
430
431
432
            #print("moving right along", flush=True)
            #report_memory("iteration {}".format(iteration))
            pass
433
434
435
436
        loss_dict, skipped_iter = train_step(forward_step_func,
                                             train_data_iterator,
                                             model,
                                             optimizer,
Mohammad's avatar
Mohammad committed
437
                                             lr_scheduler)
438

439
440
441
442
        skipped_iters += skipped_iter
        iteration += 1

        # Logging.
Mohammad's avatar
Mohammad committed
443
444
445
        loss_scale = None
        if args.fp16:
            loss_scale = optimizer.loss_scale
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
446
447
        report_memory_flag = training_log(loss_dict, total_loss_dict,
                                          optimizer.param_groups[0]['lr'],
Mohammad's avatar
Mohammad committed
448
                                          iteration, loss_scale,
Mohammad's avatar
Mohammad committed
449
                                          report_memory_flag)
450
451

        # Autoresume
452
453
        if args.adlr_autoresume and \
           (iteration % args.adlr_autoresume_interval == 0):
454
            check_adlr_autoresume_termination(iteration, model, optimizer,
455
                                              lr_scheduler)
456
457
458
459

        # Checkpointing
        if args.save and args.save_interval and \
           iteration % args.save_interval == 0:
460
            save_checkpoint(iteration, model, optimizer, lr_scheduler)
461
462
463
464
465
466

        # Evaluation
        if args.eval_interval and iteration % args.eval_interval == 0 and \
           args.do_valid:
            prefix = 'iteration {}'.format(iteration)
            evaluate_and_print_results(prefix, forward_step_func,
467
                                       valid_data_iterator, model,
Mohammad's avatar
Mohammad committed
468
                                       iteration, False)
469
470

        if args.exit_interval and iteration % args.exit_interval == 0:
471
            torch.distributed.barrier(get_data_parallel_group())
472
473
            time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
            rank = torch.distributed.get_rank()
Mohammad's avatar
Mohammad committed
474
475
476
            print_rank_0('rank: {} | time: {} | exiting the program at '
                         'iteration {}'.format(rank, time_str, iteration))
            sys.exit()
477
478
479
480

    return iteration, skipped_iters


Mohammad's avatar
Mohammad committed
481
def evaluate(forward_step_func, data_iterator, model, verbose=False):
482
    """Evaluation."""
Mohammad's avatar
Mohammad committed
483
    args = get_args()
484
485
486
487
488
489
490
491
492
493
494
495
496
497

    # Turn on evaluation mode which disables dropout.
    model.eval()

    total_loss_dict = {}

    with torch.no_grad():
        iteration = 0
        while iteration < args.eval_iters:
            iteration += 1
            if verbose and iteration % args.log_interval == 0:
                print_rank_0('Evaluating iter {}/{}'.format(iteration,
                                                            args.eval_iters))
            # Forward evaluation.
Mohammad's avatar
Mohammad committed
498
            _, loss_dict = forward_step_func(data_iterator, model)
499
500
501
            # Reduce across processes.
            for key in loss_dict:
                total_loss_dict[key] = total_loss_dict.get(key, 0.) + \
Neel Kant's avatar
Neel Kant committed
502
                    loss_dict[key]
503
504
505
506
507
508
509
510
511
512
513
    # Move model back to the train mode.
    model.train()

    for key in total_loss_dict:
        total_loss_dict[key] /= args.eval_iters

    return total_loss_dict


def evaluate_and_print_results(prefix, forward_step_func,
                               data_iterator, model,
Mohammad's avatar
Mohammad committed
514
                               iteration, verbose=False):
515
    """Helper function to evaluate and dump results on screen."""
Mohammad's avatar
Mohammad committed
516
517
518
    writer = get_tensorboard_writer()

    total_loss_dict = evaluate(forward_step_func, data_iterator, model, verbose)
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
    string = ' validation loss at {} | '.format(prefix)
    for key in total_loss_dict:
        string += '{} value: {:.6E} | '.format(key, total_loss_dict[key].item())
        ppl = math.exp(min(20, total_loss_dict[key].item()))
        string += '{} PPL: {:.6E} | '.format(key, ppl)
        if writer and torch.distributed.get_rank() == 0:
            writer.add_scalar('{} value'.format(key),
                              total_loss_dict[key].item(),
                              iteration)
            writer.add_scalar('{} ppl'.format(key), ppl, iteration)

    length = len(string) + 1
    print_rank_0('-' * length)
    print_rank_0(string)
    print_rank_0('-' * length)


536
537
538
def build_train_valid_test_data_iterators(
        build_train_valid_test_datasets_provider):
    """XXX"""
Mohammad's avatar
Mohammad committed
539
    args = get_args()
540

541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
    (train_dataloader, valid_dataloader, test_dataloader) = (None, None, None)

    print_rank_0('> building train, validation, and test datasets ...')
    # Data loader only on rank 0 of each model parallel group.
    if mpu.get_model_parallel_rank() == 0:
        # Rank, size, and global batch size.
        data_parallel_size = mpu.get_data_parallel_world_size()
        global_batch_size = args.batch_size * data_parallel_size

        # Number of train/valid/test samples.
        train_iters = args.train_iters
        eval_iters = (train_iters // args.eval_interval + 1) * args.eval_iters
        test_iters = args.eval_iters
        train_val_test_num_samples = [train_iters * global_batch_size,
                                      eval_iters * global_batch_size,
                                      test_iters * global_batch_size]
        print_rank_0(' > datasets target sizes (minimum size):')
        print_rank_0('    train:      {}'.format(train_val_test_num_samples[0]))
        print_rank_0('    validation: {}'.format(train_val_test_num_samples[1]))
        print_rank_0('    test:       {}'.format(train_val_test_num_samples[2]))

        # Build the datasets.
        train_ds, valid_ds, test_ds = build_train_valid_test_datasets_provider(
            train_val_test_num_samples)

        # Build dataloders.
        train_dataloader = make_data_loader(train_ds)
        valid_dataloader = make_data_loader(valid_ds)
        test_dataloader = make_data_loader(test_ds)

        # Flags to know if we need to do training/validation/testing.
        do_train = train_dataloader is not None and args.train_iters > 0
        do_valid = valid_dataloader is not None and args.eval_iters > 0
        do_test = test_dataloader is not None and args.eval_iters > 0
        # Need to broadcast num_tokens and num_type_tokens.
        flags = torch.cuda.LongTensor(
            [int(do_train), int(do_valid), int(do_test)])
    else:
        flags = torch.cuda.LongTensor([0, 0, 0])

    # Broadcast num tokens.
    torch.distributed.broadcast(flags,
                                mpu.get_model_parallel_src_rank(),
                                group=mpu.get_model_parallel_group())
    args.do_train = flags[0].item()
    args.do_valid = flags[1].item()
    args.do_test = flags[2].item()

Mohammad's avatar
Mohammad committed
589
    # Shift the start iterations.
590
591
    if train_dataloader is not None:
        train_dataloader.batch_sampler.start_iter = args.iteration % \
Neel Kant's avatar
Neel Kant committed
592
            len(train_dataloader)
Mohammad's avatar
Mohammad committed
593
        print_rank_0('setting training data start iteration to {}'.
594
595
                     format(train_dataloader.batch_sampler.start_iter))
    if valid_dataloader is not None:
Mohammad's avatar
Mohammad committed
596
        start_iter_val = (args.iteration // args.eval_interval) * \
Neel Kant's avatar
Neel Kant committed
597
            args.eval_iters
598
        valid_dataloader.batch_sampler.start_iter = start_iter_val % \
Neel Kant's avatar
Neel Kant committed
599
            len(valid_dataloader)
Mohammad's avatar
Mohammad committed
600
        print_rank_0('setting validation data start iteration to {}'.
601
                     format(valid_dataloader.batch_sampler.start_iter))
602

603
604
605
    # Build iterators.
    if train_dataloader is not None:
        train_data_iterator = iter(train_dataloader)
606
607
608
    else:
        train_data_iterator = None

609
610
    if valid_dataloader is not None:
        valid_data_iterator = iter(valid_dataloader)
611
    else:
612
        valid_data_iterator = None
613

614
615
    if test_dataloader is not None:
        test_data_iterator = iter(test_dataloader)
616
617
618
    else:
        test_data_iterator = None

619
    return train_data_iterator, valid_data_iterator, test_data_iterator