training.py 35.3 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Mohammad's avatar
Mohammad committed
16
"""Pretrain utilities."""
17
18
19

from datetime import datetime
import math
Mohammad's avatar
Mohammad committed
20
import sys
21
22
23
24
import time
# The earliest we can measure the start time.
_TRAIN_START_TIME = time.time()

25
26
27
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP

Neel Kant's avatar
Neel Kant committed
28
from megatron import get_args
Mohammad's avatar
Mohammad committed
29
30
from megatron import get_timers
from megatron import get_tensorboard_writer
31
from megatron import get_current_global_batch_size
mohammad's avatar
mohammad committed
32
from megatron import get_num_microbatches
mohammad's avatar
mohammad committed
33
from megatron import is_last_rank
mohammad's avatar
mohammad committed
34
from megatron import update_num_microbatches
35
from megatron import mpu
Neel Kant's avatar
Neel Kant committed
36
from megatron import print_rank_0
37
from megatron import print_rank_last
Mohammad's avatar
Mohammad committed
38
39
from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
40
from megatron.model import Float16Module
mohammad's avatar
mohammad committed
41
from megatron.optimizer import get_megatron_optimizer
Mohammad's avatar
Mohammad committed
42
from megatron.initialize import initialize_megatron
43
from megatron.initialize import write_args_to_tensorboard
44
45
46
from megatron.learning_rates import AnnealingLR
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.utils import check_adlr_autoresume_termination
47
from megatron.utils import unwrap_model
Vijay Korthikanti's avatar
Vijay Korthikanti committed
48
from megatron.data.data_samplers import build_pretraining_data_loader
mohammad's avatar
mohammad committed
49
from megatron.utils import calc_params_l2_norm
50
from megatron.schedules import get_forward_backward_func
51
from megatron.utils import report_memory
52
53


Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
54

55
56
57
58
59
60
61
def print_datetime(string):
    """Note that this call will sync across all ranks."""
    torch.distributed.barrier()
    time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    print_rank_0('[' + string + '] datetime: {} '.format(time_str))


62
def pretrain(train_valid_test_dataset_provider,
63
             model_provider,
64
65
             forward_step_func,
             extra_args_provider=None,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
66
             args_defaults={}):
67
68
69
    """Main training program.

    This function will run the followings in the order provided:
Mohammad's avatar
Mohammad committed
70
71
        1) initialize Megatron.
        2) setup model, optimizer and lr schedule using the model_provider.
72
        3) call train_val_test_data_provider to get train/val/test datasets.
Mohammad's avatar
Mohammad committed
73
        4) train the modle using the forward_step_func.
74
75

    Arguments:
76
77
78
        train_valid_test_dataset_provider: a function that takes the size of
            train/valid/test dataset and returns `train, valid, test` datasets.
        model_provider: a function that returns a vanilla version of the
Mohammad's avatar
Mohammad committed
79
80
81
82
83
84
85
86
87
88
            model. By vanilla we mean a simple model on cpu with no fp16 or ddp.
        forward_step_func: a function that takes a `data iterator` and `model`,
            and returns a `loss` scalar with a dictionary with key:values being
            the info we would like to monitor during training, for example
            `lm-loss: value`. We also require that this function add
            `batch generator` to the timers class.
        extra_args_provider: a function that takes a parser and adds arguments
            to it. It is used for programs to add their own arguments.
        args_defaults: a dictionary from argument-name to argument-value. It
            to set already parse arguments.
89
90
    """

91
    # Initalize and get arguments, timers, and Tensorboard writer.
92
93
    initialize_megatron(extra_args_provider=extra_args_provider,
                        args_defaults=args_defaults)
94

95
96
97
98
    # Adjust the startup time so it reflects the largest value.
    # This will be closer to what scheduler will see (outside of
    # image ... launches.
    global _TRAIN_START_TIME
99
    start_time_tensor = torch.cuda.DoubleTensor([_TRAIN_START_TIME])
100
101
102
    torch.distributed.all_reduce(start_time_tensor,
                                 op=torch.distributed.ReduceOp.MIN)
    _TRAIN_START_TIME = start_time_tensor.item()
mshoeybi's avatar
mshoeybi committed
103
    print_rank_0('time to initialize megatron (seconds): {:.3f}'.format(
104
105
106
        time.time() - _TRAIN_START_TIME))
    print_datetime('after megatron is initialized')

107
    args = get_args()
Mohammad's avatar
Mohammad committed
108
    timers = get_timers()
109
110

    # Model, optimizer, and learning rate.
111
    timers('model-and-optimizer-setup').start()
Mohammad's avatar
Mohammad committed
112
    model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider)
113
    timers('model-and-optimizer-setup').stop()
114
115
    print_datetime('after model, optimizer, and learning rate '
                   'scheduler are built')
116
117

    # Data stuff.
118
119
    timers('train/valid/test-data-iterators-setup').start()
    if args.virtual_pipeline_model_parallel_size is not None:
120
        all_data_iterators = [
121
122
123
            build_train_valid_test_data_iterators(train_valid_test_dataset_provider)
            for _ in range(len(model))
        ]
124
125
126
        train_data_iterator = [data_iterators[0] for data_iterators in all_data_iterators]
        valid_data_iterator = [data_iterators[1] for data_iterators in all_data_iterators]
        test_data_iterator = [data_iterators[2] for data_iterators in all_data_iterators]
127
128
129
130
131
    else:
        train_data_iterator, valid_data_iterator, test_data_iterator \
            = build_train_valid_test_data_iterators(
                train_valid_test_dataset_provider)
    timers('train/valid/test-data-iterators-setup').stop()
mshoeybi's avatar
mshoeybi committed
132
    print_datetime('after dataloaders are built')
Mohammad's avatar
Mohammad committed
133
134

    # Print setup timing.
135
136
    print_rank_0('done with setup ...')
    timers.log(['model-and-optimizer-setup', 'train/valid/test-data-iterators-setup'])
Mohammad's avatar
Mohammad committed
137
    print_rank_0('training ...')
138
139

    iteration = 0
140
    if args.do_train and args.train_iters > 0:
mohammad's avatar
mohammad committed
141
142
143
        iteration = train(forward_step_func,
                          model, optimizer, lr_scheduler,
                          train_data_iterator, valid_data_iterator)
144
    print_datetime('after training is done')
Mohammad's avatar
Mohammad committed
145

146
147
148
    if args.do_valid:
        prefix = 'the end of training for val data'
        evaluate_and_print_results(prefix, forward_step_func,
149
                                   valid_data_iterator, model,
Mohammad's avatar
Mohammad committed
150
                                   iteration, False)
151
152

    if args.save and iteration != 0:
153
        save_checkpoint(iteration, model, optimizer, lr_scheduler)
154
155
156
157
158
159

    if args.do_test:
        # Run on test data.
        prefix = 'the end of training for test data'
        evaluate_and_print_results(prefix, forward_step_func,
                                   test_data_iterator, model,
Mohammad's avatar
Mohammad committed
160
                                   0, True)
161

162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
def update_train_iters(args):

    # For iteration-based training, we don't need to do anything
    if args.train_iters:
        return

    # Constant batch size with sample-based training.
    if args.rampup_batch_size is None:
        args.train_iters = args.train_samples // args.global_batch_size

    else:
        # Sample based training with rampup batch size.
        iterations = 0
        consumed_samples = 0
        # Rampup phase.
        while consumed_samples <= int(args.rampup_batch_size[2]):
178
179
            update_num_microbatches(consumed_samples, consistency_check=False)
            consumed_samples += get_current_global_batch_size()
180
181
            iterations += 1
        # Reset
182
        update_num_microbatches(0, consistency_check=False)
183
184
185
186
187
188
189
190
        # Constant phase
        # Note that we throw away any partial last batch.
        iterations += (args.train_samples - consumed_samples) // \
                      args.global_batch_size
        args.train_iters = iterations

    print_rank_0('setting training iterations to {}'.format(args.train_iters))

191

192
def get_model(model_provider_func, wrap_with_ddp=True):
193
    """Build the model."""
Mohammad's avatar
Mohammad committed
194
    args = get_args()
195

196
    # Build model.
197
198
199
200
201
    if mpu.get_pipeline_model_parallel_world_size() > 1 and \
       args.virtual_pipeline_model_parallel_size is not None:
        model = []
        for i in range(args.virtual_pipeline_model_parallel_size):
            mpu.set_virtual_pipeline_model_parallel_rank(i)
202
203
204
            # Set pre_process and post_process only after virtual rank is set.
            pre_process = mpu.is_pipeline_first_stage()
            post_process = mpu.is_pipeline_last_stage()
205
            this_model = model_provider_func(
206
207
208
                pre_process=pre_process,
                post_process=post_process
            )
209
            model.append(this_model)
210
    else:
211
212
        pre_process = mpu.is_pipeline_first_stage()
        post_process = mpu.is_pipeline_last_stage()
213
214
215
216
217
        model = model_provider_func(
            pre_process=pre_process,
            post_process=post_process
        )

218
219
    if not isinstance(model, list):
        model = [model]
220

221
    # Set tensor model parallel attributes if not set.
mohammad's avatar
mohammad committed
222
223
224
    # Only parameters that are already tensor model parallel have these
    # attributes set for them. We should make sure the default attributes
    # are set for all params so the optimizer can use them.
225
226
227
    for model_module in model:
        for param in model_module.parameters():
            mpu.set_defaults_if_not_set_tensor_model_parallel_attributes(param)
228

229
230
    # Print number of parameters.
    if mpu.get_data_parallel_rank() == 0:
231
        print(' > number of parameters on (tensor, pipeline) '
232
              'model parallel rank ({}, {}): {}'.format(
233
234
            mpu.get_tensor_model_parallel_rank(),
            mpu.get_pipeline_model_parallel_rank(),
235
236
            sum([sum([p.nelement() for p in model_module.parameters()])
                 for model_module in model])), flush=True)
237
238

    # GPU allocation.
239
240
    for model_module in model:
        model_module.cuda(torch.cuda.current_device())
241
242

    # Fp16 conversion.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
243
244
    if args.fp16 or args.bf16:
        model = [Float16Module(model_module, args) for model_module in model]
245

246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
    if wrap_with_ddp:
        if args.DDP_impl == 'torch':
            i = torch.cuda.current_device()
            model = [torchDDP(model_module, device_ids=[i], output_device=i,
                              process_group=mpu.get_data_parallel_group())
                     for model_module in model]

        elif args.DDP_impl == 'local':
            model = [LocalDDP(model_module,
                              args.accumulate_allreduce_grads_in_fp32,
                              args.use_contiguous_buffers_in_local_ddp)
                     for model_module in model]

        else:
            raise NotImplementedError('Unknown DDP implementation specified: '
                                      '{}. Exiting.'.format(args.DDP_impl))

    return model
264
265


Mohammad's avatar
Mohammad committed
266
def get_learning_rate_scheduler(optimizer):
267
    """Build the learning rate scheduler."""
Mohammad's avatar
Mohammad committed
268
    args = get_args()
269

270
271
272
273
274
    # Iteration-based training.
    if args.train_iters:
        if args.lr_decay_iters is None:
            args.lr_decay_iters = args.train_iters
        decay_steps = args.lr_decay_iters * args.global_batch_size
275
276
        if args.lr_warmup_fraction is not None:
            warmup_steps = args.lr_warmup_fraction * decay_steps
277
278
        else:
            warmup_steps = args.lr_warmup_iters * args.global_batch_size
279
280
281
282
283
    # Sample-based training.
    elif args.train_samples:
        # We need to set training iters for later use. Technically
        # we need to adjust the training samples too (due to last
        # batch being incomplete) but we leave it as is for now.
284
        update_train_iters(args)
285
286
287
        if args.lr_decay_samples is None:
            args.lr_decay_samples = args.train_samples
        decay_steps = args.lr_decay_samples
288
289
        if args.lr_warmup_fraction is not None:
            warmup_steps = args.lr_warmup_fraction * decay_steps
290
291
        else:
            warmup_steps = args.lr_warmup_samples
292
    else:
293
294
295
        raise Exception(
            'either train-iters or train-samples should be provided.')

296
297
    lr_scheduler = AnnealingLR(
        optimizer,
298
        max_lr=args.lr,
299
        min_lr=args.min_lr,
300
301
        warmup_steps=warmup_steps,
        decay_steps=decay_steps,
302
        decay_style=args.lr_decay_style,
303
304
305
306
307
308
        use_checkpoint_lr_scheduler=args.use_checkpoint_lr_scheduler,
        override_lr_scheduler=args.override_lr_scheduler)

    return lr_scheduler


Mohammad's avatar
Mohammad committed
309
def setup_model_and_optimizer(model_provider_func):
310
    """Setup model and optimizer."""
Mohammad's avatar
Mohammad committed
311
    args = get_args()
312

Mohammad's avatar
Mohammad committed
313
    model = get_model(model_provider_func)
314

315
    unwrapped_model = unwrap_model(model,
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
316
                                   (torchDDP, LocalDDP, Float16Module))
317
318
    optimizer = get_megatron_optimizer(unwrapped_model)

Mohammad's avatar
Mohammad committed
319
    lr_scheduler = get_learning_rate_scheduler(optimizer)
320
321

    if args.load is not None:
322
323
324
325
        timers = get_timers()
        # Extra barrier is added to make sure all ranks report the
        # max time.
        torch.distributed.barrier()
326
        timers('load-checkpoint').start()
327
        args.iteration = load_checkpoint(model, optimizer, lr_scheduler)
328
        torch.distributed.barrier()
329
330
        timers('load-checkpoint').stop()
        timers.log(['load-checkpoint'])
331
332
333
    else:
        args.iteration = 0

mohammad's avatar
mohammad committed
334
    # We only support local DDP with multiple micro-batches.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
335
    if len(model) > 1 or mpu.get_pipeline_model_parallel_world_size() > 1:
mohammad's avatar
mohammad committed
336
337
        assert args.DDP_impl == 'local'

Neel Kant's avatar
Neel Kant committed
338
    # get model without FP16 and/or TorchDDP wrappers
Mostofa Patwary's avatar
Mostofa Patwary committed
339
340
    if args.iteration == 0 and len(unwrapped_model) == 1 \
        and hasattr(unwrapped_model[0], 'init_state_dict_from_bert'):
Mostofa Patwary's avatar
Mostofa Patwary committed
341
        print_rank_0("Initializing ICT from pretrained BERT model")
Mostofa Patwary's avatar
Mostofa Patwary committed
342
        unwrapped_model[0].init_state_dict_from_bert()
Mostofa Patwary's avatar
Mostofa Patwary committed
343
344
        if args.fp16:
            optimizer.reload_model_params()
Neel Kant's avatar
Neel Kant committed
345

346
347
348
    return model, optimizer, lr_scheduler


349
350
351
352
353
354
355
def train_step(forward_step_func, data_iterator,
               model, optimizer, lr_scheduler):
    """Single training step."""
    args = get_args()
    timers = get_timers()

    # Set grad to zero.
356
    if args.DDP_impl == 'local' and args.use_contiguous_buffers_in_local_ddp:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
357
358
        for partition in model:
            partition.zero_grad_buffer()
359
    optimizer.zero_grad()
360

361
    forward_backward_func = get_forward_backward_func()
362
363
364
    losses_reduced = forward_backward_func(
        forward_step_func, data_iterator, model,
        optimizer, timers, forward_only=False)
365

366
    # Empty unused memory
Lawrence McAfee's avatar
Lawrence McAfee committed
367
    if args.empty_unused_memory_level >= 1:
368
369
        torch.cuda.empty_cache()

370
371
    # All-reduce if needed.
    if args.DDP_impl == 'local':
372
        timers('backward-params-all-reduce').start()
373
        for model_module in model:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
374
            model_module.allreduce_gradients()
375
        timers('backward-params-all-reduce').stop()
376

377
378
379
380
    # All-reduce word_embeddings' grad across first and last stages to ensure
    # that word_embeddings parameters stay in sync.
    # This should only run for models that support pipelined model parallelism
    # (BERT and GPT-2).
381
    timers('backward-embedding-all-reduce').start()
382
383
    if (mpu.is_pipeline_first_stage(ignore_virtual=True) or
        mpu.is_pipeline_last_stage(ignore_virtual=True)) and \
384
            mpu.get_pipeline_model_parallel_world_size() > 1:
385
386
387
388
        if mpu.is_pipeline_first_stage(ignore_virtual=True):
            unwrapped_model = model[0]
        elif mpu.is_pipeline_last_stage(ignore_virtual=True):
            unwrapped_model = model[-1]
389
        unwrapped_model = unwrap_model(
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
390
            unwrapped_model, (torchDDP, LocalDDP, Float16Module))
391

392
393
        if unwrapped_model.share_word_embeddings:
            word_embeddings_weight = unwrapped_model.word_embeddings_weight()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
394
395
396
397
398
            if args.DDP_impl == 'local':
                grad = word_embeddings_weight.main_grad
            else:
                grad = word_embeddings_weight.grad
            torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())
399
    timers('backward-embedding-all-reduce').stop()
400

401
402
    # Update parameters.
    timers('optimizer').start()
403
    update_successful, grad_norm, num_zeros_in_grad = optimizer.step()
404
405
406
    timers('optimizer').stop()

    # Update learning rate.
407
    if update_successful:
408
409
410
411
        increment = get_num_microbatches() * \
                    args.micro_batch_size * \
                    args.data_parallel_size
        lr_scheduler.step(increment=increment)
mohammad's avatar
mohammad committed
412
        skipped_iter = 0
413
414
415
    else:
        skipped_iter = 1

416
    # Empty unused memory
Lawrence McAfee's avatar
Lawrence McAfee committed
417
    if args.empty_unused_memory_level >= 2:
418
419
        torch.cuda.empty_cache()

420
    if mpu.is_pipeline_last_stage(ignore_virtual=True):
421
422
423
424
        # Average loss across microbatches.
        loss_reduced = {}
        for key in losses_reduced[0]:
            losses_reduced_for_key = [x[key] for x in losses_reduced]
425
            loss_reduced[key] = sum(losses_reduced_for_key) / len(losses_reduced_for_key)
426
427
        return loss_reduced, skipped_iter, grad_norm, num_zeros_in_grad
    return {}, skipped_iter, grad_norm, num_zeros_in_grad
428
429


Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
430
def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
mohammad's avatar
mohammad committed
431
                 loss_scale, report_memory_flag, skipped_iter,
432
                 grad_norm, params_norm, num_zeros_in_grad):
Mohammad's avatar
Mohammad committed
433
434
435
436
    """Log training information such as losses, timing, ...."""
    args = get_args()
    timers = get_timers()
    writer = get_tensorboard_writer()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
437

mohammad's avatar
mohammad committed
438
439
    # Advanced, skipped, and Nan iterations.
    advanced_iters_key = 'advanced iterations'
mohammad's avatar
mohammad committed
440
    skipped_iters_key = 'skipped iterations'
mohammad's avatar
mohammad committed
441
442
443
444
445
446
447
448
449
    nan_iters_key = 'nan iterations'
    # Advanced iterations.
    if not skipped_iter:
        total_loss_dict[advanced_iters_key] = total_loss_dict.get(
            advanced_iters_key, 0) + 1
    else:
        if advanced_iters_key not in total_loss_dict:
            total_loss_dict[advanced_iters_key] = 0
    # Skipped iterations.
mohammad's avatar
mohammad committed
450
451
    total_loss_dict[skipped_iters_key] = total_loss_dict.get(
        skipped_iters_key, 0) + skipped_iter
mohammad's avatar
mohammad committed
452
    # Update losses and set nan iterations
mohammad's avatar
mohammad committed
453
    got_nan = False
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
454
    for key in loss_dict:
mohammad's avatar
mohammad committed
455
        if not skipped_iter:
456
457
            total_loss_dict[key] = total_loss_dict.get(
                key, torch.cuda.FloatTensor([0.0])) + loss_dict[key]
mohammad's avatar
mohammad committed
458
459
460
461
462
        else:
            value = loss_dict[key].float().sum().item()
            is_nan = value == float('inf') or \
                     value == -float('inf') or \
                     value != value
mohammad's avatar
mohammad committed
463
            got_nan = got_nan or is_nan
mohammad's avatar
mohammad committed
464
465
    total_loss_dict[nan_iters_key] = total_loss_dict.get(
        nan_iters_key, 0) + int(got_nan)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
466
467
468

    # Logging.
    timers_to_log = []
Neel Kant's avatar
Neel Kant committed
469

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
470
471
472
    def add_to_logging(name):
        if name in timers.timers:
            timers_to_log.append(name)
473
474
475
    add_to_logging('forward-compute')
    add_to_logging('forward-recv')
    add_to_logging('forward-send')
476
    add_to_logging('forward-backward-send-forward-backward-recv')
477
478
479
    add_to_logging('backward-compute')
    add_to_logging('backward-recv')
    add_to_logging('backward-send')
Deepak Narayanan's avatar
Deepak Narayanan committed
480
    add_to_logging('backward-send-forward-recv')
481
    add_to_logging('backward-send-backward-recv')
482
    add_to_logging('backward-params-all-reduce')
483
    add_to_logging('backward-embedding-all-reduce')
484
    add_to_logging('optimizer-copy-to-main-grad')
mohammad's avatar
mohammad committed
485
    add_to_logging('optimizer-unscale-and-check-inf')
486
487
    add_to_logging('optimizer-clip-main-grad')
    add_to_logging('optimizer-copy-main-to-model-params')
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
488
    add_to_logging('optimizer')
mohammad's avatar
mohammad committed
489
    add_to_logging('batch-generator')
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
490

mohammad's avatar
mohammad committed
491
    # Calculate batch size.
mshoeybi's avatar
mshoeybi committed
492
493
494
    batch_size = args.micro_batch_size * args.data_parallel_size * \
        get_num_microbatches()

mohammad's avatar
mohammad committed
495
496
497
    total_iterations = total_loss_dict[advanced_iters_key] + \
                       total_loss_dict[skipped_iters_key]

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
498
    # Tensorboard values.
499
500
501
502
503
504
505
506
507
508
    if writer and (iteration % args.tensorboard_log_interval == 0 ) and \
       is_last_rank():
        if args.log_learning_rate_to_tensorboard:
            writer.add_scalar('learning-rate', learning_rate, iteration)
            writer.add_scalar('learning-rate vs samples', learning_rate,
                              args.consumed_train_samples)
        if args.log_batch_size_to_tensorboard:
            writer.add_scalar('batch-size', batch_size, iteration)
            writer.add_scalar('batch-size vs samples', batch_size,
                              args.consumed_train_samples)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
509
        for key in loss_dict:
mohammad's avatar
mohammad committed
510
511
            writer.add_scalar(key , loss_dict[key], iteration)
            writer.add_scalar(key + ' vs samples', loss_dict[key],
512
                              args.consumed_train_samples)
513
514
515
516
        if args.log_loss_scale_to_tensorboard:
            writer.add_scalar('loss-scale', loss_scale, iteration)
            writer.add_scalar('loss-scale vs samples', loss_scale,
                              args.consumed_train_samples)
517
518
519
520
        if grad_norm is not None:
            writer.add_scalar('grad-norm', grad_norm, iteration)
            writer.add_scalar('grad-norm vs samples', grad_norm,
                              args.consumed_train_samples)
521
522
523
        if num_zeros_in_grad is not None:
            writer.add_scalar('num-zeros', num_zeros_in_grad, iteration)
            writer.add_scalar('num-zeros vs samples', num_zeros_in_grad,
Rewon Child's avatar
Rewon Child committed
524
                              args.consumed_train_samples)
mohammad's avatar
mohammad committed
525
526
527
528
        if params_norm is not None:
            writer.add_scalar('params-norm', params_norm, iteration)
            writer.add_scalar('params-norm vs samples', params_norm,
                              args.consumed_train_samples)
529
530
531
        if args.log_timers_to_tensorboard:
            timers.write(timers_to_log, writer, iteration,
                         normalizer=total_iterations)
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
        if args.log_memory_to_tensorboard:
            mem_stats = torch.cuda.memory_stats()
            writer.add_scalar(
                "mem-reserved-bytes",
                mem_stats["reserved_bytes.all.current"],
                iteration,
            )
            writer.add_scalar(
                "mem-allocated-bytes",
                mem_stats["allocated_bytes.all.current"],
                iteration,
            )
            writer.add_scalar(
                "mem-allocated-count",
                mem_stats["allocation.all.current"],
                iteration,
            )
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
549
550

    if iteration % args.log_interval == 0:
551
        elapsed_time = timers('interval-time').elapsed()
mohammad's avatar
mohammad committed
552
        elapsed_time_per_iteration = elapsed_time / total_iterations
mshoeybi's avatar
mshoeybi committed
553
        if writer:
554
555
556
            if args.log_timers_to_tensorboard:
                writer.add_scalar('iteration-time',
                                  elapsed_time_per_iteration, iteration)
557
558
        log_string = ' iteration {:8d}/{:8d} |'.format(
            iteration, args.train_iters)
mshoeybi's avatar
mshoeybi committed
559
        log_string += ' consumed samples: {:12d} |'.format(
560
            args.consumed_train_samples)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
561
        log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
mohammad's avatar
mohammad committed
562
            elapsed_time_per_iteration * 1000.0)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
563
        log_string += ' learning rate: {:.3E} |'.format(learning_rate)
mohammad's avatar
mohammad committed
564
        log_string += ' global batch size: {:5d} |'.format(batch_size)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
565
        for key in total_loss_dict:
mohammad's avatar
mohammad committed
566
567
568
569
            if key not in [advanced_iters_key, skipped_iters_key,
                           nan_iters_key]:
                avg = total_loss_dict[key].item() / \
                      float(max(1, total_loss_dict[advanced_iters_key]))
570
571
572
                if avg > 0.0:
                    log_string += ' {}: {:.6E} |'.format(key, avg)
                total_loss_dict[key] = torch.cuda.FloatTensor([0.0])
573
        log_string += ' loss scale: {:.1f} |'.format(loss_scale)
574
575
        if grad_norm is not None:
            log_string += ' grad norm: {:.3f} |'.format(grad_norm)
576
577
        if num_zeros_in_grad is not None:
            log_string += ' num zeros: {:.1f} |'.format(num_zeros_in_grad)
mohammad's avatar
mohammad committed
578
579
        if params_norm is not None:
            log_string += ' params norm: {:.3f} |'.format(params_norm)
mohammad's avatar
mohammad committed
580
581
        log_string += ' number of skipped iterations: {:3d} |'.format(
            total_loss_dict[skipped_iters_key])
mohammad's avatar
mohammad committed
582
        log_string += ' number of nan iterations: {:3d} |'.format(
mohammad's avatar
mohammad committed
583
584
            total_loss_dict[nan_iters_key])
        total_loss_dict[advanced_iters_key] = 0
mohammad's avatar
mohammad committed
585
        total_loss_dict[skipped_iters_key] = 0
mohammad's avatar
mohammad committed
586
        total_loss_dict[nan_iters_key] = 0
587
        print_rank_last(log_string)
588
589
590
        if report_memory_flag and learning_rate > 0.:
            # Report memory after optimizer state has been initialized.
            report_memory('(after {} iterations)'.format(iteration))
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
591
592
593
594
595
596
            report_memory_flag = False
        timers.log(timers_to_log, normalizer=args.log_interval)

    return report_memory_flag


597
598
599
600
601
def save_checkpoint_and_time(iteration, model, optimizer, lr_scheduler):
    timers = get_timers()
    # Extra barrier is added to make sure
    # all ranks report the max time.
    torch.distributed.barrier()
602
    timers('save-checkpoint').start()
603
604
    save_checkpoint(iteration, model, optimizer, lr_scheduler)
    torch.distributed.barrier()
605
606
    timers('save-checkpoint').stop()
    timers.log(['save-checkpoint'])
607
608


609
def train(forward_step_func, model, optimizer, lr_scheduler,
610
          train_data_iterator, valid_data_iterator):
611
    """Train the model function."""
Mohammad's avatar
Mohammad committed
612
613
    args = get_args()
    timers = get_timers()
614

615
616
617
    # Write args to tensorboard
    write_args_to_tensorboard()

618
    # Turn on training mode which enables dropout.
619
620
    for model_module in model:
        model_module.train()
621
622
623
624
625
626
627

    # Tracking loss.
    total_loss_dict = {}

    # Iterations.
    iteration = args.iteration

628
    timers('interval-time').start()
629
    print_datetime('before the start of training step')
630
631
    report_memory_flag = True
    while iteration < args.train_iters:
mohammad's avatar
mohammad committed
632
        update_num_microbatches(args.consumed_train_samples)
633
634
635
636
637
638
        loss_dict, skipped_iter, grad_norm, num_zeros_in_grad = \
            train_step(forward_step_func,
                       train_data_iterator,
                       model,
                       optimizer,
                       lr_scheduler)
639
        iteration += 1
640
        args.consumed_train_samples += mpu.get_data_parallel_world_size() * \
641
                                       args.micro_batch_size * \
mohammad's avatar
mohammad committed
642
                                       get_num_microbatches()
643
644

        # Logging.
645
        loss_scale = optimizer.get_loss_scale().item()
646
647
648
        params_norm = None
        if args.log_params_norm:
            params_norm = calc_params_l2_norm(model)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
649
650
        report_memory_flag = training_log(loss_dict, total_loss_dict,
                                          optimizer.param_groups[0]['lr'],
Mohammad's avatar
Mohammad committed
651
                                          iteration, loss_scale,
652
                                          report_memory_flag, skipped_iter,
653
                                          grad_norm, params_norm, num_zeros_in_grad)
654
655

        # Autoresume
656
657
        if args.adlr_autoresume and \
           (iteration % args.adlr_autoresume_interval == 0):
658
            check_adlr_autoresume_termination(iteration, model, optimizer,
659
                                              lr_scheduler)
660
661
662
663
664
665

        # Evaluation
        if args.eval_interval and iteration % args.eval_interval == 0 and \
           args.do_valid:
            prefix = 'iteration {}'.format(iteration)
            evaluate_and_print_results(prefix, forward_step_func,
666
                                       valid_data_iterator, model,
Mohammad's avatar
Mohammad committed
667
                                       iteration, False)
668

669
670
671
672
673
674
675
676
        # Checkpointing
        saved_checkpoint = False
        if args.save and args.save_interval and \
           iteration % args.save_interval == 0:
            save_checkpoint_and_time(iteration, model, optimizer,
                                     lr_scheduler)
            saved_checkpoint = True

677
678
679
680
681
682
683
684
685
686
687
688
        # Exiting based on duration
        if args.exit_duration_in_mins:
            train_time = (time.time() - _TRAIN_START_TIME) / 60.0
            done_cuda = torch.cuda.IntTensor(
                [train_time > args.exit_duration_in_mins])
            torch.distributed.all_reduce(
                done_cuda, op=torch.distributed.ReduceOp.MAX)
            done = done_cuda.item()
            if done:
                if not saved_checkpoint:
                    save_checkpoint_and_time(iteration, model, optimizer,
                                             lr_scheduler)
689
                print_datetime('exiting program after {} minutes'.format(train_time))
690
691
                sys.exit()

692
        # Exiting based on iterations
693
        if args.exit_interval and iteration % args.exit_interval == 0:
694
695
696
            if not saved_checkpoint:
                save_checkpoint_and_time(iteration, model, optimizer,
                                         lr_scheduler)
697
            torch.distributed.barrier()
698
            print_datetime('exiting program at iteration {}'.format(iteration))
Mohammad's avatar
Mohammad committed
699
            sys.exit()
700

701

mohammad's avatar
mohammad committed
702
    return iteration
703
704


Mohammad's avatar
Mohammad committed
705
def evaluate(forward_step_func, data_iterator, model, verbose=False):
706
    """Evaluation."""
Mohammad's avatar
Mohammad committed
707
    args = get_args()
708
709

    # Turn on evaluation mode which disables dropout.
710
711
    for model_module in model:
        model_module.eval()
712
713
714
715
716
717
718
719
720
721

    total_loss_dict = {}

    with torch.no_grad():
        iteration = 0
        while iteration < args.eval_iters:
            iteration += 1
            if verbose and iteration % args.log_interval == 0:
                print_rank_0('Evaluating iter {}/{}'.format(iteration,
                                                            args.eval_iters))
722

723
            forward_backward_func = get_forward_backward_func()
724
725
726
727
            loss_dicts = forward_backward_func(
                forward_step_func, data_iterator, model, optimizer=None,
                timers=None, forward_only=True)

728
            # Empty unused memory
Lawrence McAfee's avatar
Lawrence McAfee committed
729
            if args.empty_unused_memory_level >= 1:
730
731
                torch.cuda.empty_cache()

732
733
734
            if mpu.is_pipeline_last_stage(ignore_virtual=True):
                # Reduce across processes.
                for loss_dict in loss_dicts:
735
                    for key in loss_dict:
736
737
                        total_loss_dict[key] = total_loss_dict.get(
                            key, torch.cuda.FloatTensor([0.0])) + loss_dict[key]
738

739
            args.consumed_valid_samples += mpu.get_data_parallel_world_size() \
740
                                           * args.micro_batch_size \
mohammad's avatar
mohammad committed
741
                                           * get_num_microbatches()
742
    # Move model back to the train mode.
743
744
    for model_module in model:
        model_module.train()
745
746

    for key in total_loss_dict:
mohammad's avatar
mohammad committed
747
        total_loss_dict[key] /= args.eval_iters * get_num_microbatches()
748
749
750
751
752

    return total_loss_dict

def evaluate_and_print_results(prefix, forward_step_func,
                               data_iterator, model,
Mohammad's avatar
Mohammad committed
753
                               iteration, verbose=False):
754
    """Helper function to evaluate and dump results on screen."""
755
    args = get_args()
Mohammad's avatar
Mohammad committed
756
757
758
    writer = get_tensorboard_writer()

    total_loss_dict = evaluate(forward_step_func, data_iterator, model, verbose)
759
760
761
762
763
    string = ' validation loss at {} | '.format(prefix)
    for key in total_loss_dict:
        string += '{} value: {:.6E} | '.format(key, total_loss_dict[key].item())
        ppl = math.exp(min(20, total_loss_dict[key].item()))
        string += '{} PPL: {:.6E} | '.format(key, ppl)
mshoeybi's avatar
mshoeybi committed
764
        if writer:
mohammad's avatar
mohammad committed
765
            writer.add_scalar('{} validation'.format(key),
766
767
                              total_loss_dict[key].item(),
                              iteration)
mohammad's avatar
mohammad committed
768
            writer.add_scalar('{} validation vs samples'.format(key),
769
770
                              total_loss_dict[key].item(),
                              args.consumed_train_samples)
771
            if args.log_validation_ppl_to_tensorboard:
mohammad's avatar
mohammad committed
772
                writer.add_scalar('{} validation ppl'.format(key), ppl,
773
                                  iteration)
mohammad's avatar
mohammad committed
774
                writer.add_scalar('{} validation ppl vs samples'.format(key),
775
                                  ppl, args.consumed_train_samples)
776
777

    length = len(string) + 1
778
779
780
    print_rank_last('-' * length)
    print_rank_last(string)
    print_rank_last('-' * length)
781
782


Vijay Korthikanti's avatar
Vijay Korthikanti committed
783
def cyclic_iter(iter):
784
    while True:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
785
        for x in iter:
786
787
            yield x

788
789
790
def build_train_valid_test_data_iterators(
        build_train_valid_test_datasets_provider):
    """XXX"""
Mohammad's avatar
Mohammad committed
791
    args = get_args()
792

793
794
795
    (train_dataloader, valid_dataloader, test_dataloader) = (None, None, None)

    print_rank_0('> building train, validation, and test datasets ...')
796
797
798

    # Backward compatibility, assume fixed batch size.
    if args.iteration > 0 and args.consumed_train_samples == 0:
799
800
        assert args.train_samples is None, \
            'only backward compatiblity support for iteration-based training'
mohammad's avatar
mohammad committed
801
        args.consumed_train_samples = args.iteration * args.global_batch_size
802
    if args.iteration > 0 and args.consumed_valid_samples == 0:
803
804
805
        if args.train_samples is None:
            args.consumed_valid_samples = (args.iteration // args.eval_interval) * \
                args.eval_iters * args.global_batch_size
806

807
    # Data loader only on rank 0 of each model parallel group.
808
    if mpu.get_tensor_model_parallel_rank() == 0:
809
810

        # Number of train/valid/test samples.
811
812
813
814
815
816
        if args.train_samples:
            train_samples = args.train_samples
        else:
            train_samples = args.train_iters * args.global_batch_size
        eval_iters = (args.train_iters // args.eval_interval + 1) * \
                     args.eval_iters
817
        test_iters = args.eval_iters
818
        train_val_test_num_samples = [train_samples,
mohammad's avatar
mohammad committed
819
820
                                      eval_iters * args.global_batch_size,
                                      test_iters * args.global_batch_size]
821
822
823
824
825
826
827
828
829
830
        print_rank_0(' > datasets target sizes (minimum size):')
        print_rank_0('    train:      {}'.format(train_val_test_num_samples[0]))
        print_rank_0('    validation: {}'.format(train_val_test_num_samples[1]))
        print_rank_0('    test:       {}'.format(train_val_test_num_samples[2]))

        # Build the datasets.
        train_ds, valid_ds, test_ds = build_train_valid_test_datasets_provider(
            train_val_test_num_samples)

        # Build dataloders.
831
832
833
834
835
        train_dataloader = build_pretraining_data_loader(
            train_ds, args.consumed_train_samples)
        valid_dataloader = build_pretraining_data_loader(
            valid_ds, args.consumed_valid_samples)
        test_dataloader = build_pretraining_data_loader(test_ds, 0)
836
837
838
839
840
841
842
843
844
845
846
847
848

        # Flags to know if we need to do training/validation/testing.
        do_train = train_dataloader is not None and args.train_iters > 0
        do_valid = valid_dataloader is not None and args.eval_iters > 0
        do_test = test_dataloader is not None and args.eval_iters > 0
        # Need to broadcast num_tokens and num_type_tokens.
        flags = torch.cuda.LongTensor(
            [int(do_train), int(do_valid), int(do_test)])
    else:
        flags = torch.cuda.LongTensor([0, 0, 0])

    # Broadcast num tokens.
    torch.distributed.broadcast(flags,
849
850
                                mpu.get_tensor_model_parallel_src_rank(),
                                group=mpu.get_tensor_model_parallel_group())
851
852
853
854
    args.do_train = flags[0].item()
    args.do_valid = flags[1].item()
    args.do_test = flags[2].item()

Vijay Korthikanti's avatar
Vijay Korthikanti committed
855

856
    # Build iterators.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
857
858
859
    dl_type = args.dataloader_type
    assert dl_type in ['single', 'cyclic']

860
    if train_dataloader is not None:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
861
862
        train_data_iterator = iter(train_dataloader) if dl_type == 'single' \
                              else iter(cyclic_iter(train_dataloader))
863
864
865
    else:
        train_data_iterator = None

866
    if valid_dataloader is not None:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
867
868
        valid_data_iterator = iter(valid_dataloader) if dl_type == 'single' \
                              else iter(cyclic_iter(valid_dataloader))
869
    else:
870
        valid_data_iterator = None
871

872
    if test_dataloader is not None:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
873
874
        test_data_iterator = iter(test_dataloader) if dl_type == 'single' \
                             else iter(cyclic_iter(test_dataloader))
875
876
877
    else:
        test_data_iterator = None

878
    return train_data_iterator, valid_data_iterator, test_data_iterator