training.py 34.1 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Mohammad's avatar
Mohammad committed
16
"""Pretrain utilities."""
17
18
19

from datetime import datetime
import math
Mohammad's avatar
Mohammad committed
20
import sys
21
22
23
24
import time
# The earliest we can measure the start time.
_TRAIN_START_TIME = time.time()

25
26
27
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP

Neel Kant's avatar
Neel Kant committed
28
from megatron import get_args
Mohammad's avatar
Mohammad committed
29
30
from megatron import get_timers
from megatron import get_tensorboard_writer
31
from megatron import get_current_global_batch_size
mohammad's avatar
mohammad committed
32
from megatron import get_num_microbatches
mohammad's avatar
mohammad committed
33
from megatron import is_last_rank
mohammad's avatar
mohammad committed
34
from megatron import update_num_microbatches
35
from megatron import mpu
Neel Kant's avatar
Neel Kant committed
36
from megatron import print_rank_0
37
from megatron import print_rank_last
Mohammad's avatar
Mohammad committed
38
39
from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint
40
from megatron.model import FP16Module
mohammad's avatar
mohammad committed
41
from megatron.optimizer import get_megatron_optimizer
mohammad's avatar
mohammad committed
42

Mohammad's avatar
Mohammad committed
43
from megatron.initialize import initialize_megatron
44
from megatron.initialize import write_args_to_tensorboard
45
46
from megatron.learning_rates import AnnealingLR
from megatron.model import DistributedDataParallel as LocalDDP
Neel Kant's avatar
Neel Kant committed
47
from megatron.model.realm_model import ICTBertModel
48
from megatron.utils import check_adlr_autoresume_termination
Vijay Korthikanti's avatar
Vijay Korthikanti committed
49
from megatron.data.data_samplers import build_pretraining_data_loader
mohammad's avatar
mohammad committed
50
from megatron.utils import calc_params_l2_norm
51
52
53
from megatron.schedules import forward_backward_no_pipelining
from megatron.schedules import forward_backward_pipelining
from megatron.schedules import forward_backward_pipelining_with_interleaving
54
from megatron.utils import report_memory
55
56


57
58
59
60
61
62
63
def print_datetime(string):
    """Note that this call will sync across all ranks."""
    torch.distributed.barrier()
    time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    print_rank_0('[' + string + '] datetime: {} '.format(time_str))


64
65
66
67
def pretrain(train_valid_test_dataset_provider, 
             model_provider,
             forward_step_func, 
             extra_args_provider=None, 
Vijay Korthikanti's avatar
Vijay Korthikanti committed
68
             args_defaults={}):
69
70
71
    """Main training program.

    This function will run the followings in the order provided:
Mohammad's avatar
Mohammad committed
72
73
        1) initialize Megatron.
        2) setup model, optimizer and lr schedule using the model_provider.
74
        3) call train_val_test_data_provider to get train/val/test datasets.
Mohammad's avatar
Mohammad committed
75
        4) train the modle using the forward_step_func.
76
77

    Arguments:
78
79
80
        train_valid_test_dataset_provider: a function that takes the size of
            train/valid/test dataset and returns `train, valid, test` datasets.
        model_provider: a function that returns a vanilla version of the
Mohammad's avatar
Mohammad committed
81
82
83
84
85
86
87
88
89
90
            model. By vanilla we mean a simple model on cpu with no fp16 or ddp.
        forward_step_func: a function that takes a `data iterator` and `model`,
            and returns a `loss` scalar with a dictionary with key:values being
            the info we would like to monitor during training, for example
            `lm-loss: value`. We also require that this function add
            `batch generator` to the timers class.
        extra_args_provider: a function that takes a parser and adds arguments
            to it. It is used for programs to add their own arguments.
        args_defaults: a dictionary from argument-name to argument-value. It
            to set already parse arguments.
91
92
    """

93
    # Initalize and get arguments, timers, and Tensorboard writer.
94
95
    initialize_megatron(extra_args_provider=extra_args_provider,
                        args_defaults=args_defaults)
96

97
98
99
100
101
102
103
104
    # Adjust the startup time so it reflects the largest value.
    # This will be closer to what scheduler will see (outside of
    # image ... launches.
    global _TRAIN_START_TIME
    start_time_tensor = torch.cuda.FloatTensor([_TRAIN_START_TIME])
    torch.distributed.all_reduce(start_time_tensor,
                                 op=torch.distributed.ReduceOp.MIN)
    _TRAIN_START_TIME = start_time_tensor.item()
mshoeybi's avatar
mshoeybi committed
105
    print_rank_0('time to initialize megatron (seconds): {:.3f}'.format(
106
107
108
        time.time() - _TRAIN_START_TIME))
    print_datetime('after megatron is initialized')

109
    args = get_args()
Mohammad's avatar
Mohammad committed
110
    timers = get_timers()
111
112

    # Model, optimizer, and learning rate.
113
    timers('model-and-optimizer-setup').start()
Mohammad's avatar
Mohammad committed
114
    model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider)
115
    timers('model-and-optimizer-setup').stop()
116
117
    print_datetime('after model, optimizer, and learning rate '
                   'scheduler are built')
118
119

    # Data stuff.
120
121
122
123
124
125
126
127
128
129
130
131
132
133
    timers('train/valid/test-data-iterators-setup').start()
    if args.virtual_pipeline_model_parallel_size is not None:
        data_iterators = [
            build_train_valid_test_data_iterators(train_valid_test_dataset_provider)
            for _ in range(len(model))
        ]
        train_data_iterator = [x[0] for x in data_iterators]
        valid_data_iterator = [x[1] for x in data_iterators]
        test_data_iterator = [x[2] for x in data_iterators]
    else:
        train_data_iterator, valid_data_iterator, test_data_iterator \
            = build_train_valid_test_data_iterators(
                train_valid_test_dataset_provider)
    timers('train/valid/test-data-iterators-setup').stop()
mshoeybi's avatar
mshoeybi committed
134
    print_datetime('after dataloaders are built')
Mohammad's avatar
Mohammad committed
135
136

    # Print setup timing.
137
138
    print_rank_0('done with setup ...')
    timers.log(['model-and-optimizer-setup', 'train/valid/test-data-iterators-setup'])
Mohammad's avatar
Mohammad committed
139
    print_rank_0('training ...')
140
141

    iteration = 0
142
    if args.do_train and args.train_iters > 0:
mohammad's avatar
mohammad committed
143
144
145
        iteration = train(forward_step_func,
                          model, optimizer, lr_scheduler,
                          train_data_iterator, valid_data_iterator)
146
    print_datetime('after training is done')
Mohammad's avatar
Mohammad committed
147

148
149
150
    if args.do_valid:
        prefix = 'the end of training for val data'
        evaluate_and_print_results(prefix, forward_step_func,
151
                                   valid_data_iterator, model,
Mohammad's avatar
Mohammad committed
152
                                   iteration, False)
153
154

    if args.save and iteration != 0:
155
        save_checkpoint(iteration, model, optimizer, lr_scheduler)
156
157
158
159
160
161

    if args.do_test:
        # Run on test data.
        prefix = 'the end of training for test data'
        evaluate_and_print_results(prefix, forward_step_func,
                                   test_data_iterator, model,
Mohammad's avatar
Mohammad committed
162
                                   0, True)
163

164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
def update_train_iters(args):

    # For iteration-based training, we don't need to do anything
    if args.train_iters:
        return

    # Constant batch size with sample-based training.
    if args.rampup_batch_size is None:
        args.train_iters = args.train_samples // args.global_batch_size

    else:
        # Sample based training with rampup batch size.
        iterations = 0
        consumed_samples = 0
        # Rampup phase.
        while consumed_samples <= int(args.rampup_batch_size[2]):
180
181
            update_num_microbatches(consumed_samples, consistency_check=False)
            consumed_samples += get_current_global_batch_size()
182
183
            iterations += 1
        # Reset
184
        update_num_microbatches(0, consistency_check=False)
185
186
187
188
189
190
191
192
        # Constant phase
        # Note that we throw away any partial last batch.
        iterations += (args.train_samples - consumed_samples) // \
                      args.global_batch_size
        args.train_iters = iterations

    print_rank_0('setting training iterations to {}'.format(args.train_iters))

193

Mohammad's avatar
Mohammad committed
194
def get_model(model_provider_func):
195
    """Build the model."""
Mohammad's avatar
Mohammad committed
196
    args = get_args()
197
198

    # Build model on cpu.
Mohammad's avatar
Mohammad committed
199
    model = model_provider_func()
200
201
    if not isinstance(model, list):
        model = [model]
202

203
    # Set tensor model parallel attributes if not set.
mohammad's avatar
mohammad committed
204
205
206
    # Only parameters that are already tensor model parallel have these
    # attributes set for them. We should make sure the default attributes
    # are set for all params so the optimizer can use them.
207
208
209
    for model_module in model:
        for param in model_module.parameters():
            mpu.set_defaults_if_not_set_tensor_model_parallel_attributes(param)
210

211
212
    # Print number of parameters.
    if mpu.get_data_parallel_rank() == 0:
213
        print(' > number of parameters on (tensor, pipeline) '
214
              'model parallel rank ({}, {}): {}'.format(
215
216
            mpu.get_tensor_model_parallel_rank(),
            mpu.get_pipeline_model_parallel_rank(),
217
218
            sum([sum([p.nelement() for p in model_module.parameters()])
                 for model_module in model])), flush=True)
219
220

    # GPU allocation.
221
222
    for model_module in model:
        model_module.cuda(torch.cuda.current_device())
223
224
225

    # Fp16 conversion.
    if args.fp16:
226
        model = [FP16Module(model_module) for model_module in model]
227
228
229

    if args.DDP_impl == 'torch':
        i = torch.cuda.current_device()
230
231
232
        model = [torchDDP(model_module, device_ids=[i], output_device=i,
                          process_group=mpu.get_data_parallel_group())
                 for model_module in model]
233
234
        return model
    if args.DDP_impl == 'local':
235
        model = [LocalDDP(model_module) for model_module in model]
236
237
        return model

238
    raise NotImplementedError('Unknown DDP implementation specified: {}. '
239
                              'Exiting.'.format(args.DDP_impl))
240
241


Mohammad's avatar
Mohammad committed
242
def get_learning_rate_scheduler(optimizer):
243
    """Build the learning rate scheduler."""
Mohammad's avatar
Mohammad committed
244
    args = get_args()
245

246
247
248
249
250
    # Iteration-based training.
    if args.train_iters:
        if args.lr_decay_iters is None:
            args.lr_decay_iters = args.train_iters
        decay_steps = args.lr_decay_iters * args.global_batch_size
251
252
        if args.lr_warmup_fraction is not None:
            warmup_steps = args.lr_warmup_fraction * decay_steps
253
254
        else:
            warmup_steps = args.lr_warmup_iters * args.global_batch_size
255
256
257
258
259
    # Sample-based training.
    elif args.train_samples:
        # We need to set training iters for later use. Technically
        # we need to adjust the training samples too (due to last
        # batch being incomplete) but we leave it as is for now.
260
        update_train_iters(args)
261
262
263
        if args.lr_decay_samples is None:
            args.lr_decay_samples = args.train_samples
        decay_steps = args.lr_decay_samples
264
265
        if args.lr_warmup_fraction is not None:
            warmup_steps = args.lr_warmup_fraction * decay_steps
266
267
        else:
            warmup_steps = args.lr_warmup_samples
268
    else:
269
270
271
        raise Exception(
            'either train-iters or train-samples should be provided.')

272
273
    lr_scheduler = AnnealingLR(
        optimizer,
274
        max_lr=args.lr,
275
        min_lr=args.min_lr,
276
277
        warmup_steps=warmup_steps,
        decay_steps=decay_steps,
278
        decay_style=args.lr_decay_style,
279
280
281
282
283
284
        use_checkpoint_lr_scheduler=args.use_checkpoint_lr_scheduler,
        override_lr_scheduler=args.override_lr_scheduler)

    return lr_scheduler


Mohammad's avatar
Mohammad committed
285
def setup_model_and_optimizer(model_provider_func):
286
    """Setup model and optimizer."""
Mohammad's avatar
Mohammad committed
287
    args = get_args()
288

Mohammad's avatar
Mohammad committed
289
    model = get_model(model_provider_func)
290
291

    unwrapped_model = model
292
    while isinstance(unwrapped_model, (torchDDP, LocalDDP, FP16Module)):
293
294
295
        unwrapped_model = unwrapped_model.module
    optimizer = get_megatron_optimizer(unwrapped_model)

Mohammad's avatar
Mohammad committed
296
    lr_scheduler = get_learning_rate_scheduler(optimizer)
297
298

    if args.load is not None:
299
300
301
302
        timers = get_timers()
        # Extra barrier is added to make sure all ranks report the
        # max time.
        torch.distributed.barrier()
303
        timers('load-checkpoint').start()
304
        args.iteration = load_checkpoint(model, optimizer, lr_scheduler)
305
        torch.distributed.barrier()
306
307
        timers('load-checkpoint').stop()
        timers.log(['load-checkpoint'])
308
309
310
    else:
        args.iteration = 0

mohammad's avatar
mohammad committed
311
    # We only support local DDP with multiple micro-batches.
mohammad's avatar
mohammad committed
312
313
314
    if get_num_microbatches() > 1:
        assert args.DDP_impl == 'local'

Neel Kant's avatar
Neel Kant committed
315
    # get model without FP16 and/or TorchDDP wrappers
316
317
318
319
    for module in model:
        unwrapped_module = module
        while hasattr(unwrapped_module, 'module'):
            unwrapped_module = unwrapped_module.module
Neel Kant's avatar
Neel Kant committed
320

321
322
323
324
        if args.iteration == 0 and hasattr(unwrapped_module,
                                           'init_state_dict_from_bert'):
            print("Initializing ICT from pretrained BERT model", flush=True)
            unwrapped_module.init_state_dict_from_bert()
Neel Kant's avatar
Neel Kant committed
325

326
327
328
    return model, optimizer, lr_scheduler


329
330
331
332
333
334
335
def train_step(forward_step_func, data_iterator,
               model, optimizer, lr_scheduler):
    """Single training step."""
    args = get_args()
    timers = get_timers()

    # Set grad to zero.
mohammad's avatar
mohammad committed
336
    optimizer.zero_grad()
337
338

    if mpu.get_pipeline_model_parallel_world_size() > 1:
339
340
341
342
        if args.virtual_pipeline_model_parallel_size is not None:
            forward_backward_func = forward_backward_pipelining_with_interleaving
        else:
            forward_backward_func = forward_backward_pipelining
343
    else:
344
345
346
347
        forward_backward_func = forward_backward_no_pipelining
    losses_reduced = forward_backward_func(
        forward_step_func, data_iterator, model,
        optimizer, timers, forward_only=False)
348
349
350

    # All-reduce if needed.
    if args.DDP_impl == 'local':
351
        timers('backward-params-all-reduce').start()
352
353
354
        for model_module in model:
            model_module.allreduce_params(reduce_after=False,
                                          fp32_allreduce=args.fp32_allreduce)
355
        timers('backward-params-all-reduce').stop()
356

357
358
359
360
361
    # Barrier to measure backward stall.
    timers('backward-pipeline-stall').start()
    torch.distributed.barrier(group=mpu.get_pipeline_model_parallel_group())
    timers('backward-pipeline-stall').stop()

362
363
364
365
    # All-reduce word_embeddings' grad across first and last stages to ensure
    # that word_embeddings parameters stay in sync.
    # This should only run for models that support pipelined model parallelism
    # (BERT and GPT-2).
366
    timers('backward-embedding-all-reduce').start()
367
    if (mpu.is_pipeline_first_stage(ignore_virtual=True) or mpu.is_pipeline_last_stage(ignore_virtual=True)) and \
368
            mpu.get_pipeline_model_parallel_world_size() > 1:
369
370
371
372
        if mpu.is_pipeline_first_stage(ignore_virtual=True):
            unwrapped_model = model[0]
        elif mpu.is_pipeline_last_stage(ignore_virtual=True):
            unwrapped_model = model[-1]
373
        while isinstance(unwrapped_model, (torchDDP, LocalDDP, FP16Module)):
374
375
            unwrapped_model = unwrapped_model.module

376
377
378
379
        if unwrapped_model.share_word_embeddings:
            word_embeddings_weight = unwrapped_model.word_embeddings_weight()
            torch.distributed.all_reduce(word_embeddings_weight.grad,
                                         group=mpu.get_embedding_group())
380
    timers('backward-embedding-all-reduce').stop()
381

382
383
    # Update parameters.
    timers('optimizer').start()
384
    update_successful, grad_norm = optimizer.step()
385
386
387
    timers('optimizer').stop()

    # Update learning rate.
388
    if update_successful:
389
390
391
392
        increment = get_num_microbatches() * \
                    args.micro_batch_size * \
                    args.data_parallel_size
        lr_scheduler.step(increment=increment)
mohammad's avatar
mohammad committed
393
        skipped_iter = 0
394
395
396
    else:
        skipped_iter = 1

397
    if mpu.is_pipeline_last_stage(ignore_virtual=True):
398
399
400
401
        # Average loss across microbatches.
        loss_reduced = {}
        for key in losses_reduced[0]:
            losses_reduced_for_key = [x[key] for x in losses_reduced]
402
            loss_reduced[key] = sum(losses_reduced_for_key) / len(losses_reduced_for_key)
403
404
        return loss_reduced, skipped_iter, grad_norm
    return {}, skipped_iter, grad_norm
405
406


Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
407
def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
mohammad's avatar
mohammad committed
408
409
                 loss_scale, report_memory_flag, skipped_iter,
                 grad_norm, params_norm):
Mohammad's avatar
Mohammad committed
410
411
412
413
    """Log training information such as losses, timing, ...."""
    args = get_args()
    timers = get_timers()
    writer = get_tensorboard_writer()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
414

mohammad's avatar
mohammad committed
415
416
    # Advanced, skipped, and Nan iterations.
    advanced_iters_key = 'advanced iterations'
mohammad's avatar
mohammad committed
417
    skipped_iters_key = 'skipped iterations'
mohammad's avatar
mohammad committed
418
419
420
421
422
423
424
425
426
    nan_iters_key = 'nan iterations'
    # Advanced iterations.
    if not skipped_iter:
        total_loss_dict[advanced_iters_key] = total_loss_dict.get(
            advanced_iters_key, 0) + 1
    else:
        if advanced_iters_key not in total_loss_dict:
            total_loss_dict[advanced_iters_key] = 0
    # Skipped iterations.
mohammad's avatar
mohammad committed
427
428
    total_loss_dict[skipped_iters_key] = total_loss_dict.get(
        skipped_iters_key, 0) + skipped_iter
mohammad's avatar
mohammad committed
429
    # Update losses and set nan iterations
mohammad's avatar
mohammad committed
430
    got_nan = False
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
431
    for key in loss_dict:
mohammad's avatar
mohammad committed
432
        if not skipped_iter:
433
434
            total_loss_dict[key] = total_loss_dict.get(
                key, torch.cuda.FloatTensor([0.0])) + loss_dict[key]
mohammad's avatar
mohammad committed
435
436
437
438
439
        else:
            value = loss_dict[key].float().sum().item()
            is_nan = value == float('inf') or \
                     value == -float('inf') or \
                     value != value
mohammad's avatar
mohammad committed
440
            got_nan = got_nan or is_nan
mohammad's avatar
mohammad committed
441
442
    total_loss_dict[nan_iters_key] = total_loss_dict.get(
        nan_iters_key, 0) + int(got_nan)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
443
444
445

    # Logging.
    timers_to_log = []
Neel Kant's avatar
Neel Kant committed
446

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
447
448
449
    def add_to_logging(name):
        if name in timers.timers:
            timers_to_log.append(name)
450
    add_to_logging('forward-compute')
451
    add_to_logging('forward-pipeline-stall')
452
453
    add_to_logging('forward-recv')
    add_to_logging('forward-send')
454
    add_to_logging('forward-backward-send-forward-backward-recv')
455
    add_to_logging('backward-compute')
456
    add_to_logging('backward-pipeline-stall')
457
458
    add_to_logging('backward-recv')
    add_to_logging('backward-send')
Deepak Narayanan's avatar
Deepak Narayanan committed
459
    add_to_logging('backward-send-forward-recv')
460
    add_to_logging('backward-send-backward-recv')
461
    add_to_logging('backward-params-all-reduce')
462
    add_to_logging('backward-embedding-all-reduce')
463
    add_to_logging('optimizer-copy-to-main-grad')
mohammad's avatar
mohammad committed
464
    add_to_logging('optimizer-unscale-and-check-inf')
465
466
    add_to_logging('optimizer-clip-main-grad')
    add_to_logging('optimizer-copy-main-to-model-params')
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
467
    add_to_logging('optimizer')
mohammad's avatar
mohammad committed
468
    add_to_logging('batch-generator')
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
469

mohammad's avatar
mohammad committed
470
    # Calculate batch size.
mshoeybi's avatar
mshoeybi committed
471
472
473
    batch_size = args.micro_batch_size * args.data_parallel_size * \
        get_num_microbatches()

mohammad's avatar
mohammad committed
474
475
476
    total_iterations = total_loss_dict[advanced_iters_key] + \
                       total_loss_dict[skipped_iters_key]

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
477
    # Tensorboard values.
478
479
480
481
482
483
484
485
486
487
    if writer and (iteration % args.tensorboard_log_interval == 0 ) and \
       is_last_rank():
        if args.log_learning_rate_to_tensorboard:
            writer.add_scalar('learning-rate', learning_rate, iteration)
            writer.add_scalar('learning-rate vs samples', learning_rate,
                              args.consumed_train_samples)
        if args.log_batch_size_to_tensorboard:
            writer.add_scalar('batch-size', batch_size, iteration)
            writer.add_scalar('batch-size vs samples', batch_size,
                              args.consumed_train_samples)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
488
        for key in loss_dict:
mohammad's avatar
mohammad committed
489
490
            writer.add_scalar(key , loss_dict[key], iteration)
            writer.add_scalar(key + ' vs samples', loss_dict[key],
491
                              args.consumed_train_samples)
492
493
494
495
        if args.log_loss_scale_to_tensorboard:
            writer.add_scalar('loss-scale', loss_scale, iteration)
            writer.add_scalar('loss-scale vs samples', loss_scale,
                              args.consumed_train_samples)
496
497
498
499
        if grad_norm is not None:
            writer.add_scalar('grad-norm', grad_norm, iteration)
            writer.add_scalar('grad-norm vs samples', grad_norm,
                              args.consumed_train_samples)
mohammad's avatar
mohammad committed
500
501
502
503
        if params_norm is not None:
            writer.add_scalar('params-norm', params_norm, iteration)
            writer.add_scalar('params-norm vs samples', params_norm,
                              args.consumed_train_samples)
504
505
506
        if args.log_timers_to_tensorboard:
            timers.write(timers_to_log, writer, iteration,
                         normalizer=total_iterations)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
507
508

    if iteration % args.log_interval == 0:
509
        elapsed_time = timers('interval-time').elapsed()
mohammad's avatar
mohammad committed
510
        elapsed_time_per_iteration = elapsed_time / total_iterations
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
511
        if writer and torch.distributed.get_rank() == 0:
512
513
514
            if args.log_timers_to_tensorboard:
                writer.add_scalar('iteration-time',
                                  elapsed_time_per_iteration, iteration)
515
516
        log_string = ' iteration {:8d}/{:8d} |'.format(
            iteration, args.train_iters)
mshoeybi's avatar
mshoeybi committed
517
        log_string += ' consumed samples: {:12d} |'.format(
518
            args.consumed_train_samples)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
519
        log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
mohammad's avatar
mohammad committed
520
            elapsed_time_per_iteration * 1000.0)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
521
        log_string += ' learning rate: {:.3E} |'.format(learning_rate)
mohammad's avatar
mohammad committed
522
        log_string += ' global batch size: {:5d} |'.format(batch_size)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
523
        for key in total_loss_dict:
mohammad's avatar
mohammad committed
524
525
526
527
            if key not in [advanced_iters_key, skipped_iters_key,
                           nan_iters_key]:
                avg = total_loss_dict[key].item() / \
                      float(max(1, total_loss_dict[advanced_iters_key]))
528
529
530
                if avg > 0.0:
                    log_string += ' {}: {:.6E} |'.format(key, avg)
                total_loss_dict[key] = torch.cuda.FloatTensor([0.0])
531
        log_string += ' loss scale: {:.1f} |'.format(loss_scale)
532
533
        if grad_norm is not None:
            log_string += ' grad norm: {:.3f} |'.format(grad_norm)
mohammad's avatar
mohammad committed
534
535
        if params_norm is not None:
            log_string += ' params norm: {:.3f} |'.format(params_norm)
mohammad's avatar
mohammad committed
536
537
        log_string += ' number of skipped iterations: {:3d} |'.format(
            total_loss_dict[skipped_iters_key])
mohammad's avatar
mohammad committed
538
        log_string += ' number of nan iterations: {:3d} |'.format(
mohammad's avatar
mohammad committed
539
540
            total_loss_dict[nan_iters_key])
        total_loss_dict[advanced_iters_key] = 0
mohammad's avatar
mohammad committed
541
        total_loss_dict[skipped_iters_key] = 0
mohammad's avatar
mohammad committed
542
        total_loss_dict[nan_iters_key] = 0
543
        print_rank_last(log_string)
544
545
546
        if report_memory_flag and learning_rate > 0.:
            # Report memory after optimizer state has been initialized.
            report_memory('(after {} iterations)'.format(iteration))
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
547
548
549
550
551
552
            report_memory_flag = False
        timers.log(timers_to_log, normalizer=args.log_interval)

    return report_memory_flag


553
554
555
556
557
def save_checkpoint_and_time(iteration, model, optimizer, lr_scheduler):
    timers = get_timers()
    # Extra barrier is added to make sure
    # all ranks report the max time.
    torch.distributed.barrier()
558
    timers('save-checkpoint').start()
559
560
    save_checkpoint(iteration, model, optimizer, lr_scheduler)
    torch.distributed.barrier()
561
562
    timers('save-checkpoint').stop()
    timers.log(['save-checkpoint'])
563
564


565
def train(forward_step_func, model, optimizer, lr_scheduler,
566
          train_data_iterator, valid_data_iterator):
567
    """Train the model function."""
Mohammad's avatar
Mohammad committed
568
569
    args = get_args()
    timers = get_timers()
570

571
572
573
    # Write args to tensorboard
    write_args_to_tensorboard()

574
    # Turn on training mode which enables dropout.
575
576
    for model_module in model:
        model_module.train()
577
578
579
580
581
582
583

    # Tracking loss.
    total_loss_dict = {}

    # Iterations.
    iteration = args.iteration

584
    timers('interval-time').start()
585
    print_datetime('before the start of training step')
586
587
    report_memory_flag = True
    while iteration < args.train_iters:
mohammad's avatar
mohammad committed
588
        update_num_microbatches(args.consumed_train_samples)
589
590
591
592
593
        loss_dict, skipped_iter, grad_norm = train_step(forward_step_func,
                                                        train_data_iterator,
                                                        model,
                                                        optimizer,
                                                        lr_scheduler)
594
        iteration += 1
595
        args.consumed_train_samples += mpu.get_data_parallel_world_size() * \
596
                                       args.micro_batch_size * \
mohammad's avatar
mohammad committed
597
                                       get_num_microbatches()
598
599

        # Logging.
600
        loss_scale = optimizer.get_loss_scale().item()
601
602
603
        params_norm = None
        if args.log_params_norm:
            params_norm = calc_params_l2_norm(model)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
604
605
        report_memory_flag = training_log(loss_dict, total_loss_dict,
                                          optimizer.param_groups[0]['lr'],
Mohammad's avatar
Mohammad committed
606
                                          iteration, loss_scale,
607
                                          report_memory_flag, skipped_iter,
mohammad's avatar
mohammad committed
608
                                          grad_norm, params_norm)
609
610

        # Autoresume
611
612
        if args.adlr_autoresume and \
           (iteration % args.adlr_autoresume_interval == 0):
613
            check_adlr_autoresume_termination(iteration, model, optimizer,
614
                                              lr_scheduler)
615
616
617
618
619
620

        # Evaluation
        if args.eval_interval and iteration % args.eval_interval == 0 and \
           args.do_valid:
            prefix = 'iteration {}'.format(iteration)
            evaluate_and_print_results(prefix, forward_step_func,
621
                                       valid_data_iterator, model,
Mohammad's avatar
Mohammad committed
622
                                       iteration, False)
623

624
625
626
627
628
629
630
631
        # Checkpointing
        saved_checkpoint = False
        if args.save and args.save_interval and \
           iteration % args.save_interval == 0:
            save_checkpoint_and_time(iteration, model, optimizer,
                                     lr_scheduler)
            saved_checkpoint = True

632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
        # Exiting based on duration
        if args.exit_duration_in_mins:
            train_time = (time.time() - _TRAIN_START_TIME) / 60.0
            done_cuda = torch.cuda.IntTensor(
                [train_time > args.exit_duration_in_mins])
            torch.distributed.all_reduce(
                done_cuda, op=torch.distributed.ReduceOp.MAX)
            done = done_cuda.item()
            if done:
                if not saved_checkpoint:
                    save_checkpoint_and_time(iteration, model, optimizer,
                                             lr_scheduler)
                print_datetime('exiting program after {} minutes'.format(train_time))                
                sys.exit()

        # Exiting based on iterations        
648
        if args.exit_interval and iteration % args.exit_interval == 0:
649
650
651
            if not saved_checkpoint:
                save_checkpoint_and_time(iteration, model, optimizer,
                                         lr_scheduler)
652
            torch.distributed.barrier()
653
            print_datetime('exiting program at iteration {}'.format(iteration))                
Mohammad's avatar
Mohammad committed
654
            sys.exit()
655

656

mohammad's avatar
mohammad committed
657
    return iteration
658
659


Mohammad's avatar
Mohammad committed
660
def evaluate(forward_step_func, data_iterator, model, verbose=False):
661
    """Evaluation."""
Mohammad's avatar
Mohammad committed
662
    args = get_args()
663
664

    # Turn on evaluation mode which disables dropout.
665
666
    for model_module in model:
        model_module.eval()
667
668
669
670
671
672
673
674
675
676

    total_loss_dict = {}

    with torch.no_grad():
        iteration = 0
        while iteration < args.eval_iters:
            iteration += 1
            if verbose and iteration % args.log_interval == 0:
                print_rank_0('Evaluating iter {}/{}'.format(iteration,
                                                            args.eval_iters))
677

678
679
680
            if mpu.get_pipeline_model_parallel_world_size() > 1:
                if args.virtual_pipeline_model_parallel_size is not None:
                    forward_backward_func = forward_backward_pipelining_with_interleaving
681
                else:
682
683
684
685
686
687
688
689
690
691
                    forward_backward_func = forward_backward_pipelining
            else:
                forward_backward_func = forward_backward_no_pipelining
            loss_dicts = forward_backward_func(
                forward_step_func, data_iterator, model, optimizer=None,
                timers=None, forward_only=True)

            if mpu.is_pipeline_last_stage(ignore_virtual=True):
                # Reduce across processes.
                for loss_dict in loss_dicts:
692
693
694
                    for key in loss_dict:
                        total_loss_dict[key] = total_loss_dict.get(key, torch.cuda.FloatTensor([0.0])) + \
                            loss_dict[key]
695

696
            args.consumed_valid_samples += mpu.get_data_parallel_world_size() \
697
                                           * args.micro_batch_size \
mohammad's avatar
mohammad committed
698
                                           * get_num_microbatches()
699
    # Move model back to the train mode.
700
701
    for model_module in model:
        model_module.train()
702
703

    for key in total_loss_dict:
mohammad's avatar
mohammad committed
704
        total_loss_dict[key] /= args.eval_iters * get_num_microbatches()
705
706
707
708
709

    return total_loss_dict

def evaluate_and_print_results(prefix, forward_step_func,
                               data_iterator, model,
Mohammad's avatar
Mohammad committed
710
                               iteration, verbose=False):
711
    """Helper function to evaluate and dump results on screen."""
712
    args = get_args()
Mohammad's avatar
Mohammad committed
713
714
715
    writer = get_tensorboard_writer()

    total_loss_dict = evaluate(forward_step_func, data_iterator, model, verbose)
716
717
718
719
720
    string = ' validation loss at {} | '.format(prefix)
    for key in total_loss_dict:
        string += '{} value: {:.6E} | '.format(key, total_loss_dict[key].item())
        ppl = math.exp(min(20, total_loss_dict[key].item()))
        string += '{} PPL: {:.6E} | '.format(key, ppl)
721
        if writer and is_last_rank():
mohammad's avatar
mohammad committed
722
            writer.add_scalar('{} validation'.format(key),
723
724
                              total_loss_dict[key].item(),
                              iteration)
mohammad's avatar
mohammad committed
725
            writer.add_scalar('{} validation vs samples'.format(key),
726
727
                              total_loss_dict[key].item(),
                              args.consumed_train_samples)
728
            if args.log_validation_ppl_to_tensorboard:
mohammad's avatar
mohammad committed
729
                writer.add_scalar('{} validation ppl'.format(key), ppl,
730
                                  iteration)
mohammad's avatar
mohammad committed
731
                writer.add_scalar('{} validation ppl vs samples'.format(key),
732
                                  ppl, args.consumed_train_samples)
733
734

    length = len(string) + 1
735
736
737
    print_rank_last('-' * length)
    print_rank_last(string)
    print_rank_last('-' * length)
738
739


Vijay Korthikanti's avatar
Vijay Korthikanti committed
740
def cyclic_iter(iter):
741
    while True:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
742
        for x in iter:
743
744
            yield x

745
746
747
def build_train_valid_test_data_iterators(
        build_train_valid_test_datasets_provider):
    """XXX"""
Mohammad's avatar
Mohammad committed
748
    args = get_args()
749

750
751
752
    (train_dataloader, valid_dataloader, test_dataloader) = (None, None, None)

    print_rank_0('> building train, validation, and test datasets ...')
753
754
755

    # Backward compatibility, assume fixed batch size.
    if args.iteration > 0 and args.consumed_train_samples == 0:
756
757
        assert args.train_samples is None, \
            'only backward compatiblity support for iteration-based training'
mohammad's avatar
mohammad committed
758
        args.consumed_train_samples = args.iteration * args.global_batch_size
759
    if args.iteration > 0 and args.consumed_valid_samples == 0:
760
761
        assert args.train_samples is None, \
            'only backward compatiblity support for iteration-based training'
762
        args.consumed_valid_samples = (args.iteration // args.eval_interval) * \
mohammad's avatar
mohammad committed
763
            args.eval_iters * args.global_batch_size
764

765
    # Data loader only on rank 0 of each model parallel group.
766
    if mpu.get_tensor_model_parallel_rank() == 0:
767
768

        # Number of train/valid/test samples.
769
770
771
772
773
774
        if args.train_samples:
            train_samples = args.train_samples
        else:
            train_samples = args.train_iters * args.global_batch_size
        eval_iters = (args.train_iters // args.eval_interval + 1) * \
                     args.eval_iters
775
        test_iters = args.eval_iters
776
        train_val_test_num_samples = [train_samples,
mohammad's avatar
mohammad committed
777
778
                                      eval_iters * args.global_batch_size,
                                      test_iters * args.global_batch_size]
779
780
781
782
783
784
785
786
787
788
        print_rank_0(' > datasets target sizes (minimum size):')
        print_rank_0('    train:      {}'.format(train_val_test_num_samples[0]))
        print_rank_0('    validation: {}'.format(train_val_test_num_samples[1]))
        print_rank_0('    test:       {}'.format(train_val_test_num_samples[2]))

        # Build the datasets.
        train_ds, valid_ds, test_ds = build_train_valid_test_datasets_provider(
            train_val_test_num_samples)

        # Build dataloders.
789
790
791
792
793
        train_dataloader = build_pretraining_data_loader(
            train_ds, args.consumed_train_samples)
        valid_dataloader = build_pretraining_data_loader(
            valid_ds, args.consumed_valid_samples)
        test_dataloader = build_pretraining_data_loader(test_ds, 0)
794
795
796
797
798
799
800
801
802
803
804
805
806

        # Flags to know if we need to do training/validation/testing.
        do_train = train_dataloader is not None and args.train_iters > 0
        do_valid = valid_dataloader is not None and args.eval_iters > 0
        do_test = test_dataloader is not None and args.eval_iters > 0
        # Need to broadcast num_tokens and num_type_tokens.
        flags = torch.cuda.LongTensor(
            [int(do_train), int(do_valid), int(do_test)])
    else:
        flags = torch.cuda.LongTensor([0, 0, 0])

    # Broadcast num tokens.
    torch.distributed.broadcast(flags,
807
808
                                mpu.get_tensor_model_parallel_src_rank(),
                                group=mpu.get_tensor_model_parallel_group())
809
810
811
812
    args.do_train = flags[0].item()
    args.do_valid = flags[1].item()
    args.do_test = flags[2].item()

Vijay Korthikanti's avatar
Vijay Korthikanti committed
813

814
    # Build iterators.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
815
816
817
    dl_type = args.dataloader_type
    assert dl_type in ['single', 'cyclic']

818
    if train_dataloader is not None:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
819
820
        train_data_iterator = iter(train_dataloader) if dl_type == 'single' \
                              else iter(cyclic_iter(train_dataloader))
821
822
823
    else:
        train_data_iterator = None

824
    if valid_dataloader is not None:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
825
826
        valid_data_iterator = iter(valid_dataloader) if dl_type == 'single' \
                              else iter(cyclic_iter(valid_dataloader))
827
    else:
828
        valid_data_iterator = None
829

830
    if test_dataloader is not None:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
831
832
        test_data_iterator = iter(test_dataloader) if dl_type == 'single' \
                             else iter(cyclic_iter(test_dataloader))
833
834
835
    else:
        test_data_iterator = None

836
    return train_data_iterator, valid_data_iterator, test_data_iterator