training.py 34 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Mohammad's avatar
Mohammad committed
16
"""Pretrain utilities."""
17
18
19

from datetime import datetime
import math
Mohammad's avatar
Mohammad committed
20
import sys
21
22
23
24
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from apex.optimizers import FusedAdam as Adam

Neel Kant's avatar
Neel Kant committed
25
from megatron import get_args
Mohammad's avatar
Mohammad committed
26
27
from megatron import get_timers
from megatron import get_tensorboard_writer
28
from megatron import mpu
Neel Kant's avatar
Neel Kant committed
29
from megatron import print_rank_0
30
from megatron import print_rank_last
Mohammad's avatar
Mohammad committed
31
32
from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint
33
34
from megatron.fp16 import FP16_Module
from megatron.fp16 import FP16_Optimizer
Mohammad's avatar
Mohammad committed
35
from megatron.initialize import initialize_megatron
36
37
38
from megatron.learning_rates import AnnealingLR
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import get_params_for_weight_decay_optimization
Neel Kant's avatar
Neel Kant committed
39
from megatron.model.realm_model import ICTBertModel
40
from megatron.utils import check_adlr_autoresume_termination
41
from megatron.utils import make_data_loader
42
from megatron.utils import report_memory
43
44


45
def pretrain(train_valid_test_dataset_provider, model_provider,
46
             forward_step_func, extra_args_provider=None, args_defaults={}):
47
48
49
    """Main training program.

    This function will run the followings in the order provided:
Mohammad's avatar
Mohammad committed
50
51
        1) initialize Megatron.
        2) setup model, optimizer and lr schedule using the model_provider.
52
        3) call train_val_test_data_provider to get train/val/test datasets.
Mohammad's avatar
Mohammad committed
53
        4) train the modle using the forward_step_func.
54
55

    Arguments:
56
57
58
        train_valid_test_dataset_provider: a function that takes the size of
            train/valid/test dataset and returns `train, valid, test` datasets.
        model_provider: a function that returns a vanilla version of the
Mohammad's avatar
Mohammad committed
59
60
61
62
63
64
65
66
67
68
            model. By vanilla we mean a simple model on cpu with no fp16 or ddp.
        forward_step_func: a function that takes a `data iterator` and `model`,
            and returns a `loss` scalar with a dictionary with key:values being
            the info we would like to monitor during training, for example
            `lm-loss: value`. We also require that this function add
            `batch generator` to the timers class.
        extra_args_provider: a function that takes a parser and adds arguments
            to it. It is used for programs to add their own arguments.
        args_defaults: a dictionary from argument-name to argument-value. It
            to set already parse arguments.
69
70
    """

71
    # Initalize and get arguments, timers, and Tensorboard writer.
72
73
    initialize_megatron(extra_args_provider=extra_args_provider,
                        args_defaults=args_defaults)
74

75
    args = get_args()
Mohammad's avatar
Mohammad committed
76
    timers = get_timers()
77
78

    # Model, optimizer, and learning rate.
Mohammad's avatar
Mohammad committed
79
80
81
    timers('model and optimizer').start()
    model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider)
    timers('model and optimizer').stop()
82
83

    # Data stuff.
84
85
86
87
88
    timers('train/valid/test data iterators').start()
    train_data_iterator, valid_data_iterator, test_data_iterator \
        = build_train_valid_test_data_iterators(
            train_valid_test_dataset_provider)
    timers('train/valid/test data iterators').stop()
Mohammad's avatar
Mohammad committed
89
90
91

    # Print setup timing.
    print_rank_0('done with setups ...')
92
    timers.log(['model and optimizer', 'train/valid/test data iterators'])
Mohammad's avatar
Mohammad committed
93
    print_rank_0('training ...')
94
95

    iteration = 0
96
    if args.do_train and args.train_iters > 0:
mohammad's avatar
mohammad committed
97
98
99
        iteration = train(forward_step_func,
                          model, optimizer, lr_scheduler,
                          train_data_iterator, valid_data_iterator)
Mohammad's avatar
Mohammad committed
100

101
102
103
    if args.do_valid:
        prefix = 'the end of training for val data'
        evaluate_and_print_results(prefix, forward_step_func,
104
                                   valid_data_iterator, model,
Mohammad's avatar
Mohammad committed
105
                                   iteration, False)
106
107

    if args.save and iteration != 0:
108
        save_checkpoint(iteration, model, optimizer, lr_scheduler)
109
110
111
112
113
114

    if args.do_test:
        # Run on test data.
        prefix = 'the end of training for test data'
        evaluate_and_print_results(prefix, forward_step_func,
                                   test_data_iterator, model,
Mohammad's avatar
Mohammad committed
115
                                   0, True)
116
117


Mohammad's avatar
Mohammad committed
118
def get_model(model_provider_func):
119
    """Build the model."""
Mohammad's avatar
Mohammad committed
120
    args = get_args()
121
122

    # Build model on cpu.
Mohammad's avatar
Mohammad committed
123
    model = model_provider_func()
124
125
126

    # Print number of parameters.
    if mpu.get_data_parallel_rank() == 0:
127
        print(' > number of parameters on (tensor, pipeline) '
128
              'model parallel rank ({}, {}): {}'.format(
129
130
            mpu.get_tensor_model_parallel_rank(),
            mpu.get_pipeline_model_parallel_rank(),
131
132
133
134
135
136
137
138
139
140
            sum([p.nelement() for p in model.parameters()])), flush=True)

    # GPU allocation.
    model.cuda(torch.cuda.current_device())

    # Fp16 conversion.
    if args.fp16:
        model = FP16_Module(model)

    # Wrap model for distributed training."""
141
    if args.num_microbatches_in_minibatch > 1:
142
143
        assert args.DDP_impl == 'local'

144
145
    if args.DDP_impl == 'torch':
        i = torch.cuda.current_device()
Mohammad's avatar
Mohammad committed
146
147
        model = torchDDP(model, device_ids=[i], output_device=i,
                         process_group=mpu.get_data_parallel_group())
148
149
        return model
    if args.DDP_impl == 'local':
Mohammad's avatar
Mohammad committed
150
        model = LocalDDP(model)
151
152
        return model

153
    raise NotImplementedError('Unknown DDP implementation specified: {}. '
154
                              'Exiting.'.format(args.DDP_impl))
155
156


Mohammad's avatar
Mohammad committed
157
def get_optimizer(model):
158
    """Set up the optimizer."""
Mohammad's avatar
Mohammad committed
159
    args = get_args()
160
161

    # Build parameter groups (weight decay and non-decay).
Mohammad's avatar
Mohammad committed
162
    while isinstance(model, (torchDDP, LocalDDP, FP16_Module)):
163
164
165
166
167
168
        model = model.module
    param_groups = get_params_for_weight_decay_optimization(model)

    # Add model parallel attribute if it is not set.
    for param_group in param_groups:
        for param in param_group['params']:
169
170
            if not hasattr(param, 'tensor_model_parallel'):
                param.tensor_model_parallel = False
171
172

    # Use Adam.
173
174
    optimizer = Adam(param_groups, lr=args.lr, weight_decay=args.weight_decay,
        betas=(args.adam_beta1, args.adam_beta2), eps=args.adam_eps)
175
176
177
178
179
180
181
182

    # Wrap into fp16 optimizer.
    if args.fp16:
        optimizer = FP16_Optimizer(optimizer,
                                   static_loss_scale=args.loss_scale,
                                   dynamic_loss_scale=args.dynamic_loss_scale,
                                   dynamic_loss_args={
                                       'scale_window': args.loss_scale_window,
Neel Kant's avatar
Neel Kant committed
183
                                       'min_scale': args.min_scale,
184
185
186
187
188
                                       'delayed_shift': args.hysteresis})

    return optimizer


Mohammad's avatar
Mohammad committed
189
def get_learning_rate_scheduler(optimizer):
190
    """Build the learning rate scheduler."""
Mohammad's avatar
Mohammad committed
191
    args = get_args()
192
193
194
195
196
197
198

    # Add linear learning rate scheduler.
    if args.lr_decay_iters is not None:
        num_iters = args.lr_decay_iters
    else:
        num_iters = args.train_iters
    num_iters = max(1, num_iters)
Mohammad's avatar
Mohammad committed
199
    init_step = 0
200
201
202
203
204
    warmup_iter = args.warmup * num_iters
    lr_scheduler = AnnealingLR(
        optimizer,
        start_lr=args.lr,
        warmup_iter=warmup_iter,
Mohammad's avatar
Mohammad committed
205
        total_iters=num_iters,
206
207
208
209
210
211
212
213
214
        decay_style=args.lr_decay_style,
        last_iter=init_step,
        min_lr=args.min_lr,
        use_checkpoint_lr_scheduler=args.use_checkpoint_lr_scheduler,
        override_lr_scheduler=args.override_lr_scheduler)

    return lr_scheduler


Mohammad's avatar
Mohammad committed
215
def setup_model_and_optimizer(model_provider_func):
216
    """Setup model and optimizer."""
Mohammad's avatar
Mohammad committed
217
    args = get_args()
218

Mohammad's avatar
Mohammad committed
219
220
221
    model = get_model(model_provider_func)
    optimizer = get_optimizer(model)
    lr_scheduler = get_learning_rate_scheduler(optimizer)
222
223

    if args.load is not None:
224
        args.iteration = load_checkpoint(model, optimizer, lr_scheduler)
225
226
227
    else:
        args.iteration = 0

Neel Kant's avatar
Neel Kant committed
228
229
230
231
232
    # get model without FP16 and/or TorchDDP wrappers
    unwrapped_model = model
    while hasattr(unwrapped_model, 'module'):
        unwrapped_model = unwrapped_model.module

233
    if args.iteration == 0 and hasattr(unwrapped_model, 'init_state_dict_from_bert'):
234
        print("Initializing ICT from pretrained BERT model", flush=True)
235
        unwrapped_model.init_state_dict_from_bert()
Neel Kant's avatar
Neel Kant committed
236

237
238
239
    return model, optimizer, lr_scheduler


240
241
242
243
244
245
246
247
def communicate(tensor_send_next, tensor_send_prev, recv_forward, recv_backward):
    """Communicate tensors between stages using torch.distributed.ring_exchange(.) API."""
    args = get_args()

    # Create placeholder tensors for receive in forward and backward directions
    # if needed.
    tensor_recv_prev = None
    tensor_recv_next = None
248
    tensor_shape = (args.seq_length, args.batch_size, args.hidden_size)
249
250
251
    if recv_forward:
        tensor_recv_prev = torch.empty(tensor_shape,
                                       requires_grad=True,
252
253
                                       device=torch.cuda.current_device(),
                                       dtype=args.params_dtype)
254
255
256
    if recv_backward:
        tensor_recv_next = torch.empty(tensor_shape,
                                       requires_grad=True,
257
258
                                       device=torch.cuda.current_device(),
                                       dtype=args.params_dtype)
259
260
261
262
263
264

    # Send tensors in both the forward and backward directions as appropriate.
    torch.distributed.ring_exchange(tensor_send_prev=tensor_send_prev,
                                    tensor_recv_prev=tensor_recv_prev,
                                    tensor_send_next=tensor_send_next,
                                    tensor_recv_next=tensor_recv_next,
265
                                    group=mpu.get_pipeline_model_parallel_group())
266
267
268
269
270

    return tensor_recv_prev, tensor_recv_next


def backward_step(optimizer, model, input_tensor, output_tensor, output_tensor_grad):
271
    """Backward step."""
Mohammad's avatar
Mohammad committed
272
273
    args = get_args()
    timers = get_timers()
274

275
276
277
278
    # Retain the grad on the input_tensor.
    if input_tensor is not None:
        input_tensor.retain_grad()

279
    # Backward pass.
280
281
282
283
284
285
286
287
288
289
290
291
292
293
    if args.fp16:
        optimizer.backward(output_tensor, update_master_grads=False,
                           output_tensor_grad=output_tensor_grad)
    else:
        torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad)

    # Collect the grad of the input_tensor.
    input_tensor_grad = None
    if input_tensor is not None:
        input_tensor_grad = input_tensor.grad

    return input_tensor_grad


294
295
296
297
def forward_step_with_communication(forward_step_func, data_iterator, model,
                                    input_tensors, output_tensors,
                                    losses_reduced, timers):
    if not mpu.is_pipeline_first_stage():
298
        timers('forward-recv').start()
299
300
301
302
303
        input_tensor, _ = communicate(
            tensor_send_next=None,
            tensor_send_prev=None,
            recv_forward=True,
            recv_backward=False)
304
        timers('forward-recv').stop()
305
306
307
308
    else:
        input_tensor = None

    # Forward model for one step.
309
    timers('forward-compute').start()
310
    output_tensor = forward_step_func(data_iterator, model, input_tensor)
311
    timers('forward-compute').stop()
312
313
314
315
316
317

    if mpu.is_pipeline_last_stage():
        loss, loss_reduced = output_tensor
        output_tensor = loss
        losses_reduced.append(loss_reduced)
    else:
318
        timers('forward-send').start()
319
320
321
322
323
        communicate(
            tensor_send_next=output_tensor,
            tensor_send_prev=None,
            recv_forward=False,
            recv_backward=False)
324
        timers('forward-send').stop()
325
326
327
328
329
330
331
332
333
334
335
336
337

    input_tensors.append(input_tensor)
    output_tensors.append(output_tensor)


def backward_step_with_communication(optimizer, model, input_tensors, output_tensors, timers):
    """Backward step."""
    input_tensor = input_tensors.pop(0)
    output_tensor = output_tensors.pop(0)

    if mpu.is_pipeline_last_stage():
        output_tensor_grad = None
    else:
338
        timers('backward-recv').start()
339
340
341
342
343
        _, output_tensor_grad = communicate(
            tensor_send_next=None,
            tensor_send_prev=None,
            recv_forward=False,
            recv_backward=True)
344
        timers('backward-recv').stop()
345
346

    # Backward pass for one step.
347
    timers('backward-compute').start()
348
349
    input_grad_tensor = \
        backward_step(optimizer, model, input_tensor, output_tensor, output_tensor_grad)
350
    timers('backward-compute').stop()
351
352

    if not mpu.is_pipeline_first_stage():
353
        timers('backward-send').start()
354
355
356
357
358
        communicate(
            tensor_send_next=None,
            tensor_send_prev=input_grad_tensor,
            recv_forward=False,
            recv_backward=False)
359
        timers('backward-send').stop()
360
361


362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
def forward_and_backward_steps_with_communication(forward_step_func, data_iterator, model,
                                                  optimizer,
                                                  input_tensor, last_microbatch,
                                                  input_tensors, output_tensors,
                                                  losses_reduced, timers):
    # Forward model for one step.
    timers('forward-compute').start()
    output_tensor = forward_step_func(data_iterator, model, input_tensor)
    timers('forward-compute').stop()

    if mpu.is_pipeline_last_stage():
        loss, loss_reduced = output_tensor
        output_tensor = loss
        output_tensor_grad = None
        losses_reduced.append(loss_reduced)
    else:
Deepak Narayanan's avatar
Deepak Narayanan committed
378
        timers('forward-send-backward-recv').start()
379
380
381
382
383
        _, output_tensor_grad = communicate(
            tensor_send_next=output_tensor,
            tensor_send_prev=None,
            recv_forward=False,
            recv_backward=True)
Deepak Narayanan's avatar
Deepak Narayanan committed
384
        timers('forward-send-backward-recv').stop()
385
386
387
388
389
390
391
392
393
394
395
396
397
398

    input_tensors.append(input_tensor)
    output_tensors.append(output_tensor)

    input_tensor = input_tensors.pop(0)
    output_tensor = output_tensors.pop(0)

    # Backward pass for one step.
    timers('backward-compute').start()
    input_grad_tensor = \
        backward_step(optimizer, model, input_tensor, output_tensor, output_tensor_grad)
    timers('backward-compute').stop()

    if not mpu.is_pipeline_first_stage():
Deepak Narayanan's avatar
Deepak Narayanan committed
399
        timers('backward-send-forward-recv').start()
400
401
402
403
404
        input_tensor, _ = communicate(
            tensor_send_next=None,
            tensor_send_prev=input_grad_tensor,
            recv_forward=(not last_microbatch),
            recv_backward=False)
Deepak Narayanan's avatar
Deepak Narayanan committed
405
        timers('backward-send-forward-recv').stop()
406
407
408
409
410
411
    else:
        input_tensor = None

    return input_tensor


412
413
414
def forward_backward_no_pipelining(forward_step_func, data_iterator, model,
                                   optimizer, timers):
    """Run forward and backward passes without inter-stage communication."""
415
416
    args = get_args()

417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
    losses_reduced = []
    for i in range(args.num_microbatches_in_minibatch):
        timers('forward-compute').start()
        loss, loss_reduced = forward_step_func(data_iterator, model, input_tensor=None)
        output_tensor = loss
        losses_reduced.append(loss_reduced)
        timers('forward-compute').stop()

        timers('backward-compute').start()
        output_tensor_grad = None
        backward_step(optimizer, model, input_tensor=None,
                      output_tensor=output_tensor, output_tensor_grad=None)
        timers('backward-compute').stop()

    return losses_reduced
432

433
434
435
436
437
438
439

def forward_backward_pipelining(forward_step_func, data_iterator, model,
                                optimizer, timers):
    """Run 1F1B schedule, with communication and warmup + cooldown microbatches as needed."""
    args = get_args()

    # Compute number of warmup microbatches.
440
    num_microbatches_in_minibatch = args.num_microbatches_in_minibatch
441
442
443
444
445
446
    num_warmup_microbatches = \
        (mpu.get_pipeline_model_parallel_world_size() -
         mpu.get_pipeline_model_parallel_rank() - 1)
    num_warmup_microbatches = min(
        num_warmup_microbatches,
        num_microbatches_in_minibatch)
447
448
    num_microbatches_in_minibatch_remaining = \
        num_microbatches_in_minibatch - num_warmup_microbatches
449
450
451
452
453

    input_tensors = []
    output_tensors = []
    losses_reduced = []

454
455
    # Run warmup forward passes.
    for i in range(num_warmup_microbatches):
456
457
458
459
        forward_step_with_communication(
            forward_step_func, data_iterator, model,
            input_tensors, output_tensors,
            losses_reduced, timers)
460

461
    # Before running 1F1B, need to receive first forward tensor.
462
463
464
    # If all microbatches are run in warmup / cooldown phase, then no need to
    # receive this tensor here.
    if num_microbatches_in_minibatch_remaining > 0:
465
466
467
        if mpu.is_pipeline_first_stage():
            input_tensor = None
        else:
468
            timers('forward-recv').start()
469
470
471
472
            input_tensor, _ = communicate(tensor_send_next=None,
                                          tensor_send_prev=None,
                                          recv_forward=True,
                                          recv_backward=False)
473
            timers('forward-recv').stop()
474
475

    # Run 1F1B.
476
477
    for i in range(num_microbatches_in_minibatch_remaining):
        last_iteration = (i == (num_microbatches_in_minibatch_remaining - 1))
478
479
480
481
482
483
484
        input_tensor = \
            forward_and_backward_steps_with_communication(forward_step_func, data_iterator, model,
                                                          optimizer,
                                                          input_tensor, last_iteration,
                                                          input_tensors, output_tensors,
                                                          losses_reduced, timers)

485
486
    # Run cooldown backward passes.
    for i in range(num_warmup_microbatches):
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
        backward_step_with_communication(
            optimizer, model, input_tensors, output_tensors, timers)

    return losses_reduced


def train_step(forward_step_func, data_iterator,
               model, optimizer, lr_scheduler):
    """Single training step."""
    args = get_args()
    timers = get_timers()

    # Set grad to zero.
    if args.fp16:
        optimizer.zero_grad(set_grads_to_None=True)
    else:
        optimizer.zero_grad()

    if mpu.get_pipeline_model_parallel_world_size() > 1:
        losses_reduced = forward_backward_pipelining(
            forward_step_func, data_iterator, model, optimizer, timers)
    else:
        losses_reduced = forward_backward_no_pipelining(
            forward_step_func, data_iterator, model, optimizer, timers)
511
512
513

    # All-reduce if needed.
    if args.DDP_impl == 'local':
514
        timers('backward-params-all-reduce').start()
515
516
        model.allreduce_params(reduce_after=False,
                               fp32_allreduce=args.fp32_allreduce)
517
        timers('backward-params-all-reduce').stop()
518

519
520
521
522
    # All-reduce word_embeddings' grad across first and last stages to ensure
    # that word_embeddings parameters stay in sync.
    # This should only run for models that support pipelined model parallelism
    # (BERT and GPT-2).
523
    timers('backward-embedding-all-reduce').start()
524
    if (mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage()) and \
525
            mpu.get_pipeline_model_parallel_world_size() > 1:
526
527
528
529
530
531
532
        unwrapped_model = model
        while isinstance(unwrapped_model, (torchDDP, LocalDDP, FP16_Module)):
            unwrapped_model = unwrapped_model.module

        word_embeddings_weight = unwrapped_model.word_embeddings_weight()
        torch.distributed.all_reduce(word_embeddings_weight.grad,
                                     group=mpu.get_embedding_group())
533
    timers('backward-embedding-all-reduce').stop()
534

535
536
537
538
539
540
    # Update master gradients.
    timers('backward-master-grad').start()
    if args.fp16:
        optimizer.update_master_grads()
    timers('backward-master-grad').stop()

541
    # Clipping gradients helps prevent the exploding gradient.
542
    timers('backward-clip-grad').start()
543
    if args.clip_grad > 0.:
544
        if not args.fp16:
545
546
547
548
549
550
551
552
            named_parameters = model.named_parameters()
            parameters = []
            parameter_names = []
            for parameter_name, parameter in model.named_parameters():
                parameters.append(parameter)
                parameter_names.append(parameter_name)
            mpu.clip_grad_norm(parameters, args.clip_grad,
                               parameter_names=parameter_names)
553
554
        else:
            optimizer.clip_master_grads(args.clip_grad)
555
    timers('backward-clip-grad').stop()
556
557
558
559
560
561
562
563
564
565
566
567
568

    # Update parameters.
    timers('optimizer').start()
    optimizer.step()
    timers('optimizer').stop()

    # Update learning rate.
    skipped_iter = 0
    if not (args.fp16 and optimizer.overflow):
        lr_scheduler.step()
    else:
        skipped_iter = 1

569
    if mpu.is_pipeline_last_stage():
570
571
572
573
        # Average loss across microbatches.
        loss_reduced = {}
        for key in losses_reduced[0]:
            losses_reduced_for_key = [x[key] for x in losses_reduced]
574
            loss_reduced[key] = sum(losses_reduced_for_key)
575
576
        return loss_reduced, skipped_iter
    return {}, skipped_iter
577
578


Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
579
def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
mohammad's avatar
mohammad committed
580
                 loss_scale, report_memory_flag, skipped_iter):
Mohammad's avatar
Mohammad committed
581
582
583
584
    """Log training information such as losses, timing, ...."""
    args = get_args()
    timers = get_timers()
    writer = get_tensorboard_writer()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
585
586

    # Update losses.
mohammad's avatar
mohammad committed
587
588
589
    skipped_iters_key = 'skipped iterations'
    total_loss_dict[skipped_iters_key] = total_loss_dict.get(
        skipped_iters_key, 0) + skipped_iter
mohammad's avatar
mohammad committed
590
    got_nan_key = 'got nan'
mohammad's avatar
mohammad committed
591
592

    got_nan = False
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
593
    for key in loss_dict:
mohammad's avatar
mohammad committed
594
        if not skipped_iter:
595
596
            total_loss_dict[key] = total_loss_dict.get(
                key, torch.cuda.FloatTensor([0.0])) + loss_dict[key]
mohammad's avatar
mohammad committed
597
598
599
600
601
        else:
            value = loss_dict[key].float().sum().item()
            is_nan = value == float('inf') or \
                     value == -float('inf') or \
                     value != value
mohammad's avatar
mohammad committed
602
603
604
605
            got_nan = got_nan or is_nan

    total_loss_dict[got_nan_key] = total_loss_dict.get(
        got_nan_key, 0) + int(got_nan)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
606
607
608

    # Logging.
    timers_to_log = []
Neel Kant's avatar
Neel Kant committed
609

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
610
611
612
    def add_to_logging(name):
        if name in timers.timers:
            timers_to_log.append(name)
613
614
615
    add_to_logging('forward-compute')
    add_to_logging('forward-recv')
    add_to_logging('forward-send')
Deepak Narayanan's avatar
Deepak Narayanan committed
616
    add_to_logging('forward-send-backward-recv')
617
618
619
    add_to_logging('backward-compute')
    add_to_logging('backward-recv')
    add_to_logging('backward-send')
Deepak Narayanan's avatar
Deepak Narayanan committed
620
    add_to_logging('backward-send-forward-recv')
621
    add_to_logging('backward-master-grad')
622
    add_to_logging('backward-params-all-reduce')
623
    add_to_logging('backward-embedding-all-reduce')
624
    add_to_logging('backward-clip-grad')
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
    add_to_logging('optimizer')
    add_to_logging('batch generator')

    # Tensorboard values.
    if writer and torch.distributed.get_rank() == 0:
        writer.add_scalar('learning_rate', learning_rate, iteration)
        for key in loss_dict:
            writer.add_scalar(key, loss_dict[key], iteration)
        if args.fp16:
            writer.add_scalar('loss_scale', loss_scale, iteration)
        normalizer = iteration % args.log_interval
        if normalizer == 0:
            normalizer = args.log_interval
        timers.write(timers_to_log, writer, iteration,
                     normalizer=normalizer)

    if iteration % args.log_interval == 0:
        elapsed_time = timers('interval time').elapsed()
        if writer and torch.distributed.get_rank() == 0:
            writer.add_scalar('iteration_time',
                              elapsed_time / args.log_interval, iteration)
        log_string = ' iteration {:8d}/{:8d} |'.format(iteration,
                                                       args.train_iters)
        log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
            elapsed_time * 1000.0 / args.log_interval)
        log_string += ' learning rate: {:.3E} |'.format(learning_rate)
mohammad's avatar
mohammad committed
651
652
        num_iterations = max(
            1, args.log_interval - total_loss_dict[skipped_iters_key])
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
653
        for key in total_loss_dict:
mohammad's avatar
mohammad committed
654
            if key not in [skipped_iters_key, got_nan_key]:
mohammad's avatar
mohammad committed
655
                avg = total_loss_dict[key].item() / float(num_iterations)
656
657
658
                if avg > 0.0:
                    log_string += ' {}: {:.6E} |'.format(key, avg)
                total_loss_dict[key] = torch.cuda.FloatTensor([0.0])
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
659
660
        if args.fp16:
            log_string += ' loss scale: {:.1f} |'.format(loss_scale)
mohammad's avatar
mohammad committed
661
662
        log_string += ' number of skipped iterations: {:3d} |'.format(
            total_loss_dict[skipped_iters_key])
mohammad's avatar
mohammad committed
663
664
        log_string += ' number of nan iterations: {:3d} |'.format(
            total_loss_dict[got_nan_key])
mohammad's avatar
mohammad committed
665
        total_loss_dict[skipped_iters_key] = 0
mohammad's avatar
mohammad committed
666
        total_loss_dict[got_nan_key] = 0
667
        print_rank_last(log_string)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
668
669
670
671
672
673
674
675
        if report_memory_flag:
            report_memory('after {} iterations'.format(iteration))
            report_memory_flag = False
        timers.log(timers_to_log, normalizer=args.log_interval)

    return report_memory_flag


676
def train(forward_step_func, model, optimizer, lr_scheduler,
677
          train_data_iterator, valid_data_iterator):
678
    """Train the model function."""
Mohammad's avatar
Mohammad committed
679
680
    args = get_args()
    timers = get_timers()
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697

    # Turn on training mode which enables dropout.
    model.train()

    # Tracking loss.
    total_loss_dict = {}

    # Iterations.
    iteration = args.iteration

    timers('interval time').start()
    report_memory_flag = True
    while iteration < args.train_iters:
        loss_dict, skipped_iter = train_step(forward_step_func,
                                             train_data_iterator,
                                             model,
                                             optimizer,
Mohammad's avatar
Mohammad committed
698
                                             lr_scheduler)
699
700
701
        iteration += 1

        # Logging.
Mohammad's avatar
Mohammad committed
702
703
704
        loss_scale = None
        if args.fp16:
            loss_scale = optimizer.loss_scale
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
705
706
        report_memory_flag = training_log(loss_dict, total_loss_dict,
                                          optimizer.param_groups[0]['lr'],
Mohammad's avatar
Mohammad committed
707
                                          iteration, loss_scale,
mohammad's avatar
mohammad committed
708
                                          report_memory_flag, skipped_iter)
709
710

        # Autoresume
711
712
        if args.adlr_autoresume and \
           (iteration % args.adlr_autoresume_interval == 0):
713
            check_adlr_autoresume_termination(iteration, model, optimizer,
714
                                              lr_scheduler)
715
716
717
718

        # Checkpointing
        if args.save and args.save_interval and \
           iteration % args.save_interval == 0:
719
            save_checkpoint(iteration, model, optimizer, lr_scheduler)
720
721
722
723
724
725

        # Evaluation
        if args.eval_interval and iteration % args.eval_interval == 0 and \
           args.do_valid:
            prefix = 'iteration {}'.format(iteration)
            evaluate_and_print_results(prefix, forward_step_func,
726
                                       valid_data_iterator, model,
Mohammad's avatar
Mohammad committed
727
                                       iteration, False)
728
729

        if args.exit_interval and iteration % args.exit_interval == 0:
730
            torch.distributed.barrier()
731
732
            time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
            rank = torch.distributed.get_rank()
Mohammad's avatar
Mohammad committed
733
734
735
            print_rank_0('rank: {} | time: {} | exiting the program at '
                         'iteration {}'.format(rank, time_str, iteration))
            sys.exit()
736

mohammad's avatar
mohammad committed
737
    return iteration
738
739


Mohammad's avatar
Mohammad committed
740
def evaluate(forward_step_func, data_iterator, model, verbose=False):
741
    """Evaluation."""
Mohammad's avatar
Mohammad committed
742
    args = get_args()
743
744
745
746
747
748
749
750
751
752
753
754
755

    # Turn on evaluation mode which disables dropout.
    model.eval()

    total_loss_dict = {}

    with torch.no_grad():
        iteration = 0
        while iteration < args.eval_iters:
            iteration += 1
            if verbose and iteration % args.log_interval == 0:
                print_rank_0('Evaluating iter {}/{}'.format(iteration,
                                                            args.eval_iters))
756

757
            if not mpu.is_pipeline_first_stage():
758
759
760
761
762
763
764
765
                input_tensor, _ = communicate(
                    tensor_send_next=None,
                    tensor_send_prev=None,
                    recv_forward=True,
                    recv_backward=False)
            else:
                input_tensor = None

766
            # Forward evaluation.
767
768
            output_tensor = forward_step_func(data_iterator, model, input_tensor)

769
            if mpu.is_pipeline_last_stage():
770
771
772
773
774
775
776
777
778
779
780
781
                _, loss_dict = output_tensor
                # Reduce across processes.
                for key in loss_dict:
                    total_loss_dict[key] = total_loss_dict.get(key, 0.) + \
                        loss_dict[key]
            else:
                communicate(
                    tensor_send_next=output_tensor,
                    tensor_send_prev=None,
                    recv_forward=False,
                    recv_backward=False)

782
783
784
785
786
787
788
789
790
791
792
    # Move model back to the train mode.
    model.train()

    for key in total_loss_dict:
        total_loss_dict[key] /= args.eval_iters

    return total_loss_dict


def evaluate_and_print_results(prefix, forward_step_func,
                               data_iterator, model,
Mohammad's avatar
Mohammad committed
793
                               iteration, verbose=False):
794
    """Helper function to evaluate and dump results on screen."""
Mohammad's avatar
Mohammad committed
795
796
797
    writer = get_tensorboard_writer()

    total_loss_dict = evaluate(forward_step_func, data_iterator, model, verbose)
798
799
800
801
802
803
804
805
806
807
808
809
    string = ' validation loss at {} | '.format(prefix)
    for key in total_loss_dict:
        string += '{} value: {:.6E} | '.format(key, total_loss_dict[key].item())
        ppl = math.exp(min(20, total_loss_dict[key].item()))
        string += '{} PPL: {:.6E} | '.format(key, ppl)
        if writer and torch.distributed.get_rank() == 0:
            writer.add_scalar('{} value'.format(key),
                              total_loss_dict[key].item(),
                              iteration)
            writer.add_scalar('{} ppl'.format(key), ppl, iteration)

    length = len(string) + 1
810
811
812
    print_rank_last('-' * length)
    print_rank_last(string)
    print_rank_last('-' * length)
813
814


815
816
817
def build_train_valid_test_data_iterators(
        build_train_valid_test_datasets_provider):
    """XXX"""
Mohammad's avatar
Mohammad committed
818
    args = get_args()
819

820
821
822
823
    (train_dataloader, valid_dataloader, test_dataloader) = (None, None, None)

    print_rank_0('> building train, validation, and test datasets ...')
    # Data loader only on rank 0 of each model parallel group.
824
    if mpu.get_tensor_model_parallel_rank() == 0:
825
826
        # Rank, size, and global batch size.
        data_parallel_size = mpu.get_data_parallel_world_size()
827
        global_batch_size = args.batch_size * data_parallel_size * args.num_microbatches_in_minibatch
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861

        # Number of train/valid/test samples.
        train_iters = args.train_iters
        eval_iters = (train_iters // args.eval_interval + 1) * args.eval_iters
        test_iters = args.eval_iters
        train_val_test_num_samples = [train_iters * global_batch_size,
                                      eval_iters * global_batch_size,
                                      test_iters * global_batch_size]
        print_rank_0(' > datasets target sizes (minimum size):')
        print_rank_0('    train:      {}'.format(train_val_test_num_samples[0]))
        print_rank_0('    validation: {}'.format(train_val_test_num_samples[1]))
        print_rank_0('    test:       {}'.format(train_val_test_num_samples[2]))

        # Build the datasets.
        train_ds, valid_ds, test_ds = build_train_valid_test_datasets_provider(
            train_val_test_num_samples)

        # Build dataloders.
        train_dataloader = make_data_loader(train_ds)
        valid_dataloader = make_data_loader(valid_ds)
        test_dataloader = make_data_loader(test_ds)

        # Flags to know if we need to do training/validation/testing.
        do_train = train_dataloader is not None and args.train_iters > 0
        do_valid = valid_dataloader is not None and args.eval_iters > 0
        do_test = test_dataloader is not None and args.eval_iters > 0
        # Need to broadcast num_tokens and num_type_tokens.
        flags = torch.cuda.LongTensor(
            [int(do_train), int(do_valid), int(do_test)])
    else:
        flags = torch.cuda.LongTensor([0, 0, 0])

    # Broadcast num tokens.
    torch.distributed.broadcast(flags,
862
863
                                mpu.get_tensor_model_parallel_src_rank(),
                                group=mpu.get_tensor_model_parallel_group())
864
865
866
867
    args.do_train = flags[0].item()
    args.do_valid = flags[1].item()
    args.do_test = flags[2].item()

Mohammad's avatar
Mohammad committed
868
    # Shift the start iterations.
869
870
    if train_dataloader is not None:
        train_dataloader.batch_sampler.start_iter = args.iteration % \
Neel Kant's avatar
Neel Kant committed
871
            len(train_dataloader)
Mohammad's avatar
Mohammad committed
872
        print_rank_0('setting training data start iteration to {}'.
873
874
                     format(train_dataloader.batch_sampler.start_iter))
    if valid_dataloader is not None:
Mohammad's avatar
Mohammad committed
875
        start_iter_val = (args.iteration // args.eval_interval) * \
Neel Kant's avatar
Neel Kant committed
876
            args.eval_iters
877
        valid_dataloader.batch_sampler.start_iter = start_iter_val % \
Neel Kant's avatar
Neel Kant committed
878
            len(valid_dataloader)
Mohammad's avatar
Mohammad committed
879
        print_rank_0('setting validation data start iteration to {}'.
880
                     format(valid_dataloader.batch_sampler.start_iter))
881

882
883
884
    # Build iterators.
    if train_dataloader is not None:
        train_data_iterator = iter(train_dataloader)
885
886
887
    else:
        train_data_iterator = None

888
889
    if valid_dataloader is not None:
        valid_data_iterator = iter(valid_dataloader)
890
    else:
891
        valid_data_iterator = None
892

893
894
    if test_dataloader is not None:
        test_data_iterator = iter(test_dataloader)
895
896
897
    else:
        test_data_iterator = None

898
    return train_data_iterator, valid_data_iterator, test_data_iterator