training.py 35.3 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Mohammad's avatar
Mohammad committed
16
"""Pretrain utilities."""
17
18
19

from datetime import datetime
import math
Mohammad's avatar
Mohammad committed
20
import sys
21
22
23
24
import time
# The earliest we can measure the start time.
_TRAIN_START_TIME = time.time()

25
26
27
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP

Neel Kant's avatar
Neel Kant committed
28
from megatron import get_args
Mohammad's avatar
Mohammad committed
29
30
from megatron import get_timers
from megatron import get_tensorboard_writer
31
from megatron import get_current_global_batch_size
mohammad's avatar
mohammad committed
32
from megatron import get_num_microbatches
mohammad's avatar
mohammad committed
33
from megatron import is_last_rank
mohammad's avatar
mohammad committed
34
from megatron import update_num_microbatches
35
from megatron import mpu
Neel Kant's avatar
Neel Kant committed
36
from megatron import print_rank_0
37
from megatron import print_rank_last
Mohammad's avatar
Mohammad committed
38
39
from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
40
from megatron.model import Float16Module
mohammad's avatar
mohammad committed
41
from megatron.optimizer import get_megatron_optimizer
Mohammad's avatar
Mohammad committed
42
from megatron.initialize import initialize_megatron
43
from megatron.initialize import write_args_to_tensorboard
44
45
46
from megatron.learning_rates import AnnealingLR
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.utils import check_adlr_autoresume_termination
47
from megatron.utils import unwrap_model
Vijay Korthikanti's avatar
Vijay Korthikanti committed
48
from megatron.data.data_samplers import build_pretraining_data_loader
mohammad's avatar
mohammad committed
49
from megatron.utils import calc_params_l2_norm
50
from megatron.schedules import forward_backward_no_pipelining
51
from megatron.schedules import forward_backward_pipelining_without_interleaving
52
from megatron.schedules import forward_backward_pipelining_with_interleaving
53
from megatron.utils import report_memory
54
55


Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
56

57
58
59
60
61
62
63
def print_datetime(string):
    """Note that this call will sync across all ranks."""
    torch.distributed.barrier()
    time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    print_rank_0('[' + string + '] datetime: {} '.format(time_str))


64
65
66
67
def pretrain(train_valid_test_dataset_provider, 
             model_provider,
             forward_step_func, 
             extra_args_provider=None, 
Vijay Korthikanti's avatar
Vijay Korthikanti committed
68
             args_defaults={}):
69
70
71
    """Main training program.

    This function will run the followings in the order provided:
Mohammad's avatar
Mohammad committed
72
73
        1) initialize Megatron.
        2) setup model, optimizer and lr schedule using the model_provider.
74
        3) call train_val_test_data_provider to get train/val/test datasets.
Mohammad's avatar
Mohammad committed
75
        4) train the modle using the forward_step_func.
76
77

    Arguments:
78
79
80
        train_valid_test_dataset_provider: a function that takes the size of
            train/valid/test dataset and returns `train, valid, test` datasets.
        model_provider: a function that returns a vanilla version of the
Mohammad's avatar
Mohammad committed
81
82
83
84
85
86
87
88
89
90
            model. By vanilla we mean a simple model on cpu with no fp16 or ddp.
        forward_step_func: a function that takes a `data iterator` and `model`,
            and returns a `loss` scalar with a dictionary with key:values being
            the info we would like to monitor during training, for example
            `lm-loss: value`. We also require that this function add
            `batch generator` to the timers class.
        extra_args_provider: a function that takes a parser and adds arguments
            to it. It is used for programs to add their own arguments.
        args_defaults: a dictionary from argument-name to argument-value. It
            to set already parse arguments.
91
92
    """

93
    # Initalize and get arguments, timers, and Tensorboard writer.
94
95
    initialize_megatron(extra_args_provider=extra_args_provider,
                        args_defaults=args_defaults)
96

97
98
99
100
101
102
103
104
    # Adjust the startup time so it reflects the largest value.
    # This will be closer to what scheduler will see (outside of
    # image ... launches.
    global _TRAIN_START_TIME
    start_time_tensor = torch.cuda.FloatTensor([_TRAIN_START_TIME])
    torch.distributed.all_reduce(start_time_tensor,
                                 op=torch.distributed.ReduceOp.MIN)
    _TRAIN_START_TIME = start_time_tensor.item()
mshoeybi's avatar
mshoeybi committed
105
    print_rank_0('time to initialize megatron (seconds): {:.3f}'.format(
106
107
108
        time.time() - _TRAIN_START_TIME))
    print_datetime('after megatron is initialized')

109
    args = get_args()
Mohammad's avatar
Mohammad committed
110
    timers = get_timers()
111
112

    # Model, optimizer, and learning rate.
113
    timers('model-and-optimizer-setup').start()
Mohammad's avatar
Mohammad committed
114
    model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider)
115
    timers('model-and-optimizer-setup').stop()
116
117
    print_datetime('after model, optimizer, and learning rate '
                   'scheduler are built')
118
119

    # Data stuff.
120
121
    timers('train/valid/test-data-iterators-setup').start()
    if args.virtual_pipeline_model_parallel_size is not None:
122
        all_data_iterators = [
123
124
125
            build_train_valid_test_data_iterators(train_valid_test_dataset_provider)
            for _ in range(len(model))
        ]
126
127
128
        train_data_iterator = [data_iterators[0] for data_iterators in all_data_iterators]
        valid_data_iterator = [data_iterators[1] for data_iterators in all_data_iterators]
        test_data_iterator = [data_iterators[2] for data_iterators in all_data_iterators]
129
130
131
132
133
    else:
        train_data_iterator, valid_data_iterator, test_data_iterator \
            = build_train_valid_test_data_iterators(
                train_valid_test_dataset_provider)
    timers('train/valid/test-data-iterators-setup').stop()
mshoeybi's avatar
mshoeybi committed
134
    print_datetime('after dataloaders are built')
Mohammad's avatar
Mohammad committed
135
136

    # Print setup timing.
137
138
    print_rank_0('done with setup ...')
    timers.log(['model-and-optimizer-setup', 'train/valid/test-data-iterators-setup'])
Mohammad's avatar
Mohammad committed
139
    print_rank_0('training ...')
140
141

    iteration = 0
142
    if args.do_train and args.train_iters > 0:
mohammad's avatar
mohammad committed
143
144
145
        iteration = train(forward_step_func,
                          model, optimizer, lr_scheduler,
                          train_data_iterator, valid_data_iterator)
146
    print_datetime('after training is done')
Mohammad's avatar
Mohammad committed
147

148
149
150
    if args.do_valid:
        prefix = 'the end of training for val data'
        evaluate_and_print_results(prefix, forward_step_func,
151
                                   valid_data_iterator, model,
Mohammad's avatar
Mohammad committed
152
                                   iteration, False)
153
154

    if args.save and iteration != 0:
155
        save_checkpoint(iteration, model, optimizer, lr_scheduler)
156
157
158
159
160
161

    if args.do_test:
        # Run on test data.
        prefix = 'the end of training for test data'
        evaluate_and_print_results(prefix, forward_step_func,
                                   test_data_iterator, model,
Mohammad's avatar
Mohammad committed
162
                                   0, True)
163

164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
def update_train_iters(args):

    # For iteration-based training, we don't need to do anything
    if args.train_iters:
        return

    # Constant batch size with sample-based training.
    if args.rampup_batch_size is None:
        args.train_iters = args.train_samples // args.global_batch_size

    else:
        # Sample based training with rampup batch size.
        iterations = 0
        consumed_samples = 0
        # Rampup phase.
        while consumed_samples <= int(args.rampup_batch_size[2]):
180
181
            update_num_microbatches(consumed_samples, consistency_check=False)
            consumed_samples += get_current_global_batch_size()
182
183
            iterations += 1
        # Reset
184
        update_num_microbatches(0, consistency_check=False)
185
186
187
188
189
190
191
192
        # Constant phase
        # Note that we throw away any partial last batch.
        iterations += (args.train_samples - consumed_samples) // \
                      args.global_batch_size
        args.train_iters = iterations

    print_rank_0('setting training iterations to {}'.format(args.train_iters))

193

Mohammad's avatar
Mohammad committed
194
def get_model(model_provider_func):
195
    """Build the model."""
Mohammad's avatar
Mohammad committed
196
    args = get_args()
197
198

    # Build model on cpu.
Mohammad's avatar
Mohammad committed
199
    model = model_provider_func()
200
201
    if not isinstance(model, list):
        model = [model]
202

203
    # Set tensor model parallel attributes if not set.
mohammad's avatar
mohammad committed
204
205
206
    # Only parameters that are already tensor model parallel have these
    # attributes set for them. We should make sure the default attributes
    # are set for all params so the optimizer can use them.
207
208
209
    for model_module in model:
        for param in model_module.parameters():
            mpu.set_defaults_if_not_set_tensor_model_parallel_attributes(param)
210

211
212
    # Print number of parameters.
    if mpu.get_data_parallel_rank() == 0:
213
        print(' > number of parameters on (tensor, pipeline) '
214
              'model parallel rank ({}, {}): {}'.format(
215
216
            mpu.get_tensor_model_parallel_rank(),
            mpu.get_pipeline_model_parallel_rank(),
217
218
            sum([sum([p.nelement() for p in model_module.parameters()])
                 for model_module in model])), flush=True)
219
220

    # GPU allocation.
221
222
    for model_module in model:
        model_module.cuda(torch.cuda.current_device())
223
224

    # Fp16 conversion.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
225
226
227
228
229
230
231
232
233
234
235
236
    if args.fp16 or args.bf16:
        model = [Float16Module(model_module, args) for model_module in model]
        # For now, the layer norm does not support input float32 and outut bf16.
        # For this, we move layernorm parameters to fp32 and cast output of the
        # layernorm operation back to bf16.
        if args.bf16 and args.fp32_residual_connection:
            from megatron.model import import_layernorm
            LayerNorm = import_layernorm(args.fp32_residual_connection, args.bf16)
            for model_ in model:
                for module_ in model_.modules():
                    if isinstance(module_, LayerNorm):
                        module_.float()
237
238
239

    if args.DDP_impl == 'torch':
        i = torch.cuda.current_device()
240
241
242
        model = [torchDDP(model_module, device_ids=[i], output_device=i,
                          process_group=mpu.get_data_parallel_group())
                 for model_module in model]
243
        return model
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
244
    
245
    if args.DDP_impl == 'local':
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
246
247
248
249
        model = [LocalDDP(model_module,
                          args.accumulate_allreduce_grads_in_fp32,
                          args.use_contiguous_buffers_in_ddp)
                 for model_module in model]
250
251
        return model

252
    raise NotImplementedError('Unknown DDP implementation specified: {}. '
253
                              'Exiting.'.format(args.DDP_impl))
254
255


Mohammad's avatar
Mohammad committed
256
def get_learning_rate_scheduler(optimizer):
257
    """Build the learning rate scheduler."""
Mohammad's avatar
Mohammad committed
258
    args = get_args()
259

260
261
262
263
264
    # Iteration-based training.
    if args.train_iters:
        if args.lr_decay_iters is None:
            args.lr_decay_iters = args.train_iters
        decay_steps = args.lr_decay_iters * args.global_batch_size
265
266
        if args.lr_warmup_fraction is not None:
            warmup_steps = args.lr_warmup_fraction * decay_steps
267
268
        else:
            warmup_steps = args.lr_warmup_iters * args.global_batch_size
269
270
271
272
273
    # Sample-based training.
    elif args.train_samples:
        # We need to set training iters for later use. Technically
        # we need to adjust the training samples too (due to last
        # batch being incomplete) but we leave it as is for now.
274
        update_train_iters(args)
275
276
277
        if args.lr_decay_samples is None:
            args.lr_decay_samples = args.train_samples
        decay_steps = args.lr_decay_samples
278
279
        if args.lr_warmup_fraction is not None:
            warmup_steps = args.lr_warmup_fraction * decay_steps
280
281
        else:
            warmup_steps = args.lr_warmup_samples
282
    else:
283
284
285
        raise Exception(
            'either train-iters or train-samples should be provided.')

286
287
    lr_scheduler = AnnealingLR(
        optimizer,
288
        max_lr=args.lr,
289
        min_lr=args.min_lr,
290
291
        warmup_steps=warmup_steps,
        decay_steps=decay_steps,
292
        decay_style=args.lr_decay_style,
293
294
295
296
297
298
        use_checkpoint_lr_scheduler=args.use_checkpoint_lr_scheduler,
        override_lr_scheduler=args.override_lr_scheduler)

    return lr_scheduler


Mohammad's avatar
Mohammad committed
299
def setup_model_and_optimizer(model_provider_func):
300
    """Setup model and optimizer."""
Mohammad's avatar
Mohammad committed
301
    args = get_args()
302

Mohammad's avatar
Mohammad committed
303
    model = get_model(model_provider_func)
304

305
    unwrapped_model = unwrap_model(model,
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
306
                                   (torchDDP, LocalDDP, Float16Module))
307
308
    optimizer = get_megatron_optimizer(unwrapped_model)

Mohammad's avatar
Mohammad committed
309
    lr_scheduler = get_learning_rate_scheduler(optimizer)
310
311

    if args.load is not None:
312
313
314
315
        timers = get_timers()
        # Extra barrier is added to make sure all ranks report the
        # max time.
        torch.distributed.barrier()
316
        timers('load-checkpoint').start()
317
        args.iteration = load_checkpoint(model, optimizer, lr_scheduler)
318
        torch.distributed.barrier()
319
320
        timers('load-checkpoint').stop()
        timers.log(['load-checkpoint'])
321
322
323
    else:
        args.iteration = 0

mohammad's avatar
mohammad committed
324
    # We only support local DDP with multiple micro-batches.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
325
    if len(model) > 1 or mpu.get_pipeline_model_parallel_world_size() > 1:
mohammad's avatar
mohammad committed
326
327
        assert args.DDP_impl == 'local'

Neel Kant's avatar
Neel Kant committed
328
    # get model without FP16 and/or TorchDDP wrappers
Mostofa Patwary's avatar
Mostofa Patwary committed
329
330
    if args.iteration == 0 and len(unwrapped_model) == 1 \
        and hasattr(unwrapped_model[0], 'init_state_dict_from_bert'):
Mostofa Patwary's avatar
Mostofa Patwary committed
331
        print_rank_0("Initializing ICT from pretrained BERT model")
Mostofa Patwary's avatar
Mostofa Patwary committed
332
        unwrapped_model[0].init_state_dict_from_bert()
Mostofa Patwary's avatar
Mostofa Patwary committed
333
334
        if args.fp16:
            optimizer.reload_model_params()
Neel Kant's avatar
Neel Kant committed
335

336
337
338
    return model, optimizer, lr_scheduler


339
340
341
342
343
344
345
def train_step(forward_step_func, data_iterator,
               model, optimizer, lr_scheduler):
    """Single training step."""
    args = get_args()
    timers = get_timers()

    # Set grad to zero.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
346
347
348
349
350
    if args.DDP_impl == 'local' and args.use_contiguous_buffers_in_ddp:
        for partition in model:
            partition.zero_grad_buffer()
    else:
        optimizer.zero_grad()
351
352

    if mpu.get_pipeline_model_parallel_world_size() > 1:
353
354
        if args.virtual_pipeline_model_parallel_size is not None:
            forward_backward_func = forward_backward_pipelining_with_interleaving
355
356
357
            assert get_num_microbatches() % args.pipeline_model_parallel_size == 0, \
                'number of microbatches is not divisible by pipeline-parallel ' \
                'size when using interleaved schedule'
358
        else:
359
            forward_backward_func = forward_backward_pipelining_without_interleaving
360
    else:
361
362
363
364
        forward_backward_func = forward_backward_no_pipelining
    losses_reduced = forward_backward_func(
        forward_step_func, data_iterator, model,
        optimizer, timers, forward_only=False)
365
366
367

    # All-reduce if needed.
    if args.DDP_impl == 'local':
368
        timers('backward-params-all-reduce').start()
369
        for model_module in model:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
370
            model_module.allreduce_gradients()
371
        timers('backward-params-all-reduce').stop()
372

373
374
375
376
    # All-reduce word_embeddings' grad across first and last stages to ensure
    # that word_embeddings parameters stay in sync.
    # This should only run for models that support pipelined model parallelism
    # (BERT and GPT-2).
377
    timers('backward-embedding-all-reduce').start()
378
379
    if (mpu.is_pipeline_first_stage(ignore_virtual=True) or
        mpu.is_pipeline_last_stage(ignore_virtual=True)) and \
380
            mpu.get_pipeline_model_parallel_world_size() > 1:
381
382
383
384
        if mpu.is_pipeline_first_stage(ignore_virtual=True):
            unwrapped_model = model[0]
        elif mpu.is_pipeline_last_stage(ignore_virtual=True):
            unwrapped_model = model[-1]
385
        unwrapped_model = unwrap_model(
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
386
            unwrapped_model, (torchDDP, LocalDDP, Float16Module))
387

388
389
        if unwrapped_model.share_word_embeddings:
            word_embeddings_weight = unwrapped_model.word_embeddings_weight()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
390
391
392
393
394
            if args.DDP_impl == 'local':
                grad = word_embeddings_weight.main_grad
            else:
                grad = word_embeddings_weight.grad
            torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())
395
    timers('backward-embedding-all-reduce').stop()
396

397
398
    # Update parameters.
    timers('optimizer').start()
399
    update_successful, grad_norm, num_zeros_in_grad = optimizer.step()
400
401
402
    timers('optimizer').stop()

    # Update learning rate.
403
    if update_successful:
404
405
406
407
        increment = get_num_microbatches() * \
                    args.micro_batch_size * \
                    args.data_parallel_size
        lr_scheduler.step(increment=increment)
mohammad's avatar
mohammad committed
408
        skipped_iter = 0
409
410
411
    else:
        skipped_iter = 1

412
    if mpu.is_pipeline_last_stage(ignore_virtual=True):
413
414
415
416
        # Average loss across microbatches.
        loss_reduced = {}
        for key in losses_reduced[0]:
            losses_reduced_for_key = [x[key] for x in losses_reduced]
417
            loss_reduced[key] = sum(losses_reduced_for_key) / len(losses_reduced_for_key)
418
419
        return loss_reduced, skipped_iter, grad_norm, num_zeros_in_grad
    return {}, skipped_iter, grad_norm, num_zeros_in_grad
420
421


Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
422
def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
mohammad's avatar
mohammad committed
423
                 loss_scale, report_memory_flag, skipped_iter,
424
                 grad_norm, params_norm, num_zeros_in_grad):
Mohammad's avatar
Mohammad committed
425
426
427
428
    """Log training information such as losses, timing, ...."""
    args = get_args()
    timers = get_timers()
    writer = get_tensorboard_writer()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
429

mohammad's avatar
mohammad committed
430
431
    # Advanced, skipped, and Nan iterations.
    advanced_iters_key = 'advanced iterations'
mohammad's avatar
mohammad committed
432
    skipped_iters_key = 'skipped iterations'
mohammad's avatar
mohammad committed
433
434
435
436
437
438
439
440
441
    nan_iters_key = 'nan iterations'
    # Advanced iterations.
    if not skipped_iter:
        total_loss_dict[advanced_iters_key] = total_loss_dict.get(
            advanced_iters_key, 0) + 1
    else:
        if advanced_iters_key not in total_loss_dict:
            total_loss_dict[advanced_iters_key] = 0
    # Skipped iterations.
mohammad's avatar
mohammad committed
442
443
    total_loss_dict[skipped_iters_key] = total_loss_dict.get(
        skipped_iters_key, 0) + skipped_iter
mohammad's avatar
mohammad committed
444
    # Update losses and set nan iterations
mohammad's avatar
mohammad committed
445
    got_nan = False
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
446
    for key in loss_dict:
mohammad's avatar
mohammad committed
447
        if not skipped_iter:
448
449
            total_loss_dict[key] = total_loss_dict.get(
                key, torch.cuda.FloatTensor([0.0])) + loss_dict[key]
mohammad's avatar
mohammad committed
450
451
452
453
454
        else:
            value = loss_dict[key].float().sum().item()
            is_nan = value == float('inf') or \
                     value == -float('inf') or \
                     value != value
mohammad's avatar
mohammad committed
455
            got_nan = got_nan or is_nan
mohammad's avatar
mohammad committed
456
457
    total_loss_dict[nan_iters_key] = total_loss_dict.get(
        nan_iters_key, 0) + int(got_nan)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
458
459
460

    # Logging.
    timers_to_log = []
Neel Kant's avatar
Neel Kant committed
461

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
462
463
464
    def add_to_logging(name):
        if name in timers.timers:
            timers_to_log.append(name)
465
466
467
    add_to_logging('forward-compute')
    add_to_logging('forward-recv')
    add_to_logging('forward-send')
468
    add_to_logging('forward-backward-send-forward-backward-recv')
469
470
471
    add_to_logging('backward-compute')
    add_to_logging('backward-recv')
    add_to_logging('backward-send')
Deepak Narayanan's avatar
Deepak Narayanan committed
472
    add_to_logging('backward-send-forward-recv')
473
    add_to_logging('backward-send-backward-recv')
474
    add_to_logging('backward-params-all-reduce')
475
    add_to_logging('backward-embedding-all-reduce')
476
    add_to_logging('optimizer-copy-to-main-grad')
mohammad's avatar
mohammad committed
477
    add_to_logging('optimizer-unscale-and-check-inf')
478
479
    add_to_logging('optimizer-clip-main-grad')
    add_to_logging('optimizer-copy-main-to-model-params')
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
480
    add_to_logging('optimizer')
mohammad's avatar
mohammad committed
481
    add_to_logging('batch-generator')
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
482

mohammad's avatar
mohammad committed
483
    # Calculate batch size.
mshoeybi's avatar
mshoeybi committed
484
485
486
    batch_size = args.micro_batch_size * args.data_parallel_size * \
        get_num_microbatches()

mohammad's avatar
mohammad committed
487
488
489
    total_iterations = total_loss_dict[advanced_iters_key] + \
                       total_loss_dict[skipped_iters_key]

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
490
    # Tensorboard values.
491
492
493
494
495
496
497
498
499
500
    if writer and (iteration % args.tensorboard_log_interval == 0 ) and \
       is_last_rank():
        if args.log_learning_rate_to_tensorboard:
            writer.add_scalar('learning-rate', learning_rate, iteration)
            writer.add_scalar('learning-rate vs samples', learning_rate,
                              args.consumed_train_samples)
        if args.log_batch_size_to_tensorboard:
            writer.add_scalar('batch-size', batch_size, iteration)
            writer.add_scalar('batch-size vs samples', batch_size,
                              args.consumed_train_samples)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
501
        for key in loss_dict:
mohammad's avatar
mohammad committed
502
503
            writer.add_scalar(key , loss_dict[key], iteration)
            writer.add_scalar(key + ' vs samples', loss_dict[key],
504
                              args.consumed_train_samples)
505
506
507
508
        if args.log_loss_scale_to_tensorboard:
            writer.add_scalar('loss-scale', loss_scale, iteration)
            writer.add_scalar('loss-scale vs samples', loss_scale,
                              args.consumed_train_samples)
509
510
511
512
        if grad_norm is not None:
            writer.add_scalar('grad-norm', grad_norm, iteration)
            writer.add_scalar('grad-norm vs samples', grad_norm,
                              args.consumed_train_samples)
513
514
515
        if num_zeros_in_grad is not None:
            writer.add_scalar('num-zeros', num_zeros_in_grad, iteration)
            writer.add_scalar('num-zeros vs samples', num_zeros_in_grad,
Rewon Child's avatar
Rewon Child committed
516
                              args.consumed_train_samples)
mohammad's avatar
mohammad committed
517
518
519
520
        if params_norm is not None:
            writer.add_scalar('params-norm', params_norm, iteration)
            writer.add_scalar('params-norm vs samples', params_norm,
                              args.consumed_train_samples)
521
522
523
        if args.log_timers_to_tensorboard:
            timers.write(timers_to_log, writer, iteration,
                         normalizer=total_iterations)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
524
525

    if iteration % args.log_interval == 0:
526
        elapsed_time = timers('interval-time').elapsed()
mohammad's avatar
mohammad committed
527
        elapsed_time_per_iteration = elapsed_time / total_iterations
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
528
        if writer and torch.distributed.get_rank() == 0:
529
530
531
            if args.log_timers_to_tensorboard:
                writer.add_scalar('iteration-time',
                                  elapsed_time_per_iteration, iteration)
532
533
        log_string = ' iteration {:8d}/{:8d} |'.format(
            iteration, args.train_iters)
mshoeybi's avatar
mshoeybi committed
534
        log_string += ' consumed samples: {:12d} |'.format(
535
            args.consumed_train_samples)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
536
        log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
mohammad's avatar
mohammad committed
537
            elapsed_time_per_iteration * 1000.0)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
538
        log_string += ' learning rate: {:.3E} |'.format(learning_rate)
mohammad's avatar
mohammad committed
539
        log_string += ' global batch size: {:5d} |'.format(batch_size)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
540
        for key in total_loss_dict:
mohammad's avatar
mohammad committed
541
542
543
544
            if key not in [advanced_iters_key, skipped_iters_key,
                           nan_iters_key]:
                avg = total_loss_dict[key].item() / \
                      float(max(1, total_loss_dict[advanced_iters_key]))
545
546
547
                if avg > 0.0:
                    log_string += ' {}: {:.6E} |'.format(key, avg)
                total_loss_dict[key] = torch.cuda.FloatTensor([0.0])
548
        log_string += ' loss scale: {:.1f} |'.format(loss_scale)
549
550
        if grad_norm is not None:
            log_string += ' grad norm: {:.3f} |'.format(grad_norm)
551
552
        if num_zeros_in_grad is not None:
            log_string += ' num zeros: {:.1f} |'.format(num_zeros_in_grad)
mohammad's avatar
mohammad committed
553
554
        if params_norm is not None:
            log_string += ' params norm: {:.3f} |'.format(params_norm)
mohammad's avatar
mohammad committed
555
556
        log_string += ' number of skipped iterations: {:3d} |'.format(
            total_loss_dict[skipped_iters_key])
mohammad's avatar
mohammad committed
557
        log_string += ' number of nan iterations: {:3d} |'.format(
mohammad's avatar
mohammad committed
558
559
            total_loss_dict[nan_iters_key])
        total_loss_dict[advanced_iters_key] = 0
mohammad's avatar
mohammad committed
560
        total_loss_dict[skipped_iters_key] = 0
mohammad's avatar
mohammad committed
561
        total_loss_dict[nan_iters_key] = 0
562
        print_rank_last(log_string)
563
564
565
        if report_memory_flag and learning_rate > 0.:
            # Report memory after optimizer state has been initialized.
            report_memory('(after {} iterations)'.format(iteration))
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
566
567
568
569
570
571
            report_memory_flag = False
        timers.log(timers_to_log, normalizer=args.log_interval)

    return report_memory_flag


572
573
574
575
576
def save_checkpoint_and_time(iteration, model, optimizer, lr_scheduler):
    timers = get_timers()
    # Extra barrier is added to make sure
    # all ranks report the max time.
    torch.distributed.barrier()
577
    timers('save-checkpoint').start()
578
579
    save_checkpoint(iteration, model, optimizer, lr_scheduler)
    torch.distributed.barrier()
580
581
    timers('save-checkpoint').stop()
    timers.log(['save-checkpoint'])
582
583


584
def train(forward_step_func, model, optimizer, lr_scheduler,
585
          train_data_iterator, valid_data_iterator):
586
    """Train the model function."""
Mohammad's avatar
Mohammad committed
587
588
    args = get_args()
    timers = get_timers()
589

590
591
592
    # Write args to tensorboard
    write_args_to_tensorboard()

593
    # Turn on training mode which enables dropout.
594
595
    for model_module in model:
        model_module.train()
596
597
598
599
600
601
602

    # Tracking loss.
    total_loss_dict = {}

    # Iterations.
    iteration = args.iteration

603
    timers('interval-time').start()
604
    print_datetime('before the start of training step')
605
606
    report_memory_flag = True
    while iteration < args.train_iters:
mohammad's avatar
mohammad committed
607
        update_num_microbatches(args.consumed_train_samples)
608
609
610
611
612
613
        loss_dict, skipped_iter, grad_norm, num_zeros_in_grad = \
            train_step(forward_step_func,
                       train_data_iterator,
                       model,
                       optimizer,
                       lr_scheduler)
614
        iteration += 1
615
        args.consumed_train_samples += mpu.get_data_parallel_world_size() * \
616
                                       args.micro_batch_size * \
mohammad's avatar
mohammad committed
617
                                       get_num_microbatches()
618
619

        # Logging.
620
        loss_scale = optimizer.get_loss_scale().item()
621
622
623
        params_norm = None
        if args.log_params_norm:
            params_norm = calc_params_l2_norm(model)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
624
625
        report_memory_flag = training_log(loss_dict, total_loss_dict,
                                          optimizer.param_groups[0]['lr'],
Mohammad's avatar
Mohammad committed
626
                                          iteration, loss_scale,
627
                                          report_memory_flag, skipped_iter,
628
                                          grad_norm, params_norm, num_zeros_in_grad)
629
630

        # Autoresume
631
632
        if args.adlr_autoresume and \
           (iteration % args.adlr_autoresume_interval == 0):
633
            check_adlr_autoresume_termination(iteration, model, optimizer,
634
                                              lr_scheduler)
635
636
637
638
639
640

        # Evaluation
        if args.eval_interval and iteration % args.eval_interval == 0 and \
           args.do_valid:
            prefix = 'iteration {}'.format(iteration)
            evaluate_and_print_results(prefix, forward_step_func,
641
                                       valid_data_iterator, model,
Mohammad's avatar
Mohammad committed
642
                                       iteration, False)
643

644
645
646
647
648
649
650
651
        # Checkpointing
        saved_checkpoint = False
        if args.save and args.save_interval and \
           iteration % args.save_interval == 0:
            save_checkpoint_and_time(iteration, model, optimizer,
                                     lr_scheduler)
            saved_checkpoint = True

652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
        # Exiting based on duration
        if args.exit_duration_in_mins:
            train_time = (time.time() - _TRAIN_START_TIME) / 60.0
            done_cuda = torch.cuda.IntTensor(
                [train_time > args.exit_duration_in_mins])
            torch.distributed.all_reduce(
                done_cuda, op=torch.distributed.ReduceOp.MAX)
            done = done_cuda.item()
            if done:
                if not saved_checkpoint:
                    save_checkpoint_and_time(iteration, model, optimizer,
                                             lr_scheduler)
                print_datetime('exiting program after {} minutes'.format(train_time))                
                sys.exit()

        # Exiting based on iterations        
668
        if args.exit_interval and iteration % args.exit_interval == 0:
669
670
671
            if not saved_checkpoint:
                save_checkpoint_and_time(iteration, model, optimizer,
                                         lr_scheduler)
672
            torch.distributed.barrier()
673
            print_datetime('exiting program at iteration {}'.format(iteration))                
Mohammad's avatar
Mohammad committed
674
            sys.exit()
675

676

mohammad's avatar
mohammad committed
677
    return iteration
678
679


Mohammad's avatar
Mohammad committed
680
def evaluate(forward_step_func, data_iterator, model, verbose=False):
681
    """Evaluation."""
Mohammad's avatar
Mohammad committed
682
    args = get_args()
683
684

    # Turn on evaluation mode which disables dropout.
685
686
    for model_module in model:
        model_module.eval()
687
688
689
690
691
692
693
694
695
696

    total_loss_dict = {}

    with torch.no_grad():
        iteration = 0
        while iteration < args.eval_iters:
            iteration += 1
            if verbose and iteration % args.log_interval == 0:
                print_rank_0('Evaluating iter {}/{}'.format(iteration,
                                                            args.eval_iters))
697

698
699
700
            if mpu.get_pipeline_model_parallel_world_size() > 1:
                if args.virtual_pipeline_model_parallel_size is not None:
                    forward_backward_func = forward_backward_pipelining_with_interleaving
701
                else:
702
                    forward_backward_func = forward_backward_pipelining_without_interleaving
703
704
705
706
707
708
709
710
711
            else:
                forward_backward_func = forward_backward_no_pipelining
            loss_dicts = forward_backward_func(
                forward_step_func, data_iterator, model, optimizer=None,
                timers=None, forward_only=True)

            if mpu.is_pipeline_last_stage(ignore_virtual=True):
                # Reduce across processes.
                for loss_dict in loss_dicts:
712
                    for key in loss_dict:
713
714
                        total_loss_dict[key] = total_loss_dict.get(
                            key, torch.cuda.FloatTensor([0.0])) + loss_dict[key]
715

716
            args.consumed_valid_samples += mpu.get_data_parallel_world_size() \
717
                                           * args.micro_batch_size \
mohammad's avatar
mohammad committed
718
                                           * get_num_microbatches()
719
    # Move model back to the train mode.
720
721
    for model_module in model:
        model_module.train()
722
723

    for key in total_loss_dict:
mohammad's avatar
mohammad committed
724
        total_loss_dict[key] /= args.eval_iters * get_num_microbatches()
725
726
727
728
729

    return total_loss_dict

def evaluate_and_print_results(prefix, forward_step_func,
                               data_iterator, model,
Mohammad's avatar
Mohammad committed
730
                               iteration, verbose=False):
731
    """Helper function to evaluate and dump results on screen."""
732
    args = get_args()
Mohammad's avatar
Mohammad committed
733
734
735
    writer = get_tensorboard_writer()

    total_loss_dict = evaluate(forward_step_func, data_iterator, model, verbose)
736
737
738
739
740
    string = ' validation loss at {} | '.format(prefix)
    for key in total_loss_dict:
        string += '{} value: {:.6E} | '.format(key, total_loss_dict[key].item())
        ppl = math.exp(min(20, total_loss_dict[key].item()))
        string += '{} PPL: {:.6E} | '.format(key, ppl)
741
        if writer and is_last_rank():
mohammad's avatar
mohammad committed
742
            writer.add_scalar('{} validation'.format(key),
743
744
                              total_loss_dict[key].item(),
                              iteration)
mohammad's avatar
mohammad committed
745
            writer.add_scalar('{} validation vs samples'.format(key),
746
747
                              total_loss_dict[key].item(),
                              args.consumed_train_samples)
748
            if args.log_validation_ppl_to_tensorboard:
mohammad's avatar
mohammad committed
749
                writer.add_scalar('{} validation ppl'.format(key), ppl,
750
                                  iteration)
mohammad's avatar
mohammad committed
751
                writer.add_scalar('{} validation ppl vs samples'.format(key),
752
                                  ppl, args.consumed_train_samples)
753
754

    length = len(string) + 1
755
756
757
    print_rank_last('-' * length)
    print_rank_last(string)
    print_rank_last('-' * length)
758
759


Vijay Korthikanti's avatar
Vijay Korthikanti committed
760
def cyclic_iter(iter):
761
    while True:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
762
        for x in iter:
763
764
            yield x

765
766
767
def build_train_valid_test_data_iterators(
        build_train_valid_test_datasets_provider):
    """XXX"""
Mohammad's avatar
Mohammad committed
768
    args = get_args()
769

770
771
772
    (train_dataloader, valid_dataloader, test_dataloader) = (None, None, None)

    print_rank_0('> building train, validation, and test datasets ...')
773
774
775

    # Backward compatibility, assume fixed batch size.
    if args.iteration > 0 and args.consumed_train_samples == 0:
776
777
        assert args.train_samples is None, \
            'only backward compatiblity support for iteration-based training'
mohammad's avatar
mohammad committed
778
        args.consumed_train_samples = args.iteration * args.global_batch_size
779
    if args.iteration > 0 and args.consumed_valid_samples == 0:
780
781
        assert args.train_samples is None, \
            'only backward compatiblity support for iteration-based training'
782
        args.consumed_valid_samples = (args.iteration // args.eval_interval) * \
mohammad's avatar
mohammad committed
783
            args.eval_iters * args.global_batch_size
784

785
    # Data loader only on rank 0 of each model parallel group.
786
    if mpu.get_tensor_model_parallel_rank() == 0:
787
788

        # Number of train/valid/test samples.
789
790
791
792
793
794
        if args.train_samples:
            train_samples = args.train_samples
        else:
            train_samples = args.train_iters * args.global_batch_size
        eval_iters = (args.train_iters // args.eval_interval + 1) * \
                     args.eval_iters
795
        test_iters = args.eval_iters
796
        train_val_test_num_samples = [train_samples,
mohammad's avatar
mohammad committed
797
798
                                      eval_iters * args.global_batch_size,
                                      test_iters * args.global_batch_size]
799
800
801
802
803
804
805
806
807
808
        print_rank_0(' > datasets target sizes (minimum size):')
        print_rank_0('    train:      {}'.format(train_val_test_num_samples[0]))
        print_rank_0('    validation: {}'.format(train_val_test_num_samples[1]))
        print_rank_0('    test:       {}'.format(train_val_test_num_samples[2]))

        # Build the datasets.
        train_ds, valid_ds, test_ds = build_train_valid_test_datasets_provider(
            train_val_test_num_samples)

        # Build dataloders.
809
810
811
812
813
        train_dataloader = build_pretraining_data_loader(
            train_ds, args.consumed_train_samples)
        valid_dataloader = build_pretraining_data_loader(
            valid_ds, args.consumed_valid_samples)
        test_dataloader = build_pretraining_data_loader(test_ds, 0)
814
815
816
817
818
819
820
821
822
823
824
825
826

        # Flags to know if we need to do training/validation/testing.
        do_train = train_dataloader is not None and args.train_iters > 0
        do_valid = valid_dataloader is not None and args.eval_iters > 0
        do_test = test_dataloader is not None and args.eval_iters > 0
        # Need to broadcast num_tokens and num_type_tokens.
        flags = torch.cuda.LongTensor(
            [int(do_train), int(do_valid), int(do_test)])
    else:
        flags = torch.cuda.LongTensor([0, 0, 0])

    # Broadcast num tokens.
    torch.distributed.broadcast(flags,
827
828
                                mpu.get_tensor_model_parallel_src_rank(),
                                group=mpu.get_tensor_model_parallel_group())
829
830
831
832
    args.do_train = flags[0].item()
    args.do_valid = flags[1].item()
    args.do_test = flags[2].item()

Vijay Korthikanti's avatar
Vijay Korthikanti committed
833

834
    # Build iterators.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
835
836
837
    dl_type = args.dataloader_type
    assert dl_type in ['single', 'cyclic']

838
    if train_dataloader is not None:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
839
840
        train_data_iterator = iter(train_dataloader) if dl_type == 'single' \
                              else iter(cyclic_iter(train_dataloader))
841
842
843
    else:
        train_data_iterator = None

844
    if valid_dataloader is not None:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
845
846
        valid_data_iterator = iter(valid_dataloader) if dl_type == 'single' \
                              else iter(cyclic_iter(valid_dataloader))
847
    else:
848
        valid_data_iterator = None
849

850
    if test_dataloader is not None:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
851
852
        test_data_iterator = iter(test_dataloader) if dl_type == 'single' \
                             else iter(cyclic_iter(test_dataloader))
853
854
855
    else:
        test_data_iterator = None

856
    return train_data_iterator, valid_data_iterator, test_data_iterator