training.py 27.9 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Mohammad's avatar
Mohammad committed
16
"""Pretrain utilities."""
17
18
19

from datetime import datetime
import math
Mohammad's avatar
Mohammad committed
20
import sys
21
22
23
24
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from apex.optimizers import FusedAdam as Adam

Neel Kant's avatar
Neel Kant committed
25
from megatron import get_args
Mohammad's avatar
Mohammad committed
26
27
from megatron import get_timers
from megatron import get_tensorboard_writer
28
from megatron import mpu
Neel Kant's avatar
Neel Kant committed
29
from megatron import print_rank_0
30
from megatron import print_rank_last
Mohammad's avatar
Mohammad committed
31
32
from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint
33
34
from megatron.fp16 import FP16_Module
from megatron.fp16 import FP16_Optimizer
Mohammad's avatar
Mohammad committed
35
from megatron.initialize import initialize_megatron
36
37
38
from megatron.learning_rates import AnnealingLR
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import get_params_for_weight_decay_optimization
Neel Kant's avatar
Neel Kant committed
39
from megatron.model.realm_model import ICTBertModel
40
from megatron.utils import check_adlr_autoresume_termination
41
from megatron.utils import make_data_loader
42
from megatron.utils import report_memory
43
44


45
def pretrain(train_valid_test_dataset_provider, model_provider,
46
             forward_step_func, extra_args_provider=None, args_defaults={}):
47
48
49
    """Main training program.

    This function will run the followings in the order provided:
Mohammad's avatar
Mohammad committed
50
51
        1) initialize Megatron.
        2) setup model, optimizer and lr schedule using the model_provider.
52
        3) call train_val_test_data_provider to get train/val/test datasets.
Mohammad's avatar
Mohammad committed
53
        4) train the modle using the forward_step_func.
54
55

    Arguments:
56
57
58
        train_valid_test_dataset_provider: a function that takes the size of
            train/valid/test dataset and returns `train, valid, test` datasets.
        model_provider: a function that returns a vanilla version of the
Mohammad's avatar
Mohammad committed
59
60
61
62
63
64
65
66
67
68
            model. By vanilla we mean a simple model on cpu with no fp16 or ddp.
        forward_step_func: a function that takes a `data iterator` and `model`,
            and returns a `loss` scalar with a dictionary with key:values being
            the info we would like to monitor during training, for example
            `lm-loss: value`. We also require that this function add
            `batch generator` to the timers class.
        extra_args_provider: a function that takes a parser and adds arguments
            to it. It is used for programs to add their own arguments.
        args_defaults: a dictionary from argument-name to argument-value. It
            to set already parse arguments.
69
70
    """

71
    # Initalize and get arguments, timers, and Tensorboard writer.
72
73
    initialize_megatron(extra_args_provider=extra_args_provider,
                        args_defaults=args_defaults)
74

75
    args = get_args()
Mohammad's avatar
Mohammad committed
76
    timers = get_timers()
77
78

    # Model, optimizer, and learning rate.
Mohammad's avatar
Mohammad committed
79
80
81
    timers('model and optimizer').start()
    model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider)
    timers('model and optimizer').stop()
82
83

    # Data stuff.
84
85
86
87
88
    timers('train/valid/test data iterators').start()
    train_data_iterator, valid_data_iterator, test_data_iterator \
        = build_train_valid_test_data_iterators(
            train_valid_test_dataset_provider)
    timers('train/valid/test data iterators').stop()
Mohammad's avatar
Mohammad committed
89
90
91

    # Print setup timing.
    print_rank_0('done with setups ...')
92
    timers.log(['model and optimizer', 'train/valid/test data iterators'])
Mohammad's avatar
Mohammad committed
93
    print_rank_0('training ...')
94
95

    iteration = 0
96
    if args.do_train and args.train_iters > 0:
mohammad's avatar
mohammad committed
97
98
99
        iteration = train(forward_step_func,
                          model, optimizer, lr_scheduler,
                          train_data_iterator, valid_data_iterator)
Mohammad's avatar
Mohammad committed
100

101
102
103
    if args.do_valid:
        prefix = 'the end of training for val data'
        evaluate_and_print_results(prefix, forward_step_func,
104
                                   valid_data_iterator, model,
Mohammad's avatar
Mohammad committed
105
                                   iteration, False)
106
107

    if args.save and iteration != 0:
108
        save_checkpoint(iteration, model, optimizer, lr_scheduler)
109
110
111
112
113
114

    if args.do_test:
        # Run on test data.
        prefix = 'the end of training for test data'
        evaluate_and_print_results(prefix, forward_step_func,
                                   test_data_iterator, model,
Mohammad's avatar
Mohammad committed
115
                                   0, True)
116
117


Mohammad's avatar
Mohammad committed
118
def get_model(model_provider_func):
119
    """Build the model."""
Mohammad's avatar
Mohammad committed
120
    args = get_args()
121
122

    # Build model on cpu.
Mohammad's avatar
Mohammad committed
123
    model = model_provider_func()
124
125
126

    # Print number of parameters.
    if mpu.get_data_parallel_rank() == 0:
127
        print(' > number of parameters on (tensor, pipeline) '
128
              'model parallel rank ({}, {}): {}'.format(
129
130
            mpu.get_tensor_model_parallel_rank(),
            mpu.get_pipeline_model_parallel_rank(),
131
132
133
134
135
136
137
138
139
140
            sum([p.nelement() for p in model.parameters()])), flush=True)

    # GPU allocation.
    model.cuda(torch.cuda.current_device())

    # Fp16 conversion.
    if args.fp16:
        model = FP16_Module(model)

    # Wrap model for distributed training."""
141
142
143
    if args.use_pipelining:
        assert args.DDP_impl == 'local'

144
145
    if args.DDP_impl == 'torch':
        i = torch.cuda.current_device()
Mohammad's avatar
Mohammad committed
146
147
        model = torchDDP(model, device_ids=[i], output_device=i,
                         process_group=mpu.get_data_parallel_group())
148
149
        return model
    if args.DDP_impl == 'local':
Mohammad's avatar
Mohammad committed
150
        model = LocalDDP(model)
151
152
        return model

153
    raise NotImplementedError('Unknown DDP implementation specified: {}. '
154
                              'Exiting.'.format(args.DDP_impl))
155
156


Mohammad's avatar
Mohammad committed
157
def get_optimizer(model):
158
    """Set up the optimizer."""
Mohammad's avatar
Mohammad committed
159
    args = get_args()
160
161

    # Build parameter groups (weight decay and non-decay).
Mohammad's avatar
Mohammad committed
162
    while isinstance(model, (torchDDP, LocalDDP, FP16_Module)):
163
164
165
166
167
168
        model = model.module
    param_groups = get_params_for_weight_decay_optimization(model)

    # Add model parallel attribute if it is not set.
    for param_group in param_groups:
        for param in param_group['params']:
169
170
            if not hasattr(param, 'tensor_model_parallel'):
                param.tensor_model_parallel = False
171
172

    # Use Adam.
173
174
    optimizer = Adam(param_groups, lr=args.lr, weight_decay=args.weight_decay,
        betas=(args.adam_beta1, args.adam_beta2), eps=args.adam_eps)
175
176
177
178
179
180
181
182

    # Wrap into fp16 optimizer.
    if args.fp16:
        optimizer = FP16_Optimizer(optimizer,
                                   static_loss_scale=args.loss_scale,
                                   dynamic_loss_scale=args.dynamic_loss_scale,
                                   dynamic_loss_args={
                                       'scale_window': args.loss_scale_window,
Neel Kant's avatar
Neel Kant committed
183
                                       'min_scale': args.min_scale,
184
185
186
187
188
                                       'delayed_shift': args.hysteresis})

    return optimizer


Mohammad's avatar
Mohammad committed
189
def get_learning_rate_scheduler(optimizer):
190
    """Build the learning rate scheduler."""
Mohammad's avatar
Mohammad committed
191
    args = get_args()
192
193
194
195
196
197
198

    # Add linear learning rate scheduler.
    if args.lr_decay_iters is not None:
        num_iters = args.lr_decay_iters
    else:
        num_iters = args.train_iters
    num_iters = max(1, num_iters)
Mohammad's avatar
Mohammad committed
199
    init_step = 0
200
201
202
203
204
    warmup_iter = args.warmup * num_iters
    lr_scheduler = AnnealingLR(
        optimizer,
        start_lr=args.lr,
        warmup_iter=warmup_iter,
Mohammad's avatar
Mohammad committed
205
        total_iters=num_iters,
206
207
208
209
210
211
212
213
214
        decay_style=args.lr_decay_style,
        last_iter=init_step,
        min_lr=args.min_lr,
        use_checkpoint_lr_scheduler=args.use_checkpoint_lr_scheduler,
        override_lr_scheduler=args.override_lr_scheduler)

    return lr_scheduler


Mohammad's avatar
Mohammad committed
215
def setup_model_and_optimizer(model_provider_func):
216
    """Setup model and optimizer."""
Mohammad's avatar
Mohammad committed
217
    args = get_args()
218

Mohammad's avatar
Mohammad committed
219
220
221
    model = get_model(model_provider_func)
    optimizer = get_optimizer(model)
    lr_scheduler = get_learning_rate_scheduler(optimizer)
222
223

    if args.load is not None:
224
        args.iteration = load_checkpoint(model, optimizer, lr_scheduler)
225
226
227
    else:
        args.iteration = 0

Neel Kant's avatar
Neel Kant committed
228
229
230
231
232
    # get model without FP16 and/or TorchDDP wrappers
    unwrapped_model = model
    while hasattr(unwrapped_model, 'module'):
        unwrapped_model = unwrapped_model.module

233
    if args.iteration == 0 and hasattr(unwrapped_model, 'init_state_dict_from_bert'):
234
        print("Initializing ICT from pretrained BERT model", flush=True)
235
        unwrapped_model.init_state_dict_from_bert()
Neel Kant's avatar
Neel Kant committed
236

237
238
239
    return model, optimizer, lr_scheduler


240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
def communicate(tensor_send_next, tensor_send_prev, recv_forward, recv_backward):
    """Communicate tensors between stages using torch.distributed.ring_exchange(.) API."""
    args = get_args()

    # Create placeholder tensors for receive in forward and backward directions
    # if needed.
    tensor_recv_prev = None
    tensor_recv_next = None
    tensor_shape = (args.batch_size, args.seq_length, args.hidden_size)
    if recv_forward:
        tensor_recv_prev = torch.empty(tensor_shape,
                                       requires_grad=True,
                                       dtype=args.params_dtype).cuda()
    if recv_backward:
        tensor_recv_next = torch.empty(tensor_shape,
                                       requires_grad=True,
                                       dtype=args.params_dtype).cuda()

    # Send tensors in both the forward and backward directions as appropriate.
    torch.distributed.ring_exchange(tensor_send_prev=tensor_send_prev,
                                    tensor_recv_prev=tensor_recv_prev,
                                    tensor_send_next=tensor_send_next,
                                    tensor_recv_next=tensor_recv_next,
263
                                    group=mpu.get_pipeline_model_parallel_group())
264
265
266
267
268

    return tensor_recv_prev, tensor_recv_next


def backward_step(optimizer, model, input_tensor, output_tensor, output_tensor_grad):
269
    """Backward step."""
Mohammad's avatar
Mohammad committed
270
271
    args = get_args()
    timers = get_timers()
272

273
274
275
276
    # Retain the grad on the input_tensor.
    if input_tensor is not None:
        input_tensor.retain_grad()

277
    # Backward pass.
278
    timers('backward-backward').start()
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
    if args.fp16:
        optimizer.backward(output_tensor, update_master_grads=False,
                           output_tensor_grad=output_tensor_grad)
    else:
        torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad)
    timers('backward-backward').stop()

    # Collect the grad of the input_tensor.
    input_tensor_grad = None
    if input_tensor is not None:
        input_tensor_grad = input_tensor.grad

    return input_tensor_grad


def train_step(forward_step_func, data_iterator,
               model, optimizer, lr_scheduler):
    """Single training step."""
    args = get_args()
    timers = get_timers()

    # Set grad to zero.
301
    if args.fp16:
mohammad's avatar
mohammad committed
302
        optimizer.zero_grad(set_grads_to_None=True)
303
    else:
mohammad's avatar
mohammad committed
304
        optimizer.zero_grad()
305
306

    # Compute number of microbatches in a minibatch.
307
    num_microbatches_to_pipeline = args.pipeline_model_parallel_size \
308
309
310
311
312
313
314
315
            if args.use_pipelining else 1

    input_tensors = []
    output_tensors = []
    losses_reduced = []

    # Run forward pass for all microbatches in minibatch.
    for i in range(num_microbatches_to_pipeline):
316
        if not mpu.is_pipeline_first_stage():
317
318
319
320
321
322
323
324
325
326
327
328
329
            input_tensor, _ = communicate(
                tensor_send_next=None,
                tensor_send_prev=None,
                recv_forward=True,
                recv_backward=False)
        else:
            input_tensor = None

        # Forward model for one step.
        timers('forward').start()
        output_tensor = forward_step_func(data_iterator, model, input_tensor)
        timers('forward').stop()

330
        if mpu.is_pipeline_last_stage():
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
            loss, loss_reduced = output_tensor
            output_tensor = loss
            losses_reduced.append(loss_reduced)
        else:
            communicate(
                tensor_send_next=output_tensor,
                tensor_send_prev=None,
                recv_forward=False,
                recv_backward=False)

        input_tensors.append(input_tensor)
        output_tensors.append(output_tensor)

    # Run backward pass for all microbatches in minibatch.
    for i in range(num_microbatches_to_pipeline):
        input_tensor = input_tensors.pop(0)
        output_tensor = output_tensors.pop(0)

349
        if mpu.is_pipeline_last_stage():
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
            output_grad_tensor = None
        else:
            _, output_grad_tensor = communicate(
                tensor_send_next=None,
                tensor_send_prev=None,
                recv_forward=False,
                recv_backward=True)

        # Backward pass for one step.
        # TODO: This timer is a bit redundant now with backward-backward.
        timers('backward').start()
        input_grad_tensor = \
            backward_step(optimizer, model, input_tensor, output_tensor, output_grad_tensor)
        timers('backward').stop()

365
        if not mpu.is_pipeline_first_stage():
366
367
368
369
370
            communicate(
                tensor_send_next=None,
                tensor_send_prev=input_grad_tensor,
                recv_forward=False,
                recv_backward=False)
371
372
373

    # All-reduce if needed.
    if args.DDP_impl == 'local':
374
        timers('allreduce').start()
375
376
        model.allreduce_params(reduce_after=False,
                               fp32_allreduce=args.fp32_allreduce)
377
        timers('allreduce').stop()
378

379
    # Update master gradients.
380
    timers('backward-master-grad').start()
381
382
    if args.fp16:
        optimizer.update_master_grads()
383
    timers('backward-master-grad').stop()
384

385
    # All-reduce across first and last stages.
386
387
    if (mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage()) and \
            args.pipeline_model_parallel_size > 1:
388
389
390
391
392
393
394
395
        unwrapped_model = model
        while isinstance(unwrapped_model, (torchDDP, LocalDDP, FP16_Module)):
            unwrapped_model = unwrapped_model.module

        word_embeddings_weight = unwrapped_model.word_embeddings_weight()
        torch.distributed.all_reduce(word_embeddings_weight.grad,
                                     group=mpu.get_embedding_group())

396
    # Clipping gradients helps prevent the exploding gradient.
397
    timers('backward-clip-grad').start()
398
    if args.clip_grad > 0.:
399
        if not args.fp16:
400
401
402
403
404
405
406
407
            named_parameters = model.named_parameters()
            parameters = []
            parameter_names = []
            for parameter_name, parameter in model.named_parameters():
                parameters.append(parameter)
                parameter_names.append(parameter_name)
            mpu.clip_grad_norm(parameters, args.clip_grad,
                               parameter_names=parameter_names)
408
409
        else:
            optimizer.clip_master_grads(args.clip_grad)
410
    timers('backward-clip-grad').stop()
411
412
413
414
415
416
417
418
419
420
421
422
423

    # Update parameters.
    timers('optimizer').start()
    optimizer.step()
    timers('optimizer').stop()

    # Update learning rate.
    skipped_iter = 0
    if not (args.fp16 and optimizer.overflow):
        lr_scheduler.step()
    else:
        skipped_iter = 1

424
    if mpu.is_pipeline_last_stage():
425
426
427
428
429
430
431
432
        # Average loss across microbatches.
        loss_reduced = {}
        for key in losses_reduced[0]:
            losses_reduced_for_key = [x[key] for x in losses_reduced]
            loss_reduced[key] = sum(losses_reduced_for_key) / \
                    len(losses_reduced_for_key)
        return loss_reduced, skipped_iter
    return {}, skipped_iter
433
434


Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
435
def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
mohammad's avatar
mohammad committed
436
                 loss_scale, report_memory_flag, skipped_iter):
Mohammad's avatar
Mohammad committed
437
438
439
440
    """Log training information such as losses, timing, ...."""
    args = get_args()
    timers = get_timers()
    writer = get_tensorboard_writer()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
441
442

    # Update losses.
mohammad's avatar
mohammad committed
443
444
445
    skipped_iters_key = 'skipped iterations'
    total_loss_dict[skipped_iters_key] = total_loss_dict.get(
        skipped_iters_key, 0) + skipped_iter
mohammad's avatar
mohammad committed
446
    got_nan_key = 'got nan'
mohammad's avatar
mohammad committed
447
448

    got_nan = False
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
449
    for key in loss_dict:
mohammad's avatar
mohammad committed
450
        if not skipped_iter:
451
452
            total_loss_dict[key] = total_loss_dict.get(
                key, torch.cuda.FloatTensor([0.0])) + loss_dict[key]
mohammad's avatar
mohammad committed
453
454
455
456
457
        else:
            value = loss_dict[key].float().sum().item()
            is_nan = value == float('inf') or \
                     value == -float('inf') or \
                     value != value
mohammad's avatar
mohammad committed
458
459
460
461
            got_nan = got_nan or is_nan

    total_loss_dict[got_nan_key] = total_loss_dict.get(
        got_nan_key, 0) + int(got_nan)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
462
463
464

    # Logging.
    timers_to_log = []
Neel Kant's avatar
Neel Kant committed
465

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
466
467
468
469
470
    def add_to_logging(name):
        if name in timers.timers:
            timers_to_log.append(name)
    add_to_logging('forward')
    add_to_logging('backward')
471
472
473
474
    add_to_logging('backward-backward')
    add_to_logging('backward-allreduce')
    add_to_logging('backward-master-grad')
    add_to_logging('backward-clip-grad')
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
    add_to_logging('optimizer')
    add_to_logging('batch generator')

    # Tensorboard values.
    if writer and torch.distributed.get_rank() == 0:
        writer.add_scalar('learning_rate', learning_rate, iteration)
        for key in loss_dict:
            writer.add_scalar(key, loss_dict[key], iteration)
        if args.fp16:
            writer.add_scalar('loss_scale', loss_scale, iteration)
        normalizer = iteration % args.log_interval
        if normalizer == 0:
            normalizer = args.log_interval
        timers.write(timers_to_log, writer, iteration,
                     normalizer=normalizer)

    if iteration % args.log_interval == 0:
        elapsed_time = timers('interval time').elapsed()
        if writer and torch.distributed.get_rank() == 0:
            writer.add_scalar('iteration_time',
                              elapsed_time / args.log_interval, iteration)
        log_string = ' iteration {:8d}/{:8d} |'.format(iteration,
                                                       args.train_iters)
        log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
            elapsed_time * 1000.0 / args.log_interval)
        log_string += ' learning rate: {:.3E} |'.format(learning_rate)
mohammad's avatar
mohammad committed
501
502
        num_iterations = max(
            1, args.log_interval - total_loss_dict[skipped_iters_key])
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
503
        for key in total_loss_dict:
mohammad's avatar
mohammad committed
504
            if key not in [skipped_iters_key, got_nan_key]:
mohammad's avatar
mohammad committed
505
                avg = total_loss_dict[key].item() / float(num_iterations)
506
507
508
                if avg > 0.0:
                    log_string += ' {}: {:.6E} |'.format(key, avg)
                total_loss_dict[key] = torch.cuda.FloatTensor([0.0])
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
509
510
        if args.fp16:
            log_string += ' loss scale: {:.1f} |'.format(loss_scale)
mohammad's avatar
mohammad committed
511
512
        log_string += ' number of skipped iterations: {:3d} |'.format(
            total_loss_dict[skipped_iters_key])
mohammad's avatar
mohammad committed
513
514
        log_string += ' number of nan iterations: {:3d} |'.format(
            total_loss_dict[got_nan_key])
mohammad's avatar
mohammad committed
515
        total_loss_dict[skipped_iters_key] = 0
mohammad's avatar
mohammad committed
516
        total_loss_dict[got_nan_key] = 0
517
        print_rank_last(log_string)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
518
519
520
521
522
523
524
525
        if report_memory_flag:
            report_memory('after {} iterations'.format(iteration))
            report_memory_flag = False
        timers.log(timers_to_log, normalizer=args.log_interval)

    return report_memory_flag


526
def train(forward_step_func, model, optimizer, lr_scheduler,
527
          train_data_iterator, valid_data_iterator):
528
    """Train the model function."""
Mohammad's avatar
Mohammad committed
529
530
    args = get_args()
    timers = get_timers()
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547

    # Turn on training mode which enables dropout.
    model.train()

    # Tracking loss.
    total_loss_dict = {}

    # Iterations.
    iteration = args.iteration

    timers('interval time').start()
    report_memory_flag = True
    while iteration < args.train_iters:
        loss_dict, skipped_iter = train_step(forward_step_func,
                                             train_data_iterator,
                                             model,
                                             optimizer,
Mohammad's avatar
Mohammad committed
548
                                             lr_scheduler)
549
550
551
        iteration += 1

        # Logging.
Mohammad's avatar
Mohammad committed
552
553
554
        loss_scale = None
        if args.fp16:
            loss_scale = optimizer.loss_scale
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
555
556
        report_memory_flag = training_log(loss_dict, total_loss_dict,
                                          optimizer.param_groups[0]['lr'],
Mohammad's avatar
Mohammad committed
557
                                          iteration, loss_scale,
mohammad's avatar
mohammad committed
558
                                          report_memory_flag, skipped_iter)
559
560

        # Autoresume
561
562
        if args.adlr_autoresume and \
           (iteration % args.adlr_autoresume_interval == 0):
563
            check_adlr_autoresume_termination(iteration, model, optimizer,
564
                                              lr_scheduler)
565
566
567
568

        # Checkpointing
        if args.save and args.save_interval and \
           iteration % args.save_interval == 0:
569
            save_checkpoint(iteration, model, optimizer, lr_scheduler)
570
571
572
573
574
575

        # Evaluation
        if args.eval_interval and iteration % args.eval_interval == 0 and \
           args.do_valid:
            prefix = 'iteration {}'.format(iteration)
            evaluate_and_print_results(prefix, forward_step_func,
576
                                       valid_data_iterator, model,
Mohammad's avatar
Mohammad committed
577
                                       iteration, False)
578
579

        if args.exit_interval and iteration % args.exit_interval == 0:
580
            torch.distributed.barrier()
581
582
            time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
            rank = torch.distributed.get_rank()
Mohammad's avatar
Mohammad committed
583
584
585
            print_rank_0('rank: {} | time: {} | exiting the program at '
                         'iteration {}'.format(rank, time_str, iteration))
            sys.exit()
586

mohammad's avatar
mohammad committed
587
    return iteration
588
589


Mohammad's avatar
Mohammad committed
590
def evaluate(forward_step_func, data_iterator, model, verbose=False):
591
    """Evaluation."""
Mohammad's avatar
Mohammad committed
592
    args = get_args()
593
594
595
596
597
598
599
600
601
602
603
604
605

    # Turn on evaluation mode which disables dropout.
    model.eval()

    total_loss_dict = {}

    with torch.no_grad():
        iteration = 0
        while iteration < args.eval_iters:
            iteration += 1
            if verbose and iteration % args.log_interval == 0:
                print_rank_0('Evaluating iter {}/{}'.format(iteration,
                                                            args.eval_iters))
606

607
            if not mpu.is_pipeline_first_stage():
608
609
610
611
612
613
614
615
                input_tensor, _ = communicate(
                    tensor_send_next=None,
                    tensor_send_prev=None,
                    recv_forward=True,
                    recv_backward=False)
            else:
                input_tensor = None

616
            # Forward evaluation.
617
618
            output_tensor = forward_step_func(data_iterator, model, input_tensor)

619
            if mpu.is_pipeline_last_stage():
620
621
622
623
624
625
626
627
628
629
630
631
                _, loss_dict = output_tensor
                # Reduce across processes.
                for key in loss_dict:
                    total_loss_dict[key] = total_loss_dict.get(key, 0.) + \
                        loss_dict[key]
            else:
                communicate(
                    tensor_send_next=output_tensor,
                    tensor_send_prev=None,
                    recv_forward=False,
                    recv_backward=False)

632
633
634
635
636
637
638
639
640
641
642
    # Move model back to the train mode.
    model.train()

    for key in total_loss_dict:
        total_loss_dict[key] /= args.eval_iters

    return total_loss_dict


def evaluate_and_print_results(prefix, forward_step_func,
                               data_iterator, model,
Mohammad's avatar
Mohammad committed
643
                               iteration, verbose=False):
644
    """Helper function to evaluate and dump results on screen."""
Mohammad's avatar
Mohammad committed
645
646
647
    writer = get_tensorboard_writer()

    total_loss_dict = evaluate(forward_step_func, data_iterator, model, verbose)
648
649
650
651
652
653
654
655
656
657
658
659
    string = ' validation loss at {} | '.format(prefix)
    for key in total_loss_dict:
        string += '{} value: {:.6E} | '.format(key, total_loss_dict[key].item())
        ppl = math.exp(min(20, total_loss_dict[key].item()))
        string += '{} PPL: {:.6E} | '.format(key, ppl)
        if writer and torch.distributed.get_rank() == 0:
            writer.add_scalar('{} value'.format(key),
                              total_loss_dict[key].item(),
                              iteration)
            writer.add_scalar('{} ppl'.format(key), ppl, iteration)

    length = len(string) + 1
660
661
662
    print_rank_last('-' * length)
    print_rank_last(string)
    print_rank_last('-' * length)
663
664


665
666
667
def build_train_valid_test_data_iterators(
        build_train_valid_test_datasets_provider):
    """XXX"""
Mohammad's avatar
Mohammad committed
668
    args = get_args()
669

670
671
672
673
    (train_dataloader, valid_dataloader, test_dataloader) = (None, None, None)

    print_rank_0('> building train, validation, and test datasets ...')
    # Data loader only on rank 0 of each model parallel group.
674
    if mpu.get_tensor_model_parallel_rank() == 0:
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
        # Rank, size, and global batch size.
        data_parallel_size = mpu.get_data_parallel_world_size()
        global_batch_size = args.batch_size * data_parallel_size

        # Number of train/valid/test samples.
        train_iters = args.train_iters
        eval_iters = (train_iters // args.eval_interval + 1) * args.eval_iters
        test_iters = args.eval_iters
        train_val_test_num_samples = [train_iters * global_batch_size,
                                      eval_iters * global_batch_size,
                                      test_iters * global_batch_size]
        print_rank_0(' > datasets target sizes (minimum size):')
        print_rank_0('    train:      {}'.format(train_val_test_num_samples[0]))
        print_rank_0('    validation: {}'.format(train_val_test_num_samples[1]))
        print_rank_0('    test:       {}'.format(train_val_test_num_samples[2]))

        # Build the datasets.
        train_ds, valid_ds, test_ds = build_train_valid_test_datasets_provider(
            train_val_test_num_samples)

        # Build dataloders.
        train_dataloader = make_data_loader(train_ds)
        valid_dataloader = make_data_loader(valid_ds)
        test_dataloader = make_data_loader(test_ds)

        # Flags to know if we need to do training/validation/testing.
        do_train = train_dataloader is not None and args.train_iters > 0
        do_valid = valid_dataloader is not None and args.eval_iters > 0
        do_test = test_dataloader is not None and args.eval_iters > 0
        # Need to broadcast num_tokens and num_type_tokens.
        flags = torch.cuda.LongTensor(
            [int(do_train), int(do_valid), int(do_test)])
    else:
        flags = torch.cuda.LongTensor([0, 0, 0])

    # Broadcast num tokens.
    torch.distributed.broadcast(flags,
712
713
                                mpu.get_tensor_model_parallel_src_rank(),
                                group=mpu.get_tensor_model_parallel_group())
714
715
716
717
    args.do_train = flags[0].item()
    args.do_valid = flags[1].item()
    args.do_test = flags[2].item()

Mohammad's avatar
Mohammad committed
718
    # Shift the start iterations.
719
720
    if train_dataloader is not None:
        train_dataloader.batch_sampler.start_iter = args.iteration % \
Neel Kant's avatar
Neel Kant committed
721
            len(train_dataloader)
Mohammad's avatar
Mohammad committed
722
        print_rank_0('setting training data start iteration to {}'.
723
724
                     format(train_dataloader.batch_sampler.start_iter))
    if valid_dataloader is not None:
Mohammad's avatar
Mohammad committed
725
        start_iter_val = (args.iteration // args.eval_interval) * \
Neel Kant's avatar
Neel Kant committed
726
            args.eval_iters
727
        valid_dataloader.batch_sampler.start_iter = start_iter_val % \
Neel Kant's avatar
Neel Kant committed
728
            len(valid_dataloader)
Mohammad's avatar
Mohammad committed
729
        print_rank_0('setting validation data start iteration to {}'.
730
                     format(valid_dataloader.batch_sampler.start_iter))
731

732
733
734
    # Build iterators.
    if train_dataloader is not None:
        train_data_iterator = iter(train_dataloader)
735
736
737
    else:
        train_data_iterator = None

738
739
    if valid_dataloader is not None:
        valid_data_iterator = iter(valid_dataloader)
740
    else:
741
        valid_data_iterator = None
742

743
744
    if test_dataloader is not None:
        test_data_iterator = iter(test_dataloader)
745
746
747
    else:
        test_data_iterator = None

748
    return train_data_iterator, valid_data_iterator, test_data_iterator