training.py 40 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Mohammad's avatar
Mohammad committed
16
"""Pretrain utilities."""
17
18
19

from datetime import datetime
import math
Mohammad's avatar
Mohammad committed
20
import sys
21
22
23
import time
# The earliest we can measure the start time.
_TRAIN_START_TIME = time.time()
24
25
26
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP

Neel Kant's avatar
Neel Kant committed
27
from megatron import get_args
28
from megatron import get_signal_handler
Mohammad's avatar
Mohammad committed
29
30
from megatron import get_timers
from megatron import get_tensorboard_writer
31
from megatron import get_current_global_batch_size
mohammad's avatar
mohammad committed
32
from megatron import get_num_microbatches
mohammad's avatar
mohammad committed
33
from megatron import is_last_rank
mohammad's avatar
mohammad committed
34
from megatron import update_num_microbatches
35
from megatron import mpu
Neel Kant's avatar
Neel Kant committed
36
from megatron import print_rank_0
37
from megatron import print_rank_last
Mohammad's avatar
Mohammad committed
38
39
from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
40
from megatron.model import Float16Module
41
from megatron.model import ModelType
mohammad's avatar
mohammad committed
42
from megatron.optimizer import get_megatron_optimizer
Mohammad's avatar
Mohammad committed
43
from megatron.initialize import initialize_megatron
44
from megatron.initialize import write_args_to_tensorboard
45
from megatron.initialize import set_jit_fusion_options
46
from megatron.optimizer_param_scheduler import OptimizerParamScheduler
47
48
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.utils import check_adlr_autoresume_termination
49
from megatron.utils import unwrap_model
Vijay Korthikanti's avatar
Vijay Korthikanti committed
50
from megatron.data.data_samplers import build_pretraining_data_loader
mohammad's avatar
mohammad committed
51
from megatron.utils import calc_params_l2_norm
52
from megatron.schedules import get_forward_backward_func
53
from megatron.utils import report_memory
54
from megatron.model.vision.knn_monitor import compute_feature_bank
55

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
56

57
58
59
60
61
62
63
def print_datetime(string):
    """Note that this call will sync across all ranks."""
    torch.distributed.barrier()
    time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    print_rank_0('[' + string + '] datetime: {} '.format(time_str))


64
def pretrain(train_valid_test_dataset_provider,
65
             model_provider,
66
             model_type,
67
             forward_step_func,
68
             process_non_loss_data_func=None,
69
             extra_args_provider=None,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
70
             args_defaults={}):
71
72
73
    """Main training program.

    This function will run the followings in the order provided:
Mohammad's avatar
Mohammad committed
74
75
        1) initialize Megatron.
        2) setup model, optimizer and lr schedule using the model_provider.
76
        3) call train_val_test_data_provider to get train/val/test datasets.
Mohammad's avatar
Mohammad committed
77
        4) train the modle using the forward_step_func.
78
79

    Arguments:
80
81
82
        train_valid_test_dataset_provider: a function that takes the size of
            train/valid/test dataset and returns `train, valid, test` datasets.
        model_provider: a function that returns a vanilla version of the
Mohammad's avatar
Mohammad committed
83
            model. By vanilla we mean a simple model on cpu with no fp16 or ddp.
84
        model_type: an enum that specifies the type of model being trained.
Mohammad's avatar
Mohammad committed
85
86
87
88
89
        forward_step_func: a function that takes a `data iterator` and `model`,
            and returns a `loss` scalar with a dictionary with key:values being
            the info we would like to monitor during training, for example
            `lm-loss: value`. We also require that this function add
            `batch generator` to the timers class.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
90
91
92
93
        process_non_loss_data_func: a function to post process outputs of the
            network. It can be used for dumping output tensors (e.g images) to
            tensorboard. It takes `collected data`(list of tensors),
            `current iteration index` and `tensorboard writer` as arguments.
Mohammad's avatar
Mohammad committed
94
95
96
97
        extra_args_provider: a function that takes a parser and adds arguments
            to it. It is used for programs to add their own arguments.
        args_defaults: a dictionary from argument-name to argument-value. It
            to set already parse arguments.
98
99
    """

100
    # Initalize and get arguments, timers, and Tensorboard writer.
101
102
    initialize_megatron(extra_args_provider=extra_args_provider,
                        args_defaults=args_defaults)
103
104
    # Set pytorch JIT layer fusion options and warmup JIT functions.
    set_jit_fusion_options()
105

106
107
108
109
    # Adjust the startup time so it reflects the largest value.
    # This will be closer to what scheduler will see (outside of
    # image ... launches.
    global _TRAIN_START_TIME
110
    start_time_tensor = torch.cuda.DoubleTensor([_TRAIN_START_TIME])
111
112
113
    torch.distributed.all_reduce(start_time_tensor,
                                 op=torch.distributed.ReduceOp.MIN)
    _TRAIN_START_TIME = start_time_tensor.item()
mshoeybi's avatar
mshoeybi committed
114
    print_rank_0('time to initialize megatron (seconds): {:.3f}'.format(
115
116
117
        time.time() - _TRAIN_START_TIME))
    print_datetime('after megatron is initialized')

118
    args = get_args()
Mohammad's avatar
Mohammad committed
119
    timers = get_timers()
120
121

    # Model, optimizer, and learning rate.
122
    timers('model-and-optimizer-setup').start()
123
    model, optimizer, opt_param_scheduler = setup_model_and_optimizer(model_provider,
124
                                                               model_type)
125
    timers('model-and-optimizer-setup').stop()
126
127
    print_datetime('after model, optimizer, and learning rate '
                   'scheduler are built')
128
129

    # Data stuff.
130
131
    timers('train/valid/test-data-iterators-setup').start()
    if args.virtual_pipeline_model_parallel_size is not None:
132
        all_data_iterators = [
133
134
135
            build_train_valid_test_data_iterators(train_valid_test_dataset_provider)
            for _ in range(len(model))
        ]
136
137
138
        train_data_iterator = [data_iterators[0] for data_iterators in all_data_iterators]
        valid_data_iterator = [data_iterators[1] for data_iterators in all_data_iterators]
        test_data_iterator = [data_iterators[2] for data_iterators in all_data_iterators]
139
140
141
142
143
    else:
        train_data_iterator, valid_data_iterator, test_data_iterator \
            = build_train_valid_test_data_iterators(
                train_valid_test_dataset_provider)
    timers('train/valid/test-data-iterators-setup').stop()
mshoeybi's avatar
mshoeybi committed
144
    print_datetime('after dataloaders are built')
Mohammad's avatar
Mohammad committed
145
146

    # Print setup timing.
147
148
    print_rank_0('done with setup ...')
    timers.log(['model-and-optimizer-setup', 'train/valid/test-data-iterators-setup'])
Mohammad's avatar
Mohammad committed
149
    print_rank_0('training ...')
150
151

    iteration = 0
152
    if args.do_train and args.train_iters > 0:
mohammad's avatar
mohammad committed
153
        iteration = train(forward_step_func,
154
                          model, optimizer, opt_param_scheduler,
155
156
                          train_data_iterator, valid_data_iterator,
                          process_non_loss_data_func)
157
    print_datetime('after training is done')
Mohammad's avatar
Mohammad committed
158

159
160
161
    if args.do_valid:
        prefix = 'the end of training for val data'
        evaluate_and_print_results(prefix, forward_step_func,
162
                                   valid_data_iterator, model,
163
164
                                   iteration, process_non_loss_data_func,
                                   False)
165
166

    if args.save and iteration != 0:
167
        save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
168
169
170
171
172
173

    if args.do_test:
        # Run on test data.
        prefix = 'the end of training for test data'
        evaluate_and_print_results(prefix, forward_step_func,
                                   test_data_iterator, model,
174
175
                                   0, process_non_loss_data_func,
                                   True)
176

177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
def update_train_iters(args):

    # For iteration-based training, we don't need to do anything
    if args.train_iters:
        return

    # Constant batch size with sample-based training.
    if args.rampup_batch_size is None:
        args.train_iters = args.train_samples // args.global_batch_size

    else:
        # Sample based training with rampup batch size.
        iterations = 0
        consumed_samples = 0
        # Rampup phase.
        while consumed_samples <= int(args.rampup_batch_size[2]):
193
194
            update_num_microbatches(consumed_samples, consistency_check=False)
            consumed_samples += get_current_global_batch_size()
195
196
            iterations += 1
        # Reset
197
        update_num_microbatches(0, consistency_check=False)
198
199
200
201
202
203
204
205
        # Constant phase
        # Note that we throw away any partial last batch.
        iterations += (args.train_samples - consumed_samples) // \
                      args.global_batch_size
        args.train_iters = iterations

    print_rank_0('setting training iterations to {}'.format(args.train_iters))

206

207
def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap_with_ddp=True):
208
    """Build the model."""
Mohammad's avatar
Mohammad committed
209
    args = get_args()
210
    args.model_type = model_type
211

212
    # Build model.
213
214
    if mpu.get_pipeline_model_parallel_world_size() > 1 and \
       args.virtual_pipeline_model_parallel_size is not None:
215
216
        assert model_type != ModelType.encoder_and_decoder, \
            "Interleaved schedule not supported for model with both encoder and decoder"
217
218
219
        model = []
        for i in range(args.virtual_pipeline_model_parallel_size):
            mpu.set_virtual_pipeline_model_parallel_rank(i)
220
221
222
            # Set pre_process and post_process only after virtual rank is set.
            pre_process = mpu.is_pipeline_first_stage()
            post_process = mpu.is_pipeline_last_stage()
223
            this_model = model_provider_func(
224
225
226
                pre_process=pre_process,
                post_process=post_process
            )
227
            this_model.model_type = model_type
228
            model.append(this_model)
229
    else:
230
231
        pre_process = mpu.is_pipeline_first_stage()
        post_process = mpu.is_pipeline_last_stage()
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
        add_encoder = True
        add_decoder = True
        if model_type == ModelType.encoder_and_decoder:
            if mpu.get_pipeline_model_parallel_world_size() > 1:
                assert args.pipeline_model_parallel_split_rank is not None, \
                    "Split rank needs to be specified for model with both encoder and decoder"
                rank = mpu.get_pipeline_model_parallel_rank()
                split_rank = args.pipeline_model_parallel_split_rank
                world_size = mpu.get_pipeline_model_parallel_world_size()
                pre_process = rank == 0 or rank == split_rank
                post_process = (rank == (split_rank - 1)) or (
                        rank == (world_size - 1))
                add_encoder = mpu.is_pipeline_stage_before_split()
                add_decoder = mpu.is_pipeline_stage_after_split()
            model = model_provider_func(
                pre_process=pre_process,
                post_process=post_process,
                add_encoder=add_encoder,
                add_decoder=add_decoder)
        else:
            model = model_provider_func(
                pre_process=pre_process,
                post_process=post_process
            )
        model.model_type = model_type
257

258
259
    if not isinstance(model, list):
        model = [model]
260

261
    # Set tensor model parallel attributes if not set.
mohammad's avatar
mohammad committed
262
263
264
    # Only parameters that are already tensor model parallel have these
    # attributes set for them. We should make sure the default attributes
    # are set for all params so the optimizer can use them.
265
266
267
    for model_module in model:
        for param in model_module.parameters():
            mpu.set_defaults_if_not_set_tensor_model_parallel_attributes(param)
268

269
270
    # Print number of parameters.
    if mpu.get_data_parallel_rank() == 0:
271
        print(' > number of parameters on (tensor, pipeline) '
272
              'model parallel rank ({}, {}): {}'.format(
273
274
            mpu.get_tensor_model_parallel_rank(),
            mpu.get_pipeline_model_parallel_rank(),
275
276
            sum([sum([p.nelement() for p in model_module.parameters()])
                 for model_module in model])), flush=True)
277
278

    # GPU allocation.
279
280
    for model_module in model:
        model_module.cuda(torch.cuda.current_device())
281
282

    # Fp16 conversion.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
283
284
    if args.fp16 or args.bf16:
        model = [Float16Module(model_module, args) for model_module in model]
285

286
287
288
289
290
291
    if wrap_with_ddp:
        if args.DDP_impl == 'torch':
            i = torch.cuda.current_device()
            model = [torchDDP(model_module, device_ids=[i], output_device=i,
                              process_group=mpu.get_data_parallel_group())
                     for model_module in model]
292

293
294
295
296
297
        elif args.DDP_impl == 'local':
            model = [LocalDDP(model_module,
                              args.accumulate_allreduce_grads_in_fp32,
                              args.use_contiguous_buffers_in_local_ddp)
                     for model_module in model]
298
299
300
301
            # broad cast params from data parallel src rank to other data parallel ranks
            if args.data_parallel_random_init:
                for model_module in model:
                    model_module.broadcast_params()
302
303
304
        else:
            raise NotImplementedError('Unknown DDP implementation specified: '
                                      '{}. Exiting.'.format(args.DDP_impl))
305

306
    return model
307
308


309
def get_optimizer_param_scheduler(optimizer):
310
    """Build the learning rate scheduler."""
Mohammad's avatar
Mohammad committed
311
    args = get_args()
312

313
314
315
316
    # Iteration-based training.
    if args.train_iters:
        if args.lr_decay_iters is None:
            args.lr_decay_iters = args.train_iters
Vijay Korthikanti's avatar
Vijay Korthikanti committed
317
318
        lr_decay_steps = args.lr_decay_iters * args.global_batch_size
        wd_incr_steps = args.train_iters * args.global_batch_size
319
        if args.lr_warmup_fraction is not None:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
320
            lr_warmup_steps = args.lr_warmup_fraction * lr_decay_steps
321
        else:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
322
            lr_warmup_steps = args.lr_warmup_iters * args.global_batch_size
323
324
325
326
327
    # Sample-based training.
    elif args.train_samples:
        # We need to set training iters for later use. Technically
        # we need to adjust the training samples too (due to last
        # batch being incomplete) but we leave it as is for now.
328
        update_train_iters(args)
329
330
        if args.lr_decay_samples is None:
            args.lr_decay_samples = args.train_samples
Vijay Korthikanti's avatar
Vijay Korthikanti committed
331
332
        lr_decay_steps = args.lr_decay_samples
        wd_incr_steps = args.train_samples
333
        if args.lr_warmup_fraction is not None:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
334
            lr_warmup_steps = args.lr_warmup_fraction * lr_decay_steps
335
        else:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
336
            lr_warmup_steps = args.lr_warmup_samples
337
    else:
338
339
340
        raise Exception(
            'either train-iters or train-samples should be provided.')

341
    opt_param_scheduler = OptimizerParamScheduler(
342
        optimizer,
343
        max_lr=args.lr,
344
        min_lr=args.min_lr,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
345
346
347
        lr_warmup_steps=lr_warmup_steps,
        lr_decay_steps=lr_decay_steps,
        lr_decay_style=args.lr_decay_style,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
348
349
        start_wd=args.start_weight_decay,
        end_wd=args.end_weight_decay,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
350
        wd_incr_steps=wd_incr_steps,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
351
        wd_incr_style=args.weight_decay_incr_style,
352
353
        use_checkpoint_opt_param_scheduler=args.use_checkpoint_opt_param_scheduler,
        override_opt_param_scheduler=args.override_opt_param_scheduler)
354

355
    return opt_param_scheduler
356
357


358
359
360
361
362
def setup_model_and_optimizer(model_provider_func,
                              model_type,
                              no_wd_decay_cond=None,
                              scale_lr_cond=None,
                              lr_mult=1.0):
363
    """Setup model and optimizer."""
Mohammad's avatar
Mohammad committed
364
    args = get_args()
365

366
    model = get_model(model_provider_func, model_type)
367
    unwrapped_model = unwrap_model(model,
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
368
                                   (torchDDP, LocalDDP, Float16Module))
Lawrence McAfee's avatar
Lawrence McAfee committed
369

370
    optimizer = get_megatron_optimizer(model, no_wd_decay_cond,
371
                                       scale_lr_cond, lr_mult)
372
    opt_param_scheduler = get_optimizer_param_scheduler(optimizer)
373
374

    if args.load is not None:
375
376
377
378
        timers = get_timers()
        # Extra barrier is added to make sure all ranks report the
        # max time.
        torch.distributed.barrier()
379
        timers('load-checkpoint').start()
380
        args.iteration = load_checkpoint(model, optimizer, opt_param_scheduler)
381
        torch.distributed.barrier()
382
383
        timers('load-checkpoint').stop()
        timers.log(['load-checkpoint'])
384
385
386
    else:
        args.iteration = 0

mohammad's avatar
mohammad committed
387
    # We only support local DDP with multiple micro-batches.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
388
    if len(model) > 1 or mpu.get_pipeline_model_parallel_world_size() > 1:
mohammad's avatar
mohammad committed
389
390
        assert args.DDP_impl == 'local'

Neel Kant's avatar
Neel Kant committed
391
    # get model without FP16 and/or TorchDDP wrappers
Mostofa Patwary's avatar
Mostofa Patwary committed
392
393
    if args.iteration == 0 and len(unwrapped_model) == 1 \
        and hasattr(unwrapped_model[0], 'init_state_dict_from_bert'):
Mostofa Patwary's avatar
Mostofa Patwary committed
394
        print_rank_0("Initializing ICT from pretrained BERT model")
Mostofa Patwary's avatar
Mostofa Patwary committed
395
        unwrapped_model[0].init_state_dict_from_bert()
Mostofa Patwary's avatar
Mostofa Patwary committed
396
397
        if args.fp16:
            optimizer.reload_model_params()
Neel Kant's avatar
Neel Kant committed
398

399
    return model, optimizer, opt_param_scheduler
400
401


402
def train_step(forward_step_func, data_iterator,
Lawrence McAfee's avatar
Lawrence McAfee committed
403
               model, optimizer, opt_param_scheduler):
404
405
406
407
408
    """Single training step."""
    args = get_args()
    timers = get_timers()

    # Set grad to zero.
409
    if args.DDP_impl == 'local' and args.use_contiguous_buffers_in_local_ddp:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
410
411
        for partition in model:
            partition.zero_grad_buffer()
412
    optimizer.zero_grad()
413

414
    # Forward pass.
415
    forward_backward_func = get_forward_backward_func()
416
417
418
    losses_reduced = forward_backward_func(
        forward_step_func, data_iterator, model,
        optimizer, timers, forward_only=False)
419

420
    # Empty unused memory.
Lawrence McAfee's avatar
Lawrence McAfee committed
421
    if args.empty_unused_memory_level >= 1:
422
423
        torch.cuda.empty_cache()

424
    # Reduce gradients.
Lawrence McAfee's avatar
Lawrence McAfee committed
425
    timers('backward-reduce-model-grads').start()
426
    optimizer.reduce_model_grads(args, timers)
Lawrence McAfee's avatar
Lawrence McAfee committed
427
    timers('backward-reduce-model-grads').stop()
428

Lawrence McAfee's avatar
Lawrence McAfee committed
429
    # Vision gradients.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
430
    if args.vision_pretraining and args.vision_pretraining_type == "dino":
431
432
433
434
        unwrapped_model = unwrap_model(model[0],
                                       (torchDDP, LocalDDP, Float16Module))
        unwrapped_model.cancel_gradients_last_layer(args.curr_iteration)

435
436
    # Update parameters.
    timers('optimizer').start()
Lawrence McAfee's avatar
Lawrence McAfee committed
437
    update_successful, grad_norm, num_zeros_in_grad = optimizer.step(args, timers)
438
439
    timers('optimizer').stop()

440
    # Gather params.
441
    if update_successful:
Lawrence McAfee's avatar
Lawrence McAfee committed
442
        timers('backward-gather-model-params').start()
Lawrence McAfee's avatar
Lawrence McAfee committed
443
        optimizer.gather_model_params(args, timers)
Lawrence McAfee's avatar
Lawrence McAfee committed
444
        timers('backward-gather-model-params').stop()
445

Lawrence McAfee's avatar
Lawrence McAfee committed
446
    # Vision momentum.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
447
    if args.vision_pretraining and args.vision_pretraining_type == "dino":
448
449
450
451
        unwrapped_model = unwrap_model(model[0],
                                       (torchDDP, LocalDDP, Float16Module))
        unwrapped_model.update_momentum(args.curr_iteration)

452
    # Update learning rate.
453
    if update_successful:
454
455
456
        increment = get_num_microbatches() * \
                    args.micro_batch_size * \
                    args.data_parallel_size
457
        opt_param_scheduler.step(increment=increment)
mohammad's avatar
mohammad committed
458
        skipped_iter = 0
459
460
461
    else:
        skipped_iter = 1

462
    # Empty unused memory.
Lawrence McAfee's avatar
Lawrence McAfee committed
463
    if args.empty_unused_memory_level >= 2:
464
465
        torch.cuda.empty_cache()

466
    if mpu.is_pipeline_last_stage(ignore_virtual=True):
467
468
469
470
        # Average loss across microbatches.
        loss_reduced = {}
        for key in losses_reduced[0]:
            losses_reduced_for_key = [x[key] for x in losses_reduced]
471
            loss_reduced[key] = sum(losses_reduced_for_key) / len(losses_reduced_for_key)
472
473
        return loss_reduced, skipped_iter, grad_norm, num_zeros_in_grad
    return {}, skipped_iter, grad_norm, num_zeros_in_grad
474
475


Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
476
def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
mohammad's avatar
mohammad committed
477
                 loss_scale, report_memory_flag, skipped_iter,
478
                 grad_norm, params_norm, num_zeros_in_grad):
Mohammad's avatar
Mohammad committed
479
480
481
482
    """Log training information such as losses, timing, ...."""
    args = get_args()
    timers = get_timers()
    writer = get_tensorboard_writer()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
483

mohammad's avatar
mohammad committed
484
485
    # Advanced, skipped, and Nan iterations.
    advanced_iters_key = 'advanced iterations'
mohammad's avatar
mohammad committed
486
    skipped_iters_key = 'skipped iterations'
mohammad's avatar
mohammad committed
487
488
489
490
491
492
493
494
495
    nan_iters_key = 'nan iterations'
    # Advanced iterations.
    if not skipped_iter:
        total_loss_dict[advanced_iters_key] = total_loss_dict.get(
            advanced_iters_key, 0) + 1
    else:
        if advanced_iters_key not in total_loss_dict:
            total_loss_dict[advanced_iters_key] = 0
    # Skipped iterations.
mohammad's avatar
mohammad committed
496
497
    total_loss_dict[skipped_iters_key] = total_loss_dict.get(
        skipped_iters_key, 0) + skipped_iter
mohammad's avatar
mohammad committed
498
    # Update losses and set nan iterations
mohammad's avatar
mohammad committed
499
    got_nan = False
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
500
    for key in loss_dict:
mohammad's avatar
mohammad committed
501
        if not skipped_iter:
502
503
            total_loss_dict[key] = total_loss_dict.get(
                key, torch.cuda.FloatTensor([0.0])) + loss_dict[key]
mohammad's avatar
mohammad committed
504
505
506
507
508
        else:
            value = loss_dict[key].float().sum().item()
            is_nan = value == float('inf') or \
                     value == -float('inf') or \
                     value != value
mohammad's avatar
mohammad committed
509
            got_nan = got_nan or is_nan
mohammad's avatar
mohammad committed
510
511
    total_loss_dict[nan_iters_key] = total_loss_dict.get(
        nan_iters_key, 0) + int(got_nan)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
512
513
514

    # Logging.
    timers_to_log = []
Neel Kant's avatar
Neel Kant committed
515

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
516
517
518
    def add_to_logging(name):
        if name in timers.timers:
            timers_to_log.append(name)
519
520
521
    add_to_logging('forward-compute')
    add_to_logging('forward-recv')
    add_to_logging('forward-send')
522
    add_to_logging('forward-backward-send-forward-backward-recv')
523
524
525
    add_to_logging('backward-compute')
    add_to_logging('backward-recv')
    add_to_logging('backward-send')
Deepak Narayanan's avatar
Deepak Narayanan committed
526
    add_to_logging('backward-send-forward-recv')
527
    add_to_logging('backward-send-backward-recv')
528
    add_to_logging('backward-params-all-reduce')
529
    add_to_logging('backward-layernorm-all-reduce')
530
    add_to_logging('backward-embedding-all-reduce')
Lawrence McAfee's avatar
Lawrence McAfee committed
531
532
    add_to_logging('backward-reduce-model-grads')
    add_to_logging('backward-gather-model-params')
533
    add_to_logging('optimizer-copy-to-main-grad')
mohammad's avatar
mohammad committed
534
    add_to_logging('optimizer-unscale-and-check-inf')
535
    add_to_logging('optimizer-clip-main-grad')
536
537
    add_to_logging('optimizer-count-zeros')
    add_to_logging('optimizer-inner-step')
538
    add_to_logging('optimizer-copy-main-to-model-params')
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
539
    add_to_logging('optimizer')
mohammad's avatar
mohammad committed
540
    add_to_logging('batch-generator')
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
541

mohammad's avatar
mohammad committed
542
    # Calculate batch size.
mshoeybi's avatar
mshoeybi committed
543
544
545
    batch_size = args.micro_batch_size * args.data_parallel_size * \
        get_num_microbatches()

mohammad's avatar
mohammad committed
546
547
548
    total_iterations = total_loss_dict[advanced_iters_key] + \
                       total_loss_dict[skipped_iters_key]

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
549
    # Tensorboard values.
550
551
552
553
554
555
556
557
558
559
    if writer and (iteration % args.tensorboard_log_interval == 0 ) and \
       is_last_rank():
        if args.log_learning_rate_to_tensorboard:
            writer.add_scalar('learning-rate', learning_rate, iteration)
            writer.add_scalar('learning-rate vs samples', learning_rate,
                              args.consumed_train_samples)
        if args.log_batch_size_to_tensorboard:
            writer.add_scalar('batch-size', batch_size, iteration)
            writer.add_scalar('batch-size vs samples', batch_size,
                              args.consumed_train_samples)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
560
        for key in loss_dict:
mohammad's avatar
mohammad committed
561
562
            writer.add_scalar(key , loss_dict[key], iteration)
            writer.add_scalar(key + ' vs samples', loss_dict[key],
563
                              args.consumed_train_samples)
564
565
566
567
        if args.log_loss_scale_to_tensorboard:
            writer.add_scalar('loss-scale', loss_scale, iteration)
            writer.add_scalar('loss-scale vs samples', loss_scale,
                              args.consumed_train_samples)
568
569
570
571
        if args.log_world_size_to_tensorboard:
            writer.add_scalar('world-size', args.world_size, iteration)
            writer.add_scalar('world-size vs samples', args.world_size,
                              args.consumed_train_samples)
572
573
574
575
        if grad_norm is not None:
            writer.add_scalar('grad-norm', grad_norm, iteration)
            writer.add_scalar('grad-norm vs samples', grad_norm,
                              args.consumed_train_samples)
576
577
578
        if num_zeros_in_grad is not None:
            writer.add_scalar('num-zeros', num_zeros_in_grad, iteration)
            writer.add_scalar('num-zeros vs samples', num_zeros_in_grad,
Rewon Child's avatar
Rewon Child committed
579
                              args.consumed_train_samples)
mohammad's avatar
mohammad committed
580
581
582
583
        if params_norm is not None:
            writer.add_scalar('params-norm', params_norm, iteration)
            writer.add_scalar('params-norm vs samples', params_norm,
                              args.consumed_train_samples)
584
585
586
        if args.log_timers_to_tensorboard:
            timers.write(timers_to_log, writer, iteration,
                         normalizer=total_iterations)
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
        if args.log_memory_to_tensorboard:
            mem_stats = torch.cuda.memory_stats()
            writer.add_scalar(
                "mem-reserved-bytes",
                mem_stats["reserved_bytes.all.current"],
                iteration,
            )
            writer.add_scalar(
                "mem-allocated-bytes",
                mem_stats["allocated_bytes.all.current"],
                iteration,
            )
            writer.add_scalar(
                "mem-allocated-count",
                mem_stats["allocation.all.current"],
                iteration,
            )
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
604
605

    if iteration % args.log_interval == 0:
606
        elapsed_time = timers('interval-time').elapsed()
mohammad's avatar
mohammad committed
607
        elapsed_time_per_iteration = elapsed_time / total_iterations
mshoeybi's avatar
mshoeybi committed
608
        if writer:
609
610
611
            if args.log_timers_to_tensorboard:
                writer.add_scalar('iteration-time',
                                  elapsed_time_per_iteration, iteration)
612
613
        log_string = ' iteration {:8d}/{:8d} |'.format(
            iteration, args.train_iters)
mshoeybi's avatar
mshoeybi committed
614
        log_string += ' consumed samples: {:12d} |'.format(
615
            args.consumed_train_samples)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
616
        log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
mohammad's avatar
mohammad committed
617
            elapsed_time_per_iteration * 1000.0)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
618
        log_string += ' learning rate: {:.3E} |'.format(learning_rate)
mohammad's avatar
mohammad committed
619
        log_string += ' global batch size: {:5d} |'.format(batch_size)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
620
        for key in total_loss_dict:
mohammad's avatar
mohammad committed
621
622
623
624
            if key not in [advanced_iters_key, skipped_iters_key,
                           nan_iters_key]:
                avg = total_loss_dict[key].item() / \
                      float(max(1, total_loss_dict[advanced_iters_key]))
625
626
627
                if avg > 0.0:
                    log_string += ' {}: {:.6E} |'.format(key, avg)
                total_loss_dict[key] = torch.cuda.FloatTensor([0.0])
628
        log_string += ' loss scale: {:.1f} |'.format(loss_scale)
629
630
        if grad_norm is not None:
            log_string += ' grad norm: {:.3f} |'.format(grad_norm)
631
632
        if num_zeros_in_grad is not None:
            log_string += ' num zeros: {:.1f} |'.format(num_zeros_in_grad)
mohammad's avatar
mohammad committed
633
634
        if params_norm is not None:
            log_string += ' params norm: {:.3f} |'.format(params_norm)
mohammad's avatar
mohammad committed
635
636
        log_string += ' number of skipped iterations: {:3d} |'.format(
            total_loss_dict[skipped_iters_key])
mohammad's avatar
mohammad committed
637
        log_string += ' number of nan iterations: {:3d} |'.format(
mohammad's avatar
mohammad committed
638
639
            total_loss_dict[nan_iters_key])
        total_loss_dict[advanced_iters_key] = 0
mohammad's avatar
mohammad committed
640
        total_loss_dict[skipped_iters_key] = 0
mohammad's avatar
mohammad committed
641
        total_loss_dict[nan_iters_key] = 0
642
        print_rank_last(log_string)
643
644
645
        if report_memory_flag and learning_rate > 0.:
            # Report memory after optimizer state has been initialized.
            report_memory('(after {} iterations)'.format(iteration))
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
646
647
648
649
650
651
            report_memory_flag = False
        timers.log(timers_to_log, normalizer=args.log_interval)

    return report_memory_flag


652
def save_checkpoint_and_time(iteration, model, optimizer, opt_param_scheduler):
653
654
655
656
    timers = get_timers()
    # Extra barrier is added to make sure
    # all ranks report the max time.
    torch.distributed.barrier()
657
    timers('save-checkpoint').start()
658
    save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
659
    torch.distributed.barrier()
660
661
    timers('save-checkpoint').stop()
    timers.log(['save-checkpoint'])
662
663


664
def train(forward_step_func, model, optimizer, opt_param_scheduler,
665
666
          train_data_iterator, valid_data_iterator,
          process_non_loss_data_func):
667
    """Train the model function."""
Mohammad's avatar
Mohammad committed
668
669
    args = get_args()
    timers = get_timers()
670

671
672
673
    # Write args to tensorboard
    write_args_to_tensorboard()

674
    # Turn on training mode which enables dropout.
675
676
    for model_module in model:
        model_module.train()
677
678
679
680
681
682
683

    # Tracking loss.
    total_loss_dict = {}

    # Iterations.
    iteration = args.iteration

684
    timers('interval-time').start()
685
    print_datetime('before the start of training step')
686
687
    report_memory_flag = True
    while iteration < args.train_iters:
mohammad's avatar
mohammad committed
688
        update_num_microbatches(args.consumed_train_samples)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
689
        args.curr_iteration = iteration
690
691
692
693
694
        loss_dict, skipped_iter, grad_norm, num_zeros_in_grad = \
            train_step(forward_step_func,
                       train_data_iterator,
                       model,
                       optimizer,
Lawrence McAfee's avatar
Lawrence McAfee committed
695
                       opt_param_scheduler)
696
        iteration += 1
697
        args.consumed_train_samples += mpu.get_data_parallel_world_size() * \
698
                                       args.micro_batch_size * \
mohammad's avatar
mohammad committed
699
                                       get_num_microbatches()
700
701

        # Logging.
702
        loss_scale = optimizer.get_loss_scale().item()
703
704
705
        params_norm = None
        if args.log_params_norm:
            params_norm = calc_params_l2_norm(model)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
706
707
        report_memory_flag = training_log(loss_dict, total_loss_dict,
                                          optimizer.param_groups[0]['lr'],
Mohammad's avatar
Mohammad committed
708
                                          iteration, loss_scale,
709
                                          report_memory_flag, skipped_iter,
710
                                          grad_norm, params_norm, num_zeros_in_grad)
711
712

        # Autoresume
713
714
        if args.adlr_autoresume and \
           (iteration % args.adlr_autoresume_interval == 0):
715
            check_adlr_autoresume_termination(iteration, model, optimizer,
716
                                              opt_param_scheduler)
717
718
719
720
721
722

        # Evaluation
        if args.eval_interval and iteration % args.eval_interval == 0 and \
           args.do_valid:
            prefix = 'iteration {}'.format(iteration)
            evaluate_and_print_results(prefix, forward_step_func,
723
                                       valid_data_iterator, model,
724
725
                                       iteration, process_non_loss_data_func,
                                       False)
726

727
728
        # Checkpointing
        saved_checkpoint = False
729
730
731
732
        if args.exit_signal_handler:
            signal_handler = get_signal_handler()
            if any(signal_handler.signals_received()):
                save_checkpoint_and_time(iteration, model, optimizer,
733
                                         opt_param_scheduler)
734
735
736
                print_datetime('exiting program after receiving SIGTERM.')
                sys.exit()

737
738
739
        if args.save and args.save_interval and \
           iteration % args.save_interval == 0:
            save_checkpoint_and_time(iteration, model, optimizer,
740
                                     opt_param_scheduler)
741
742
            saved_checkpoint = True

743
744
745
746
747
748
749
750
751
752
753
        # Exiting based on duration
        if args.exit_duration_in_mins:
            train_time = (time.time() - _TRAIN_START_TIME) / 60.0
            done_cuda = torch.cuda.IntTensor(
                [train_time > args.exit_duration_in_mins])
            torch.distributed.all_reduce(
                done_cuda, op=torch.distributed.ReduceOp.MAX)
            done = done_cuda.item()
            if done:
                if not saved_checkpoint:
                    save_checkpoint_and_time(iteration, model, optimizer,
754
                                             opt_param_scheduler)
755
                print_datetime('exiting program after {} minutes'.format(train_time))
756
757
                sys.exit()

758
        # Exiting based on iterations
759
        if args.exit_interval and iteration % args.exit_interval == 0:
760
761
            if not saved_checkpoint:
                save_checkpoint_and_time(iteration, model, optimizer,
762
                                         opt_param_scheduler)
763
            torch.distributed.barrier()
764
            print_datetime('exiting program at iteration {}'.format(iteration))
Mohammad's avatar
Mohammad committed
765
            sys.exit()
766

767

mohammad's avatar
mohammad committed
768
    return iteration
769
770


771
772
773
774
775
def evaluate(forward_step_func,
             data_iterator,
             model,
             process_non_loss_data_func,
             verbose=False):
776
    """Evaluation."""
Mohammad's avatar
Mohammad committed
777
    args = get_args()
778

Vijay Korthikanti's avatar
Vijay Korthikanti committed
779
780
    if args.vision_pretraining and args.vision_pretraining_type == "dino":
        compute_feature_bank(model)
781

782
    # Turn on evaluation mode which disables dropout.
783
784
    for model_module in model:
        model_module.eval()
785
786
787
788
789
790
791
792
793
794

    total_loss_dict = {}

    with torch.no_grad():
        iteration = 0
        while iteration < args.eval_iters:
            iteration += 1
            if verbose and iteration % args.log_interval == 0:
                print_rank_0('Evaluating iter {}/{}'.format(iteration,
                                                            args.eval_iters))
795

796
            forward_backward_func = get_forward_backward_func()
797
798
799
800
            loss_dicts = forward_backward_func(
                forward_step_func, data_iterator, model, optimizer=None,
                timers=None, forward_only=True)

801
            # Empty unused memory
Lawrence McAfee's avatar
Lawrence McAfee committed
802
            if args.empty_unused_memory_level >= 1:
803
804
                torch.cuda.empty_cache()

805
806
807
            if mpu.is_pipeline_last_stage(ignore_virtual=True):
                # Reduce across processes.
                for loss_dict in loss_dicts:
808
                    for key in loss_dict:
809
810
                        total_loss_dict[key] = total_loss_dict.get(
                            key, torch.cuda.FloatTensor([0.0])) + loss_dict[key]
811

812
            args.consumed_valid_samples += mpu.get_data_parallel_world_size() \
813
                                           * args.micro_batch_size \
mohammad's avatar
mohammad committed
814
                                           * get_num_microbatches()
815
816
817
818
819
820
        collected_non_loss_data = None
        if process_non_loss_data_func is not None and is_last_rank():
            collected_non_loss_data = forward_backward_func(
                forward_step_func, data_iterator, model, optimizer=None,
                timers=None, forward_only=True, collect_non_loss_data=True)

821
    # Move model back to the train mode.
822
823
    for model_module in model:
        model_module.train()
824
825

    for key in total_loss_dict:
mohammad's avatar
mohammad committed
826
        total_loss_dict[key] /= args.eval_iters * get_num_microbatches()
827

828
    return total_loss_dict, collected_non_loss_data
829
830
831

def evaluate_and_print_results(prefix, forward_step_func,
                               data_iterator, model,
832
833
                               iteration, process_non_loss_data_func,
                               verbose=False):
834
    """Helper function to evaluate and dump results on screen."""
835
    args = get_args()
Mohammad's avatar
Mohammad committed
836
837
    writer = get_tensorboard_writer()

838
839
840
    total_loss_dict, collected_non_loss_data = evaluate(
        forward_step_func, data_iterator, model,
        process_non_loss_data_func, verbose)
841
842
843
844
845
    string = ' validation loss at {} | '.format(prefix)
    for key in total_loss_dict:
        string += '{} value: {:.6E} | '.format(key, total_loss_dict[key].item())
        ppl = math.exp(min(20, total_loss_dict[key].item()))
        string += '{} PPL: {:.6E} | '.format(key, ppl)
mshoeybi's avatar
mshoeybi committed
846
        if writer:
mohammad's avatar
mohammad committed
847
            writer.add_scalar('{} validation'.format(key),
848
849
                              total_loss_dict[key].item(),
                              iteration)
mohammad's avatar
mohammad committed
850
            writer.add_scalar('{} validation vs samples'.format(key),
851
852
                              total_loss_dict[key].item(),
                              args.consumed_train_samples)
853
            if args.log_validation_ppl_to_tensorboard:
mohammad's avatar
mohammad committed
854
                writer.add_scalar('{} validation ppl'.format(key), ppl,
855
                                  iteration)
mohammad's avatar
mohammad committed
856
                writer.add_scalar('{} validation ppl vs samples'.format(key),
857
                                  ppl, args.consumed_train_samples)
858

859
860
861
    if process_non_loss_data_func is not None and writer and is_last_rank():
        process_non_loss_data_func(collected_non_loss_data, iteration, writer)

862
    length = len(string) + 1
863
864
865
    print_rank_last('-' * length)
    print_rank_last(string)
    print_rank_last('-' * length)
866
867


Vijay Korthikanti's avatar
Vijay Korthikanti committed
868
def cyclic_iter(iter):
869
    while True:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
870
        for x in iter:
871
872
            yield x

873
874
875
def build_train_valid_test_data_iterators(
        build_train_valid_test_datasets_provider):
    """XXX"""
Mohammad's avatar
Mohammad committed
876
    args = get_args()
877

878
879
880
    (train_dataloader, valid_dataloader, test_dataloader) = (None, None, None)

    print_rank_0('> building train, validation, and test datasets ...')
881
882
883

    # Backward compatibility, assume fixed batch size.
    if args.iteration > 0 and args.consumed_train_samples == 0:
884
885
        assert args.train_samples is None, \
            'only backward compatiblity support for iteration-based training'
mohammad's avatar
mohammad committed
886
        args.consumed_train_samples = args.iteration * args.global_batch_size
887
    if args.iteration > 0 and args.consumed_valid_samples == 0:
888
889
890
        if args.train_samples is None:
            args.consumed_valid_samples = (args.iteration // args.eval_interval) * \
                args.eval_iters * args.global_batch_size
891

892
    # Data loader only on rank 0 of each model parallel group.
893
    if mpu.get_tensor_model_parallel_rank() == 0:
894
895

        # Number of train/valid/test samples.
896
897
898
899
900
901
        if args.train_samples:
            train_samples = args.train_samples
        else:
            train_samples = args.train_iters * args.global_batch_size
        eval_iters = (args.train_iters // args.eval_interval + 1) * \
                     args.eval_iters
902
        test_iters = args.eval_iters
903
        train_val_test_num_samples = [train_samples,
mohammad's avatar
mohammad committed
904
905
                                      eval_iters * args.global_batch_size,
                                      test_iters * args.global_batch_size]
906
907
908
909
910
911
912
913
914
915
        print_rank_0(' > datasets target sizes (minimum size):')
        print_rank_0('    train:      {}'.format(train_val_test_num_samples[0]))
        print_rank_0('    validation: {}'.format(train_val_test_num_samples[1]))
        print_rank_0('    test:       {}'.format(train_val_test_num_samples[2]))

        # Build the datasets.
        train_ds, valid_ds, test_ds = build_train_valid_test_datasets_provider(
            train_val_test_num_samples)

        # Build dataloders.
916
917
918
919
920
        train_dataloader = build_pretraining_data_loader(
            train_ds, args.consumed_train_samples)
        valid_dataloader = build_pretraining_data_loader(
            valid_ds, args.consumed_valid_samples)
        test_dataloader = build_pretraining_data_loader(test_ds, 0)
921
922
923
924
925
926
927
928
929
930
931
932
933

        # Flags to know if we need to do training/validation/testing.
        do_train = train_dataloader is not None and args.train_iters > 0
        do_valid = valid_dataloader is not None and args.eval_iters > 0
        do_test = test_dataloader is not None and args.eval_iters > 0
        # Need to broadcast num_tokens and num_type_tokens.
        flags = torch.cuda.LongTensor(
            [int(do_train), int(do_valid), int(do_test)])
    else:
        flags = torch.cuda.LongTensor([0, 0, 0])

    # Broadcast num tokens.
    torch.distributed.broadcast(flags,
934
935
                                mpu.get_tensor_model_parallel_src_rank(),
                                group=mpu.get_tensor_model_parallel_group())
936
937
938
939
940
    args.do_train = flags[0].item()
    args.do_valid = flags[1].item()
    args.do_test = flags[2].item()

    # Build iterators.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
941
942
943
    dl_type = args.dataloader_type
    assert dl_type in ['single', 'cyclic']

944
    if train_dataloader is not None:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
945
946
        train_data_iterator = iter(train_dataloader) if dl_type == 'single' \
                              else iter(cyclic_iter(train_dataloader))
947
948
949
    else:
        train_data_iterator = None

950
    if valid_dataloader is not None:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
951
952
        valid_data_iterator = iter(valid_dataloader) if dl_type == 'single' \
                              else iter(cyclic_iter(valid_dataloader))
953
    else:
954
        valid_data_iterator = None
955

956
    if test_dataloader is not None:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
957
958
        test_data_iterator = iter(test_dataloader) if dl_type == 'single' \
                             else iter(cyclic_iter(test_dataloader))
959
960
961
    else:
        test_data_iterator = None

962
    return train_data_iterator, valid_data_iterator, test_data_iterator