training.py 22 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Mohammad's avatar
Mohammad committed
16
"""Pretrain utilities."""
17
18
19

from datetime import datetime
import math
Mohammad's avatar
Mohammad committed
20
import sys
21
22
23
24
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from apex.optimizers import FusedAdam as Adam

Neel Kant's avatar
Neel Kant committed
25
from megatron import get_args
Mohammad's avatar
Mohammad committed
26
27
from megatron import get_timers
from megatron import get_tensorboard_writer
28
from megatron import mpu
Neel Kant's avatar
Neel Kant committed
29
from megatron import print_rank_0
Mohammad's avatar
Mohammad committed
30
31
from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint
32
33
from megatron.fp16 import FP16_Module
from megatron.fp16 import FP16_Optimizer
Mohammad's avatar
Mohammad committed
34
from megatron.initialize import initialize_megatron
35
36
37
from megatron.learning_rates import AnnealingLR
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import get_params_for_weight_decay_optimization
Neel Kant's avatar
Neel Kant committed
38
from megatron.model.realm_model import ICTBertModel
39
from megatron.utils import check_adlr_autoresume_termination
40
from megatron.utils import make_data_loader
41
from megatron.utils import report_memory
42
43


44
def pretrain(train_valid_test_dataset_provider, model_provider,
45
             forward_step_func, extra_args_provider=None, args_defaults={}):
46
47
48
    """Main training program.

    This function will run the followings in the order provided:
Mohammad's avatar
Mohammad committed
49
50
        1) initialize Megatron.
        2) setup model, optimizer and lr schedule using the model_provider.
51
        3) call train_val_test_data_provider to get train/val/test datasets.
Mohammad's avatar
Mohammad committed
52
        4) train the modle using the forward_step_func.
53
54

    Arguments:
55
56
57
        train_valid_test_dataset_provider: a function that takes the size of
            train/valid/test dataset and returns `train, valid, test` datasets.
        model_provider: a function that returns a vanilla version of the
Mohammad's avatar
Mohammad committed
58
59
60
61
62
63
64
65
66
67
            model. By vanilla we mean a simple model on cpu with no fp16 or ddp.
        forward_step_func: a function that takes a `data iterator` and `model`,
            and returns a `loss` scalar with a dictionary with key:values being
            the info we would like to monitor during training, for example
            `lm-loss: value`. We also require that this function add
            `batch generator` to the timers class.
        extra_args_provider: a function that takes a parser and adds arguments
            to it. It is used for programs to add their own arguments.
        args_defaults: a dictionary from argument-name to argument-value. It
            to set already parse arguments.
68
69
    """

70
    # Initalize and get arguments, timers, and Tensorboard writer.
71
72
    initialize_megatron(extra_args_provider=extra_args_provider,
                        args_defaults=args_defaults)
73

74
    args = get_args()
Mohammad's avatar
Mohammad committed
75
    timers = get_timers()
76
77

    # Model, optimizer, and learning rate.
Mohammad's avatar
Mohammad committed
78
79
80
    timers('model and optimizer').start()
    model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider)
    timers('model and optimizer').stop()
81
82

    # Data stuff.
83
84
85
86
87
    timers('train/valid/test data iterators').start()
    train_data_iterator, valid_data_iterator, test_data_iterator \
        = build_train_valid_test_data_iterators(
            train_valid_test_dataset_provider)
    timers('train/valid/test data iterators').stop()
Mohammad's avatar
Mohammad committed
88
89
90

    # Print setup timing.
    print_rank_0('done with setups ...')
91
    timers.log(['model and optimizer', 'train/valid/test data iterators'])
Mohammad's avatar
Mohammad committed
92
    print_rank_0('training ...')
93
94

    iteration = 0
95
    if args.do_train and args.train_iters > 0:
mohammad's avatar
mohammad committed
96
97
98
        iteration = train(forward_step_func,
                          model, optimizer, lr_scheduler,
                          train_data_iterator, valid_data_iterator)
Mohammad's avatar
Mohammad committed
99

100
101
102
    if args.do_valid:
        prefix = 'the end of training for val data'
        evaluate_and_print_results(prefix, forward_step_func,
103
                                   valid_data_iterator, model,
Mohammad's avatar
Mohammad committed
104
                                   iteration, False)
105
106

    if args.save and iteration != 0:
107
        save_checkpoint(iteration, model, optimizer, lr_scheduler)
108
109
110
111
112
113

    if args.do_test:
        # Run on test data.
        prefix = 'the end of training for test data'
        evaluate_and_print_results(prefix, forward_step_func,
                                   test_data_iterator, model,
Mohammad's avatar
Mohammad committed
114
                                   0, True)
115
116


Mohammad's avatar
Mohammad committed
117
def get_model(model_provider_func):
118
    """Build the model."""
Mohammad's avatar
Mohammad committed
119
    args = get_args()
120
121

    # Build model on cpu.
Mohammad's avatar
Mohammad committed
122
    model = model_provider_func()
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139

    # Print number of parameters.
    if mpu.get_data_parallel_rank() == 0:
        print(' > number of parameters on model parallel rank {}: {}'.format(
            mpu.get_model_parallel_rank(),
            sum([p.nelement() for p in model.parameters()])), flush=True)

    # GPU allocation.
    model.cuda(torch.cuda.current_device())

    # Fp16 conversion.
    if args.fp16:
        model = FP16_Module(model)

    # Wrap model for distributed training."""
    if args.DDP_impl == 'torch':
        i = torch.cuda.current_device()
Mohammad's avatar
Mohammad committed
140
141
        model = torchDDP(model, device_ids=[i], output_device=i,
                         process_group=mpu.get_data_parallel_group())
142
143
        return model
    if args.DDP_impl == 'local':
Mohammad's avatar
Mohammad committed
144
        model = LocalDDP(model)
145
146
        return model

147
    raise NotImplementedError('Unknown DDP implementation specified: {}. '
148
                              'Exiting.'.format(args.DDP_impl))
149
150


Mohammad's avatar
Mohammad committed
151
def get_optimizer(model):
152
    """Set up the optimizer."""
Mohammad's avatar
Mohammad committed
153
    args = get_args()
154
155

    # Build parameter groups (weight decay and non-decay).
Mohammad's avatar
Mohammad committed
156
    while isinstance(model, (torchDDP, LocalDDP, FP16_Module)):
157
158
159
160
161
162
163
164
165
166
        model = model.module
    param_groups = get_params_for_weight_decay_optimization(model)

    # Add model parallel attribute if it is not set.
    for param_group in param_groups:
        for param in param_group['params']:
            if not hasattr(param, 'model_parallel'):
                param.model_parallel = False

    # Use Adam.
167
168
    optimizer = Adam(param_groups, lr=args.lr, weight_decay=args.weight_decay,
        betas=(args.adam_beta1, args.adam_beta2), eps=args.adam_eps)
169
170
171
172
173
174
175
176

    # Wrap into fp16 optimizer.
    if args.fp16:
        optimizer = FP16_Optimizer(optimizer,
                                   static_loss_scale=args.loss_scale,
                                   dynamic_loss_scale=args.dynamic_loss_scale,
                                   dynamic_loss_args={
                                       'scale_window': args.loss_scale_window,
Neel Kant's avatar
Neel Kant committed
177
                                       'min_scale': args.min_scale,
178
179
180
181
182
                                       'delayed_shift': args.hysteresis})

    return optimizer


Mohammad's avatar
Mohammad committed
183
def get_learning_rate_scheduler(optimizer):
184
    """Build the learning rate scheduler."""
Mohammad's avatar
Mohammad committed
185
    args = get_args()
186
187
188
189
190
191
192

    # Add linear learning rate scheduler.
    if args.lr_decay_iters is not None:
        num_iters = args.lr_decay_iters
    else:
        num_iters = args.train_iters
    num_iters = max(1, num_iters)
Mohammad's avatar
Mohammad committed
193
    init_step = 0
194
195
196
197
198
    warmup_iter = args.warmup * num_iters
    lr_scheduler = AnnealingLR(
        optimizer,
        start_lr=args.lr,
        warmup_iter=warmup_iter,
Mohammad's avatar
Mohammad committed
199
        total_iters=num_iters,
200
201
202
203
204
205
206
207
208
        decay_style=args.lr_decay_style,
        last_iter=init_step,
        min_lr=args.min_lr,
        use_checkpoint_lr_scheduler=args.use_checkpoint_lr_scheduler,
        override_lr_scheduler=args.override_lr_scheduler)

    return lr_scheduler


Mohammad's avatar
Mohammad committed
209
def setup_model_and_optimizer(model_provider_func):
210
    """Setup model and optimizer."""
Mohammad's avatar
Mohammad committed
211
    args = get_args()
212

Mohammad's avatar
Mohammad committed
213
214
215
    model = get_model(model_provider_func)
    optimizer = get_optimizer(model)
    lr_scheduler = get_learning_rate_scheduler(optimizer)
216
217

    if args.load is not None:
218
        args.iteration = load_checkpoint(model, optimizer, lr_scheduler)
219
220
221
    else:
        args.iteration = 0

Neel Kant's avatar
Neel Kant committed
222
223
224
225
226
    # get model without FP16 and/or TorchDDP wrappers
    unwrapped_model = model
    while hasattr(unwrapped_model, 'module'):
        unwrapped_model = unwrapped_model.module

227
    if args.iteration == 0 and hasattr(unwrapped_model, 'init_state_dict_from_bert'):
228
        print("Initializing ICT from pretrained BERT model", flush=True)
229
        unwrapped_model.init_state_dict_from_bert()
Neel Kant's avatar
Neel Kant committed
230

231
232
233
    return model, optimizer, lr_scheduler


Mohammad's avatar
Mohammad committed
234
def backward_step(optimizer, model, loss):
235
    """Backward step."""
Mohammad's avatar
Mohammad committed
236
237
    args = get_args()
    timers = get_timers()
238
239

    # Backward pass.
240
    timers('backward-backward').start()
241
    optimizer.zero_grad(set_grads_to_None=True)
242
243
244
245
    if args.fp16:
        optimizer.backward(loss, update_master_grads=False)
    else:
        loss.backward()
246
    timers('backward-backward').stop()
247
248
249

    # All-reduce if needed.
    if args.DDP_impl == 'local':
250
        timers('backward-allreduce').start()
251
252
        model.allreduce_params(reduce_after=False,
                               fp32_allreduce=args.fp32_allreduce)
253
        timers('backward-allreduce').stop()
254

255
    # Update master gradients.
256
    timers('backward-master-grad').start()
257
258
    if args.fp16:
        optimizer.update_master_grads()
259
    timers('backward-master-grad').stop()
260

261
    # Clipping gradients helps prevent the exploding gradient.
262
    timers('backward-clip-grad').start()
263
264
265
266
267
    if args.clip_grad > 0:
        if not args.fp16:
            mpu.clip_grad_norm(model.parameters(), args.clip_grad)
        else:
            optimizer.clip_master_grads(args.clip_grad)
268
    timers('backward-clip-grad').stop()
269
270


Mohammad's avatar
Mohammad committed
271
272
def train_step(forward_step_func, data_iterator,
               model, optimizer, lr_scheduler):
273
    """Single training step."""
Mohammad's avatar
Mohammad committed
274
275
    args = get_args()
    timers = get_timers()
276
277
278

    # Forward model for one step.
    timers('forward').start()
Mohammad's avatar
Mohammad committed
279
    loss, loss_reduced = forward_step_func(data_iterator, model)
280
281
    timers('forward').stop()

282
    # Calculate gradients, reduce across processes, and clip.
283
    timers('backward').start()
Mohammad's avatar
Mohammad committed
284
    backward_step(optimizer, model, loss)
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
    timers('backward').stop()

    # Update parameters.
    timers('optimizer').start()
    optimizer.step()
    timers('optimizer').stop()

    # Update learning rate.
    skipped_iter = 0
    if not (args.fp16 and optimizer.overflow):
        lr_scheduler.step()
    else:
        skipped_iter = 1

    return loss_reduced, skipped_iter


Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
302
def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
mohammad's avatar
mohammad committed
303
                 loss_scale, report_memory_flag, skipped_iter):
Mohammad's avatar
Mohammad committed
304
305
306
307
    """Log training information such as losses, timing, ...."""
    args = get_args()
    timers = get_timers()
    writer = get_tensorboard_writer()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
308
309

    # Update losses.
mohammad's avatar
mohammad committed
310
311
312
    skipped_iters_key = 'skipped iterations'
    total_loss_dict[skipped_iters_key] = total_loss_dict.get(
        skipped_iters_key, 0) + skipped_iter
mohammad's avatar
mohammad committed
313
    got_nan_key = 'got nan'
mohammad's avatar
mohammad committed
314
315

    got_nan = False
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
316
    for key in loss_dict:
mohammad's avatar
mohammad committed
317
        if not skipped_iter:
318
319
            total_loss_dict[key] = total_loss_dict.get(
                key, torch.cuda.FloatTensor([0.0])) + loss_dict[key]
mohammad's avatar
mohammad committed
320
321
322
323
324
        else:
            value = loss_dict[key].float().sum().item()
            is_nan = value == float('inf') or \
                     value == -float('inf') or \
                     value != value
mohammad's avatar
mohammad committed
325
326
327
328
            got_nan = got_nan or is_nan

    total_loss_dict[got_nan_key] = total_loss_dict.get(
        got_nan_key, 0) + int(got_nan)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
329
330
331

    # Logging.
    timers_to_log = []
Neel Kant's avatar
Neel Kant committed
332

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
333
334
335
336
337
    def add_to_logging(name):
        if name in timers.timers:
            timers_to_log.append(name)
    add_to_logging('forward')
    add_to_logging('backward')
338
339
340
341
    add_to_logging('backward-backward')
    add_to_logging('backward-allreduce')
    add_to_logging('backward-master-grad')
    add_to_logging('backward-clip-grad')
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
    add_to_logging('optimizer')
    add_to_logging('batch generator')

    # Tensorboard values.
    if writer and torch.distributed.get_rank() == 0:
        writer.add_scalar('learning_rate', learning_rate, iteration)
        for key in loss_dict:
            writer.add_scalar(key, loss_dict[key], iteration)
        if args.fp16:
            writer.add_scalar('loss_scale', loss_scale, iteration)
        normalizer = iteration % args.log_interval
        if normalizer == 0:
            normalizer = args.log_interval
        timers.write(timers_to_log, writer, iteration,
                     normalizer=normalizer)

    if iteration % args.log_interval == 0:
        elapsed_time = timers('interval time').elapsed()
        if writer and torch.distributed.get_rank() == 0:
            writer.add_scalar('iteration_time',
                              elapsed_time / args.log_interval, iteration)
        log_string = ' iteration {:8d}/{:8d} |'.format(iteration,
                                                       args.train_iters)
        log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
            elapsed_time * 1000.0 / args.log_interval)
        log_string += ' learning rate: {:.3E} |'.format(learning_rate)
mohammad's avatar
mohammad committed
368
369
        num_iterations = max(
            1, args.log_interval - total_loss_dict[skipped_iters_key])
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
370
        for key in total_loss_dict:
mohammad's avatar
mohammad committed
371
            if key not in [skipped_iters_key, got_nan_key]:
mohammad's avatar
mohammad committed
372
                avg = total_loss_dict[key].item() / float(num_iterations)
373
374
375
                if avg > 0.0:
                    log_string += ' {}: {:.6E} |'.format(key, avg)
                total_loss_dict[key] = torch.cuda.FloatTensor([0.0])
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
376
377
        if args.fp16:
            log_string += ' loss scale: {:.1f} |'.format(loss_scale)
mohammad's avatar
mohammad committed
378
379
        log_string += ' number of skipped iterations: {:3d} |'.format(
            total_loss_dict[skipped_iters_key])
mohammad's avatar
mohammad committed
380
381
        log_string += ' number of nan iterations: {:3d} |'.format(
            total_loss_dict[got_nan_key])
mohammad's avatar
mohammad committed
382
        total_loss_dict[skipped_iters_key] = 0
mohammad's avatar
mohammad committed
383
        total_loss_dict[got_nan_key] = 0
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
384
385
386
387
388
389
390
391
392
        print_rank_0(log_string)
        if report_memory_flag:
            report_memory('after {} iterations'.format(iteration))
            report_memory_flag = False
        timers.log(timers_to_log, normalizer=args.log_interval)

    return report_memory_flag


393
def train(forward_step_func, model, optimizer, lr_scheduler,
394
          train_data_iterator, valid_data_iterator):
395
    """Train the model function."""
Mohammad's avatar
Mohammad committed
396
397
    args = get_args()
    timers = get_timers()
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414

    # Turn on training mode which enables dropout.
    model.train()

    # Tracking loss.
    total_loss_dict = {}

    # Iterations.
    iteration = args.iteration

    timers('interval time').start()
    report_memory_flag = True
    while iteration < args.train_iters:
        loss_dict, skipped_iter = train_step(forward_step_func,
                                             train_data_iterator,
                                             model,
                                             optimizer,
Mohammad's avatar
Mohammad committed
415
                                             lr_scheduler)
416
417
418
        iteration += 1

        # Logging.
Mohammad's avatar
Mohammad committed
419
420
421
        loss_scale = None
        if args.fp16:
            loss_scale = optimizer.loss_scale
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
422
423
        report_memory_flag = training_log(loss_dict, total_loss_dict,
                                          optimizer.param_groups[0]['lr'],
Mohammad's avatar
Mohammad committed
424
                                          iteration, loss_scale,
mohammad's avatar
mohammad committed
425
                                          report_memory_flag, skipped_iter)
426
427

        # Autoresume
428
429
        if args.adlr_autoresume and \
           (iteration % args.adlr_autoresume_interval == 0):
430
            check_adlr_autoresume_termination(iteration, model, optimizer,
431
                                              lr_scheduler)
432
433
434
435

        # Checkpointing
        if args.save and args.save_interval and \
           iteration % args.save_interval == 0:
436
            save_checkpoint(iteration, model, optimizer, lr_scheduler)
437
438
439
440
441
442

        # Evaluation
        if args.eval_interval and iteration % args.eval_interval == 0 and \
           args.do_valid:
            prefix = 'iteration {}'.format(iteration)
            evaluate_and_print_results(prefix, forward_step_func,
443
                                       valid_data_iterator, model,
Mohammad's avatar
Mohammad committed
444
                                       iteration, False)
445
446

        if args.exit_interval and iteration % args.exit_interval == 0:
447
            torch.distributed.barrier()
448
449
            time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
            rank = torch.distributed.get_rank()
Mohammad's avatar
Mohammad committed
450
451
452
            print_rank_0('rank: {} | time: {} | exiting the program at '
                         'iteration {}'.format(rank, time_str, iteration))
            sys.exit()
453

mohammad's avatar
mohammad committed
454
    return iteration
455
456


Mohammad's avatar
Mohammad committed
457
def evaluate(forward_step_func, data_iterator, model, verbose=False):
458
    """Evaluation."""
Mohammad's avatar
Mohammad committed
459
    args = get_args()
460
461
462
463
464
465
466
467
468
469
470
471
472
473

    # Turn on evaluation mode which disables dropout.
    model.eval()

    total_loss_dict = {}

    with torch.no_grad():
        iteration = 0
        while iteration < args.eval_iters:
            iteration += 1
            if verbose and iteration % args.log_interval == 0:
                print_rank_0('Evaluating iter {}/{}'.format(iteration,
                                                            args.eval_iters))
            # Forward evaluation.
Mohammad's avatar
Mohammad committed
474
            _, loss_dict = forward_step_func(data_iterator, model)
475
476
477
            # Reduce across processes.
            for key in loss_dict:
                total_loss_dict[key] = total_loss_dict.get(key, 0.) + \
Neel Kant's avatar
Neel Kant committed
478
                    loss_dict[key]
479
480
481
482
483
484
485
486
487
488
489
    # Move model back to the train mode.
    model.train()

    for key in total_loss_dict:
        total_loss_dict[key] /= args.eval_iters

    return total_loss_dict


def evaluate_and_print_results(prefix, forward_step_func,
                               data_iterator, model,
Mohammad's avatar
Mohammad committed
490
                               iteration, verbose=False):
491
    """Helper function to evaluate and dump results on screen."""
Mohammad's avatar
Mohammad committed
492
493
494
    writer = get_tensorboard_writer()

    total_loss_dict = evaluate(forward_step_func, data_iterator, model, verbose)
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
    string = ' validation loss at {} | '.format(prefix)
    for key in total_loss_dict:
        string += '{} value: {:.6E} | '.format(key, total_loss_dict[key].item())
        ppl = math.exp(min(20, total_loss_dict[key].item()))
        string += '{} PPL: {:.6E} | '.format(key, ppl)
        if writer and torch.distributed.get_rank() == 0:
            writer.add_scalar('{} value'.format(key),
                              total_loss_dict[key].item(),
                              iteration)
            writer.add_scalar('{} ppl'.format(key), ppl, iteration)

    length = len(string) + 1
    print_rank_0('-' * length)
    print_rank_0(string)
    print_rank_0('-' * length)


512
513
514
def build_train_valid_test_data_iterators(
        build_train_valid_test_datasets_provider):
    """XXX"""
Mohammad's avatar
Mohammad committed
515
    args = get_args()
516

517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
    (train_dataloader, valid_dataloader, test_dataloader) = (None, None, None)

    print_rank_0('> building train, validation, and test datasets ...')
    # Data loader only on rank 0 of each model parallel group.
    if mpu.get_model_parallel_rank() == 0:
        # Rank, size, and global batch size.
        data_parallel_size = mpu.get_data_parallel_world_size()
        global_batch_size = args.batch_size * data_parallel_size

        # Number of train/valid/test samples.
        train_iters = args.train_iters
        eval_iters = (train_iters // args.eval_interval + 1) * args.eval_iters
        test_iters = args.eval_iters
        train_val_test_num_samples = [train_iters * global_batch_size,
                                      eval_iters * global_batch_size,
                                      test_iters * global_batch_size]
        print_rank_0(' > datasets target sizes (minimum size):')
        print_rank_0('    train:      {}'.format(train_val_test_num_samples[0]))
        print_rank_0('    validation: {}'.format(train_val_test_num_samples[1]))
        print_rank_0('    test:       {}'.format(train_val_test_num_samples[2]))

        # Build the datasets.
        train_ds, valid_ds, test_ds = build_train_valid_test_datasets_provider(
            train_val_test_num_samples)

        # Build dataloders.
        train_dataloader = make_data_loader(train_ds)
        valid_dataloader = make_data_loader(valid_ds)
        test_dataloader = make_data_loader(test_ds)

        # Flags to know if we need to do training/validation/testing.
        do_train = train_dataloader is not None and args.train_iters > 0
        do_valid = valid_dataloader is not None and args.eval_iters > 0
        do_test = test_dataloader is not None and args.eval_iters > 0
        # Need to broadcast num_tokens and num_type_tokens.
        flags = torch.cuda.LongTensor(
            [int(do_train), int(do_valid), int(do_test)])
    else:
        flags = torch.cuda.LongTensor([0, 0, 0])

    # Broadcast num tokens.
    torch.distributed.broadcast(flags,
                                mpu.get_model_parallel_src_rank(),
                                group=mpu.get_model_parallel_group())
    args.do_train = flags[0].item()
    args.do_valid = flags[1].item()
    args.do_test = flags[2].item()

Mohammad's avatar
Mohammad committed
565
    # Shift the start iterations.
566
567
    if train_dataloader is not None:
        train_dataloader.batch_sampler.start_iter = args.iteration % \
Neel Kant's avatar
Neel Kant committed
568
            len(train_dataloader)
Mohammad's avatar
Mohammad committed
569
        print_rank_0('setting training data start iteration to {}'.
570
571
                     format(train_dataloader.batch_sampler.start_iter))
    if valid_dataloader is not None:
Mohammad's avatar
Mohammad committed
572
        start_iter_val = (args.iteration // args.eval_interval) * \
Neel Kant's avatar
Neel Kant committed
573
            args.eval_iters
574
        valid_dataloader.batch_sampler.start_iter = start_iter_val % \
Neel Kant's avatar
Neel Kant committed
575
            len(valid_dataloader)
Mohammad's avatar
Mohammad committed
576
        print_rank_0('setting validation data start iteration to {}'.
577
                     format(valid_dataloader.batch_sampler.start_iter))
578

579
580
581
    # Build iterators.
    if train_dataloader is not None:
        train_data_iterator = iter(train_dataloader)
582
583
584
    else:
        train_data_iterator = None

585
586
    if valid_dataloader is not None:
        valid_data_iterator = iter(valid_dataloader)
587
    else:
588
        valid_data_iterator = None
589

590
591
    if test_dataloader is not None:
        test_data_iterator = iter(test_dataloader)
592
593
594
    else:
        test_data_iterator = None

595
    return train_data_iterator, valid_data_iterator, test_data_iterator