arguments.py 43.6 KB
Newer Older
Raul Puri's avatar
Raul Puri committed
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
Raul Puri's avatar
Raul Puri committed
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Mohammad's avatar
Mohammad committed
16
"""Megatron arguments."""
Raul Puri's avatar
Raul Puri committed
17
18
19
20

import argparse
import os

21
import torch
Raul Puri's avatar
Raul Puri committed
22

23
24
def parse_args(extra_args_provider=None, defaults={},
               ignore_unknown_args=False):
Mohammad's avatar
Mohammad committed
25
    """Parse all arguments."""
26
27
    parser = argparse.ArgumentParser(description='Megatron-LM Arguments',
                                     allow_abbrev=False)
Mohammad's avatar
Mohammad committed
28

Mohammad's avatar
Mohammad committed
29
30
31
32
33
34
35
36
37
38
39
40
    # Standard arguments.
    parser = _add_network_size_args(parser)
    parser = _add_regularization_args(parser)
    parser = _add_training_args(parser)
    parser = _add_initialization_args(parser)
    parser = _add_learning_rate_args(parser)
    parser = _add_checkpointing_args(parser)
    parser = _add_mixed_precision_args(parser)
    parser = _add_distributed_args(parser)
    parser = _add_validation_args(parser)
    parser = _add_data_args(parser)
    parser = _add_autoresume_args(parser)
Mostofa Patwary's avatar
Mostofa Patwary committed
41
    parser = _add_biencoder_args(parser)
42
    parser = _add_vit_args(parser)
43
    parser = _add_logging_args(parser)
mshoeybi's avatar
mshoeybi committed
44
    parser = _add_inference_args(parser)
Mohammad's avatar
Mohammad committed
45
46
47
48

    # Custom arguments.
    if extra_args_provider is not None:
        parser = extra_args_provider(parser)
Mohammad's avatar
Mohammad committed
49

Mohammad's avatar
Mohammad committed
50
    # Parse.
51
52
53
54
    if ignore_unknown_args:
        args, _ = parser.parse_known_args()
    else:
        args = parser.parse_args()
Mohammad's avatar
Mohammad committed
55

Mohammad's avatar
Mohammad committed
56
57
58
    # Distributed args.
    args.rank = int(os.getenv('RANK', '0'))
    args.world_size = int(os.getenv("WORLD_SIZE", '1'))
mohammad's avatar
mohammad committed
59
    # Tensor model parallel size.
60
61
    args.tensor_model_parallel_size = min(
        args.tensor_model_parallel_size, args.world_size)
mohammad's avatar
mohammad committed
62
63
64
65
    assert args.world_size % args.tensor_model_parallel_size == 0, 'world size'\
        ' ({}) is not divisible by tensor model parallel size ({})'.format(
            args.world_size, args.tensor_model_parallel_size)
    # Pipeline model parallel size.
66
67
68
    args.pipeline_model_parallel_size = min(
        args.pipeline_model_parallel_size,
        (args.world_size // args.tensor_model_parallel_size))
69
70
71
72
73
    args.transformer_pipeline_model_parallel_size = (
        args.pipeline_model_parallel_size - 1
        if args.standalone_embed_stage else
        args.pipeline_model_parallel_size
    )
mohammad's avatar
mohammad committed
74
    # Checks.
75
76
77
    model_parallel_size = args.pipeline_model_parallel_size * \
                          args.tensor_model_parallel_size
    assert args.world_size % model_parallel_size == 0, 'world size is not'\
78
        ' divisible by tensor parallel size ({}) times pipeline parallel ' \
mohammad's avatar
mohammad committed
79
80
        'size ({})'.format(args.world_size, args.tensor_model_parallel_size,
                           args.pipeline_model_parallel_size)
81
    args.data_parallel_size = args.world_size // model_parallel_size
Mohammad's avatar
Mohammad committed
82
    if args.rank == 0:
mohammad's avatar
mohammad committed
83
84
85
86
87
88
        print('using world size: {}, data-parallel-size: {}, '
              'tensor-model-parallel size: {}, '
              'pipeline-model-parallel size: {} '.format(
                  args.world_size, args.data_parallel_size,
                  args.tensor_model_parallel_size,
                  args.pipeline_model_parallel_size), flush=True)
89
90
91
92
93
94
    if args.pipeline_model_parallel_size > 1:
        if args.pipeline_model_parallel_split_rank is not None:
            assert args.pipeline_model_parallel_split_rank < \
                    args.pipeline_model_parallel_size, 'split rank needs'\
                    ' to be less than pipeline model parallel size ({})'.format(
                            args.pipeline_model_parallel_size)
mohammad's avatar
mohammad committed
95

96
97
98
99
100
101
102
103
104
105
    # Deprecated arguments
    assert args.batch_size is None, '--batch-size argument is no longer ' \
        'valid, use --micro-batch-size instead'
    del args.batch_size
    assert args.warmup is None, '--warmup argument is no longer valid, use ' \
        '--lr-warmup-fraction instead'
    del args.warmup
    assert args.model_parallel_size is None, '--model-parallel-size is no ' \
        'longer valid, use --tensor-model-parallel-size instead'
    del args.model_parallel_size
106
107
    if args.checkpoint_activations:
        args.activations_checkpoint_method = 'uniform'
slym's avatar
slym committed
108
109
110
111
        if args.rank == 0:
            print('--checkpoint-activations is no longer valid, '
                  'use --activation-checkpoint-method instead. '
                  'Defaulting to activation-checkpoint-method=uniform.')
112
    del args.checkpoint_activations
113

Jared Casper's avatar
Jared Casper committed
114
115
116
117
118
119
120
121
122
123
124
125
126
127
    # Set input defaults.
    for key in defaults:
        # For default to be valid, it should not be provided in the
        # arguments that are passed to the program. We check this by
        # ensuring the arg is set to None.
        if getattr(args, key) is not None:
            if args.rank == 0:
                print('WARNING: overriding default arguments for {key}:{v} \
                       with {key}:{v2}'.format(key=key, v=defaults[key],
                                               v2=getattr(args, key)),
                                               flush=True)
        else:
            setattr(args, key, defaults[key])

mohammad's avatar
mohammad committed
128
129
130
131
132
133
134
135
136
    # Batch size.
    assert args.micro_batch_size is not None
    assert args.micro_batch_size > 0
    if args.global_batch_size is None:
        args.global_batch_size = args.micro_batch_size * args.data_parallel_size
        if args.rank == 0:
            print('setting global batch size to {}'.format(
                args.global_batch_size), flush=True)
    assert args.global_batch_size > 0
137
    if args.num_layers_per_virtual_pipeline_stage is not None:
138
139
140
        assert args.pipeline_model_parallel_size > 2, \
            'pipeline-model-parallel size should be greater than 2 with ' \
            'interleaved schedule'
141
142
143
144
        assert args.num_layers % args.num_layers_per_virtual_pipeline_stage == 0, \
            'number of layers is not divisible by number of layers per virtual ' \
            'pipeline stage'
        args.virtual_pipeline_model_parallel_size = \
Lawrence McAfee's avatar
Lawrence McAfee committed
145
            (args.num_layers // args.transformer_pipeline_model_parallel_size) // \
146
147
148
            args.num_layers_per_virtual_pipeline_stage
    else:
        args.virtual_pipeline_model_parallel_size = None
Mohammad's avatar
Mohammad committed
149

150
151
152
    # Parameters dtype.
    args.params_dtype = torch.float
    if args.fp16:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
153
        assert not args.bf16
154
        args.params_dtype = torch.half
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
155
156
157
    if args.bf16:
        assert not args.fp16
        args.params_dtype = torch.bfloat16
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
158
159
160
161
162
163
164
        # bfloat16 requires gradient accumulation and all-reduce to
        # be done in fp32.
        if not args.accumulate_allreduce_grads_in_fp32:
            args.accumulate_allreduce_grads_in_fp32 = True
            if args.rank == 0:
                print('accumulate and all-reduce gradients in fp32 for '
                      'bfloat16 data type.', flush=True)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
165

166
167
168
169
    if args.rank == 0:
        print('using {} for parameters ...'.format(args.params_dtype),
              flush=True)

170
171
    # If we do accumulation and all-reduces in fp32, we need to have local DDP
    # and we should make sure use-contiguous-buffers-in-local-ddp is not off.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
172
173
    if args.accumulate_allreduce_grads_in_fp32:
        assert args.DDP_impl == 'local'
174
        assert args.use_contiguous_buffers_in_local_ddp
175

mshoeybi's avatar
mshoeybi committed
176
177
178
179
    # For torch DDP, we do not use contiguous buffer
    if args.DDP_impl == 'torch':
        args.use_contiguous_buffers_in_local_ddp = False

180
181
182
    if args.dataloader_type is None:
        args.dataloader_type = 'single'

183
184
185
    # Consumed tokens.
    args.consumed_train_samples = 0
    args.consumed_valid_samples = 0
186

187
188
189
190
191
192
193
194
195
    # Iteration-based training.
    if args.train_iters:
        # If we use iteration-based training, make sure the
        # sample-based options are off.
        assert args.train_samples is None, \
            'expected iteration-based training'
        assert args.lr_decay_samples is None, \
            'expected iteration-based learning rate decay'
        assert args.lr_warmup_samples == 0, \
196
            'expected iteration-based learning rate warmup'
197
198
        assert args.rampup_batch_size is None, \
            'expected no batch-size rampup for iteration-based training'
199
        if args.lr_warmup_fraction is not None:
200
            assert args.lr_warmup_iters == 0, \
201
                'can only specify one of lr-warmup-fraction and lr-warmup-iters'
202
203
204
205
206
207
208
209
210
211
212

    # Sample-based training.
    if args.train_samples:
        # If we use sample-based training, make sure the
        # iteration-based options are off.
        assert args.train_iters is None, \
            'expected sample-based training'
        assert args.lr_decay_iters is None, \
            'expected sample-based learning rate decay'
        assert args.lr_warmup_iters == 0, \
            'expected sample-based learnig rate warmup'
213
        if args.lr_warmup_fraction is not None:
214
            assert args.lr_warmup_samples == 0, \
215
216
                'can only specify one of lr-warmup-fraction ' \
                'and lr-warmup-samples'
217

218
    # Check required arguments.
Mohammad's avatar
Mohammad committed
219
220
    required_args = ['num_layers', 'hidden_size', 'num_attention_heads',
                     'max_position_embeddings']
221
    for req_arg in required_args:
Mohammad's avatar
Mohammad committed
222
        _check_arg_is_not_none(args, req_arg)
223

Mohammad's avatar
Mohammad committed
224
    # Checks.
225
226
227
228
229
230
231
232
233
234
235
236
237
    if args.ffn_hidden_size is None:
        args.ffn_hidden_size = 4 * args.hidden_size

    if args.kv_channels is None:
        assert args.hidden_size % args.num_attention_heads == 0
        args.kv_channels = args.hidden_size // args.num_attention_heads

    if args.seq_length is not None:
        assert args.encoder_seq_length is None
        args.encoder_seq_length = args.seq_length
    else:
        assert args.encoder_seq_length is not None
        args.seq_length = args.encoder_seq_length
238

Mohammad's avatar
Mohammad committed
239
240
    if args.seq_length is not None:
        assert args.max_position_embeddings >= args.seq_length
Jared Casper's avatar
Jared Casper committed
241
242
    if args.decoder_seq_length is not None:
        assert args.max_position_embeddings >= args.decoder_seq_length
Mohammad's avatar
Mohammad committed
243
244
    if args.lr is not None:
        assert args.min_lr <= args.lr
Mohammad's avatar
Mohammad committed
245
246
    if args.save is not None:
        assert args.save_interval is not None
mohammad's avatar
mohammad committed
247
248
249
    # Mixed precision checks.
    if args.fp16_lm_cross_entropy:
        assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.'
250
    if args.fp32_residual_connection:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
251
252
        assert args.fp16 or args.bf16, \
            'residual connection in fp32 only supported when using fp16 or bf16.'
Vijay Korthikanti's avatar
Vijay Korthikanti committed
253

Sangkug Lym's avatar
Sangkug Lym committed
254
255
256
257
258
259
260
261
262
263
    TORCH_MAJOR = int(torch.__version__.split('.')[0])
    TORCH_MINOR = int(torch.__version__.split('.')[1])
    # Persistent fused layer norm.
    if TORCH_MAJOR < 1 or (TORCH_MAJOR == 1 and TORCH_MINOR < 11):
        args.no_persist_layer_norm = True
        if args.rank == 0:
            print('Persistent fused layer norm kernel is supported from '
                  'pytorch v1.11 (nvidia pytorch container paired with v1.11). '
                  'Defaulting to no_persist_layer_norm=True')

264
265
266
267
268
269
270
271
272
273
274
275
276
    # Activation checkpointing.
    if args.distribute_checkpointed_activations:
        assert args.tensor_model_parallel_size > 1, 'can distribute ' \
            'checkpointed activations only across tensor model ' \
            'parallel groups'
        assert args.activations_checkpoint_method is not None, \
            'for distributed checkpoint activations to work you '\
            'need to use a activation-checkpoint method '
        assert TORCH_MAJOR >= 1 and TORCH_MINOR >= 10, \
            'distributed checkpoint activations are supported for pytorch ' \
            'v1.10 and above (Nvidia Pytorch container >= 21.07). Current ' \
            'pytorch version is v%s.%s.' % (TORCH_MAJOR, TORCH_MINOR)

Mohammad's avatar
Mohammad committed
277
278
    _print_args(args)
    return args
Mohammad's avatar
Mohammad committed
279
280


Mohammad's avatar
Mohammad committed
281
282
283
def _print_args(args):
    """Print arguments."""
    if args.rank == 0:
mohammad's avatar
mohammad committed
284
285
        print('------------------------ arguments ------------------------',
              flush=True)
Mohammad's avatar
Mohammad committed
286
287
        str_list = []
        for arg in vars(args):
mohammad's avatar
mohammad committed
288
            dots = '.' * (48 - len(arg))
Mohammad's avatar
Mohammad committed
289
290
291
            str_list.append('  {} {} {}'.format(arg, dots, getattr(args, arg)))
        for arg in sorted(str_list, key=lambda x: x.lower()):
            print(arg, flush=True)
mohammad's avatar
mohammad committed
292
293
        print('-------------------- end of arguments ---------------------',
              flush=True)
Mohammad's avatar
Mohammad committed
294
295


296
297
298
299
def _check_arg_is_not_none(args, arg):
    assert getattr(args, arg) is not None, '{} argument is None'.format(arg)


mshoeybi's avatar
mshoeybi committed
300
301
302
303
304
305
306
307
308
309
310
311
def _add_inference_args(parser):
    group = parser.add_argument_group(title='inference')

    group.add_argument('--inference-batch-times-seqlen-threshold',
                       type=int, default=512,
                       help='During inference, if batch-size times '
                       'sequence-length is smaller than this threshold '
                       'then we will not use pipelining, otherwise we will.')

    return parser

    
Mohammad's avatar
Mohammad committed
312
def _add_network_size_args(parser):
Mohammad's avatar
Mohammad committed
313
    group = parser.add_argument_group(title='network size')
Mohammad's avatar
Mohammad committed
314

315
    group.add_argument('--num-layers', type=int, default=None,
Mohammad's avatar
Mohammad committed
316
                       help='Number of transformer layers.')
317
    group.add_argument('--hidden-size', type=int, default=None,
Mohammad's avatar
Mohammad committed
318
                       help='Tansformer hidden size.')
319
    group.add_argument('--ffn-hidden-size', type=int, default=None,
320
321
                       help='Transformer Feed-Forward Network hidden size. '
                       'This is set to 4*hidden-size if not provided')
322
    group.add_argument('--num-attention-heads', type=int, default=None,
Mohammad's avatar
Mohammad committed
323
                       help='Number of transformer attention heads.')
324
    group.add_argument('--kv-channels', type=int, default=None,
325
326
327
328
                       help='Projection weights dimension in multi-head '
                       'attention. This is set to '
                       '   args.hidden_size // args.num_attention_heads '
                       'if not provided.')
329
    group.add_argument('--max-position-embeddings', type=int, default=None,
Mohammad's avatar
Mohammad committed
330
331
332
333
334
                       help='Maximum number of position embeddings to use. '
                       'This is the size of position embedding.')
    group.add_argument('--make-vocab-size-divisible-by', type=int, default=128,
                       help='Pad the vocab size to be divisible by this value.'
                       'This is added for computational efficieny reasons.')
Mohammad's avatar
Mohammad committed
335
336
    group.add_argument('--layernorm-epsilon', type=float, default=1e-5,
                       help='Layer norm epsilon.')
Mohammad's avatar
Mohammad committed
337
338
339
340
    group.add_argument('--apply-residual-connection-post-layernorm',
                       action='store_true',
                       help='If set, use original BERT residula connection '
                       'ordering.')
341
342
343
344
    group.add_argument('--openai-gelu', action='store_true',
                       help='Use OpenAIs GeLU implementation. This option'
                       'should not be used unless for backward compatibility'
                       'reasons.')
345
    group.add_argument('--onnx-safe', type=bool, required=False,
346
347
                       help='Use workarounds for known problems with '
                       'Torch ONNX exporter')
348
349
350
    group.add_argument('--bert-no-binary-head', action='store_false',
                       help='Disable BERT binary head.',
                       dest='bert_binary_head')
Mohammad's avatar
Mohammad committed
351

Mohammad's avatar
Mohammad committed
352
353
354
    return parser


355
356
357
358
359
def _add_logging_args(parser):
    group = parser.add_argument_group(title='logging')

    group.add_argument('--log-params-norm', action='store_true',
                       help='If set, calculate and log parameters norm.')
360
    group.add_argument('--log-num-zeros-in-grad', action='store_true',
Rewon Child's avatar
Rewon Child committed
361
                       help='If set, calculate and log the number of zeros in gradient.')
362
363
    group.add_argument('--tensorboard-log-interval', type=int, default=1,
                       help='Report to tensorboard interval.')
364
365
366
367
    group.add_argument('--tensorboard-queue-size', type=int, default=1000,
                       help='Size of the tensorboard queue for pending events '
                       'and summaries before one of the ‘add’ calls forces a '
                       'flush to disk.')
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
    group.add_argument('--log-timers-to-tensorboard', action='store_true',
                       help='If set, write timers to tensorboard.')
    group.add_argument('--log-batch-size-to-tensorboard', action='store_true',
                       help='If set, write batch-size to tensorboard.')
    group.add_argument('--no-log-learnig-rate-to-tensorboard',
                       action='store_false',
                       help='Disable learning rate logging to tensorboard.',
                       dest='log_learning_rate_to_tensorboard')
    group.add_argument('--no-log-loss-scale-to-tensorboard',
                       action='store_false',
                       help='Disable loss-scale logging to tensorboard.',
                       dest='log_loss_scale_to_tensorboard')
    group.add_argument('--log-validation-ppl-to-tensorboard',
                       action='store_true',
                       help='If set, write validation perplexity to '
                       'tensorboard.')
384
385
    group.add_argument('--log-memory-to-tensorboard',
                       action='store_true',
386
                       help='Enable memory logging to tensorboard.')
387
388
389
    group.add_argument('--log-world-size-to-tensorboard',
                       action='store_true',
                       help='Enable world size logging to tensorboard.')
390
391
392
393

    return parser


Mohammad's avatar
Mohammad committed
394
def _add_regularization_args(parser):
Mohammad's avatar
Mohammad committed
395
396
397
    group = parser.add_argument_group(title='regularization')

    group.add_argument('--attention-dropout', type=float, default=0.1,
398
                       help='Post attention dropout probability.')
Mohammad's avatar
Mohammad committed
399
400
401
402
403
404
    group.add_argument('--hidden-dropout', type=float, default=0.1,
                       help='Dropout probability for hidden state transformer.')
    group.add_argument('--weight-decay', type=float, default=0.01,
                       help='Weight decay coefficient for L2 regularization.')
    group.add_argument('--clip-grad', type=float, default=1.0,
                       help='Gradient clipping based on global L2 norm.')
405
    group.add_argument('--adam-beta1', type=float, default=0.9,
406
407
                       help='First coefficient for computing running averages '
                       'of gradient and its square')
408
    group.add_argument('--adam-beta2', type=float, default=0.999,
409
410
                       help='Second coefficient for computing running averages '
                       'of gradient and its square')
411
    group.add_argument('--adam-eps', type=float, default=1e-08,
412
                       help='Term added to the denominator to improve'
413
                       'numerical stability')
414
415
    group.add_argument('--sgd-momentum', type=float, default=0.9,
                       help='Momentum factor for sgd')
Mohammad's avatar
Mohammad committed
416
417
418

    return parser

Mohammad's avatar
Mohammad committed
419
420

def _add_training_args(parser):
Mohammad's avatar
Mohammad committed
421
422
    group = parser.add_argument_group(title='training')

423
    group.add_argument('--micro-batch-size', type=int, default=None,
Mohammad's avatar
Mohammad committed
424
425
                       help='Batch size per model instance (local batch size). '
                       'Global batch size is local batch size times data '
mohammad's avatar
mohammad committed
426
                       'parallel size times number of micro batches.')
427
428
429
    group.add_argument('--batch-size', type=int, default=None,
                       help='Old batch size parameter, do not use. '
                       'Use --micro-batch-size instead')
mohammad's avatar
mohammad committed
430
    group.add_argument('--global-batch-size', type=int, default=None,
mohammad's avatar
mohammad committed
431
432
433
                       help='Training batch size. If set, it should be a '
                       'multiple of micro-batch-size times data-parallel-size. '
                       'If this value is None, then '
mohammad's avatar
mohammad committed
434
                       'use micro-batch-size * data-parallel-size as the '
mohammad's avatar
mohammad committed
435
436
                       'global batch size. This choice will result in 1 for '
                       'number of micro-batches.')
mohammad's avatar
mohammad committed
437
438
439
440
441
442
443
444
445
446
447
448
    group.add_argument('--rampup-batch-size', nargs='*', default=None,
                       help='Batch size ramp up with the following values:'
                       '  --rampup-batch-size <start batch size> '
                       '                      <batch size incerement> '
                       '                      <ramp-up samples> '
                       'For example:'
                       '   --rampup-batch-size 16 8 300000 \ '
                       '   --global-batch-size 1024'
                       'will start with global batch size 16 and over '
                       ' (1024 - 16) / 8 = 126 intervals will increase'
                       'the batch size linearly to 1024. In each interval'
                       'we will use approximately 300000 / 126 = 2380 samples.')
Mohammad's avatar
Mohammad committed
449
450
451
    group.add_argument('--checkpoint-activations', action='store_true',
                       help='Checkpoint activation to allow for training '
                       'with larger models, sequences, and batch sizes.')
452
453
454
455
    group.add_argument('--distribute-checkpointed-activations',
                       action='store_true',
                       help='If set, distribute checkpointed activations '
                       'across model parallel group.')
456
457
458
459
460
    group.add_argument('--activations-checkpoint-method', type=str, default=None,
                       choices=['uniform', 'block'],
                       help='1) uniform: uniformly divide the total number of '
                       'Transformer layers and checkpoint the input activation of '
                       'each divided chunk, '
slym's avatar
slym committed
461
462
463
464
                       '2) checkpoint the input activations of only a set number of '
                       'individual Transformer layers per pipeline stage and do the '
                       'rest without any checkpointing'
                       'default) do not apply activations checkpoint to any layers')
465
466
467
468
469
    group.add_argument('--activations-checkpoint-num-layers', type=int, default=1,
                       help='1) uniform: the number of Transformer layers in each '
                       'uniformly divided checkpoint unit, '
                       '2) block: the number of individual Transformer layers '
                       'to checkpoint within each pipeline stage.')
Mohammad's avatar
Mohammad committed
470
    group.add_argument('--train-iters', type=int, default=None,
Mohammad's avatar
Mohammad committed
471
                       help='Total number of iterations to train over all '
472
473
474
475
476
477
                       'training runs. Note that either train-iters or '
                       'train-samples should be provided.')
    group.add_argument('--train-samples', type=int, default=None,
                       help='Total number of samples to train over all '
                       'training runs. Note that either train-iters or '
                       'train-samples should be provided.')
Mohammad's avatar
Mohammad committed
478
479
480
481
482
    group.add_argument('--log-interval', type=int, default=100,
                       help='Report loss and timing interval.')
    group.add_argument('--exit-interval', type=int, default=None,
                       help='Exit the program after the iteration is divisible '
                       'by this value.')
483
484
    group.add_argument('--exit-duration-in-mins', type=int, default=None,
                       help='Exit the program after this many minutes.')
485
486
487
    group.add_argument('--exit-signal-handler', action='store_true',
                       help='Dynamically save the checkpoint and shutdown the '
                       'training if SIGTERM is received')
Mohammad's avatar
Mohammad committed
488
489
    group.add_argument('--tensorboard-dir', type=str, default=None,
                       help='Write TensorBoard logs to this directory.')
490
    group.add_argument('--no-masked-softmax-fusion',
491
492
493
                       action='store_false',
                       help='Disable fusion of query_key_value scaling, '
                       'masking, and softmax.',
494
                       dest='masked_softmax_fusion')
495
496
497
498
499
500
    group.add_argument('--no-bias-gelu-fusion', action='store_false',
                       help='Disable bias and gelu fusion.',
                       dest='bias_gelu_fusion')
    group.add_argument('--no-bias-dropout-fusion', action='store_false',
                       help='Disable bias and dropout fusion.',
                       dest='bias_dropout_fusion')
501
502
503
    group.add_argument('--optimizer', type=str, default='adam',
                       choices=['adam', 'sgd'],
                       help='Optimizer function')
504
    group.add_argument('--dataloader-type', type=str, default=None,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
505
506
                       choices=['single', 'cyclic'],
                       help='Single pass vs multiple pass data loader')
slym's avatar
slym committed
507
508
509
510
511
    group.add_argument('--no-async-tensor-model-parallel-allreduce',
                       action='store_true',
                       help='Disable asynchronous execution of '
                       'tensor-model-parallel all-reduce with weight '
                       'gradient compuation of a column-linear layer.')
Sangkug Lym's avatar
Sangkug Lym committed
512
513
514
515
516
    group.add_argument('--no-persist-layer-norm', action='store_true',
                       help='Disable using persistent fused layer norm kernel. '
                       'This kernel supports only a set of hidden sizes. Please '
                       'check persist_ln_hidden_sizes if your hidden '
                       'size is supported.')
Mohammad's avatar
Mohammad committed
517
518
519
    return parser


Mohammad's avatar
Mohammad committed
520
def _add_initialization_args(parser):
Mohammad's avatar
Mohammad committed
521
522
523
524
525
    group = parser.add_argument_group(title='initialization')

    group.add_argument('--seed', type=int, default=1234,
                       help='Random seed used for python, numpy, '
                       'pytorch, and cuda.')
526
527
528
    group.add_argument('--data-parallel-random-init', action='store_true',
                       help='Enable random initialization of params '
                       'across data parallel ranks')
Mohammad's avatar
Mohammad committed
529
530
531
    group.add_argument('--init-method-std', type=float, default=0.02,
                       help='Standard deviation of the zero mean normal '
                       'distribution used for weight initialization.')
532
533
    group.add_argument('--init-method-xavier-uniform', action='store_true',
                       help='Enable Xavier uniform parameter initialization')
Mohammad's avatar
Mohammad committed
534

Mohammad's avatar
Mohammad committed
535
536
537
    return parser


Mohammad's avatar
Mohammad committed
538
def _add_learning_rate_args(parser):
Mohammad's avatar
Mohammad committed
539
540
    group = parser.add_argument_group(title='learning rate')

Mohammad's avatar
Mohammad committed
541
    group.add_argument('--lr', type=float, default=None,
Mohammad's avatar
Mohammad committed
542
543
544
545
                       help='Initial learning rate. Depending on decay style '
                       'and initial warmup, the learing rate at each '
                       'iteration would be different.')
    group.add_argument('--lr-decay-style', type=str, default='linear',
mohammad's avatar
mohammad committed
546
                       choices=['constant', 'linear', 'cosine'],
Mohammad's avatar
Mohammad committed
547
548
549
550
                       help='Learning rate decay function.')
    group.add_argument('--lr-decay-iters', type=int, default=None,
                       help='number of iterations to decay learning rate over,'
                       ' If None defaults to `--train-iters`')
551
552
553
    group.add_argument('--lr-decay-samples', type=int, default=None,
                       help='number of samples to decay learning rate over,'
                       ' If None defaults to `--train-samples`')
554
555
556
    group.add_argument('--lr-warmup-fraction', type=float, default=None,
                       help='fraction of lr-warmup-(iters/samples) to use '
                       'for warmup (as a float)')
557
558
559
560
561
562
    group.add_argument('--lr-warmup-iters', type=int, default=0,
                       help='number of iterations to linearly warmup '
                       'learning rate over.')
    group.add_argument('--lr-warmup-samples', type=int, default=0,
                       help='number of samples to linearly warmup '
                       'learning rate over.')
563
    group.add_argument('--warmup', type=int, default=None,
564
                       help='Old lr warmup argument, do not use. Use one of the'
565
                       '--lr-warmup-* arguments above')
Mohammad's avatar
Mohammad committed
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
    group.add_argument('--min-lr', type=float, default=0.0,
                       help='Minumum value for learning rate. The scheduler'
                       'clip values below this threshold.')
    group.add_argument('--override-lr-scheduler', action='store_true',
                       help='Reset the values of the scheduler (learning rate,'
                       'warmup iterations, minimum learning rate, maximum '
                       'number of iterations, and decay style from input '
                       'arguments and ignore values from checkpoints. Note'
                       'that all the above values will be reset.')
    group.add_argument('--use-checkpoint-lr-scheduler', action='store_true',
                       help='Use checkpoint to set the values of the scheduler '
                       '(learning rate, warmup iterations, minimum learning '
                       'rate, maximum number of iterations, and decay style '
                       'from checkpoint and ignore input arguments.')

    return parser


Mohammad's avatar
Mohammad committed
584
def _add_checkpointing_args(parser):
Mohammad's avatar
Mohammad committed
585
586
587
588
589
590
    group = parser.add_argument_group(title='checkpointing')

    group.add_argument('--save', type=str, default=None,
                       help='Output directory to save checkpoints to.')
    group.add_argument('--save-interval', type=int, default=None,
                       help='Number of iterations between checkpoint saves.')
591
    group.add_argument('--no-save-optim', action='store_true', default=None,
Mohammad's avatar
Mohammad committed
592
                       help='Do not save current optimizer.')
593
    group.add_argument('--no-save-rng', action='store_true', default=None,
Mohammad's avatar
Mohammad committed
594
595
596
                       help='Do not save current rng state.')
    group.add_argument('--load', type=str, default=None,
                       help='Directory containing a model checkpoint.')
Jared Casper's avatar
Jared Casper committed
597
    group.add_argument('--no-load-optim', action='store_true', default=None,
Mohammad's avatar
Mohammad committed
598
                       help='Do not load optimizer when loading checkpoint.')
Jared Casper's avatar
Jared Casper committed
599
    group.add_argument('--no-load-rng', action='store_true', default=None,
Mohammad's avatar
Mohammad committed
600
601
602
603
604
605
606
607
608
                       help='Do not load rng state when loading checkpoint.')
    group.add_argument('--finetune', action='store_true',
                       help='Load model for finetuning. Do not load optimizer '
                       'or rng state from checkpoint and set iteration to 0. '
                       'Assumed when loading a release checkpoint.')

    return parser


Mohammad's avatar
Mohammad committed
609
def _add_mixed_precision_args(parser):
Mohammad's avatar
Mohammad committed
610
611
612
613
    group = parser.add_argument_group(title='mixed precision')

    group.add_argument('--fp16', action='store_true',
                       help='Run model in fp16 mode.')
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
614
615
    group.add_argument('--bf16', action='store_true',
                       help='Run model in bfloat16 mode.')
mohammad's avatar
mohammad committed
616
617
618
619
620
621
622
623
624
625
626
627
    group.add_argument('--loss-scale', type=float, default=None,
                       help='Static loss scaling, positive power of 2 '
                       'values can improve fp16 convergence. If None, dynamic'
                       'loss scaling is used.')
    group.add_argument('--initial-loss-scale', type=float, default=2**32,
                       help='Initial loss-scale for dynamic loss scaling.')
    group.add_argument('--min-loss-scale', type=float, default=1.0,
                       help='Minimum loss scale for dynamic loss scale.')
    group.add_argument('--loss-scale-window', type=float, default=1000,
                       help='Window over which to raise/lower dynamic scale.')
    group.add_argument('--hysteresis', type=int, default=2,
                       help='hysteresis for dynamic loss scaling')
628
629
    group.add_argument('--fp32-residual-connection', action='store_true',
                       help='Move residual connections to fp32.')
630
631
632
    group.add_argument('--no-query-key-layer-scaling', action='store_false',
                       help='Do not scale Q * K^T by 1 / layer-number.',
                       dest='apply_query_key_layer_scaling')
Mohammad's avatar
Mohammad committed
633
    group.add_argument('--attention-softmax-in-fp32', action='store_true',
634
635
636
                       help='Run attention masking and softmax in fp32. '
                       'This flag is ignored unless '
                       '--no-query-key-layer-scaling is specified.')
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
637
638
639
    group.add_argument('--accumulate-allreduce-grads-in-fp32',
                       action='store_true',
                       help='Gradient accumulation and all-reduce in fp32.')
640
641
642
643
    group.add_argument('--fp16-lm-cross-entropy', action='store_true',
                       help='Move the cross entropy unreduced loss calculation'
                       'for lm head to fp16.')

Mohammad's avatar
Mohammad committed
644
645
646
    return parser


Mohammad's avatar
Mohammad committed
647
def _add_distributed_args(parser):
648
649
    group = parser.add_argument_group(title='distributed')

650
651
652
653
    group.add_argument('--tensor-model-parallel-size', type=int, default=1,
                       help='Degree of tensor model parallelism.')
    group.add_argument('--pipeline-model-parallel-size', type=int, default=1,
                       help='Degree of pipeline model parallelism.')
654
655
656
    group.add_argument('--pipeline-model-parallel-split-rank',
                       type=int, default=None,
                       help='Rank where encoder and decoder should be split.')
657
658
659
    group.add_argument('--model-parallel-size', type=int, default=None,
                       help='Old model parallel argument, do not use. Use '
                       '--tensor-model-parallel-size instead.')
660
661
    group.add_argument('--num-layers-per-virtual-pipeline-stage', type=int, default=None,
                       help='Number of layers per virtual pipeline stage')
Mohammad's avatar
Mohammad committed
662
663
664
665
    group.add_argument('--distributed-backend', default='nccl',
                       choices=['nccl', 'gloo'],
                       help='Which backend to use for distributed training.')
    group.add_argument('--DDP-impl', default='local',
Mohammad's avatar
Mohammad committed
666
                       choices=['local', 'torch'],
Mohammad's avatar
Mohammad committed
667
668
                       help='which DistributedDataParallel implementation '
                       'to use.')
669
670
671
672
    group.add_argument('--no-contiguous-buffers-in-local-ddp',
                       action='store_false', help='If set, dont use '
                       'contiguous buffer in local DDP.',
                       dest='use_contiguous_buffers_in_local_ddp')
673
674
675
    group.add_argument('--no-scatter-gather-tensors-in-pipeline', action='store_false',
                       help='Use scatter/gather to optimize communication of tensors in pipeline',
                       dest='scatter_gather_tensors_in_pipeline')
Mohammad's avatar
Mohammad committed
676
677
    group.add_argument('--local_rank', type=int, default=None,
                       help='local rank passed from distributed launcher.')
678
    group.add_argument('--lazy-mpu-init', type=bool, required=False,
679
680
681
682
683
684
685
686
                       help='If set to True, initialize_megatron() '
                       'skips DDP initialization and returns function to '
                       'complete it instead.Also turns on '
                       '--use-cpu-initialization flag. This is for '
                       'external DDP manager.' )
    group.add_argument('--use-cpu-initialization', action='store_true',
                       default=None, help='If set, affine parallel weights '
                       'initialization uses CPU' )
Lawrence McAfee's avatar
Lawrence McAfee committed
687
    group.add_argument('--empty-unused-memory-level', default=0, type=int,
688
689
690
691
                       choices=[0, 1, 2],
                       help='Call torch.cuda.empty_cache() each iteration '
                       '(training and eval), to reduce fragmentation.'
                       '0=off, 1=moderate, 2=aggressive.')
Lawrence McAfee's avatar
Lawrence McAfee committed
692
693
694
    group.add_argument('--standalone-embed-stage', action='store_true',
                       default=False, help='If set, *input* embedding layer '
                       'is placed on its own pipeline stage, without any '
Lawrence McAfee's avatar
Lawrence McAfee committed
695
696
                       'transformer layers. (For T5, this flag currently only '
                       'affects the encoder embedding.)')
Mohammad's avatar
Mohammad committed
697
698
699
    return parser


Mohammad's avatar
Mohammad committed
700
def _add_validation_args(parser):
Mohammad's avatar
Mohammad committed
701
702
703
704
705
706
707
708
709
    group = parser.add_argument_group(title='validation')

    group.add_argument('--eval-iters', type=int, default=100,
                       help='Number of iterations to run for evaluation'
                       'validation/test for.')
    group.add_argument('--eval-interval', type=int, default=1000,
                       help='Interval between running evaluation on '
                       'validation set.')

Mohammad's avatar
Mohammad committed
710
711
712
    return parser


Mohammad's avatar
Mohammad committed
713
def _add_data_args(parser):
Mohammad's avatar
Mohammad committed
714
715
    group = parser.add_argument_group(title='data and dataloader')

mohammad's avatar
mohammad committed
716
    group.add_argument('--data-path', nargs='*', default=None,
mohammad's avatar
mohammad committed
717
718
719
720
                       help='Path to the training dataset. Accepted format:'
                       '1) a single data path, 2) multiple datasets in the'
                       'form: dataset1-weight dataset1-path dataset2-weight '
                       'dataset2-path ...')
Mohammad's avatar
Mohammad committed
721
    group.add_argument('--split', type=str, default='969, 30, 1',
Mohammad's avatar
Mohammad committed
722
723
                       help='Comma-separated list of proportions for training,'
                       ' validation, and test split. For example the split '
724
725
                       '`90,5,5` will use 90%% of data for training, 5%% for '
                       'validation and 5%% for test.')
Mohammad's avatar
Mohammad committed
726
    group.add_argument('--vocab-file', type=str, default=None,
Mohammad's avatar
Mohammad committed
727
                       help='Path to the vocab file.')
Mohammad's avatar
Mohammad committed
728
729
    group.add_argument('--merge-file', type=str, default=None,
                       help='Path to the BPE merge file.')
730
731
732
    group.add_argument('--vocab-extra-ids', type=int, default=0,
                       help='Number of additional vocabulary tokens. '
                            'They are used for span masking in the T5 model')
Mohammad's avatar
Mohammad committed
733
    group.add_argument('--seq-length', type=int, default=None,
734
                       help='Maximum sequence length to process.')
735
    group.add_argument('--encoder-seq-length', type=int, default=None,
736
737
                       help='Maximum encoder sequence length to process.'
                       'This should be exclusive of --seq-length')
738
739
    group.add_argument('--decoder-seq-length', type=int, default=None,
                       help="Maximum decoder sequence length to process.")
Mostofa Patwary's avatar
Mostofa Patwary committed
740
741
    group.add_argument('--retriever-seq-length', type=int, default=256,
                       help='Maximum sequence length for the biencoder model '
Mostofa Patwary's avatar
Mostofa Patwary committed
742
                        ' for retriever')
743
744
745
    group.add_argument('--sample-rate', type=float, default=1.0,
                       help='sample rate for training data. Supposed to be 0 '
                            ' < sample_rate < 1')
Mohammad's avatar
Mohammad committed
746
747
748
749
750
751
752
753
    group.add_argument('--mask-prob', type=float, default=0.15,
                       help='Probability of replacing a token with mask.')
    group.add_argument('--short-seq-prob', type=float, default=0.1,
                       help='Probability of producing a short sequence.')
    group.add_argument('--mmap-warmup', action='store_true',
                       help='Warm up mmap files.')
    group.add_argument('--num-workers', type=int, default=2,
                       help="Dataloader number of workers.")
Mohammad's avatar
Mohammad committed
754
755
756
    group.add_argument('--tokenizer-type', type=str,
                       default=None,
                       choices=['BertWordPieceLowerCase',
Raul Puri's avatar
Raul Puri committed
757
                                'BertWordPieceCase',
Mohammad's avatar
Mohammad committed
758
759
                                'GPT2BPETokenizer'],
                       help='What type of tokenizer to use.')
760
761
762
763
764
765
766
767
768
769
    group.add_argument('--data-impl', type=str, default='infer',
                       choices=['lazy', 'cached', 'mmap', 'infer'],
                       help='Implementation of indexed datasets.')
    group.add_argument('--reset-position-ids', action='store_true',
                       help='Reset posistion ids after end-of-document token.')
    group.add_argument('--reset-attention-mask', action='store_true',
                       help='Reset self attention maske after '
                       'end-of-document token.')
    group.add_argument('--eod-mask-loss', action='store_true',
                       help='Mask loss for the end of document tokens.')
Mohammad's avatar
Mohammad committed
770

Mohammad's avatar
Mohammad committed
771
772
    return parser

Raul Puri's avatar
Raul Puri committed
773

Mohammad's avatar
Mohammad committed
774
775
def _add_autoresume_args(parser):
    group = parser.add_argument_group(title='autoresume')
Raul Puri's avatar
Raul Puri committed
776

Mohammad's avatar
Mohammad committed
777
778
779
780
781
    group.add_argument('--adlr-autoresume', action='store_true',
                       help='Enable autoresume on adlr cluster.')
    group.add_argument('--adlr-autoresume-interval', type=int, default=1000,
                       help='Intervals over which check for autoresume'
                       'termination signal')
Raul Puri's avatar
Raul Puri committed
782

Mohammad's avatar
Mohammad committed
783
    return parser
Neel Kant's avatar
Neel Kant committed
784
785


Mostofa Patwary's avatar
Mostofa Patwary committed
786
787
def _add_biencoder_args(parser):
    group = parser.add_argument_group(title='biencoder')
Neel Kant's avatar
Neel Kant committed
788
789
790

    # network size
    group.add_argument('--ict-head-size', type=int, default=None,
791
                       help='Size of block embeddings to be used in ICT and '
Mostofa Patwary's avatar
Mostofa Patwary committed
792
                        'REALM (paper default: 128)')
793
    group.add_argument('--biencoder-projection-dim', type=int, default=0,
Mostofa Patwary's avatar
Mostofa Patwary committed
794
795
                       help='Size of projection head used in biencoder (paper'
                        ' default: 128)')
796
    group.add_argument('--biencoder-shared-query-context-model', action='store_true',
Mostofa Patwary's avatar
Mostofa Patwary committed
797
798
                        help='Whether to share the parameters of the query '
                        'and context models or not')
Neel Kant's avatar
Neel Kant committed
799
800
801
802
803

    # checkpointing
    group.add_argument('--ict-load', type=str, default=None,
                       help='Directory containing an ICTBertModel checkpoint')
    group.add_argument('--bert-load', type=str, default=None,
804
805
                       help='Directory containing an BertModel checkpoint '
                       '(needed to start ICT and REALM)')
Neel Kant's avatar
Neel Kant committed
806
807
808
809
810

    # data
    group.add_argument('--titles-data-path', type=str, default=None,
                       help='Path to titles dataset used for ICT')
    group.add_argument('--query-in-block-prob', type=float, default=0.1,
811
812
                       help='Probability of keeping query in block for '
                       'ICT dataset')
Neel Kant's avatar
Neel Kant committed
813
    group.add_argument('--use-one-sent-docs', action='store_true',
Neel Kant's avatar
Neel Kant committed
814
                       help='Whether to use one sentence documents in ICT')
815
816
    group.add_argument('--evidence-data-path', type=str, default=None,
                       help='Path to Wikipedia Evidence frm DPR paper')
Neel Kant's avatar
Neel Kant committed
817

818
    # training
819
    group.add_argument('--retriever-report-topk-accuracies', nargs='+', type=int,
Mostofa Patwary's avatar
Mostofa Patwary committed
820
821
                        default=[], help="Which top-k accuracies to report "
                        "(e.g. '1 5 20')")
Mostofa Patwary's avatar
Mostofa Patwary committed
822
    group.add_argument('--retriever-score-scaling', action='store_true',
Mostofa Patwary's avatar
Mostofa Patwary committed
823
824
                       help='Whether to scale retriever scores by inverse '
                        'square root of hidden size')
825

Neel Kant's avatar
Neel Kant committed
826
    # faiss index
Neel Kant's avatar
Neel Kant committed
827
    group.add_argument('--block-data-path', type=str, default=None,
Neel Kant's avatar
Neel Kant committed
828
                       help='Where to save/load BlockData to/from')
Mostofa Patwary's avatar
Mostofa Patwary committed
829
830
831
    group.add_argument('--embedding-path', type=str, default=None,
                       help='Where to save/load Open-Retrieval Embedding'
                        ' data to/from')
Neel Kant's avatar
Neel Kant committed
832
833
834

    # indexer
    group.add_argument('--indexer-batch-size', type=int, default=128,
835
836
                       help='How large of batches to use when doing indexing '
                       'jobs')
Neel Kant's avatar
Neel Kant committed
837
    group.add_argument('--indexer-log-interval', type=int, default=1000,
838
839
                       help='After how many batches should the indexer '
                       'report progress')
Neel Kant's avatar
Neel Kant committed
840
    return parser
841
842
843
844
845
846
847


def _add_vit_args(parser):
    group = parser.add_argument_group(title="vit")

    group.add_argument('--num-classes', type=int, default=1000,
                       help='num of classes in vision classificaiton task')
848
849
850
851
    group.add_argument('--img-h', type=int, default=224,
                       help='Image height for vision classification task')
    group.add_argument('--img-w', type=int, default=224,
                       help='Image height for vision classification task')
852
853
854
855
    group.add_argument('--num-channels', type=int, default=3,
                       help='Number of channels in input image data')
    group.add_argument('--patch-dim', type=int, default=16,
                       help='patch dimension used in vit')
856
857
858
859
860
861
862
    group.add_argument('--classes-fraction', type=float, default=1.0,
                       help='training with fraction of classes.')
    group.add_argument('--data-per-class-fraction', type=float, default=1.0,
                       help='training with fraction of data per class.')
    group.add_argument('--no-data-sharding', action='store_false',
                       help='Disable data sharding.',
                       dest='data_sharding')
863
864

    return parser