training.py 36.9 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Mohammad's avatar
Mohammad committed
16
"""Pretrain utilities."""
17
18
19

from datetime import datetime
import math
Mohammad's avatar
Mohammad committed
20
import sys
21
22
23
24
import time
# The earliest we can measure the start time.
_TRAIN_START_TIME = time.time()

25
26
27
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP

Neel Kant's avatar
Neel Kant committed
28
from megatron import get_args
Mohammad's avatar
Mohammad committed
29
30
from megatron import get_timers
from megatron import get_tensorboard_writer
31
from megatron import get_current_global_batch_size
mohammad's avatar
mohammad committed
32
from megatron import get_num_microbatches
mohammad's avatar
mohammad committed
33
from megatron import is_last_rank
mohammad's avatar
mohammad committed
34
from megatron import update_num_microbatches
35
from megatron import mpu
Neel Kant's avatar
Neel Kant committed
36
from megatron import print_rank_0
37
from megatron import print_rank_last
Mohammad's avatar
Mohammad committed
38
39
from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
40
from megatron.model import Float16Module
41
from megatron.model import ModelType
mohammad's avatar
mohammad committed
42
from megatron.optimizer import get_megatron_optimizer
Mohammad's avatar
Mohammad committed
43
from megatron.initialize import initialize_megatron
44
from megatron.initialize import write_args_to_tensorboard
45
46
47
from megatron.learning_rates import AnnealingLR
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.utils import check_adlr_autoresume_termination
48
from megatron.utils import unwrap_model
Vijay Korthikanti's avatar
Vijay Korthikanti committed
49
from megatron.data.data_samplers import build_pretraining_data_loader
mohammad's avatar
mohammad committed
50
from megatron.utils import calc_params_l2_norm
51
from megatron.schedules import get_forward_backward_func
52
from megatron.utils import report_memory
53
54


Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
55

56
57
58
59
60
61
62
def print_datetime(string):
    """Note that this call will sync across all ranks."""
    torch.distributed.barrier()
    time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    print_rank_0('[' + string + '] datetime: {} '.format(time_str))


63
def pretrain(train_valid_test_dataset_provider,
64
             model_provider,
65
             model_type,
66
67
             forward_step_func,
             extra_args_provider=None,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
68
             args_defaults={}):
69
70
71
    """Main training program.

    This function will run the followings in the order provided:
Mohammad's avatar
Mohammad committed
72
73
        1) initialize Megatron.
        2) setup model, optimizer and lr schedule using the model_provider.
74
        3) call train_val_test_data_provider to get train/val/test datasets.
Mohammad's avatar
Mohammad committed
75
        4) train the modle using the forward_step_func.
76
77

    Arguments:
78
79
80
        train_valid_test_dataset_provider: a function that takes the size of
            train/valid/test dataset and returns `train, valid, test` datasets.
        model_provider: a function that returns a vanilla version of the
Mohammad's avatar
Mohammad committed
81
            model. By vanilla we mean a simple model on cpu with no fp16 or ddp.
82
        model_type: an enum that specifies the type of model being trained.
Mohammad's avatar
Mohammad committed
83
84
85
86
87
88
89
90
91
        forward_step_func: a function that takes a `data iterator` and `model`,
            and returns a `loss` scalar with a dictionary with key:values being
            the info we would like to monitor during training, for example
            `lm-loss: value`. We also require that this function add
            `batch generator` to the timers class.
        extra_args_provider: a function that takes a parser and adds arguments
            to it. It is used for programs to add their own arguments.
        args_defaults: a dictionary from argument-name to argument-value. It
            to set already parse arguments.
92
93
    """

94
    # Initalize and get arguments, timers, and Tensorboard writer.
95
96
    initialize_megatron(extra_args_provider=extra_args_provider,
                        args_defaults=args_defaults)
97

98
99
100
101
    # Adjust the startup time so it reflects the largest value.
    # This will be closer to what scheduler will see (outside of
    # image ... launches.
    global _TRAIN_START_TIME
102
    start_time_tensor = torch.cuda.DoubleTensor([_TRAIN_START_TIME])
103
104
105
    torch.distributed.all_reduce(start_time_tensor,
                                 op=torch.distributed.ReduceOp.MIN)
    _TRAIN_START_TIME = start_time_tensor.item()
mshoeybi's avatar
mshoeybi committed
106
    print_rank_0('time to initialize megatron (seconds): {:.3f}'.format(
107
108
109
        time.time() - _TRAIN_START_TIME))
    print_datetime('after megatron is initialized')

110
    args = get_args()
Mohammad's avatar
Mohammad committed
111
    timers = get_timers()
112
113

    # Model, optimizer, and learning rate.
114
    timers('model-and-optimizer-setup').start()
115
116
    model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider,
                                                               model_type)
117
    timers('model-and-optimizer-setup').stop()
118
119
    print_datetime('after model, optimizer, and learning rate '
                   'scheduler are built')
120
121

    # Data stuff.
122
123
    timers('train/valid/test-data-iterators-setup').start()
    if args.virtual_pipeline_model_parallel_size is not None:
124
        all_data_iterators = [
125
126
127
            build_train_valid_test_data_iterators(train_valid_test_dataset_provider)
            for _ in range(len(model))
        ]
128
129
130
        train_data_iterator = [data_iterators[0] for data_iterators in all_data_iterators]
        valid_data_iterator = [data_iterators[1] for data_iterators in all_data_iterators]
        test_data_iterator = [data_iterators[2] for data_iterators in all_data_iterators]
131
132
133
134
135
    else:
        train_data_iterator, valid_data_iterator, test_data_iterator \
            = build_train_valid_test_data_iterators(
                train_valid_test_dataset_provider)
    timers('train/valid/test-data-iterators-setup').stop()
mshoeybi's avatar
mshoeybi committed
136
    print_datetime('after dataloaders are built')
Mohammad's avatar
Mohammad committed
137
138

    # Print setup timing.
139
140
    print_rank_0('done with setup ...')
    timers.log(['model-and-optimizer-setup', 'train/valid/test-data-iterators-setup'])
Mohammad's avatar
Mohammad committed
141
    print_rank_0('training ...')
142
143

    iteration = 0
144
    if args.do_train and args.train_iters > 0:
mohammad's avatar
mohammad committed
145
146
147
        iteration = train(forward_step_func,
                          model, optimizer, lr_scheduler,
                          train_data_iterator, valid_data_iterator)
148
    print_datetime('after training is done')
Mohammad's avatar
Mohammad committed
149

150
151
152
    if args.do_valid:
        prefix = 'the end of training for val data'
        evaluate_and_print_results(prefix, forward_step_func,
153
                                   valid_data_iterator, model,
Mohammad's avatar
Mohammad committed
154
                                   iteration, False)
155
156

    if args.save and iteration != 0:
157
        save_checkpoint(iteration, model, optimizer, lr_scheduler)
158
159
160
161
162
163

    if args.do_test:
        # Run on test data.
        prefix = 'the end of training for test data'
        evaluate_and_print_results(prefix, forward_step_func,
                                   test_data_iterator, model,
Mohammad's avatar
Mohammad committed
164
                                   0, True)
165

166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
def update_train_iters(args):

    # For iteration-based training, we don't need to do anything
    if args.train_iters:
        return

    # Constant batch size with sample-based training.
    if args.rampup_batch_size is None:
        args.train_iters = args.train_samples // args.global_batch_size

    else:
        # Sample based training with rampup batch size.
        iterations = 0
        consumed_samples = 0
        # Rampup phase.
        while consumed_samples <= int(args.rampup_batch_size[2]):
182
183
            update_num_microbatches(consumed_samples, consistency_check=False)
            consumed_samples += get_current_global_batch_size()
184
185
            iterations += 1
        # Reset
186
        update_num_microbatches(0, consistency_check=False)
187
188
189
190
191
192
193
194
        # Constant phase
        # Note that we throw away any partial last batch.
        iterations += (args.train_samples - consumed_samples) // \
                      args.global_batch_size
        args.train_iters = iterations

    print_rank_0('setting training iterations to {}'.format(args.train_iters))

195

196
def get_model(model_provider_func, model_type, wrap_with_ddp=True):
197
    """Build the model."""
Mohammad's avatar
Mohammad committed
198
    args = get_args()
199
    args.model_type = model_type
200

201
    # Build model.
202
203
    if mpu.get_pipeline_model_parallel_world_size() > 1 and \
       args.virtual_pipeline_model_parallel_size is not None:
204
205
        assert model_type != ModelType.encoder_and_decoder, \
            "Interleaved schedule not supported for model with both encoder and decoder"
206
207
208
        model = []
        for i in range(args.virtual_pipeline_model_parallel_size):
            mpu.set_virtual_pipeline_model_parallel_rank(i)
209
210
211
            # Set pre_process and post_process only after virtual rank is set.
            pre_process = mpu.is_pipeline_first_stage()
            post_process = mpu.is_pipeline_last_stage()
212
            this_model = model_provider_func(
213
214
215
                pre_process=pre_process,
                post_process=post_process
            )
216
            this_model.model_type = model_type
217
            model.append(this_model)
218
    else:
219
220
        pre_process = mpu.is_pipeline_first_stage()
        post_process = mpu.is_pipeline_last_stage()
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
        add_encoder = True
        add_decoder = True
        if model_type == ModelType.encoder_and_decoder:
            if mpu.get_pipeline_model_parallel_world_size() > 1:
                assert args.pipeline_model_parallel_split_rank is not None, \
                    "Split rank needs to be specified for model with both encoder and decoder"
                rank = mpu.get_pipeline_model_parallel_rank()
                split_rank = args.pipeline_model_parallel_split_rank
                world_size = mpu.get_pipeline_model_parallel_world_size()
                pre_process = rank == 0 or rank == split_rank
                post_process = (rank == (split_rank - 1)) or (
                        rank == (world_size - 1))
                add_encoder = mpu.is_pipeline_stage_before_split()
                add_decoder = mpu.is_pipeline_stage_after_split()
            model = model_provider_func(
                pre_process=pre_process,
                post_process=post_process,
                add_encoder=add_encoder,
                add_decoder=add_decoder)
        else:
            model = model_provider_func(
                pre_process=pre_process,
                post_process=post_process
            )
        model.model_type = model_type
246

247
248
    if not isinstance(model, list):
        model = [model]
249

250
    # Set tensor model parallel attributes if not set.
mohammad's avatar
mohammad committed
251
252
253
    # Only parameters that are already tensor model parallel have these
    # attributes set for them. We should make sure the default attributes
    # are set for all params so the optimizer can use them.
254
255
256
    for model_module in model:
        for param in model_module.parameters():
            mpu.set_defaults_if_not_set_tensor_model_parallel_attributes(param)
257

258
259
    # Print number of parameters.
    if mpu.get_data_parallel_rank() == 0:
260
        print(' > number of parameters on (tensor, pipeline) '
261
              'model parallel rank ({}, {}): {}'.format(
262
263
            mpu.get_tensor_model_parallel_rank(),
            mpu.get_pipeline_model_parallel_rank(),
264
265
            sum([sum([p.nelement() for p in model_module.parameters()])
                 for model_module in model])), flush=True)
266
267

    # GPU allocation.
268
269
    for model_module in model:
        model_module.cuda(torch.cuda.current_device())
270
271

    # Fp16 conversion.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
272
273
    if args.fp16 or args.bf16:
        model = [Float16Module(model_module, args) for model_module in model]
274

275
276
277
278
279
280
    if wrap_with_ddp:
        if args.DDP_impl == 'torch':
            i = torch.cuda.current_device()
            model = [torchDDP(model_module, device_ids=[i], output_device=i,
                              process_group=mpu.get_data_parallel_group())
                     for model_module in model]
281

282
283
284
285
286
287
288
289
290
        elif args.DDP_impl == 'local':
            model = [LocalDDP(model_module,
                              args.accumulate_allreduce_grads_in_fp32,
                              args.use_contiguous_buffers_in_local_ddp)
                     for model_module in model]

        else:
            raise NotImplementedError('Unknown DDP implementation specified: '
                                      '{}. Exiting.'.format(args.DDP_impl))
291

292
    return model
293
294


Mohammad's avatar
Mohammad committed
295
def get_learning_rate_scheduler(optimizer):
296
    """Build the learning rate scheduler."""
Mohammad's avatar
Mohammad committed
297
    args = get_args()
298

299
300
301
302
303
    # Iteration-based training.
    if args.train_iters:
        if args.lr_decay_iters is None:
            args.lr_decay_iters = args.train_iters
        decay_steps = args.lr_decay_iters * args.global_batch_size
304
305
        if args.lr_warmup_fraction is not None:
            warmup_steps = args.lr_warmup_fraction * decay_steps
306
307
        else:
            warmup_steps = args.lr_warmup_iters * args.global_batch_size
308
309
310
311
312
    # Sample-based training.
    elif args.train_samples:
        # We need to set training iters for later use. Technically
        # we need to adjust the training samples too (due to last
        # batch being incomplete) but we leave it as is for now.
313
        update_train_iters(args)
314
315
316
        if args.lr_decay_samples is None:
            args.lr_decay_samples = args.train_samples
        decay_steps = args.lr_decay_samples
317
318
        if args.lr_warmup_fraction is not None:
            warmup_steps = args.lr_warmup_fraction * decay_steps
319
320
        else:
            warmup_steps = args.lr_warmup_samples
321
    else:
322
323
324
        raise Exception(
            'either train-iters or train-samples should be provided.')

325
326
    lr_scheduler = AnnealingLR(
        optimizer,
327
        max_lr=args.lr,
328
        min_lr=args.min_lr,
329
330
        warmup_steps=warmup_steps,
        decay_steps=decay_steps,
331
        decay_style=args.lr_decay_style,
332
333
334
335
336
337
        use_checkpoint_lr_scheduler=args.use_checkpoint_lr_scheduler,
        override_lr_scheduler=args.override_lr_scheduler)

    return lr_scheduler


338
def setup_model_and_optimizer(model_provider_func, model_type):
339
    """Setup model and optimizer."""
Mohammad's avatar
Mohammad committed
340
    args = get_args()
341

342
    model = get_model(model_provider_func, model_type)
343

344
    unwrapped_model = unwrap_model(model,
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
345
                                   (torchDDP, LocalDDP, Float16Module))
346
347
    optimizer = get_megatron_optimizer(unwrapped_model)

Mohammad's avatar
Mohammad committed
348
    lr_scheduler = get_learning_rate_scheduler(optimizer)
349
350

    if args.load is not None:
351
352
353
354
        timers = get_timers()
        # Extra barrier is added to make sure all ranks report the
        # max time.
        torch.distributed.barrier()
355
        timers('load-checkpoint').start()
356
        args.iteration = load_checkpoint(model, optimizer, lr_scheduler)
357
        torch.distributed.barrier()
358
359
        timers('load-checkpoint').stop()
        timers.log(['load-checkpoint'])
360
361
362
    else:
        args.iteration = 0

mohammad's avatar
mohammad committed
363
    # We only support local DDP with multiple micro-batches.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
364
    if len(model) > 1 or mpu.get_pipeline_model_parallel_world_size() > 1:
mohammad's avatar
mohammad committed
365
366
        assert args.DDP_impl == 'local'

Neel Kant's avatar
Neel Kant committed
367
    # get model without FP16 and/or TorchDDP wrappers
Mostofa Patwary's avatar
Mostofa Patwary committed
368
369
    if args.iteration == 0 and len(unwrapped_model) == 1 \
        and hasattr(unwrapped_model[0], 'init_state_dict_from_bert'):
Mostofa Patwary's avatar
Mostofa Patwary committed
370
        print_rank_0("Initializing ICT from pretrained BERT model")
Mostofa Patwary's avatar
Mostofa Patwary committed
371
        unwrapped_model[0].init_state_dict_from_bert()
Mostofa Patwary's avatar
Mostofa Patwary committed
372
373
        if args.fp16:
            optimizer.reload_model_params()
Neel Kant's avatar
Neel Kant committed
374

375
376
377
    return model, optimizer, lr_scheduler


378
379
380
381
382
383
384
def train_step(forward_step_func, data_iterator,
               model, optimizer, lr_scheduler):
    """Single training step."""
    args = get_args()
    timers = get_timers()

    # Set grad to zero.
385
    if args.DDP_impl == 'local' and args.use_contiguous_buffers_in_local_ddp:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
386
387
        for partition in model:
            partition.zero_grad_buffer()
388
    optimizer.zero_grad()
389

390
    forward_backward_func = get_forward_backward_func()
391
392
393
    losses_reduced = forward_backward_func(
        forward_step_func, data_iterator, model,
        optimizer, timers, forward_only=False)
394

395
    # Empty unused memory
Lawrence McAfee's avatar
Lawrence McAfee committed
396
    if args.empty_unused_memory_level >= 1:
397
398
        torch.cuda.empty_cache()

399
400
    # All-reduce if needed.
    if args.DDP_impl == 'local':
401
        timers('backward-params-all-reduce').start()
402
        for model_module in model:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
403
            model_module.allreduce_gradients()
404
        timers('backward-params-all-reduce').stop()
405

406
407
408
409
    # All-reduce word_embeddings' grad across first and last stages to ensure
    # that word_embeddings parameters stay in sync.
    # This should only run for models that support pipelined model parallelism
    # (BERT and GPT-2).
410
    timers('backward-embedding-all-reduce').start()
411
    if mpu.is_rank_in_embedding_group(ignore_virtual=True) and \
412
            mpu.get_pipeline_model_parallel_world_size() > 1:
413
414
415
416
        if mpu.is_pipeline_first_stage(ignore_virtual=True):
            unwrapped_model = model[0]
        elif mpu.is_pipeline_last_stage(ignore_virtual=True):
            unwrapped_model = model[-1]
417
418
        else:  # We do not support the interleaved schedule for T5 yet.
            unwrapped_model = model[0]
419
        unwrapped_model = unwrap_model(
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
420
            unwrapped_model, (torchDDP, LocalDDP, Float16Module))
421

422
423
        if unwrapped_model.share_word_embeddings:
            word_embeddings_weight = unwrapped_model.word_embeddings_weight()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
424
425
426
427
428
            if args.DDP_impl == 'local':
                grad = word_embeddings_weight.main_grad
            else:
                grad = word_embeddings_weight.grad
            torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())
429
    timers('backward-embedding-all-reduce').stop()
430

431
432
    # Update parameters.
    timers('optimizer').start()
433
    update_successful, grad_norm, num_zeros_in_grad = optimizer.step()
434
435
436
    timers('optimizer').stop()

    # Update learning rate.
437
    if update_successful:
438
439
440
441
        increment = get_num_microbatches() * \
                    args.micro_batch_size * \
                    args.data_parallel_size
        lr_scheduler.step(increment=increment)
mohammad's avatar
mohammad committed
442
        skipped_iter = 0
443
444
445
    else:
        skipped_iter = 1

446
    # Empty unused memory
Lawrence McAfee's avatar
Lawrence McAfee committed
447
    if args.empty_unused_memory_level >= 2:
448
449
        torch.cuda.empty_cache()

450
    if mpu.is_pipeline_last_stage(ignore_virtual=True):
451
452
453
454
        # Average loss across microbatches.
        loss_reduced = {}
        for key in losses_reduced[0]:
            losses_reduced_for_key = [x[key] for x in losses_reduced]
455
            loss_reduced[key] = sum(losses_reduced_for_key) / len(losses_reduced_for_key)
456
457
        return loss_reduced, skipped_iter, grad_norm, num_zeros_in_grad
    return {}, skipped_iter, grad_norm, num_zeros_in_grad
458
459


Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
460
def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
mohammad's avatar
mohammad committed
461
                 loss_scale, report_memory_flag, skipped_iter,
462
                 grad_norm, params_norm, num_zeros_in_grad):
Mohammad's avatar
Mohammad committed
463
464
465
466
    """Log training information such as losses, timing, ...."""
    args = get_args()
    timers = get_timers()
    writer = get_tensorboard_writer()
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
467

mohammad's avatar
mohammad committed
468
469
    # Advanced, skipped, and Nan iterations.
    advanced_iters_key = 'advanced iterations'
mohammad's avatar
mohammad committed
470
    skipped_iters_key = 'skipped iterations'
mohammad's avatar
mohammad committed
471
472
473
474
475
476
477
478
479
    nan_iters_key = 'nan iterations'
    # Advanced iterations.
    if not skipped_iter:
        total_loss_dict[advanced_iters_key] = total_loss_dict.get(
            advanced_iters_key, 0) + 1
    else:
        if advanced_iters_key not in total_loss_dict:
            total_loss_dict[advanced_iters_key] = 0
    # Skipped iterations.
mohammad's avatar
mohammad committed
480
481
    total_loss_dict[skipped_iters_key] = total_loss_dict.get(
        skipped_iters_key, 0) + skipped_iter
mohammad's avatar
mohammad committed
482
    # Update losses and set nan iterations
mohammad's avatar
mohammad committed
483
    got_nan = False
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
484
    for key in loss_dict:
mohammad's avatar
mohammad committed
485
        if not skipped_iter:
486
487
            total_loss_dict[key] = total_loss_dict.get(
                key, torch.cuda.FloatTensor([0.0])) + loss_dict[key]
mohammad's avatar
mohammad committed
488
489
490
491
492
        else:
            value = loss_dict[key].float().sum().item()
            is_nan = value == float('inf') or \
                     value == -float('inf') or \
                     value != value
mohammad's avatar
mohammad committed
493
            got_nan = got_nan or is_nan
mohammad's avatar
mohammad committed
494
495
    total_loss_dict[nan_iters_key] = total_loss_dict.get(
        nan_iters_key, 0) + int(got_nan)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
496
497
498

    # Logging.
    timers_to_log = []
Neel Kant's avatar
Neel Kant committed
499

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
500
501
502
    def add_to_logging(name):
        if name in timers.timers:
            timers_to_log.append(name)
503
504
505
    add_to_logging('forward-compute')
    add_to_logging('forward-recv')
    add_to_logging('forward-send')
506
    add_to_logging('forward-backward-send-forward-backward-recv')
507
508
509
    add_to_logging('backward-compute')
    add_to_logging('backward-recv')
    add_to_logging('backward-send')
Deepak Narayanan's avatar
Deepak Narayanan committed
510
    add_to_logging('backward-send-forward-recv')
511
    add_to_logging('backward-send-backward-recv')
512
    add_to_logging('backward-params-all-reduce')
513
    add_to_logging('backward-embedding-all-reduce')
514
    add_to_logging('optimizer-copy-to-main-grad')
mohammad's avatar
mohammad committed
515
    add_to_logging('optimizer-unscale-and-check-inf')
516
517
    add_to_logging('optimizer-clip-main-grad')
    add_to_logging('optimizer-copy-main-to-model-params')
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
518
    add_to_logging('optimizer')
mohammad's avatar
mohammad committed
519
    add_to_logging('batch-generator')
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
520

mohammad's avatar
mohammad committed
521
    # Calculate batch size.
mshoeybi's avatar
mshoeybi committed
522
523
524
    batch_size = args.micro_batch_size * args.data_parallel_size * \
        get_num_microbatches()

mohammad's avatar
mohammad committed
525
526
527
    total_iterations = total_loss_dict[advanced_iters_key] + \
                       total_loss_dict[skipped_iters_key]

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
528
    # Tensorboard values.
529
530
531
532
533
534
535
536
537
538
    if writer and (iteration % args.tensorboard_log_interval == 0 ) and \
       is_last_rank():
        if args.log_learning_rate_to_tensorboard:
            writer.add_scalar('learning-rate', learning_rate, iteration)
            writer.add_scalar('learning-rate vs samples', learning_rate,
                              args.consumed_train_samples)
        if args.log_batch_size_to_tensorboard:
            writer.add_scalar('batch-size', batch_size, iteration)
            writer.add_scalar('batch-size vs samples', batch_size,
                              args.consumed_train_samples)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
539
        for key in loss_dict:
mohammad's avatar
mohammad committed
540
541
            writer.add_scalar(key , loss_dict[key], iteration)
            writer.add_scalar(key + ' vs samples', loss_dict[key],
542
                              args.consumed_train_samples)
543
544
545
546
        if args.log_loss_scale_to_tensorboard:
            writer.add_scalar('loss-scale', loss_scale, iteration)
            writer.add_scalar('loss-scale vs samples', loss_scale,
                              args.consumed_train_samples)
547
548
549
550
        if grad_norm is not None:
            writer.add_scalar('grad-norm', grad_norm, iteration)
            writer.add_scalar('grad-norm vs samples', grad_norm,
                              args.consumed_train_samples)
551
552
553
        if num_zeros_in_grad is not None:
            writer.add_scalar('num-zeros', num_zeros_in_grad, iteration)
            writer.add_scalar('num-zeros vs samples', num_zeros_in_grad,
Rewon Child's avatar
Rewon Child committed
554
                              args.consumed_train_samples)
mohammad's avatar
mohammad committed
555
556
557
558
        if params_norm is not None:
            writer.add_scalar('params-norm', params_norm, iteration)
            writer.add_scalar('params-norm vs samples', params_norm,
                              args.consumed_train_samples)
559
560
561
        if args.log_timers_to_tensorboard:
            timers.write(timers_to_log, writer, iteration,
                         normalizer=total_iterations)
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
        if args.log_memory_to_tensorboard:
            mem_stats = torch.cuda.memory_stats()
            writer.add_scalar(
                "mem-reserved-bytes",
                mem_stats["reserved_bytes.all.current"],
                iteration,
            )
            writer.add_scalar(
                "mem-allocated-bytes",
                mem_stats["allocated_bytes.all.current"],
                iteration,
            )
            writer.add_scalar(
                "mem-allocated-count",
                mem_stats["allocation.all.current"],
                iteration,
            )
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
579
580

    if iteration % args.log_interval == 0:
581
        elapsed_time = timers('interval-time').elapsed()
mohammad's avatar
mohammad committed
582
        elapsed_time_per_iteration = elapsed_time / total_iterations
mshoeybi's avatar
mshoeybi committed
583
        if writer:
584
585
586
            if args.log_timers_to_tensorboard:
                writer.add_scalar('iteration-time',
                                  elapsed_time_per_iteration, iteration)
587
588
        log_string = ' iteration {:8d}/{:8d} |'.format(
            iteration, args.train_iters)
mshoeybi's avatar
mshoeybi committed
589
        log_string += ' consumed samples: {:12d} |'.format(
590
            args.consumed_train_samples)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
591
        log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
mohammad's avatar
mohammad committed
592
            elapsed_time_per_iteration * 1000.0)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
593
        log_string += ' learning rate: {:.3E} |'.format(learning_rate)
mohammad's avatar
mohammad committed
594
        log_string += ' global batch size: {:5d} |'.format(batch_size)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
595
        for key in total_loss_dict:
mohammad's avatar
mohammad committed
596
597
598
599
            if key not in [advanced_iters_key, skipped_iters_key,
                           nan_iters_key]:
                avg = total_loss_dict[key].item() / \
                      float(max(1, total_loss_dict[advanced_iters_key]))
600
601
602
                if avg > 0.0:
                    log_string += ' {}: {:.6E} |'.format(key, avg)
                total_loss_dict[key] = torch.cuda.FloatTensor([0.0])
603
        log_string += ' loss scale: {:.1f} |'.format(loss_scale)
604
605
        if grad_norm is not None:
            log_string += ' grad norm: {:.3f} |'.format(grad_norm)
606
607
        if num_zeros_in_grad is not None:
            log_string += ' num zeros: {:.1f} |'.format(num_zeros_in_grad)
mohammad's avatar
mohammad committed
608
609
        if params_norm is not None:
            log_string += ' params norm: {:.3f} |'.format(params_norm)
mohammad's avatar
mohammad committed
610
611
        log_string += ' number of skipped iterations: {:3d} |'.format(
            total_loss_dict[skipped_iters_key])
mohammad's avatar
mohammad committed
612
        log_string += ' number of nan iterations: {:3d} |'.format(
mohammad's avatar
mohammad committed
613
614
            total_loss_dict[nan_iters_key])
        total_loss_dict[advanced_iters_key] = 0
mohammad's avatar
mohammad committed
615
        total_loss_dict[skipped_iters_key] = 0
mohammad's avatar
mohammad committed
616
        total_loss_dict[nan_iters_key] = 0
617
        print_rank_last(log_string)
618
619
620
        if report_memory_flag and learning_rate > 0.:
            # Report memory after optimizer state has been initialized.
            report_memory('(after {} iterations)'.format(iteration))
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
621
622
623
624
625
626
            report_memory_flag = False
        timers.log(timers_to_log, normalizer=args.log_interval)

    return report_memory_flag


627
628
629
630
631
def save_checkpoint_and_time(iteration, model, optimizer, lr_scheduler):
    timers = get_timers()
    # Extra barrier is added to make sure
    # all ranks report the max time.
    torch.distributed.barrier()
632
    timers('save-checkpoint').start()
633
634
    save_checkpoint(iteration, model, optimizer, lr_scheduler)
    torch.distributed.barrier()
635
636
    timers('save-checkpoint').stop()
    timers.log(['save-checkpoint'])
637
638


639
def train(forward_step_func, model, optimizer, lr_scheduler,
640
          train_data_iterator, valid_data_iterator):
641
    """Train the model function."""
Mohammad's avatar
Mohammad committed
642
643
    args = get_args()
    timers = get_timers()
644

645
646
647
    # Write args to tensorboard
    write_args_to_tensorboard()

648
    # Turn on training mode which enables dropout.
649
650
    for model_module in model:
        model_module.train()
651
652
653
654
655
656
657

    # Tracking loss.
    total_loss_dict = {}

    # Iterations.
    iteration = args.iteration

658
    timers('interval-time').start()
659
    print_datetime('before the start of training step')
660
661
    report_memory_flag = True
    while iteration < args.train_iters:
mohammad's avatar
mohammad committed
662
        update_num_microbatches(args.consumed_train_samples)
663
664
665
666
667
668
        loss_dict, skipped_iter, grad_norm, num_zeros_in_grad = \
            train_step(forward_step_func,
                       train_data_iterator,
                       model,
                       optimizer,
                       lr_scheduler)
669
        iteration += 1
670
        args.consumed_train_samples += mpu.get_data_parallel_world_size() * \
671
                                       args.micro_batch_size * \
mohammad's avatar
mohammad committed
672
                                       get_num_microbatches()
673
674

        # Logging.
675
        loss_scale = optimizer.get_loss_scale().item()
676
677
678
        params_norm = None
        if args.log_params_norm:
            params_norm = calc_params_l2_norm(model)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
679
680
        report_memory_flag = training_log(loss_dict, total_loss_dict,
                                          optimizer.param_groups[0]['lr'],
Mohammad's avatar
Mohammad committed
681
                                          iteration, loss_scale,
682
                                          report_memory_flag, skipped_iter,
683
                                          grad_norm, params_norm, num_zeros_in_grad)
684
685

        # Autoresume
686
687
        if args.adlr_autoresume and \
           (iteration % args.adlr_autoresume_interval == 0):
688
            check_adlr_autoresume_termination(iteration, model, optimizer,
689
                                              lr_scheduler)
690
691
692
693
694
695

        # Evaluation
        if args.eval_interval and iteration % args.eval_interval == 0 and \
           args.do_valid:
            prefix = 'iteration {}'.format(iteration)
            evaluate_and_print_results(prefix, forward_step_func,
696
                                       valid_data_iterator, model,
Mohammad's avatar
Mohammad committed
697
                                       iteration, False)
698

699
700
701
702
703
704
705
706
        # Checkpointing
        saved_checkpoint = False
        if args.save and args.save_interval and \
           iteration % args.save_interval == 0:
            save_checkpoint_and_time(iteration, model, optimizer,
                                     lr_scheduler)
            saved_checkpoint = True

707
708
709
710
711
712
713
714
715
716
717
718
        # Exiting based on duration
        if args.exit_duration_in_mins:
            train_time = (time.time() - _TRAIN_START_TIME) / 60.0
            done_cuda = torch.cuda.IntTensor(
                [train_time > args.exit_duration_in_mins])
            torch.distributed.all_reduce(
                done_cuda, op=torch.distributed.ReduceOp.MAX)
            done = done_cuda.item()
            if done:
                if not saved_checkpoint:
                    save_checkpoint_and_time(iteration, model, optimizer,
                                             lr_scheduler)
719
                print_datetime('exiting program after {} minutes'.format(train_time))
720
721
                sys.exit()

722
        # Exiting based on iterations
723
        if args.exit_interval and iteration % args.exit_interval == 0:
724
725
726
            if not saved_checkpoint:
                save_checkpoint_and_time(iteration, model, optimizer,
                                         lr_scheduler)
727
            torch.distributed.barrier()
728
            print_datetime('exiting program at iteration {}'.format(iteration))
Mohammad's avatar
Mohammad committed
729
            sys.exit()
730

731

mohammad's avatar
mohammad committed
732
    return iteration
733
734


Mohammad's avatar
Mohammad committed
735
def evaluate(forward_step_func, data_iterator, model, verbose=False):
736
    """Evaluation."""
Mohammad's avatar
Mohammad committed
737
    args = get_args()
738
739

    # Turn on evaluation mode which disables dropout.
740
741
    for model_module in model:
        model_module.eval()
742
743
744
745
746
747
748
749
750
751

    total_loss_dict = {}

    with torch.no_grad():
        iteration = 0
        while iteration < args.eval_iters:
            iteration += 1
            if verbose and iteration % args.log_interval == 0:
                print_rank_0('Evaluating iter {}/{}'.format(iteration,
                                                            args.eval_iters))
752

753
            forward_backward_func = get_forward_backward_func()
754
755
756
757
            loss_dicts = forward_backward_func(
                forward_step_func, data_iterator, model, optimizer=None,
                timers=None, forward_only=True)

758
            # Empty unused memory
Lawrence McAfee's avatar
Lawrence McAfee committed
759
            if args.empty_unused_memory_level >= 1:
760
761
                torch.cuda.empty_cache()

762
763
764
            if mpu.is_pipeline_last_stage(ignore_virtual=True):
                # Reduce across processes.
                for loss_dict in loss_dicts:
765
                    for key in loss_dict:
766
767
                        total_loss_dict[key] = total_loss_dict.get(
                            key, torch.cuda.FloatTensor([0.0])) + loss_dict[key]
768

769
            args.consumed_valid_samples += mpu.get_data_parallel_world_size() \
770
                                           * args.micro_batch_size \
mohammad's avatar
mohammad committed
771
                                           * get_num_microbatches()
772
    # Move model back to the train mode.
773
774
    for model_module in model:
        model_module.train()
775
776

    for key in total_loss_dict:
mohammad's avatar
mohammad committed
777
        total_loss_dict[key] /= args.eval_iters * get_num_microbatches()
778
779
780
781
782

    return total_loss_dict

def evaluate_and_print_results(prefix, forward_step_func,
                               data_iterator, model,
Mohammad's avatar
Mohammad committed
783
                               iteration, verbose=False):
784
    """Helper function to evaluate and dump results on screen."""
785
    args = get_args()
Mohammad's avatar
Mohammad committed
786
787
788
    writer = get_tensorboard_writer()

    total_loss_dict = evaluate(forward_step_func, data_iterator, model, verbose)
789
790
791
792
793
    string = ' validation loss at {} | '.format(prefix)
    for key in total_loss_dict:
        string += '{} value: {:.6E} | '.format(key, total_loss_dict[key].item())
        ppl = math.exp(min(20, total_loss_dict[key].item()))
        string += '{} PPL: {:.6E} | '.format(key, ppl)
mshoeybi's avatar
mshoeybi committed
794
        if writer:
mohammad's avatar
mohammad committed
795
            writer.add_scalar('{} validation'.format(key),
796
797
                              total_loss_dict[key].item(),
                              iteration)
mohammad's avatar
mohammad committed
798
            writer.add_scalar('{} validation vs samples'.format(key),
799
800
                              total_loss_dict[key].item(),
                              args.consumed_train_samples)
801
            if args.log_validation_ppl_to_tensorboard:
mohammad's avatar
mohammad committed
802
                writer.add_scalar('{} validation ppl'.format(key), ppl,
803
                                  iteration)
mohammad's avatar
mohammad committed
804
                writer.add_scalar('{} validation ppl vs samples'.format(key),
805
                                  ppl, args.consumed_train_samples)
806
807

    length = len(string) + 1
808
809
810
    print_rank_last('-' * length)
    print_rank_last(string)
    print_rank_last('-' * length)
811
812


Vijay Korthikanti's avatar
Vijay Korthikanti committed
813
def cyclic_iter(iter):
814
    while True:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
815
        for x in iter:
816
817
            yield x

818
819
820
def build_train_valid_test_data_iterators(
        build_train_valid_test_datasets_provider):
    """XXX"""
Mohammad's avatar
Mohammad committed
821
    args = get_args()
822

823
824
825
    (train_dataloader, valid_dataloader, test_dataloader) = (None, None, None)

    print_rank_0('> building train, validation, and test datasets ...')
826
827
828

    # Backward compatibility, assume fixed batch size.
    if args.iteration > 0 and args.consumed_train_samples == 0:
829
830
        assert args.train_samples is None, \
            'only backward compatiblity support for iteration-based training'
mohammad's avatar
mohammad committed
831
        args.consumed_train_samples = args.iteration * args.global_batch_size
832
    if args.iteration > 0 and args.consumed_valid_samples == 0:
833
834
835
        if args.train_samples is None:
            args.consumed_valid_samples = (args.iteration // args.eval_interval) * \
                args.eval_iters * args.global_batch_size
836

837
    # Data loader only on rank 0 of each model parallel group.
838
    if mpu.get_tensor_model_parallel_rank() == 0:
839
840

        # Number of train/valid/test samples.
841
842
843
844
845
846
        if args.train_samples:
            train_samples = args.train_samples
        else:
            train_samples = args.train_iters * args.global_batch_size
        eval_iters = (args.train_iters // args.eval_interval + 1) * \
                     args.eval_iters
847
        test_iters = args.eval_iters
848
        train_val_test_num_samples = [train_samples,
mohammad's avatar
mohammad committed
849
850
                                      eval_iters * args.global_batch_size,
                                      test_iters * args.global_batch_size]
851
852
853
854
855
856
857
858
859
860
        print_rank_0(' > datasets target sizes (minimum size):')
        print_rank_0('    train:      {}'.format(train_val_test_num_samples[0]))
        print_rank_0('    validation: {}'.format(train_val_test_num_samples[1]))
        print_rank_0('    test:       {}'.format(train_val_test_num_samples[2]))

        # Build the datasets.
        train_ds, valid_ds, test_ds = build_train_valid_test_datasets_provider(
            train_val_test_num_samples)

        # Build dataloders.
861
862
863
864
865
        train_dataloader = build_pretraining_data_loader(
            train_ds, args.consumed_train_samples)
        valid_dataloader = build_pretraining_data_loader(
            valid_ds, args.consumed_valid_samples)
        test_dataloader = build_pretraining_data_loader(test_ds, 0)
866
867
868
869
870
871
872
873
874
875
876
877
878

        # Flags to know if we need to do training/validation/testing.
        do_train = train_dataloader is not None and args.train_iters > 0
        do_valid = valid_dataloader is not None and args.eval_iters > 0
        do_test = test_dataloader is not None and args.eval_iters > 0
        # Need to broadcast num_tokens and num_type_tokens.
        flags = torch.cuda.LongTensor(
            [int(do_train), int(do_valid), int(do_test)])
    else:
        flags = torch.cuda.LongTensor([0, 0, 0])

    # Broadcast num tokens.
    torch.distributed.broadcast(flags,
879
880
                                mpu.get_tensor_model_parallel_src_rank(),
                                group=mpu.get_tensor_model_parallel_group())
881
882
883
884
    args.do_train = flags[0].item()
    args.do_valid = flags[1].item()
    args.do_test = flags[2].item()

Vijay Korthikanti's avatar
Vijay Korthikanti committed
885

886
    # Build iterators.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
887
888
889
    dl_type = args.dataloader_type
    assert dl_type in ['single', 'cyclic']

890
    if train_dataloader is not None:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
891
892
        train_data_iterator = iter(train_dataloader) if dl_type == 'single' \
                              else iter(cyclic_iter(train_dataloader))
893
894
895
    else:
        train_data_iterator = None

896
    if valid_dataloader is not None:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
897
898
        valid_data_iterator = iter(valid_dataloader) if dl_type == 'single' \
                              else iter(cyclic_iter(valid_dataloader))
899
    else:
900
        valid_data_iterator = None
901

902
    if test_dataloader is not None:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
903
904
        test_data_iterator = iter(test_dataloader) if dl_type == 'single' \
                             else iter(cyclic_iter(test_dataloader))
905
906
907
    else:
        test_data_iterator = None

908
    return train_data_iterator, valid_data_iterator, test_data_iterator