training.py 39.3 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2

Mohammad's avatar
Mohammad committed
3
"""Pretrain utilities."""
4
5
6

from datetime import datetime
import math
Mohammad's avatar
Mohammad committed
7
import sys
8
9
10
import time
# The earliest we can measure the start time.
_TRAIN_START_TIME = time.time()
11
12
13
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP

Neel Kant's avatar
Neel Kant committed
14
from megatron import get_args
15
from megatron import get_signal_handler
Mohammad's avatar
Mohammad committed
16
17
from megatron import get_timers
from megatron import get_tensorboard_writer
18
from megatron import get_current_global_batch_size
mohammad's avatar
mohammad committed
19
from megatron import get_num_microbatches
mohammad's avatar
mohammad committed
20
from megatron import is_last_rank
mohammad's avatar
mohammad committed
21
from megatron import update_num_microbatches
22
from megatron import mpu
23
from megatron import core
Neel Kant's avatar
Neel Kant committed
24
from megatron import print_rank_0
25
from megatron import print_rank_last
Mohammad's avatar
Mohammad committed
26
27
from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
28
from megatron.model import Float16Module
29
from megatron.model import ModelType
mohammad's avatar
mohammad committed
30
from megatron.optimizer import get_megatron_optimizer
Mohammad's avatar
Mohammad committed
31
from megatron.initialize import initialize_megatron
32
from megatron.initialize import write_args_to_tensorboard
33
from megatron.initialize import set_jit_fusion_options
34
from megatron.optimizer_param_scheduler import OptimizerParamScheduler
35
36
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.utils import check_adlr_autoresume_termination
37
from megatron.utils import unwrap_model
Vijay Korthikanti's avatar
Vijay Korthikanti committed
38
from megatron.data.data_samplers import build_pretraining_data_loader
mohammad's avatar
mohammad committed
39
from megatron.utils import calc_params_l2_norm
40
from megatron.schedules import get_forward_backward_func
41
from megatron.utils import report_memory
42
from megatron.model.vision.knn_monitor import compute_feature_bank
43

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
44

45
46
47
48
49
50
51
def print_datetime(string):
    """Note that this call will sync across all ranks."""
    torch.distributed.barrier()
    time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    print_rank_0('[' + string + '] datetime: {} '.format(time_str))


52
def pretrain(train_valid_test_dataset_provider,
53
             model_provider,
54
             model_type,
55
             forward_step_func,
56
             process_non_loss_data_func=None,
57
             extra_args_provider=None,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
58
             args_defaults={}):
59
60
61
    """Main training program.

    This function will run the followings in the order provided:
Mohammad's avatar
Mohammad committed
62
63
        1) initialize Megatron.
        2) setup model, optimizer and lr schedule using the model_provider.
64
        3) call train_val_test_data_provider to get train/val/test datasets.
Mohammad's avatar
Mohammad committed
65
        4) train the modle using the forward_step_func.
66
67

    Arguments:
68
69
70
        train_valid_test_dataset_provider: a function that takes the size of
            train/valid/test dataset and returns `train, valid, test` datasets.
        model_provider: a function that returns a vanilla version of the
Mohammad's avatar
Mohammad committed
71
            model. By vanilla we mean a simple model on cpu with no fp16 or ddp.
72
        model_type: an enum that specifies the type of model being trained.
Mohammad's avatar
Mohammad committed
73
74
75
76
77
        forward_step_func: a function that takes a `data iterator` and `model`,
            and returns a `loss` scalar with a dictionary with key:values being
            the info we would like to monitor during training, for example
            `lm-loss: value`. We also require that this function add
            `batch generator` to the timers class.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
78
79
80
81
        process_non_loss_data_func: a function to post process outputs of the
            network. It can be used for dumping output tensors (e.g images) to
            tensorboard. It takes `collected data`(list of tensors),
            `current iteration index` and `tensorboard writer` as arguments.
Mohammad's avatar
Mohammad committed
82
83
84
85
        extra_args_provider: a function that takes a parser and adds arguments
            to it. It is used for programs to add their own arguments.
        args_defaults: a dictionary from argument-name to argument-value. It
            to set already parse arguments.
86
87
    """

88
    # Initalize and get arguments, timers, and Tensorboard writer.
89
90
    initialize_megatron(extra_args_provider=extra_args_provider,
                        args_defaults=args_defaults)
91
92
    # Set pytorch JIT layer fusion options and warmup JIT functions.
    set_jit_fusion_options()
93

94
95
96
97
    # Adjust the startup time so it reflects the largest value.
    # This will be closer to what scheduler will see (outside of
    # image ... launches.
    global _TRAIN_START_TIME
98
    start_time_tensor = torch.cuda.DoubleTensor([_TRAIN_START_TIME])
99
100
101
    torch.distributed.all_reduce(start_time_tensor,
                                 op=torch.distributed.ReduceOp.MIN)
    _TRAIN_START_TIME = start_time_tensor.item()
mshoeybi's avatar
mshoeybi committed
102
    print_rank_0('time to initialize megatron (seconds): {:.3f}'.format(
103
104
105
        time.time() - _TRAIN_START_TIME))
    print_datetime('after megatron is initialized')

106
    args = get_args()
Mohammad's avatar
Mohammad committed
107
    timers = get_timers()
108
109

    # Model, optimizer, and learning rate.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
110
111
112
    timers('model-and-optimizer-setup', log_level=0).start(barrier=True)
    model, optimizer, opt_param_scheduler = setup_model_and_optimizer(
        model_provider, model_type)
113
    timers('model-and-optimizer-setup').stop()
114
115
    print_datetime('after model, optimizer, and learning rate '
                   'scheduler are built')
116
117

    # Data stuff.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
118
119
    timers('train/valid/test-data-iterators-setup', log_level=0).start(
        barrier=True)
120
    if args.virtual_pipeline_model_parallel_size is not None:
121
        all_data_iterators = [
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
122
123
            build_train_valid_test_data_iterators(
                train_valid_test_dataset_provider)
124
125
            for _ in range(len(model))
        ]
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
126
127
128
129
130
131
        train_data_iterator = [data_iterators[0]
                               for data_iterators in all_data_iterators]
        valid_data_iterator = [data_iterators[1]
                               for data_iterators in all_data_iterators]
        test_data_iterator = [data_iterators[2]
                              for data_iterators in all_data_iterators]
132
133
134
135
136
    else:
        train_data_iterator, valid_data_iterator, test_data_iterator \
            = build_train_valid_test_data_iterators(
                train_valid_test_dataset_provider)
    timers('train/valid/test-data-iterators-setup').stop()
mshoeybi's avatar
mshoeybi committed
137
    print_datetime('after dataloaders are built')
Mohammad's avatar
Mohammad committed
138
139

    # Print setup timing.
140
    print_rank_0('done with setup ...')
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
141
142
    timers.log(['model-and-optimizer-setup',
                'train/valid/test-data-iterators-setup'], barrier=True)
Mohammad's avatar
Mohammad committed
143
    print_rank_0('training ...')
144
145

    iteration = 0
146
    if args.do_train and args.train_iters > 0:
mohammad's avatar
mohammad committed
147
        iteration = train(forward_step_func,
148
                          model, optimizer, opt_param_scheduler,
149
150
                          train_data_iterator, valid_data_iterator,
                          process_non_loss_data_func)
151
    print_datetime('after training is done')
Mohammad's avatar
Mohammad committed
152

153
154
155
    if args.do_valid:
        prefix = 'the end of training for val data'
        evaluate_and_print_results(prefix, forward_step_func,
156
                                   valid_data_iterator, model,
157
158
                                   iteration, process_non_loss_data_func,
                                   False)
159
160

    if args.save and iteration != 0:
161
        save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
162
163
164
165
166
167

    if args.do_test:
        # Run on test data.
        prefix = 'the end of training for test data'
        evaluate_and_print_results(prefix, forward_step_func,
                                   test_data_iterator, model,
168
169
                                   0, process_non_loss_data_func,
                                   True)
170

171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
def update_train_iters(args):

    # For iteration-based training, we don't need to do anything
    if args.train_iters:
        return

    # Constant batch size with sample-based training.
    if args.rampup_batch_size is None:
        args.train_iters = args.train_samples // args.global_batch_size

    else:
        # Sample based training with rampup batch size.
        iterations = 0
        consumed_samples = 0
        # Rampup phase.
        while consumed_samples <= int(args.rampup_batch_size[2]):
187
188
            update_num_microbatches(consumed_samples, consistency_check=False)
            consumed_samples += get_current_global_batch_size()
189
190
            iterations += 1
        # Reset
191
        update_num_microbatches(0, consistency_check=False)
192
193
194
195
196
197
198
199
        # Constant phase
        # Note that we throw away any partial last batch.
        iterations += (args.train_samples - consumed_samples) // \
                      args.global_batch_size
        args.train_iters = iterations

    print_rank_0('setting training iterations to {}'.format(args.train_iters))

200

201
def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap_with_ddp=True):
202
    """Build the model."""
Mohammad's avatar
Mohammad committed
203
    args = get_args()
204
    args.model_type = model_type
205

206
    # Build model.
207
208
    if mpu.get_pipeline_model_parallel_world_size() > 1 and \
       args.virtual_pipeline_model_parallel_size is not None:
209
210
        assert model_type != ModelType.encoder_and_decoder, \
            "Interleaved schedule not supported for model with both encoder and decoder"
211
212
213
        model = []
        for i in range(args.virtual_pipeline_model_parallel_size):
            mpu.set_virtual_pipeline_model_parallel_rank(i)
214
215
216
            # Set pre_process and post_process only after virtual rank is set.
            pre_process = mpu.is_pipeline_first_stage()
            post_process = mpu.is_pipeline_last_stage()
217
            this_model = model_provider_func(
218
219
220
                pre_process=pre_process,
                post_process=post_process
            )
221
            this_model.model_type = model_type
222
            model.append(this_model)
223
    else:
224
225
        pre_process = mpu.is_pipeline_first_stage()
        post_process = mpu.is_pipeline_last_stage()
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
        add_encoder = True
        add_decoder = True
        if model_type == ModelType.encoder_and_decoder:
            if mpu.get_pipeline_model_parallel_world_size() > 1:
                assert args.pipeline_model_parallel_split_rank is not None, \
                    "Split rank needs to be specified for model with both encoder and decoder"
                rank = mpu.get_pipeline_model_parallel_rank()
                split_rank = args.pipeline_model_parallel_split_rank
                world_size = mpu.get_pipeline_model_parallel_world_size()
                pre_process = rank == 0 or rank == split_rank
                post_process = (rank == (split_rank - 1)) or (
                        rank == (world_size - 1))
                add_encoder = mpu.is_pipeline_stage_before_split()
                add_decoder = mpu.is_pipeline_stage_after_split()
            model = model_provider_func(
                pre_process=pre_process,
                post_process=post_process,
                add_encoder=add_encoder,
                add_decoder=add_decoder)
        else:
            model = model_provider_func(
                pre_process=pre_process,
                post_process=post_process
            )
        model.model_type = model_type
251

252
253
    if not isinstance(model, list):
        model = [model]
254

255
    # Set tensor model parallel attributes if not set.
mohammad's avatar
mohammad committed
256
257
258
    # Only parameters that are already tensor model parallel have these
    # attributes set for them. We should make sure the default attributes
    # are set for all params so the optimizer can use them.
259
260
    for model_module in model:
        for param in model_module.parameters():
261
            core.tensor_parallel.set_defaults_if_not_set_tensor_model_parallel_attributes(param)
262

263
264
    # Print number of parameters.
    if mpu.get_data_parallel_rank() == 0:
265
        print(' > number of parameters on (tensor, pipeline) '
266
              'model parallel rank ({}, {}): {}'.format(
267
268
            mpu.get_tensor_model_parallel_rank(),
            mpu.get_pipeline_model_parallel_rank(),
269
270
            sum([sum([p.nelement() for p in model_module.parameters()])
                 for model_module in model])), flush=True)
271
272

    # GPU allocation.
273
274
    for model_module in model:
        model_module.cuda(torch.cuda.current_device())
275
276

    # Fp16 conversion.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
277
278
    if args.fp16 or args.bf16:
        model = [Float16Module(model_module, args) for model_module in model]
279

280
281
282
283
284
285
    if wrap_with_ddp:
        if args.DDP_impl == 'torch':
            i = torch.cuda.current_device()
            model = [torchDDP(model_module, device_ids=[i], output_device=i,
                              process_group=mpu.get_data_parallel_group())
                     for model_module in model]
286

287
288
289
290
291
        elif args.DDP_impl == 'local':
            model = [LocalDDP(model_module,
                              args.accumulate_allreduce_grads_in_fp32,
                              args.use_contiguous_buffers_in_local_ddp)
                     for model_module in model]
292
293
294
295
            # broad cast params from data parallel src rank to other data parallel ranks
            if args.data_parallel_random_init:
                for model_module in model:
                    model_module.broadcast_params()
296
297
298
        else:
            raise NotImplementedError('Unknown DDP implementation specified: '
                                      '{}. Exiting.'.format(args.DDP_impl))
299

300
    return model
301
302


303
def get_optimizer_param_scheduler(optimizer):
304
    """Build the learning rate scheduler."""
Mohammad's avatar
Mohammad committed
305
    args = get_args()
306

307
308
309
310
    # Iteration-based training.
    if args.train_iters:
        if args.lr_decay_iters is None:
            args.lr_decay_iters = args.train_iters
Vijay Korthikanti's avatar
Vijay Korthikanti committed
311
312
        lr_decay_steps = args.lr_decay_iters * args.global_batch_size
        wd_incr_steps = args.train_iters * args.global_batch_size
313
        if args.lr_warmup_fraction is not None:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
314
            lr_warmup_steps = args.lr_warmup_fraction * lr_decay_steps
315
        else:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
316
            lr_warmup_steps = args.lr_warmup_iters * args.global_batch_size
317
318
319
320
321
    # Sample-based training.
    elif args.train_samples:
        # We need to set training iters for later use. Technically
        # we need to adjust the training samples too (due to last
        # batch being incomplete) but we leave it as is for now.
322
        update_train_iters(args)
323
324
        if args.lr_decay_samples is None:
            args.lr_decay_samples = args.train_samples
Vijay Korthikanti's avatar
Vijay Korthikanti committed
325
326
        lr_decay_steps = args.lr_decay_samples
        wd_incr_steps = args.train_samples
327
        if args.lr_warmup_fraction is not None:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
328
            lr_warmup_steps = args.lr_warmup_fraction * lr_decay_steps
329
        else:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
330
            lr_warmup_steps = args.lr_warmup_samples
331
    else:
332
333
334
        raise Exception(
            'either train-iters or train-samples should be provided.')

335
    opt_param_scheduler = OptimizerParamScheduler(
336
        optimizer,
337
        max_lr=args.lr,
338
        min_lr=args.min_lr,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
339
340
341
        lr_warmup_steps=lr_warmup_steps,
        lr_decay_steps=lr_decay_steps,
        lr_decay_style=args.lr_decay_style,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
342
343
        start_wd=args.start_weight_decay,
        end_wd=args.end_weight_decay,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
344
        wd_incr_steps=wd_incr_steps,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
345
        wd_incr_style=args.weight_decay_incr_style,
346
347
        use_checkpoint_opt_param_scheduler=args.use_checkpoint_opt_param_scheduler,
        override_opt_param_scheduler=args.override_opt_param_scheduler)
348

349
    return opt_param_scheduler
350
351


352
353
354
355
356
def setup_model_and_optimizer(model_provider_func,
                              model_type,
                              no_wd_decay_cond=None,
                              scale_lr_cond=None,
                              lr_mult=1.0):
357
    """Setup model and optimizer."""
Mohammad's avatar
Mohammad committed
358
    args = get_args()
359

360
    model = get_model(model_provider_func, model_type)
361
    unwrapped_model = unwrap_model(model,
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
362
                                   (torchDDP, LocalDDP, Float16Module))
Lawrence McAfee's avatar
Lawrence McAfee committed
363

364
    optimizer = get_megatron_optimizer(model, no_wd_decay_cond,
365
                                       scale_lr_cond, lr_mult)
366
    opt_param_scheduler = get_optimizer_param_scheduler(optimizer)
367
368

    if args.load is not None:
369
        timers = get_timers()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
370
        timers('load-checkpoint', log_level=0).start(barrier=True)
371
        args.iteration = load_checkpoint(model, optimizer, opt_param_scheduler)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
372
        timers('load-checkpoint').stop(barrier=True)
373
        timers.log(['load-checkpoint'])
374
375
376
    else:
        args.iteration = 0

mohammad's avatar
mohammad committed
377
    # We only support local DDP with multiple micro-batches.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
378
    if len(model) > 1 or mpu.get_pipeline_model_parallel_world_size() > 1:
mohammad's avatar
mohammad committed
379
380
        assert args.DDP_impl == 'local'

Neel Kant's avatar
Neel Kant committed
381
    # get model without FP16 and/or TorchDDP wrappers
Mostofa Patwary's avatar
Mostofa Patwary committed
382
383
    if args.iteration == 0 and len(unwrapped_model) == 1 \
        and hasattr(unwrapped_model[0], 'init_state_dict_from_bert'):
Mostofa Patwary's avatar
Mostofa Patwary committed
384
        print_rank_0("Initializing ICT from pretrained BERT model")
Mostofa Patwary's avatar
Mostofa Patwary committed
385
        unwrapped_model[0].init_state_dict_from_bert()
Mostofa Patwary's avatar
Mostofa Patwary committed
386
387
        if args.fp16:
            optimizer.reload_model_params()
Neel Kant's avatar
Neel Kant committed
388

389
    return model, optimizer, opt_param_scheduler
390
391


392
def train_step(forward_step_func, data_iterator,
Lawrence McAfee's avatar
Lawrence McAfee committed
393
               model, optimizer, opt_param_scheduler):
394
395
396
397
398
    """Single training step."""
    args = get_args()
    timers = get_timers()

    # Set grad to zero.
399
    if args.DDP_impl == 'local' and args.use_contiguous_buffers_in_local_ddp:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
400
401
        for partition in model:
            partition.zero_grad_buffer()
402
    optimizer.zero_grad()
403

404
    # Forward pass.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
405
406
    timers('forward-backward', log_level=1).start(
        barrier=args.barrier_with_L1_time)
407
    forward_backward_func = get_forward_backward_func()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
408
    fwd_bwd_timers = timers if args.timing_log_level > 1 else None
409
410
    losses_reduced = forward_backward_func(
        forward_step_func, data_iterator, model,
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
411
412
        optimizer, fwd_bwd_timers, forward_only=False)
    timers('forward-backward').stop()
413

414
    # Empty unused memory.
Lawrence McAfee's avatar
Lawrence McAfee committed
415
    if args.empty_unused_memory_level >= 1:
416
417
        torch.cuda.empty_cache()

418
    # Reduce gradients.
419
    optimizer.reduce_model_grads(args, timers)
420

Lawrence McAfee's avatar
Lawrence McAfee committed
421
    # Vision gradients.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
422
    if args.vision_pretraining and args.vision_pretraining_type == "dino":
423
424
425
426
        unwrapped_model = unwrap_model(model[0],
                                       (torchDDP, LocalDDP, Float16Module))
        unwrapped_model.cancel_gradients_last_layer(args.curr_iteration)

427
    # Update parameters.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
428
    timers('optimizer', log_level=1).start(barrier=args.barrier_with_L1_time)
Lawrence McAfee's avatar
Lawrence McAfee committed
429
    update_successful, grad_norm, num_zeros_in_grad = optimizer.step(args, timers)
430
431
    timers('optimizer').stop()

432
    # Gather params.
433
    if update_successful:
Lawrence McAfee's avatar
Lawrence McAfee committed
434
        optimizer.gather_model_params(args, timers)
435

Lawrence McAfee's avatar
Lawrence McAfee committed
436
    # Vision momentum.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
437
    if args.vision_pretraining and args.vision_pretraining_type == "dino":
438
439
440
441
        unwrapped_model = unwrap_model(model[0],
                                       (torchDDP, LocalDDP, Float16Module))
        unwrapped_model.update_momentum(args.curr_iteration)

442
    # Update learning rate.
443
    if update_successful:
444
445
446
        increment = get_num_microbatches() * \
                    args.micro_batch_size * \
                    args.data_parallel_size
447
        opt_param_scheduler.step(increment=increment)
mohammad's avatar
mohammad committed
448
        skipped_iter = 0
449
450
451
    else:
        skipped_iter = 1

452
    # Empty unused memory.
Lawrence McAfee's avatar
Lawrence McAfee committed
453
    if args.empty_unused_memory_level >= 2:
454
455
        torch.cuda.empty_cache()

456
    if mpu.is_pipeline_last_stage(ignore_virtual=True):
457
458
459
460
        # Average loss across microbatches.
        loss_reduced = {}
        for key in losses_reduced[0]:
            losses_reduced_for_key = [x[key] for x in losses_reduced]
461
            loss_reduced[key] = sum(losses_reduced_for_key) / len(losses_reduced_for_key)
462
463
        return loss_reduced, skipped_iter, grad_norm, num_zeros_in_grad
    return {}, skipped_iter, grad_norm, num_zeros_in_grad
464
465


Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
466
def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
mohammad's avatar
mohammad committed
467
                 loss_scale, report_memory_flag, skipped_iter,
468
                 grad_norm, params_norm, num_zeros_in_grad):
Mohammad's avatar
Mohammad committed
469
470
471
472
    """Log training information such as losses, timing, ...."""
    args = get_args()
    timers = get_timers()
    writer = get_tensorboard_writer()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
473

mohammad's avatar
mohammad committed
474
475
    # Advanced, skipped, and Nan iterations.
    advanced_iters_key = 'advanced iterations'
mohammad's avatar
mohammad committed
476
    skipped_iters_key = 'skipped iterations'
mohammad's avatar
mohammad committed
477
478
479
480
481
482
483
484
485
    nan_iters_key = 'nan iterations'
    # Advanced iterations.
    if not skipped_iter:
        total_loss_dict[advanced_iters_key] = total_loss_dict.get(
            advanced_iters_key, 0) + 1
    else:
        if advanced_iters_key not in total_loss_dict:
            total_loss_dict[advanced_iters_key] = 0
    # Skipped iterations.
mohammad's avatar
mohammad committed
486
487
    total_loss_dict[skipped_iters_key] = total_loss_dict.get(
        skipped_iters_key, 0) + skipped_iter
mohammad's avatar
mohammad committed
488
    # Update losses and set nan iterations
mohammad's avatar
mohammad committed
489
    got_nan = False
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
490
    for key in loss_dict:
mohammad's avatar
mohammad committed
491
        if not skipped_iter:
492
493
            total_loss_dict[key] = total_loss_dict.get(
                key, torch.cuda.FloatTensor([0.0])) + loss_dict[key]
mohammad's avatar
mohammad committed
494
495
496
497
498
        else:
            value = loss_dict[key].float().sum().item()
            is_nan = value == float('inf') or \
                     value == -float('inf') or \
                     value != value
mohammad's avatar
mohammad committed
499
            got_nan = got_nan or is_nan
mohammad's avatar
mohammad committed
500
501
    total_loss_dict[nan_iters_key] = total_loss_dict.get(
        nan_iters_key, 0) + int(got_nan)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
502
503

    # Logging.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
    timers_to_log = [
        'forward-backward',
        'forward-compute',
        'backward-compute',
        'batch-generator',
        'forward-recv',
        'forward-send',
        'backward-recv',
        'backward-send',
        'forward-send-forward-recv',
        'forward-send-backward-recv',
        'backward-send-forward-recv',
        'backward-send-backward-recv',
        'forward-backward-send-forward-backward-recv',
        'layernorm-grads-all-reduce',
        'embedding-grads-all-reduce',
        'grads-all-reduce',
        'grads-reduce-scatter',
        'params-all-gather',
        'optimizer-copy-to-main-grad',
        'optimizer-unscale-and-check-inf',
        'optimizer-clip-main-grad',
        'optimizer-count-zeros',
        'optimizer-inner-step',
        'optimizer-copy-main-to-model-params',
        'optimizer']
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
530

mohammad's avatar
mohammad committed
531
    # Calculate batch size.
mshoeybi's avatar
mshoeybi committed
532
533
534
    batch_size = args.micro_batch_size * args.data_parallel_size * \
        get_num_microbatches()

mohammad's avatar
mohammad committed
535
536
537
    total_iterations = total_loss_dict[advanced_iters_key] + \
                       total_loss_dict[skipped_iters_key]

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
538
    # Tensorboard values.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
539
540
541
542
543
544
    # Timer requires all the ranks to call.
    if args.log_timers_to_tensorboard and \
       (iteration % args.tensorboard_log_interval == 0):
        timers.write(timers_to_log, writer, iteration,
                     normalizer=total_iterations)
    if writer and (iteration % args.tensorboard_log_interval == 0):
545
546
547
548
549
550
551
552
        if args.log_learning_rate_to_tensorboard:
            writer.add_scalar('learning-rate', learning_rate, iteration)
            writer.add_scalar('learning-rate vs samples', learning_rate,
                              args.consumed_train_samples)
        if args.log_batch_size_to_tensorboard:
            writer.add_scalar('batch-size', batch_size, iteration)
            writer.add_scalar('batch-size vs samples', batch_size,
                              args.consumed_train_samples)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
553
        for key in loss_dict:
mohammad's avatar
mohammad committed
554
555
            writer.add_scalar(key , loss_dict[key], iteration)
            writer.add_scalar(key + ' vs samples', loss_dict[key],
556
                              args.consumed_train_samples)
557
558
559
560
        if args.log_loss_scale_to_tensorboard:
            writer.add_scalar('loss-scale', loss_scale, iteration)
            writer.add_scalar('loss-scale vs samples', loss_scale,
                              args.consumed_train_samples)
561
562
563
564
        if args.log_world_size_to_tensorboard:
            writer.add_scalar('world-size', args.world_size, iteration)
            writer.add_scalar('world-size vs samples', args.world_size,
                              args.consumed_train_samples)
565
566
567
568
        if grad_norm is not None:
            writer.add_scalar('grad-norm', grad_norm, iteration)
            writer.add_scalar('grad-norm vs samples', grad_norm,
                              args.consumed_train_samples)
569
570
571
        if num_zeros_in_grad is not None:
            writer.add_scalar('num-zeros', num_zeros_in_grad, iteration)
            writer.add_scalar('num-zeros vs samples', num_zeros_in_grad,
Rewon Child's avatar
Rewon Child committed
572
                              args.consumed_train_samples)
mohammad's avatar
mohammad committed
573
574
575
576
        if params_norm is not None:
            writer.add_scalar('params-norm', params_norm, iteration)
            writer.add_scalar('params-norm vs samples', params_norm,
                              args.consumed_train_samples)
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
        if args.log_memory_to_tensorboard:
            mem_stats = torch.cuda.memory_stats()
            writer.add_scalar(
                "mem-reserved-bytes",
                mem_stats["reserved_bytes.all.current"],
                iteration,
            )
            writer.add_scalar(
                "mem-allocated-bytes",
                mem_stats["allocated_bytes.all.current"],
                iteration,
            )
            writer.add_scalar(
                "mem-allocated-count",
                mem_stats["allocation.all.current"],
                iteration,
            )
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
594
595

    if iteration % args.log_interval == 0:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
596
        elapsed_time = timers('interval-time').elapsed(barrier=True)
mohammad's avatar
mohammad committed
597
        elapsed_time_per_iteration = elapsed_time / total_iterations
mshoeybi's avatar
mshoeybi committed
598
        if writer:
599
600
601
            if args.log_timers_to_tensorboard:
                writer.add_scalar('iteration-time',
                                  elapsed_time_per_iteration, iteration)
602
603
        log_string = ' iteration {:8d}/{:8d} |'.format(
            iteration, args.train_iters)
mshoeybi's avatar
mshoeybi committed
604
        log_string += ' consumed samples: {:12d} |'.format(
605
            args.consumed_train_samples)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
606
        log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
mohammad's avatar
mohammad committed
607
            elapsed_time_per_iteration * 1000.0)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
608
        log_string += ' learning rate: {:.3E} |'.format(learning_rate)
mohammad's avatar
mohammad committed
609
        log_string += ' global batch size: {:5d} |'.format(batch_size)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
610
        for key in total_loss_dict:
mohammad's avatar
mohammad committed
611
612
613
614
            if key not in [advanced_iters_key, skipped_iters_key,
                           nan_iters_key]:
                avg = total_loss_dict[key].item() / \
                      float(max(1, total_loss_dict[advanced_iters_key]))
615
616
617
                if avg > 0.0:
                    log_string += ' {}: {:.6E} |'.format(key, avg)
                total_loss_dict[key] = torch.cuda.FloatTensor([0.0])
618
        log_string += ' loss scale: {:.1f} |'.format(loss_scale)
619
620
        if grad_norm is not None:
            log_string += ' grad norm: {:.3f} |'.format(grad_norm)
621
622
        if num_zeros_in_grad is not None:
            log_string += ' num zeros: {:.1f} |'.format(num_zeros_in_grad)
mohammad's avatar
mohammad committed
623
624
        if params_norm is not None:
            log_string += ' params norm: {:.3f} |'.format(params_norm)
mohammad's avatar
mohammad committed
625
626
        log_string += ' number of skipped iterations: {:3d} |'.format(
            total_loss_dict[skipped_iters_key])
mohammad's avatar
mohammad committed
627
        log_string += ' number of nan iterations: {:3d} |'.format(
mohammad's avatar
mohammad committed
628
629
            total_loss_dict[nan_iters_key])
        total_loss_dict[advanced_iters_key] = 0
mohammad's avatar
mohammad committed
630
        total_loss_dict[skipped_iters_key] = 0
mohammad's avatar
mohammad committed
631
        total_loss_dict[nan_iters_key] = 0
632
        print_rank_last(log_string)
633
634
635
        if report_memory_flag and learning_rate > 0.:
            # Report memory after optimizer state has been initialized.
            report_memory('(after {} iterations)'.format(iteration))
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
636
637
638
639
640
641
            report_memory_flag = False
        timers.log(timers_to_log, normalizer=args.log_interval)

    return report_memory_flag


642
def save_checkpoint_and_time(iteration, model, optimizer, opt_param_scheduler):
643
644
645
    timers = get_timers()
    # Extra barrier is added to make sure
    # all ranks report the max time.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
646
    timers('save-checkpoint', log_level=0).start(barrier=True)
647
    save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
648
    timers('save-checkpoint').stop(barrier=True)
649
    timers.log(['save-checkpoint'])
650
651


652
def train(forward_step_func, model, optimizer, opt_param_scheduler,
653
654
          train_data_iterator, valid_data_iterator,
          process_non_loss_data_func):
655
    """Train the model function."""
Mohammad's avatar
Mohammad committed
656
657
    args = get_args()
    timers = get_timers()
658

659
660
661
    # Write args to tensorboard
    write_args_to_tensorboard()

662
    # Turn on training mode which enables dropout.
663
664
    for model_module in model:
        model_module.train()
665
666
667
668
669
670
671

    # Tracking loss.
    total_loss_dict = {}

    # Iterations.
    iteration = args.iteration

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
672
    timers('interval-time', log_level=0).start(barrier=True)
673
    print_datetime('before the start of training step')
674
675
    report_memory_flag = True
    while iteration < args.train_iters:
mohammad's avatar
mohammad committed
676
        update_num_microbatches(args.consumed_train_samples)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
677
        args.curr_iteration = iteration
678
679
680
681
682
        loss_dict, skipped_iter, grad_norm, num_zeros_in_grad = \
            train_step(forward_step_func,
                       train_data_iterator,
                       model,
                       optimizer,
Lawrence McAfee's avatar
Lawrence McAfee committed
683
                       opt_param_scheduler)
684
        iteration += 1
685
        args.consumed_train_samples += mpu.get_data_parallel_world_size() * \
686
                                       args.micro_batch_size * \
mohammad's avatar
mohammad committed
687
                                       get_num_microbatches()
688
689

        # Logging.
690
        loss_scale = optimizer.get_loss_scale().item()
691
692
693
        params_norm = None
        if args.log_params_norm:
            params_norm = calc_params_l2_norm(model)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
694
695
        report_memory_flag = training_log(loss_dict, total_loss_dict,
                                          optimizer.param_groups[0]['lr'],
Mohammad's avatar
Mohammad committed
696
                                          iteration, loss_scale,
697
                                          report_memory_flag, skipped_iter,
698
                                          grad_norm, params_norm, num_zeros_in_grad)
699
700

        # Autoresume
701
702
        if args.adlr_autoresume and \
           (iteration % args.adlr_autoresume_interval == 0):
703
            check_adlr_autoresume_termination(iteration, model, optimizer,
704
                                              opt_param_scheduler)
705
706
707
708
709
710

        # Evaluation
        if args.eval_interval and iteration % args.eval_interval == 0 and \
           args.do_valid:
            prefix = 'iteration {}'.format(iteration)
            evaluate_and_print_results(prefix, forward_step_func,
711
                                       valid_data_iterator, model,
712
713
                                       iteration, process_non_loss_data_func,
                                       False)
714

715
716
        # Checkpointing
        saved_checkpoint = False
717
718
719
720
        if args.exit_signal_handler:
            signal_handler = get_signal_handler()
            if any(signal_handler.signals_received()):
                save_checkpoint_and_time(iteration, model, optimizer,
721
                                         opt_param_scheduler)
722
723
724
                print_datetime('exiting program after receiving SIGTERM.')
                sys.exit()

725
726
727
        if args.save and args.save_interval and \
           iteration % args.save_interval == 0:
            save_checkpoint_and_time(iteration, model, optimizer,
728
                                     opt_param_scheduler)
729
730
            saved_checkpoint = True

731
732
733
734
735
736
737
738
739
740
741
        # Exiting based on duration
        if args.exit_duration_in_mins:
            train_time = (time.time() - _TRAIN_START_TIME) / 60.0
            done_cuda = torch.cuda.IntTensor(
                [train_time > args.exit_duration_in_mins])
            torch.distributed.all_reduce(
                done_cuda, op=torch.distributed.ReduceOp.MAX)
            done = done_cuda.item()
            if done:
                if not saved_checkpoint:
                    save_checkpoint_and_time(iteration, model, optimizer,
742
                                             opt_param_scheduler)
743
                print_datetime('exiting program after {} minutes'.format(train_time))
744
745
                sys.exit()

746
        # Exiting based on iterations
747
        if args.exit_interval and iteration % args.exit_interval == 0:
748
749
            if not saved_checkpoint:
                save_checkpoint_and_time(iteration, model, optimizer,
750
                                         opt_param_scheduler)
751
            torch.distributed.barrier()
752
            print_datetime('exiting program at iteration {}'.format(iteration))
Mohammad's avatar
Mohammad committed
753
            sys.exit()
754

755

mohammad's avatar
mohammad committed
756
    return iteration
757
758


759
760
761
762
763
def evaluate(forward_step_func,
             data_iterator,
             model,
             process_non_loss_data_func,
             verbose=False):
764
    """Evaluation."""
Mohammad's avatar
Mohammad committed
765
    args = get_args()
766

Vijay Korthikanti's avatar
Vijay Korthikanti committed
767
768
    if args.vision_pretraining and args.vision_pretraining_type == "dino":
        compute_feature_bank(model)
769

770
    # Turn on evaluation mode which disables dropout.
771
772
    for model_module in model:
        model_module.eval()
773
774
775
776
777
778
779
780
781
782

    total_loss_dict = {}

    with torch.no_grad():
        iteration = 0
        while iteration < args.eval_iters:
            iteration += 1
            if verbose and iteration % args.log_interval == 0:
                print_rank_0('Evaluating iter {}/{}'.format(iteration,
                                                            args.eval_iters))
783

784
            forward_backward_func = get_forward_backward_func()
785
786
787
788
            loss_dicts = forward_backward_func(
                forward_step_func, data_iterator, model, optimizer=None,
                timers=None, forward_only=True)

789
            # Empty unused memory
Lawrence McAfee's avatar
Lawrence McAfee committed
790
            if args.empty_unused_memory_level >= 1:
791
792
                torch.cuda.empty_cache()

793
794
795
            if mpu.is_pipeline_last_stage(ignore_virtual=True):
                # Reduce across processes.
                for loss_dict in loss_dicts:
796
                    for key in loss_dict:
797
798
                        total_loss_dict[key] = total_loss_dict.get(
                            key, torch.cuda.FloatTensor([0.0])) + loss_dict[key]
799

800
            args.consumed_valid_samples += mpu.get_data_parallel_world_size() \
801
                                           * args.micro_batch_size \
mohammad's avatar
mohammad committed
802
                                           * get_num_microbatches()
803
804
805
806
807
808
        collected_non_loss_data = None
        if process_non_loss_data_func is not None and is_last_rank():
            collected_non_loss_data = forward_backward_func(
                forward_step_func, data_iterator, model, optimizer=None,
                timers=None, forward_only=True, collect_non_loss_data=True)

809
    # Move model back to the train mode.
810
811
    for model_module in model:
        model_module.train()
812
813

    for key in total_loss_dict:
mohammad's avatar
mohammad committed
814
        total_loss_dict[key] /= args.eval_iters * get_num_microbatches()
815

816
    return total_loss_dict, collected_non_loss_data
817
818
819

def evaluate_and_print_results(prefix, forward_step_func,
                               data_iterator, model,
820
821
                               iteration, process_non_loss_data_func,
                               verbose=False):
822
    """Helper function to evaluate and dump results on screen."""
823
    args = get_args()
Mohammad's avatar
Mohammad committed
824
825
    writer = get_tensorboard_writer()

826
827
828
    total_loss_dict, collected_non_loss_data = evaluate(
        forward_step_func, data_iterator, model,
        process_non_loss_data_func, verbose)
829
830
831
832
833
    string = ' validation loss at {} | '.format(prefix)
    for key in total_loss_dict:
        string += '{} value: {:.6E} | '.format(key, total_loss_dict[key].item())
        ppl = math.exp(min(20, total_loss_dict[key].item()))
        string += '{} PPL: {:.6E} | '.format(key, ppl)
mshoeybi's avatar
mshoeybi committed
834
        if writer:
mohammad's avatar
mohammad committed
835
            writer.add_scalar('{} validation'.format(key),
836
837
                              total_loss_dict[key].item(),
                              iteration)
mohammad's avatar
mohammad committed
838
            writer.add_scalar('{} validation vs samples'.format(key),
839
840
                              total_loss_dict[key].item(),
                              args.consumed_train_samples)
841
            if args.log_validation_ppl_to_tensorboard:
mohammad's avatar
mohammad committed
842
                writer.add_scalar('{} validation ppl'.format(key), ppl,
843
                                  iteration)
mohammad's avatar
mohammad committed
844
                writer.add_scalar('{} validation ppl vs samples'.format(key),
845
                                  ppl, args.consumed_train_samples)
846

847
848
849
    if process_non_loss_data_func is not None and writer and is_last_rank():
        process_non_loss_data_func(collected_non_loss_data, iteration, writer)

850
    length = len(string) + 1
851
852
853
    print_rank_last('-' * length)
    print_rank_last(string)
    print_rank_last('-' * length)
854
855


Vijay Korthikanti's avatar
Vijay Korthikanti committed
856
def cyclic_iter(iter):
857
    while True:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
858
        for x in iter:
859
860
            yield x

861
862
863
def build_train_valid_test_data_iterators(
        build_train_valid_test_datasets_provider):
    """XXX"""
Mohammad's avatar
Mohammad committed
864
    args = get_args()
865

866
867
868
    (train_dataloader, valid_dataloader, test_dataloader) = (None, None, None)

    print_rank_0('> building train, validation, and test datasets ...')
869
870
871

    # Backward compatibility, assume fixed batch size.
    if args.iteration > 0 and args.consumed_train_samples == 0:
872
873
        assert args.train_samples is None, \
            'only backward compatiblity support for iteration-based training'
mohammad's avatar
mohammad committed
874
        args.consumed_train_samples = args.iteration * args.global_batch_size
875
    if args.iteration > 0 and args.consumed_valid_samples == 0:
876
877
878
        if args.train_samples is None:
            args.consumed_valid_samples = (args.iteration // args.eval_interval) * \
                args.eval_iters * args.global_batch_size
879

880
    # Data loader only on rank 0 of each model parallel group.
881
    if mpu.get_tensor_model_parallel_rank() == 0:
882
883

        # Number of train/valid/test samples.
884
885
886
887
888
889
        if args.train_samples:
            train_samples = args.train_samples
        else:
            train_samples = args.train_iters * args.global_batch_size
        eval_iters = (args.train_iters // args.eval_interval + 1) * \
                     args.eval_iters
890
        test_iters = args.eval_iters
891
        train_val_test_num_samples = [train_samples,
mohammad's avatar
mohammad committed
892
893
                                      eval_iters * args.global_batch_size,
                                      test_iters * args.global_batch_size]
894
895
896
897
898
899
900
901
902
903
        print_rank_0(' > datasets target sizes (minimum size):')
        print_rank_0('    train:      {}'.format(train_val_test_num_samples[0]))
        print_rank_0('    validation: {}'.format(train_val_test_num_samples[1]))
        print_rank_0('    test:       {}'.format(train_val_test_num_samples[2]))

        # Build the datasets.
        train_ds, valid_ds, test_ds = build_train_valid_test_datasets_provider(
            train_val_test_num_samples)

        # Build dataloders.
904
905
906
907
908
        train_dataloader = build_pretraining_data_loader(
            train_ds, args.consumed_train_samples)
        valid_dataloader = build_pretraining_data_loader(
            valid_ds, args.consumed_valid_samples)
        test_dataloader = build_pretraining_data_loader(test_ds, 0)
909
910
911
912
913
914
915
916
917
918
919
920
921

        # Flags to know if we need to do training/validation/testing.
        do_train = train_dataloader is not None and args.train_iters > 0
        do_valid = valid_dataloader is not None and args.eval_iters > 0
        do_test = test_dataloader is not None and args.eval_iters > 0
        # Need to broadcast num_tokens and num_type_tokens.
        flags = torch.cuda.LongTensor(
            [int(do_train), int(do_valid), int(do_test)])
    else:
        flags = torch.cuda.LongTensor([0, 0, 0])

    # Broadcast num tokens.
    torch.distributed.broadcast(flags,
922
923
                                mpu.get_tensor_model_parallel_src_rank(),
                                group=mpu.get_tensor_model_parallel_group())
924
925
926
927
928
    args.do_train = flags[0].item()
    args.do_valid = flags[1].item()
    args.do_test = flags[2].item()

    # Build iterators.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
929
930
931
    dl_type = args.dataloader_type
    assert dl_type in ['single', 'cyclic']

932
    if train_dataloader is not None:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
933
934
        train_data_iterator = iter(train_dataloader) if dl_type == 'single' \
                              else iter(cyclic_iter(train_dataloader))
935
936
937
    else:
        train_data_iterator = None

938
    if valid_dataloader is not None:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
939
940
        valid_data_iterator = iter(valid_dataloader) if dl_type == 'single' \
                              else iter(cyclic_iter(valid_dataloader))
941
    else:
942
        valid_data_iterator = None
943

944
    if test_dataloader is not None:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
945
946
        test_data_iterator = iter(test_dataloader) if dl_type == 'single' \
                             else iter(cyclic_iter(test_dataloader))
947
948
949
    else:
        test_data_iterator = None

950
    return train_data_iterator, valid_data_iterator, test_data_iterator