training.py 34.4 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Mohammad's avatar
Mohammad committed
16
"""Pretrain utilities."""
17
18
19

from datetime import datetime
import math
Mohammad's avatar
Mohammad committed
20
import sys
21
22
23
24
import time
# The earliest we can measure the start time.
_TRAIN_START_TIME = time.time()

25
26
27
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP

Neel Kant's avatar
Neel Kant committed
28
from megatron import get_args
Mohammad's avatar
Mohammad committed
29
30
from megatron import get_timers
from megatron import get_tensorboard_writer
31
from megatron import get_current_global_batch_size
mohammad's avatar
mohammad committed
32
from megatron import get_num_microbatches
mohammad's avatar
mohammad committed
33
from megatron import is_last_rank
mohammad's avatar
mohammad committed
34
from megatron import update_num_microbatches
35
from megatron import mpu
Neel Kant's avatar
Neel Kant committed
36
from megatron import print_rank_0
37
from megatron import print_rank_last
Mohammad's avatar
Mohammad committed
38
39
from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint
40
from megatron.model import FP16Module
mohammad's avatar
mohammad committed
41
from megatron.optimizer import get_megatron_optimizer
mohammad's avatar
mohammad committed
42

Mohammad's avatar
Mohammad committed
43
from megatron.initialize import initialize_megatron
44
from megatron.initialize import write_args_to_tensorboard
45
46
from megatron.learning_rates import AnnealingLR
from megatron.model import DistributedDataParallel as LocalDDP
Neel Kant's avatar
Neel Kant committed
47
from megatron.model.realm_model import ICTBertModel
48
from megatron.utils import check_adlr_autoresume_termination
49
from megatron.utils import unwrap_model
Vijay Korthikanti's avatar
Vijay Korthikanti committed
50
from megatron.data.data_samplers import build_pretraining_data_loader
mohammad's avatar
mohammad committed
51
from megatron.utils import calc_params_l2_norm
52
from megatron.schedules import forward_backward_no_pipelining
53
from megatron.schedules import forward_backward_pipelining_without_interleaving
54
from megatron.schedules import forward_backward_pipelining_with_interleaving
Mostofa Patwary's avatar
Mostofa Patwary committed
55
from megatron.utils import report_memory
56
57


58
59
60
61
62
63
64
def print_datetime(string):
    """Note that this call will sync across all ranks."""
    torch.distributed.barrier()
    time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    print_rank_0('[' + string + '] datetime: {} '.format(time_str))


65
66
67
68
def pretrain(train_valid_test_dataset_provider, 
             model_provider,
             forward_step_func, 
             extra_args_provider=None, 
Vijay Korthikanti's avatar
Vijay Korthikanti committed
69
             args_defaults={}):
70
71
72
    """Main training program.

    This function will run the followings in the order provided:
Mohammad's avatar
Mohammad committed
73
74
        1) initialize Megatron.
        2) setup model, optimizer and lr schedule using the model_provider.
75
        3) call train_val_test_data_provider to get train/val/test datasets.
Mohammad's avatar
Mohammad committed
76
        4) train the modle using the forward_step_func.
77
78

    Arguments:
79
80
81
        train_valid_test_dataset_provider: a function that takes the size of
            train/valid/test dataset and returns `train, valid, test` datasets.
        model_provider: a function that returns a vanilla version of the
Mohammad's avatar
Mohammad committed
82
83
84
85
86
87
88
89
90
91
            model. By vanilla we mean a simple model on cpu with no fp16 or ddp.
        forward_step_func: a function that takes a `data iterator` and `model`,
            and returns a `loss` scalar with a dictionary with key:values being
            the info we would like to monitor during training, for example
            `lm-loss: value`. We also require that this function add
            `batch generator` to the timers class.
        extra_args_provider: a function that takes a parser and adds arguments
            to it. It is used for programs to add their own arguments.
        args_defaults: a dictionary from argument-name to argument-value. It
            to set already parse arguments.
92
93
    """

94
    # Initalize and get arguments, timers, and Tensorboard writer.
95
96
    initialize_megatron(extra_args_provider=extra_args_provider,
                        args_defaults=args_defaults)
97

98
99
100
101
102
103
104
105
    # Adjust the startup time so it reflects the largest value.
    # This will be closer to what scheduler will see (outside of
    # image ... launches.
    global _TRAIN_START_TIME
    start_time_tensor = torch.cuda.FloatTensor([_TRAIN_START_TIME])
    torch.distributed.all_reduce(start_time_tensor,
                                 op=torch.distributed.ReduceOp.MIN)
    _TRAIN_START_TIME = start_time_tensor.item()
mshoeybi's avatar
mshoeybi committed
106
    print_rank_0('time to initialize megatron (seconds): {:.3f}'.format(
107
108
109
        time.time() - _TRAIN_START_TIME))
    print_datetime('after megatron is initialized')

110
    args = get_args()
Mohammad's avatar
Mohammad committed
111
    timers = get_timers()
112
113

    # Model, optimizer, and learning rate.
114
    timers('model-and-optimizer-setup').start()
Mohammad's avatar
Mohammad committed
115
    model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider)
116
    timers('model-and-optimizer-setup').stop()
117
118
    print_datetime('after model, optimizer, and learning rate '
                   'scheduler are built')
119
120

    # Data stuff.
121
122
    timers('train/valid/test-data-iterators-setup').start()
    if args.virtual_pipeline_model_parallel_size is not None:
123
        all_data_iterators = [
124
125
126
            build_train_valid_test_data_iterators(train_valid_test_dataset_provider)
            for _ in range(len(model))
        ]
127
128
129
        train_data_iterator = [data_iterators[0] for data_iterators in all_data_iterators]
        valid_data_iterator = [data_iterators[1] for data_iterators in all_data_iterators]
        test_data_iterator = [data_iterators[2] for data_iterators in all_data_iterators]
130
131
132
133
134
    else:
        train_data_iterator, valid_data_iterator, test_data_iterator \
            = build_train_valid_test_data_iterators(
                train_valid_test_dataset_provider)
    timers('train/valid/test-data-iterators-setup').stop()
mshoeybi's avatar
mshoeybi committed
135
    print_datetime('after dataloaders are built')
Mohammad's avatar
Mohammad committed
136
137

    # Print setup timing.
138
139
    print_rank_0('done with setup ...')
    timers.log(['model-and-optimizer-setup', 'train/valid/test-data-iterators-setup'])
Mohammad's avatar
Mohammad committed
140
    print_rank_0('training ...')
141
142

    iteration = 0
143
    if args.do_train and args.train_iters > 0:
mohammad's avatar
mohammad committed
144
145
146
        iteration = train(forward_step_func,
                          model, optimizer, lr_scheduler,
                          train_data_iterator, valid_data_iterator)
147
    print_datetime('after training is done')
Mohammad's avatar
Mohammad committed
148

149
150
151
    if args.do_valid:
        prefix = 'the end of training for val data'
        evaluate_and_print_results(prefix, forward_step_func,
152
                                   valid_data_iterator, model,
Mohammad's avatar
Mohammad committed
153
                                   iteration, False)
154
155

    if args.save and iteration != 0:
156
        save_checkpoint(iteration, model, optimizer, lr_scheduler)
157
158
159
160
161
162

    if args.do_test:
        # Run on test data.
        prefix = 'the end of training for test data'
        evaluate_and_print_results(prefix, forward_step_func,
                                   test_data_iterator, model,
Mohammad's avatar
Mohammad committed
163
                                   0, True)
164

165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
def update_train_iters(args):

    # For iteration-based training, we don't need to do anything
    if args.train_iters:
        return

    # Constant batch size with sample-based training.
    if args.rampup_batch_size is None:
        args.train_iters = args.train_samples // args.global_batch_size

    else:
        # Sample based training with rampup batch size.
        iterations = 0
        consumed_samples = 0
        # Rampup phase.
        while consumed_samples <= int(args.rampup_batch_size[2]):
181
182
            update_num_microbatches(consumed_samples, consistency_check=False)
            consumed_samples += get_current_global_batch_size()
183
184
            iterations += 1
        # Reset
185
        update_num_microbatches(0, consistency_check=False)
186
187
188
189
190
191
192
193
        # Constant phase
        # Note that we throw away any partial last batch.
        iterations += (args.train_samples - consumed_samples) // \
                      args.global_batch_size
        args.train_iters = iterations

    print_rank_0('setting training iterations to {}'.format(args.train_iters))

194

Mohammad's avatar
Mohammad committed
195
def get_model(model_provider_func):
196
    """Build the model."""
Mohammad's avatar
Mohammad committed
197
    args = get_args()
198
199

    # Build model on cpu.
Mohammad's avatar
Mohammad committed
200
    model = model_provider_func()
201
202
    if not isinstance(model, list):
        model = [model]
203

204
    # Set tensor model parallel attributes if not set.
mohammad's avatar
mohammad committed
205
206
207
    # Only parameters that are already tensor model parallel have these
    # attributes set for them. We should make sure the default attributes
    # are set for all params so the optimizer can use them.
208
209
210
    for model_module in model:
        for param in model_module.parameters():
            mpu.set_defaults_if_not_set_tensor_model_parallel_attributes(param)
211

212
213
    # Print number of parameters.
    if mpu.get_data_parallel_rank() == 0:
214
        print(' > number of parameters on (tensor, pipeline) '
215
              'model parallel rank ({}, {}): {}'.format(
216
217
            mpu.get_tensor_model_parallel_rank(),
            mpu.get_pipeline_model_parallel_rank(),
218
219
            sum([sum([p.nelement() for p in model_module.parameters()])
                 for model_module in model])), flush=True)
220
221

    # GPU allocation.
222
223
    for model_module in model:
        model_module.cuda(torch.cuda.current_device())
224
225
226

    # Fp16 conversion.
    if args.fp16:
227
        model = [FP16Module(model_module) for model_module in model]
228
229
230

    if args.DDP_impl == 'torch':
        i = torch.cuda.current_device()
231
232
233
        model = [torchDDP(model_module, device_ids=[i], output_device=i,
                          process_group=mpu.get_data_parallel_group())
                 for model_module in model]
234
235
        return model
    if args.DDP_impl == 'local':
236
        model = [LocalDDP(model_module) for model_module in model]
237
238
        return model

239
    raise NotImplementedError('Unknown DDP implementation specified: {}. '
240
                              'Exiting.'.format(args.DDP_impl))
241
242


Mohammad's avatar
Mohammad committed
243
def get_learning_rate_scheduler(optimizer):
244
    """Build the learning rate scheduler."""
Mohammad's avatar
Mohammad committed
245
    args = get_args()
246

247
248
249
250
251
    # Iteration-based training.
    if args.train_iters:
        if args.lr_decay_iters is None:
            args.lr_decay_iters = args.train_iters
        decay_steps = args.lr_decay_iters * args.global_batch_size
252
253
        if args.lr_warmup_fraction is not None:
            warmup_steps = args.lr_warmup_fraction * decay_steps
254
255
        else:
            warmup_steps = args.lr_warmup_iters * args.global_batch_size
256
257
258
259
260
    # Sample-based training.
    elif args.train_samples:
        # We need to set training iters for later use. Technically
        # we need to adjust the training samples too (due to last
        # batch being incomplete) but we leave it as is for now.
261
        update_train_iters(args)
262
263
264
        if args.lr_decay_samples is None:
            args.lr_decay_samples = args.train_samples
        decay_steps = args.lr_decay_samples
265
266
        if args.lr_warmup_fraction is not None:
            warmup_steps = args.lr_warmup_fraction * decay_steps
267
268
        else:
            warmup_steps = args.lr_warmup_samples
269
    else:
270
271
272
        raise Exception(
            'either train-iters or train-samples should be provided.')

273
274
    lr_scheduler = AnnealingLR(
        optimizer,
275
        max_lr=args.lr,
276
        min_lr=args.min_lr,
277
278
        warmup_steps=warmup_steps,
        decay_steps=decay_steps,
279
        decay_style=args.lr_decay_style,
280
281
282
283
284
285
        use_checkpoint_lr_scheduler=args.use_checkpoint_lr_scheduler,
        override_lr_scheduler=args.override_lr_scheduler)

    return lr_scheduler


Mohammad's avatar
Mohammad committed
286
def setup_model_and_optimizer(model_provider_func):
287
    """Setup model and optimizer."""
Mohammad's avatar
Mohammad committed
288
    args = get_args()
289

Mohammad's avatar
Mohammad committed
290
    model = get_model(model_provider_func)
291

292
293
    unwrapped_model = unwrap_model(model,
                                   (torchDDP, LocalDDP, FP16Module))
294
295
    optimizer = get_megatron_optimizer(unwrapped_model)

Mohammad's avatar
Mohammad committed
296
    lr_scheduler = get_learning_rate_scheduler(optimizer)
297
298

    if args.load is not None:
299
300
301
302
        timers = get_timers()
        # Extra barrier is added to make sure all ranks report the
        # max time.
        torch.distributed.barrier()
303
        timers('load-checkpoint').start()
304
        args.iteration = load_checkpoint(model, optimizer, lr_scheduler)
305
        torch.distributed.barrier()
306
307
        timers('load-checkpoint').stop()
        timers.log(['load-checkpoint'])
308
309
310
    else:
        args.iteration = 0

mohammad's avatar
mohammad committed
311
    # We only support local DDP with multiple micro-batches.
Mostofa Patwary's avatar
Mostofa Patwary committed
312
    if len(model) > 1:
313
314
315
        assert args.DDP_impl == 'local'
    if mpu.get_pipeline_model_parallel_world_size() > 1:
        assert args.DDP_impl == 'local'
mohammad's avatar
mohammad committed
316

Neel Kant's avatar
Neel Kant committed
317
    # get model without FP16 and/or TorchDDP wrappers
Mostofa Patwary's avatar
Mostofa Patwary committed
318
319
    if args.iteration == 0 and len(unwrapped_model) == 1 \
        and hasattr(unwrapped_model[0], 'init_state_dict_from_bert'):
Mostofa Patwary's avatar
Mostofa Patwary committed
320
        print_rank_0("Initializing ICT from pretrained BERT model")
Mostofa Patwary's avatar
Mostofa Patwary committed
321
        unwrapped_model[0].init_state_dict_from_bert()
Mostofa Patwary's avatar
Mostofa Patwary committed
322
323
        if args.fp16:
            optimizer.reload_model_params()
Neel Kant's avatar
Neel Kant committed
324

325
326
327
    return model, optimizer, lr_scheduler


328
329
330
331
332
333
334
def train_step(forward_step_func, data_iterator,
               model, optimizer, lr_scheduler):
    """Single training step."""
    args = get_args()
    timers = get_timers()

    # Set grad to zero.
mohammad's avatar
mohammad committed
335
    optimizer.zero_grad()
336
337

    if mpu.get_pipeline_model_parallel_world_size() > 1:
338
339
        if args.virtual_pipeline_model_parallel_size is not None:
            forward_backward_func = forward_backward_pipelining_with_interleaving
340
341
342
            assert get_num_microbatches() % args.pipeline_model_parallel_size == 0, \
                'number of microbatches is not divisible by pipeline-parallel ' \
                'size when using interleaved schedule'
343
        else:
344
            forward_backward_func = forward_backward_pipelining_without_interleaving
345
    else:
346
347
348
349
        forward_backward_func = forward_backward_no_pipelining
    losses_reduced = forward_backward_func(
        forward_step_func, data_iterator, model,
        optimizer, timers, forward_only=False)
350
351
352

    # All-reduce if needed.
    if args.DDP_impl == 'local':
353
        timers('backward-params-all-reduce').start()
354
355
356
        for model_module in model:
            model_module.allreduce_params(reduce_after=False,
                                          fp32_allreduce=args.fp32_allreduce)
357
        timers('backward-params-all-reduce').stop()
358

359
360
361
362
363
    # Barrier to measure backward stall.
    timers('backward-pipeline-stall').start()
    torch.distributed.barrier(group=mpu.get_pipeline_model_parallel_group())
    timers('backward-pipeline-stall').stop()

364
365
366
367
    # All-reduce word_embeddings' grad across first and last stages to ensure
    # that word_embeddings parameters stay in sync.
    # This should only run for models that support pipelined model parallelism
    # (BERT and GPT-2).
368
    timers('backward-embedding-all-reduce').start()
369
370
    if (mpu.is_pipeline_first_stage(ignore_virtual=True) or
        mpu.is_pipeline_last_stage(ignore_virtual=True)) and \
371
            mpu.get_pipeline_model_parallel_world_size() > 1:
372
373
374
375
        if mpu.is_pipeline_first_stage(ignore_virtual=True):
            unwrapped_model = model[0]
        elif mpu.is_pipeline_last_stage(ignore_virtual=True):
            unwrapped_model = model[-1]
376
377
        unwrapped_model = unwrap_model(
            unwrapped_model, (torchDDP, LocalDDP, FP16Module))
378

379
380
381
382
        if unwrapped_model.share_word_embeddings:
            word_embeddings_weight = unwrapped_model.word_embeddings_weight()
            torch.distributed.all_reduce(word_embeddings_weight.grad,
                                         group=mpu.get_embedding_group())
383
    timers('backward-embedding-all-reduce').stop()
384

385
386
    # Update parameters.
    timers('optimizer').start()
387
    update_successful, grad_norm = optimizer.step()
388
389
390
    timers('optimizer').stop()

    # Update learning rate.
391
    if update_successful:
392
393
394
395
        increment = get_num_microbatches() * \
                    args.micro_batch_size * \
                    args.data_parallel_size
        lr_scheduler.step(increment=increment)
mohammad's avatar
mohammad committed
396
        skipped_iter = 0
397
398
399
    else:
        skipped_iter = 1

400
    if mpu.is_pipeline_last_stage(ignore_virtual=True):
401
402
403
404
        # Average loss across microbatches.
        loss_reduced = {}
        for key in losses_reduced[0]:
            losses_reduced_for_key = [x[key] for x in losses_reduced]
405
            loss_reduced[key] = sum(losses_reduced_for_key) / len(losses_reduced_for_key)
406
407
        return loss_reduced, skipped_iter, grad_norm
    return {}, skipped_iter, grad_norm
408
409


Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
410
def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
mohammad's avatar
mohammad committed
411
412
                 loss_scale, report_memory_flag, skipped_iter,
                 grad_norm, params_norm):
Mohammad's avatar
Mohammad committed
413
414
415
416
    """Log training information such as losses, timing, ...."""
    args = get_args()
    timers = get_timers()
    writer = get_tensorboard_writer()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
417

mohammad's avatar
mohammad committed
418
419
    # Advanced, skipped, and Nan iterations.
    advanced_iters_key = 'advanced iterations'
mohammad's avatar
mohammad committed
420
    skipped_iters_key = 'skipped iterations'
mohammad's avatar
mohammad committed
421
422
423
424
425
426
427
428
429
    nan_iters_key = 'nan iterations'
    # Advanced iterations.
    if not skipped_iter:
        total_loss_dict[advanced_iters_key] = total_loss_dict.get(
            advanced_iters_key, 0) + 1
    else:
        if advanced_iters_key not in total_loss_dict:
            total_loss_dict[advanced_iters_key] = 0
    # Skipped iterations.
mohammad's avatar
mohammad committed
430
431
    total_loss_dict[skipped_iters_key] = total_loss_dict.get(
        skipped_iters_key, 0) + skipped_iter
mohammad's avatar
mohammad committed
432
    # Update losses and set nan iterations
mohammad's avatar
mohammad committed
433
    got_nan = False
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
434
    for key in loss_dict:
mohammad's avatar
mohammad committed
435
        if not skipped_iter:
436
437
            total_loss_dict[key] = total_loss_dict.get(
                key, torch.cuda.FloatTensor([0.0])) + loss_dict[key]
mohammad's avatar
mohammad committed
438
439
440
441
442
        else:
            value = loss_dict[key].float().sum().item()
            is_nan = value == float('inf') or \
                     value == -float('inf') or \
                     value != value
mohammad's avatar
mohammad committed
443
            got_nan = got_nan or is_nan
mohammad's avatar
mohammad committed
444
445
    total_loss_dict[nan_iters_key] = total_loss_dict.get(
        nan_iters_key, 0) + int(got_nan)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
446
447
448

    # Logging.
    timers_to_log = []
Neel Kant's avatar
Neel Kant committed
449

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
450
451
452
    def add_to_logging(name):
        if name in timers.timers:
            timers_to_log.append(name)
453
    add_to_logging('forward-compute')
454
    add_to_logging('forward-pipeline-stall')
455
456
    add_to_logging('forward-recv')
    add_to_logging('forward-send')
457
    add_to_logging('forward-backward-send-forward-backward-recv')
458
    add_to_logging('backward-compute')
459
    add_to_logging('backward-pipeline-stall')
460
461
    add_to_logging('backward-recv')
    add_to_logging('backward-send')
Deepak Narayanan's avatar
Deepak Narayanan committed
462
    add_to_logging('backward-send-forward-recv')
463
    add_to_logging('backward-send-backward-recv')
464
    add_to_logging('backward-params-all-reduce')
465
    add_to_logging('backward-embedding-all-reduce')
466
    add_to_logging('optimizer-copy-to-main-grad')
mohammad's avatar
mohammad committed
467
    add_to_logging('optimizer-unscale-and-check-inf')
468
469
    add_to_logging('optimizer-clip-main-grad')
    add_to_logging('optimizer-copy-main-to-model-params')
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
470
    add_to_logging('optimizer')
mohammad's avatar
mohammad committed
471
    add_to_logging('batch-generator')
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
472

mohammad's avatar
mohammad committed
473
    # Calculate batch size.
mshoeybi's avatar
mshoeybi committed
474
475
476
    batch_size = args.micro_batch_size * args.data_parallel_size * \
        get_num_microbatches()

mohammad's avatar
mohammad committed
477
478
479
    total_iterations = total_loss_dict[advanced_iters_key] + \
                       total_loss_dict[skipped_iters_key]

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
480
    # Tensorboard values.
481
482
483
484
485
486
487
488
489
490
    if writer and (iteration % args.tensorboard_log_interval == 0 ) and \
       is_last_rank():
        if args.log_learning_rate_to_tensorboard:
            writer.add_scalar('learning-rate', learning_rate, iteration)
            writer.add_scalar('learning-rate vs samples', learning_rate,
                              args.consumed_train_samples)
        if args.log_batch_size_to_tensorboard:
            writer.add_scalar('batch-size', batch_size, iteration)
            writer.add_scalar('batch-size vs samples', batch_size,
                              args.consumed_train_samples)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
491
        for key in loss_dict:
mohammad's avatar
mohammad committed
492
493
            writer.add_scalar(key , loss_dict[key], iteration)
            writer.add_scalar(key + ' vs samples', loss_dict[key],
494
                              args.consumed_train_samples)
495
        if args.log_loss_scale_to_tensorboard:
mohammad's avatar
mohammad committed
496
497
            writer.add_scalar('loss-scale', loss_scale, iteration)
            writer.add_scalar('loss-scale vs samples', loss_scale,
498
                              args.consumed_train_samples)
499
500
501
502
        if grad_norm is not None:
            writer.add_scalar('grad-norm', grad_norm, iteration)
            writer.add_scalar('grad-norm vs samples', grad_norm,
                              args.consumed_train_samples)
mohammad's avatar
mohammad committed
503
504
505
506
        if params_norm is not None:
            writer.add_scalar('params-norm', params_norm, iteration)
            writer.add_scalar('params-norm vs samples', params_norm,
                              args.consumed_train_samples)
507
508
509
        if args.log_timers_to_tensorboard:
            timers.write(timers_to_log, writer, iteration,
                         normalizer=total_iterations)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
510
511

    if iteration % args.log_interval == 0:
512
        elapsed_time = timers('interval-time').elapsed()
mohammad's avatar
mohammad committed
513
        elapsed_time_per_iteration = elapsed_time / total_iterations
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
514
        if writer and torch.distributed.get_rank() == 0:
515
516
517
            if args.log_timers_to_tensorboard:
                writer.add_scalar('iteration-time',
                                  elapsed_time_per_iteration, iteration)
518
519
        log_string = ' iteration {:8d}/{:8d} |'.format(
            iteration, args.train_iters)
mshoeybi's avatar
mshoeybi committed
520
        log_string += ' consumed samples: {:12d} |'.format(
521
            args.consumed_train_samples)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
522
        log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
mohammad's avatar
mohammad committed
523
            elapsed_time_per_iteration * 1000.0)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
524
        log_string += ' learning rate: {:.3E} |'.format(learning_rate)
mohammad's avatar
mohammad committed
525
        log_string += ' global batch size: {:5d} |'.format(batch_size)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
526
        for key in total_loss_dict:
mohammad's avatar
mohammad committed
527
528
529
530
            if key not in [advanced_iters_key, skipped_iters_key,
                           nan_iters_key]:
                avg = total_loss_dict[key].item() / \
                      float(max(1, total_loss_dict[advanced_iters_key]))
531
532
533
                if avg > 0.0:
                    log_string += ' {}: {:.6E} |'.format(key, avg)
                total_loss_dict[key] = torch.cuda.FloatTensor([0.0])
534
        log_string += ' loss scale: {:.1f} |'.format(loss_scale)
535
536
        if grad_norm is not None:
            log_string += ' grad norm: {:.3f} |'.format(grad_norm)
mohammad's avatar
mohammad committed
537
538
        if params_norm is not None:
            log_string += ' params norm: {:.3f} |'.format(params_norm)
mohammad's avatar
mohammad committed
539
540
        log_string += ' number of skipped iterations: {:3d} |'.format(
            total_loss_dict[skipped_iters_key])
mohammad's avatar
mohammad committed
541
        log_string += ' number of nan iterations: {:3d} |'.format(
mohammad's avatar
mohammad committed
542
543
            total_loss_dict[nan_iters_key])
        total_loss_dict[advanced_iters_key] = 0
mohammad's avatar
mohammad committed
544
        total_loss_dict[skipped_iters_key] = 0
mohammad's avatar
mohammad committed
545
        total_loss_dict[nan_iters_key] = 0
546
        print_rank_last(log_string)
547
548
549
        if report_memory_flag and learning_rate > 0.:
            # Report memory after optimizer state has been initialized.
            report_memory('(after {} iterations)'.format(iteration))
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
550
551
552
553
554
555
            report_memory_flag = False
        timers.log(timers_to_log, normalizer=args.log_interval)

    return report_memory_flag


556
557
558
559
560
def save_checkpoint_and_time(iteration, model, optimizer, lr_scheduler):
    timers = get_timers()
    # Extra barrier is added to make sure
    # all ranks report the max time.
    torch.distributed.barrier()
561
    timers('save-checkpoint').start()
562
563
    save_checkpoint(iteration, model, optimizer, lr_scheduler)
    torch.distributed.barrier()
564
565
    timers('save-checkpoint').stop()
    timers.log(['save-checkpoint'])
566
567


568
def train(forward_step_func, model, optimizer, lr_scheduler,
569
          train_data_iterator, valid_data_iterator):
570
    """Train the model function."""
Mohammad's avatar
Mohammad committed
571
572
    args = get_args()
    timers = get_timers()
573

574
575
576
    # Write args to tensorboard
    write_args_to_tensorboard()

577
    # Turn on training mode which enables dropout.
578
579
    for model_module in model:
        model_module.train()
580
581
582
583
584
585
586

    # Tracking loss.
    total_loss_dict = {}

    # Iterations.
    iteration = args.iteration

587
    timers('interval-time').start()
588
    print_datetime('before the start of training step')
589
590
    report_memory_flag = True
    while iteration < args.train_iters:
mohammad's avatar
mohammad committed
591
        update_num_microbatches(args.consumed_train_samples)
592
593
594
595
596
        loss_dict, skipped_iter, grad_norm = train_step(forward_step_func,
                                                        train_data_iterator,
                                                        model,
                                                        optimizer,
                                                        lr_scheduler)
597
        iteration += 1
598
        args.consumed_train_samples += mpu.get_data_parallel_world_size() * \
599
                                       args.micro_batch_size * \
mohammad's avatar
mohammad committed
600
                                       get_num_microbatches()
601
602

        # Logging.
603
        loss_scale = optimizer.get_loss_scale().item()
604
605
606
        params_norm = None
        if args.log_params_norm:
            params_norm = calc_params_l2_norm(model)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
607
608
        report_memory_flag = training_log(loss_dict, total_loss_dict,
                                          optimizer.param_groups[0]['lr'],
Mohammad's avatar
Mohammad committed
609
                                          iteration, loss_scale,
610
                                          report_memory_flag, skipped_iter,
mohammad's avatar
mohammad committed
611
                                          grad_norm, params_norm)
612
613

        # Autoresume
614
615
        if args.adlr_autoresume and \
           (iteration % args.adlr_autoresume_interval == 0):
616
            check_adlr_autoresume_termination(iteration, model, optimizer,
617
                                              lr_scheduler)
618
619
620
621
622
623

        # Evaluation
        if args.eval_interval and iteration % args.eval_interval == 0 and \
           args.do_valid:
            prefix = 'iteration {}'.format(iteration)
            evaluate_and_print_results(prefix, forward_step_func,
624
                                       valid_data_iterator, model,
Mohammad's avatar
Mohammad committed
625
                                       iteration, False)
626

627
628
629
630
631
632
633
634
        # Checkpointing
        saved_checkpoint = False
        if args.save and args.save_interval and \
           iteration % args.save_interval == 0:
            save_checkpoint_and_time(iteration, model, optimizer,
                                     lr_scheduler)
            saved_checkpoint = True

635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
        # Exiting based on duration
        if args.exit_duration_in_mins:
            train_time = (time.time() - _TRAIN_START_TIME) / 60.0
            done_cuda = torch.cuda.IntTensor(
                [train_time > args.exit_duration_in_mins])
            torch.distributed.all_reduce(
                done_cuda, op=torch.distributed.ReduceOp.MAX)
            done = done_cuda.item()
            if done:
                if not saved_checkpoint:
                    save_checkpoint_and_time(iteration, model, optimizer,
                                             lr_scheduler)
                print_datetime('exiting program after {} minutes'.format(train_time))                
                sys.exit()

        # Exiting based on iterations        
651
        if args.exit_interval and iteration % args.exit_interval == 0:
Mostofa Patwary's avatar
Mostofa Patwary committed
652
653
654
            if not saved_checkpoint:
                save_checkpoint_and_time(iteration, model, optimizer,
                                         lr_scheduler)
655
            torch.distributed.barrier()
656
            print_datetime('exiting program at iteration {}'.format(iteration))                
Mohammad's avatar
Mohammad committed
657
            sys.exit()
658

659

mohammad's avatar
mohammad committed
660
    return iteration
661
662


Mohammad's avatar
Mohammad committed
663
def evaluate(forward_step_func, data_iterator, model, verbose=False):
664
    """Evaluation."""
Mohammad's avatar
Mohammad committed
665
    args = get_args()
666
667

    # Turn on evaluation mode which disables dropout.
668
669
    for model_module in model:
        model_module.eval()
670
671
672
673
674
675
676
677
678
679

    total_loss_dict = {}

    with torch.no_grad():
        iteration = 0
        while iteration < args.eval_iters:
            iteration += 1
            if verbose and iteration % args.log_interval == 0:
                print_rank_0('Evaluating iter {}/{}'.format(iteration,
                                                            args.eval_iters))
680

681
682
683
            if mpu.get_pipeline_model_parallel_world_size() > 1:
                if args.virtual_pipeline_model_parallel_size is not None:
                    forward_backward_func = forward_backward_pipelining_with_interleaving
684
                else:
685
                    forward_backward_func = forward_backward_pipelining_without_interleaving
686
687
688
689
690
691
692
693
694
            else:
                forward_backward_func = forward_backward_no_pipelining
            loss_dicts = forward_backward_func(
                forward_step_func, data_iterator, model, optimizer=None,
                timers=None, forward_only=True)

            if mpu.is_pipeline_last_stage(ignore_virtual=True):
                # Reduce across processes.
                for loss_dict in loss_dicts:
695
                    for key in loss_dict:
696
697
                        total_loss_dict[key] = total_loss_dict.get(
                            key, torch.cuda.FloatTensor([0.0])) + loss_dict[key]
698

699
            args.consumed_valid_samples += mpu.get_data_parallel_world_size() \
700
                                           * args.micro_batch_size \
mohammad's avatar
mohammad committed
701
                                           * get_num_microbatches()
702
    # Move model back to the train mode.
703
704
    for model_module in model:
        model_module.train()
705
706

    for key in total_loss_dict:
mohammad's avatar
mohammad committed
707
        total_loss_dict[key] /= args.eval_iters * get_num_microbatches()
708
709
710
711
712

    return total_loss_dict

def evaluate_and_print_results(prefix, forward_step_func,
                               data_iterator, model,
Mohammad's avatar
Mohammad committed
713
                               iteration, verbose=False):
714
    """Helper function to evaluate and dump results on screen."""
715
    args = get_args()
Mohammad's avatar
Mohammad committed
716
717
718
    writer = get_tensorboard_writer()

    total_loss_dict = evaluate(forward_step_func, data_iterator, model, verbose)
719
720
721
722
723
    string = ' validation loss at {} | '.format(prefix)
    for key in total_loss_dict:
        string += '{} value: {:.6E} | '.format(key, total_loss_dict[key].item())
        ppl = math.exp(min(20, total_loss_dict[key].item()))
        string += '{} PPL: {:.6E} | '.format(key, ppl)
724
        if writer and is_last_rank():
mohammad's avatar
mohammad committed
725
            writer.add_scalar('{} validation'.format(key),
726
727
                              total_loss_dict[key].item(),
                              iteration)
mohammad's avatar
mohammad committed
728
            writer.add_scalar('{} validation vs samples'.format(key),
729
730
                              total_loss_dict[key].item(),
                              args.consumed_train_samples)
731
            if args.log_validation_ppl_to_tensorboard:
mohammad's avatar
mohammad committed
732
                writer.add_scalar('{} validation ppl'.format(key), ppl,
733
                                  iteration)
mohammad's avatar
mohammad committed
734
                writer.add_scalar('{} validation ppl vs samples'.format(key),
735
                                  ppl, args.consumed_train_samples)
736
737

    length = len(string) + 1
738
739
740
    print_rank_last('-' * length)
    print_rank_last(string)
    print_rank_last('-' * length)
741
742


Vijay Korthikanti's avatar
Vijay Korthikanti committed
743
def cyclic_iter(iter):
744
    while True:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
745
        for x in iter:
746
747
            yield x

748
749
750
def build_train_valid_test_data_iterators(
        build_train_valid_test_datasets_provider):
    """XXX"""
Mohammad's avatar
Mohammad committed
751
    args = get_args()
752

753
754
755
    (train_dataloader, valid_dataloader, test_dataloader) = (None, None, None)

    print_rank_0('> building train, validation, and test datasets ...')
756
757
758

    # Backward compatibility, assume fixed batch size.
    if args.iteration > 0 and args.consumed_train_samples == 0:
759
760
        assert args.train_samples is None, \
            'only backward compatiblity support for iteration-based training'
mohammad's avatar
mohammad committed
761
        args.consumed_train_samples = args.iteration * args.global_batch_size
762
    if args.iteration > 0 and args.consumed_valid_samples == 0:
763
764
        assert args.train_samples is None, \
            'only backward compatiblity support for iteration-based training'
765
        args.consumed_valid_samples = (args.iteration // args.eval_interval) * \
mohammad's avatar
mohammad committed
766
            args.eval_iters * args.global_batch_size
767

768
    # Data loader only on rank 0 of each model parallel group.
769
    if mpu.get_tensor_model_parallel_rank() == 0:
770
771

        # Number of train/valid/test samples.
772
773
774
775
776
777
        if args.train_samples:
            train_samples = args.train_samples
        else:
            train_samples = args.train_iters * args.global_batch_size
        eval_iters = (args.train_iters // args.eval_interval + 1) * \
                     args.eval_iters
778
        test_iters = args.eval_iters
779
        train_val_test_num_samples = [train_samples,
mohammad's avatar
mohammad committed
780
781
                                      eval_iters * args.global_batch_size,
                                      test_iters * args.global_batch_size]
782
783
784
785
786
787
788
789
790
791
        print_rank_0(' > datasets target sizes (minimum size):')
        print_rank_0('    train:      {}'.format(train_val_test_num_samples[0]))
        print_rank_0('    validation: {}'.format(train_val_test_num_samples[1]))
        print_rank_0('    test:       {}'.format(train_val_test_num_samples[2]))

        # Build the datasets.
        train_ds, valid_ds, test_ds = build_train_valid_test_datasets_provider(
            train_val_test_num_samples)

        # Build dataloders.
792
793
794
795
796
        train_dataloader = build_pretraining_data_loader(
            train_ds, args.consumed_train_samples)
        valid_dataloader = build_pretraining_data_loader(
            valid_ds, args.consumed_valid_samples)
        test_dataloader = build_pretraining_data_loader(test_ds, 0)
797
798
799
800
801
802
803
804
805
806
807
808
809

        # Flags to know if we need to do training/validation/testing.
        do_train = train_dataloader is not None and args.train_iters > 0
        do_valid = valid_dataloader is not None and args.eval_iters > 0
        do_test = test_dataloader is not None and args.eval_iters > 0
        # Need to broadcast num_tokens and num_type_tokens.
        flags = torch.cuda.LongTensor(
            [int(do_train), int(do_valid), int(do_test)])
    else:
        flags = torch.cuda.LongTensor([0, 0, 0])

    # Broadcast num tokens.
    torch.distributed.broadcast(flags,
810
811
                                mpu.get_tensor_model_parallel_src_rank(),
                                group=mpu.get_tensor_model_parallel_group())
812
813
814
815
    args.do_train = flags[0].item()
    args.do_valid = flags[1].item()
    args.do_test = flags[2].item()

Vijay Korthikanti's avatar
Vijay Korthikanti committed
816

817
    # Build iterators.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
818
819
820
    dl_type = args.dataloader_type
    assert dl_type in ['single', 'cyclic']

821
    if train_dataloader is not None:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
822
823
        train_data_iterator = iter(train_dataloader) if dl_type == 'single' \
                              else iter(cyclic_iter(train_dataloader))
824
825
826
    else:
        train_data_iterator = None

827
    if valid_dataloader is not None:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
828
829
        valid_data_iterator = iter(valid_dataloader) if dl_type == 'single' \
                              else iter(cyclic_iter(valid_dataloader))
830
    else:
831
        valid_data_iterator = None
832

833
    if test_dataloader is not None:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
834
835
        test_data_iterator = iter(test_dataloader) if dl_type == 'single' \
                             else iter(cyclic_iter(test_dataloader))
836
837
838
    else:
        test_data_iterator = None

839
    return train_data_iterator, valid_data_iterator, test_data_iterator