arguments.py 37 KB
Newer Older
Raul Puri's avatar
Raul Puri committed
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
Raul Puri's avatar
Raul Puri committed
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Mohammad's avatar
Mohammad committed
16
"""Megatron arguments."""
Raul Puri's avatar
Raul Puri committed
17
18
19
20

import argparse
import os

21
import torch
22
from megatron import fused_kernels
Raul Puri's avatar
Raul Puri committed
23

24
25
def parse_args(extra_args_provider=None, defaults={},
               ignore_unknown_args=False):
Mohammad's avatar
Mohammad committed
26
    """Parse all arguments."""
27
28
    parser = argparse.ArgumentParser(description='Megatron-LM Arguments',
                                     allow_abbrev=False)
Mohammad's avatar
Mohammad committed
29

Mohammad's avatar
Mohammad committed
30
31
32
33
34
35
36
37
38
39
40
41
    # Standard arguments.
    parser = _add_network_size_args(parser)
    parser = _add_regularization_args(parser)
    parser = _add_training_args(parser)
    parser = _add_initialization_args(parser)
    parser = _add_learning_rate_args(parser)
    parser = _add_checkpointing_args(parser)
    parser = _add_mixed_precision_args(parser)
    parser = _add_distributed_args(parser)
    parser = _add_validation_args(parser)
    parser = _add_data_args(parser)
    parser = _add_autoresume_args(parser)
Mostofa Patwary's avatar
Mostofa Patwary committed
42
    parser = _add_biencoder_args(parser)
43
    parser = _add_vit_args(parser)
44
    parser = _add_logging_args(parser)
Mohammad's avatar
Mohammad committed
45
46
47
48

    # Custom arguments.
    if extra_args_provider is not None:
        parser = extra_args_provider(parser)
Mohammad's avatar
Mohammad committed
49

Mohammad's avatar
Mohammad committed
50
    # Parse.
51
52
53
54
    if ignore_unknown_args:
        args, _ = parser.parse_known_args()
    else:
        args = parser.parse_args()
Mohammad's avatar
Mohammad committed
55

Mohammad's avatar
Mohammad committed
56
57
58
    # Distributed args.
    args.rank = int(os.getenv('RANK', '0'))
    args.world_size = int(os.getenv("WORLD_SIZE", '1'))
mohammad's avatar
mohammad committed
59
    # Tensor model parallel size.
60
61
    args.tensor_model_parallel_size = min(
        args.tensor_model_parallel_size, args.world_size)
mohammad's avatar
mohammad committed
62
63
64
65
    assert args.world_size % args.tensor_model_parallel_size == 0, 'world size'\
        ' ({}) is not divisible by tensor model parallel size ({})'.format(
            args.world_size, args.tensor_model_parallel_size)
    # Pipeline model parallel size.
66
67
68
    args.pipeline_model_parallel_size = min(
        args.pipeline_model_parallel_size,
        (args.world_size // args.tensor_model_parallel_size))
mohammad's avatar
mohammad committed
69
    # Checks.
70
71
72
    model_parallel_size = args.pipeline_model_parallel_size * \
                          args.tensor_model_parallel_size
    assert args.world_size % model_parallel_size == 0, 'world size is not'\
73
        ' divisible by tensor parallel size ({}) times pipeline parallel ' \
mohammad's avatar
mohammad committed
74
75
        'size ({})'.format(args.world_size, args.tensor_model_parallel_size,
                           args.pipeline_model_parallel_size)
76
    args.data_parallel_size = args.world_size // model_parallel_size
Mohammad's avatar
Mohammad committed
77
    if args.rank == 0:
mohammad's avatar
mohammad committed
78
79
80
81
82
83
84
        print('using world size: {}, data-parallel-size: {}, '
              'tensor-model-parallel size: {}, '
              'pipeline-model-parallel size: {} '.format(
                  args.world_size, args.data_parallel_size,
                  args.tensor_model_parallel_size,
                  args.pipeline_model_parallel_size), flush=True)

85
86
87
88
89
90
91
92
93
94
95
    # Deprecated arguments
    assert args.batch_size is None, '--batch-size argument is no longer ' \
        'valid, use --micro-batch-size instead'
    del args.batch_size
    assert args.warmup is None, '--warmup argument is no longer valid, use ' \
        '--lr-warmup-fraction instead'
    del args.warmup
    assert args.model_parallel_size is None, '--model-parallel-size is no ' \
        'longer valid, use --tensor-model-parallel-size instead'
    del args.model_parallel_size

Jared Casper's avatar
Jared Casper committed
96
97
98
99
100
101
102
103
104
105
106
107
108
109
    # Set input defaults.
    for key in defaults:
        # For default to be valid, it should not be provided in the
        # arguments that are passed to the program. We check this by
        # ensuring the arg is set to None.
        if getattr(args, key) is not None:
            if args.rank == 0:
                print('WARNING: overriding default arguments for {key}:{v} \
                       with {key}:{v2}'.format(key=key, v=defaults[key],
                                               v2=getattr(args, key)),
                                               flush=True)
        else:
            setattr(args, key, defaults[key])

mohammad's avatar
mohammad committed
110
111
112
113
114
115
116
117
118
    # Batch size.
    assert args.micro_batch_size is not None
    assert args.micro_batch_size > 0
    if args.global_batch_size is None:
        args.global_batch_size = args.micro_batch_size * args.data_parallel_size
        if args.rank == 0:
            print('setting global batch size to {}'.format(
                args.global_batch_size), flush=True)
    assert args.global_batch_size > 0
119
120
121
122
123
124
125
126
127
    if args.num_layers_per_virtual_pipeline_stage is not None:
        assert args.num_layers % args.num_layers_per_virtual_pipeline_stage == 0, \
            'number of layers is not divisible by number of layers per virtual ' \
            'pipeline stage'
        args.virtual_pipeline_model_parallel_size = \
            (args.num_layers // args.pipeline_model_parallel_size) // \
            args.num_layers_per_virtual_pipeline_stage
    else:
        args.virtual_pipeline_model_parallel_size = None
Mohammad's avatar
Mohammad committed
128

129
130
131
132
133
134
135
136
    # Parameters dtype.
    args.params_dtype = torch.float
    if args.fp16:
        args.params_dtype = torch.half
    if args.rank == 0:
        print('using {} for parameters ...'.format(args.params_dtype),
              flush=True)

137
138
139
    if args.dataloader_type is None:
        args.dataloader_type = 'single'

140
141
142
    # Consumed tokens.
    args.consumed_train_samples = 0
    args.consumed_valid_samples = 0
143

144
145
146
147
148
149
150
151
152
    # Iteration-based training.
    if args.train_iters:
        # If we use iteration-based training, make sure the
        # sample-based options are off.
        assert args.train_samples is None, \
            'expected iteration-based training'
        assert args.lr_decay_samples is None, \
            'expected iteration-based learning rate decay'
        assert args.lr_warmup_samples == 0, \
153
            'expected iteration-based learning rate warmup'
154
155
        assert args.rampup_batch_size is None, \
            'expected no batch-size rampup for iteration-based training'
156
        if args.lr_warmup_fraction is not None:
157
            assert args.lr_warmup_iters == 0, \
158
                'can only specify one of lr-warmup-fraction and lr-warmup-iters'
159
160
161
162
163
164
165
166
167
168
169

    # Sample-based training.
    if args.train_samples:
        # If we use sample-based training, make sure the
        # iteration-based options are off.
        assert args.train_iters is None, \
            'expected sample-based training'
        assert args.lr_decay_iters is None, \
            'expected sample-based learning rate decay'
        assert args.lr_warmup_iters == 0, \
            'expected sample-based learnig rate warmup'
170
        if args.lr_warmup_fraction is not None:
171
            assert args.lr_warmup_samples == 0, \
172
173
                'can only specify one of lr-warmup-fraction ' \
                'and lr-warmup-samples'
174

175
    # Check required arguments.
Mohammad's avatar
Mohammad committed
176
177
    required_args = ['num_layers', 'hidden_size', 'num_attention_heads',
                     'max_position_embeddings']
178
    for req_arg in required_args:
Mohammad's avatar
Mohammad committed
179
        _check_arg_is_not_none(args, req_arg)
180

Mohammad's avatar
Mohammad committed
181
    # Checks.
182
183
184
185
186
187
188
189
190
191
192
193
194
195
    if args.ffn_hidden_size is None:
        args.ffn_hidden_size = 4 * args.hidden_size

    if args.kv_channels is None:
        assert args.hidden_size % args.num_attention_heads == 0
        args.kv_channels = args.hidden_size // args.num_attention_heads

    if args.seq_length is not None:
        assert args.encoder_seq_length is None
        args.encoder_seq_length = args.seq_length
    else:
        assert args.encoder_seq_length is not None
        args.seq_length = args.encoder_seq_length
 
Mohammad's avatar
Mohammad committed
196
    assert args.hidden_size % args.num_attention_heads == 0
Mohammad's avatar
Mohammad committed
197
198
199
200
    if args.seq_length is not None:
        assert args.max_position_embeddings >= args.seq_length
    if args.lr is not None:
        assert args.min_lr <= args.lr
Mohammad's avatar
Mohammad committed
201
202
    if args.save is not None:
        assert args.save_interval is not None
mohammad's avatar
mohammad committed
203
204
205
    # Mixed precision checks.
    if args.fp16_lm_cross_entropy:
        assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.'
206
207
    if args.fp32_residual_connection:
        assert args.fp16, \
mshoeybi's avatar
mshoeybi committed
208
            'residual connection in fp32 only supported when using fp16.'
mohammad's avatar
mohammad committed
209
210
211
212
213
    # Activation checkpointing.
    if args.distribute_checkpointed_activations:
        assert args.checkpoint_activations, \
            'for distribute-checkpointed-activations to work you '\
            'need to enable checkpoint-activations'
Vijay Korthikanti's avatar
Vijay Korthikanti committed
214
215
216
217
218
219
220
221
222
223
224
225

    # custom kernel constraints check
    seq_len = args.seq_length
    attn_batch_size = \
        (args.num_attention_heads / args.tensor_model_parallel_size) * \
        args.micro_batch_size

    # constraints on sequence length and attn_batch_size to enable warp based
    # optimization and upper triangular optimization (for causal mask)
    custom_kernel_constraint = seq_len > 16 and seq_len <=2048 and \
        seq_len % 4 == 0 and attn_batch_size % 4 == 0

Vijay Korthikanti's avatar
Vijay Korthikanti committed
226
    if not (args.fp16 and custom_kernel_constraint and args.masked_softmax_fusion):
Vijay Korthikanti's avatar
Vijay Korthikanti committed
227
228
229
230
        print('WARNING: constraints for invoking optimized'
            ' fused softmax kernel are not met. We default back to unfused'
            ' kernel invocations.')

231
232
233
234
    # Load scaled_masked_softmax_fusion_kernels
    if args.masked_softmax_fusion:
        fused_kernels.load_scaled_upper_triang_masked_softmax_fusion_kernel()
        fused_kernels.load_scaled_masked_softmax_fusion_kernel()
235

236
237
238
239
    # Load mixed precision fused layer norm.
    if args.fp32_residual_connection:
        fused_kernels.load_fused_mix_prec_layer_norm_kernel()

Mohammad's avatar
Mohammad committed
240
241
    _print_args(args)
    return args
Mohammad's avatar
Mohammad committed
242
243


Mohammad's avatar
Mohammad committed
244
245
246
def _print_args(args):
    """Print arguments."""
    if args.rank == 0:
mohammad's avatar
mohammad committed
247
248
        print('------------------------ arguments ------------------------',
              flush=True)
Mohammad's avatar
Mohammad committed
249
250
        str_list = []
        for arg in vars(args):
mohammad's avatar
mohammad committed
251
            dots = '.' * (48 - len(arg))
Mohammad's avatar
Mohammad committed
252
253
254
            str_list.append('  {} {} {}'.format(arg, dots, getattr(args, arg)))
        for arg in sorted(str_list, key=lambda x: x.lower()):
            print(arg, flush=True)
mohammad's avatar
mohammad committed
255
256
        print('-------------------- end of arguments ---------------------',
              flush=True)
Mohammad's avatar
Mohammad committed
257
258


259
260
261
262
def _check_arg_is_not_none(args, arg):
    assert getattr(args, arg) is not None, '{} argument is None'.format(arg)


Mohammad's avatar
Mohammad committed
263
def _add_network_size_args(parser):
Mohammad's avatar
Mohammad committed
264
    group = parser.add_argument_group(title='network size')
Mohammad's avatar
Mohammad committed
265

266
    group.add_argument('--num-layers', type=int, default=None,
Mohammad's avatar
Mohammad committed
267
                       help='Number of transformer layers.')
268
    group.add_argument('--hidden-size', type=int, default=None,
Mohammad's avatar
Mohammad committed
269
                       help='Tansformer hidden size.')
270
    group.add_argument('--ffn-hidden-size', type=int, default=None,
271
272
                       help='Transformer Feed-Forward Network hidden size. '
                       'This is set to 4*hidden-size if not provided')
273
    group.add_argument('--num-attention-heads', type=int, default=None,
Mohammad's avatar
Mohammad committed
274
                       help='Number of transformer attention heads.')
275
    group.add_argument('--kv-channels', type=int, default=None,
276
277
278
279
                       help='Projection weights dimension in multi-head '
                       'attention. This is set to '
                       '   args.hidden_size // args.num_attention_heads '
                       'if not provided.')
280
    group.add_argument('--max-position-embeddings', type=int, default=None,
Mohammad's avatar
Mohammad committed
281
282
283
284
285
                       help='Maximum number of position embeddings to use. '
                       'This is the size of position embedding.')
    group.add_argument('--make-vocab-size-divisible-by', type=int, default=128,
                       help='Pad the vocab size to be divisible by this value.'
                       'This is added for computational efficieny reasons.')
Mohammad's avatar
Mohammad committed
286
287
    group.add_argument('--layernorm-epsilon', type=float, default=1e-5,
                       help='Layer norm epsilon.')
Mohammad's avatar
Mohammad committed
288
289
290
291
    group.add_argument('--apply-residual-connection-post-layernorm',
                       action='store_true',
                       help='If set, use original BERT residula connection '
                       'ordering.')
292
293
294
295
    group.add_argument('--openai-gelu', action='store_true',
                       help='Use OpenAIs GeLU implementation. This option'
                       'should not be used unless for backward compatibility'
                       'reasons.')
296
    group.add_argument('--onnx-safe', type=bool, required=False,
297
298
                       help='Use workarounds for known problems with '
                       'Torch ONNX exporter')
299
300
301
    group.add_argument('--bert-no-binary-head', action='store_false',
                       help='Disable BERT binary head.',
                       dest='bert_binary_head')
Mohammad's avatar
Mohammad committed
302

Mohammad's avatar
Mohammad committed
303
304
305
    return parser


306
307
308
309
310
def _add_logging_args(parser):
    group = parser.add_argument_group(title='logging')

    group.add_argument('--log-params-norm', action='store_true',
                       help='If set, calculate and log parameters norm.')
311
    group.add_argument('--log-num-zeros-in-grad', action='store_true',
Rewon Child's avatar
Rewon Child committed
312
                       help='If set, calculate and log the number of zeros in gradient.')
313
314
    group.add_argument('--tensorboard-log-interval', type=int, default=1,
                       help='Report to tensorboard interval.')
315
316
317
318
    group.add_argument('--tensorboard-queue-size', type=int, default=1000,
                       help='Size of the tensorboard queue for pending events '
                       'and summaries before one of the ‘add’ calls forces a '
                       'flush to disk.')
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
    group.add_argument('--log-timers-to-tensorboard', action='store_true',
                       help='If set, write timers to tensorboard.')
    group.add_argument('--log-batch-size-to-tensorboard', action='store_true',
                       help='If set, write batch-size to tensorboard.')
    group.add_argument('--no-log-learnig-rate-to-tensorboard',
                       action='store_false',
                       help='Disable learning rate logging to tensorboard.',
                       dest='log_learning_rate_to_tensorboard')
    group.add_argument('--no-log-loss-scale-to-tensorboard',
                       action='store_false',
                       help='Disable loss-scale logging to tensorboard.',
                       dest='log_loss_scale_to_tensorboard')
    group.add_argument('--log-validation-ppl-to-tensorboard',
                       action='store_true',
                       help='If set, write validation perplexity to '
                       'tensorboard.')
335
336
337
338

    return parser


Mohammad's avatar
Mohammad committed
339
def _add_regularization_args(parser):
Mohammad's avatar
Mohammad committed
340
341
342
    group = parser.add_argument_group(title='regularization')

    group.add_argument('--attention-dropout', type=float, default=0.1,
343
                       help='Post attention dropout probability.')
Mohammad's avatar
Mohammad committed
344
345
346
347
348
349
    group.add_argument('--hidden-dropout', type=float, default=0.1,
                       help='Dropout probability for hidden state transformer.')
    group.add_argument('--weight-decay', type=float, default=0.01,
                       help='Weight decay coefficient for L2 regularization.')
    group.add_argument('--clip-grad', type=float, default=1.0,
                       help='Gradient clipping based on global L2 norm.')
350
    group.add_argument('--adam-beta1', type=float, default=0.9,
351
352
                       help='First coefficient for computing running averages '
                       'of gradient and its square')
353
    group.add_argument('--adam-beta2', type=float, default=0.999,
354
355
                       help='Second coefficient for computing running averages '
                       'of gradient and its square')
356
    group.add_argument('--adam-eps', type=float, default=1e-08,
357
                       help='Term added to the denominator to improve'
358
                       'numerical stability')
359
360
    group.add_argument('--sgd-momentum', type=float, default=0.9,
                       help='Momentum factor for sgd')
Mohammad's avatar
Mohammad committed
361
362
363

    return parser

Mohammad's avatar
Mohammad committed
364
365

def _add_training_args(parser):
Mohammad's avatar
Mohammad committed
366
367
    group = parser.add_argument_group(title='training')

368
    group.add_argument('--micro-batch-size', type=int, default=None,
Mohammad's avatar
Mohammad committed
369
370
                       help='Batch size per model instance (local batch size). '
                       'Global batch size is local batch size times data '
mohammad's avatar
mohammad committed
371
                       'parallel size times number of micro batches.')
372
373
374
    group.add_argument('--batch-size', type=int, default=None,
                       help='Old batch size parameter, do not use. '
                       'Use --micro-batch-size instead')
mohammad's avatar
mohammad committed
375
    group.add_argument('--global-batch-size', type=int, default=None,
mohammad's avatar
mohammad committed
376
377
378
                       help='Training batch size. If set, it should be a '
                       'multiple of micro-batch-size times data-parallel-size. '
                       'If this value is None, then '
mohammad's avatar
mohammad committed
379
                       'use micro-batch-size * data-parallel-size as the '
mohammad's avatar
mohammad committed
380
381
                       'global batch size. This choice will result in 1 for '
                       'number of micro-batches.')
mohammad's avatar
mohammad committed
382
383
384
385
386
387
388
389
390
391
392
393
    group.add_argument('--rampup-batch-size', nargs='*', default=None,
                       help='Batch size ramp up with the following values:'
                       '  --rampup-batch-size <start batch size> '
                       '                      <batch size incerement> '
                       '                      <ramp-up samples> '
                       'For example:'
                       '   --rampup-batch-size 16 8 300000 \ '
                       '   --global-batch-size 1024'
                       'will start with global batch size 16 and over '
                       ' (1024 - 16) / 8 = 126 intervals will increase'
                       'the batch size linearly to 1024. In each interval'
                       'we will use approximately 300000 / 126 = 2380 samples.')
Mohammad's avatar
Mohammad committed
394
395
396
    group.add_argument('--checkpoint-activations', action='store_true',
                       help='Checkpoint activation to allow for training '
                       'with larger models, sequences, and batch sizes.')
397
398
399
400
    group.add_argument('--distribute-checkpointed-activations',
                       action='store_true',
                       help='If set, distribute checkpointed activations '
                       'across model parallel group.')
Mohammad's avatar
Mohammad committed
401
402
    group.add_argument('--checkpoint-num-layers', type=int, default=1,
                       help='chunk size (number of layers) for checkpointing.')
Mohammad's avatar
Mohammad committed
403
    group.add_argument('--train-iters', type=int, default=None,
Mohammad's avatar
Mohammad committed
404
                       help='Total number of iterations to train over all '
405
406
407
408
409
410
                       'training runs. Note that either train-iters or '
                       'train-samples should be provided.')
    group.add_argument('--train-samples', type=int, default=None,
                       help='Total number of samples to train over all '
                       'training runs. Note that either train-iters or '
                       'train-samples should be provided.')
Mohammad's avatar
Mohammad committed
411
412
413
414
415
    group.add_argument('--log-interval', type=int, default=100,
                       help='Report loss and timing interval.')
    group.add_argument('--exit-interval', type=int, default=None,
                       help='Exit the program after the iteration is divisible '
                       'by this value.')
416
417
    group.add_argument('--exit-duration-in-mins', type=int, default=None,
                       help='Exit the program after this many minutes.')
Mohammad's avatar
Mohammad committed
418
419
    group.add_argument('--tensorboard-dir', type=str, default=None,
                       help='Write TensorBoard logs to this directory.')
420
    group.add_argument('--no-masked-softmax-fusion',
421
422
423
                       action='store_false',
                       help='Disable fusion of query_key_value scaling, '
                       'masking, and softmax.',
424
                       dest='masked_softmax_fusion')
425
426
427
428
429
430
    group.add_argument('--no-bias-gelu-fusion', action='store_false',
                       help='Disable bias and gelu fusion.',
                       dest='bias_gelu_fusion')
    group.add_argument('--no-bias-dropout-fusion', action='store_false',
                       help='Disable bias and dropout fusion.',
                       dest='bias_dropout_fusion')
431
432
433
    group.add_argument('--optimizer', type=str, default='adam',
                       choices=['adam', 'sgd'],
                       help='Optimizer function')
434
    group.add_argument('--dataloader-type', type=str, default=None,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
435
436
                       choices=['single', 'cyclic'],
                       help='Single pass vs multiple pass data loader')
Mohammad's avatar
Mohammad committed
437
438
439
    return parser


Mohammad's avatar
Mohammad committed
440
def _add_initialization_args(parser):
Mohammad's avatar
Mohammad committed
441
442
443
444
445
446
447
448
    group = parser.add_argument_group(title='initialization')

    group.add_argument('--seed', type=int, default=1234,
                       help='Random seed used for python, numpy, '
                       'pytorch, and cuda.')
    group.add_argument('--init-method-std', type=float, default=0.02,
                       help='Standard deviation of the zero mean normal '
                       'distribution used for weight initialization.')
449
450
    group.add_argument('--init-method-xavier-uniform', action='store_true',
                       help='Enable Xavier uniform parameter initialization')
Mohammad's avatar
Mohammad committed
451

Mohammad's avatar
Mohammad committed
452
453
454
    return parser


Mohammad's avatar
Mohammad committed
455
def _add_learning_rate_args(parser):
Mohammad's avatar
Mohammad committed
456
457
    group = parser.add_argument_group(title='learning rate')

Mohammad's avatar
Mohammad committed
458
    group.add_argument('--lr', type=float, default=None,
Mohammad's avatar
Mohammad committed
459
460
461
462
                       help='Initial learning rate. Depending on decay style '
                       'and initial warmup, the learing rate at each '
                       'iteration would be different.')
    group.add_argument('--lr-decay-style', type=str, default='linear',
mohammad's avatar
mohammad committed
463
                       choices=['constant', 'linear', 'cosine'],
Mohammad's avatar
Mohammad committed
464
465
466
467
                       help='Learning rate decay function.')
    group.add_argument('--lr-decay-iters', type=int, default=None,
                       help='number of iterations to decay learning rate over,'
                       ' If None defaults to `--train-iters`')
468
469
470
    group.add_argument('--lr-decay-samples', type=int, default=None,
                       help='number of samples to decay learning rate over,'
                       ' If None defaults to `--train-samples`')
471
472
473
    group.add_argument('--lr-warmup-fraction', type=float, default=None,
                       help='fraction of lr-warmup-(iters/samples) to use '
                       'for warmup (as a float)')
474
475
476
477
478
479
    group.add_argument('--lr-warmup-iters', type=int, default=0,
                       help='number of iterations to linearly warmup '
                       'learning rate over.')
    group.add_argument('--lr-warmup-samples', type=int, default=0,
                       help='number of samples to linearly warmup '
                       'learning rate over.')
480
    group.add_argument('--warmup', type=int, default=None,
481
                       help='Old lr warmup argument, do not use. Use one of the'
482
                       '--lr-warmup-* arguments above')
Mohammad's avatar
Mohammad committed
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
    group.add_argument('--min-lr', type=float, default=0.0,
                       help='Minumum value for learning rate. The scheduler'
                       'clip values below this threshold.')
    group.add_argument('--override-lr-scheduler', action='store_true',
                       help='Reset the values of the scheduler (learning rate,'
                       'warmup iterations, minimum learning rate, maximum '
                       'number of iterations, and decay style from input '
                       'arguments and ignore values from checkpoints. Note'
                       'that all the above values will be reset.')
    group.add_argument('--use-checkpoint-lr-scheduler', action='store_true',
                       help='Use checkpoint to set the values of the scheduler '
                       '(learning rate, warmup iterations, minimum learning '
                       'rate, maximum number of iterations, and decay style '
                       'from checkpoint and ignore input arguments.')

    return parser


Mohammad's avatar
Mohammad committed
501
def _add_checkpointing_args(parser):
Mohammad's avatar
Mohammad committed
502
503
504
505
506
507
    group = parser.add_argument_group(title='checkpointing')

    group.add_argument('--save', type=str, default=None,
                       help='Output directory to save checkpoints to.')
    group.add_argument('--save-interval', type=int, default=None,
                       help='Number of iterations between checkpoint saves.')
508
    group.add_argument('--no-save-optim', action='store_true', default=None,
Mohammad's avatar
Mohammad committed
509
                       help='Do not save current optimizer.')
510
    group.add_argument('--no-save-rng', action='store_true', default=None,
Mohammad's avatar
Mohammad committed
511
512
513
                       help='Do not save current rng state.')
    group.add_argument('--load', type=str, default=None,
                       help='Directory containing a model checkpoint.')
Jared Casper's avatar
Jared Casper committed
514
    group.add_argument('--no-load-optim', action='store_true', default=None,
Mohammad's avatar
Mohammad committed
515
                       help='Do not load optimizer when loading checkpoint.')
Jared Casper's avatar
Jared Casper committed
516
    group.add_argument('--no-load-rng', action='store_true', default=None,
Mohammad's avatar
Mohammad committed
517
518
519
520
521
522
523
524
525
                       help='Do not load rng state when loading checkpoint.')
    group.add_argument('--finetune', action='store_true',
                       help='Load model for finetuning. Do not load optimizer '
                       'or rng state from checkpoint and set iteration to 0. '
                       'Assumed when loading a release checkpoint.')

    return parser


Mohammad's avatar
Mohammad committed
526
def _add_mixed_precision_args(parser):
Mohammad's avatar
Mohammad committed
527
528
529
530
    group = parser.add_argument_group(title='mixed precision')

    group.add_argument('--fp16', action='store_true',
                       help='Run model in fp16 mode.')
mohammad's avatar
mohammad committed
531
532
533
534
535
536
537
538
539
540
541
542
    group.add_argument('--loss-scale', type=float, default=None,
                       help='Static loss scaling, positive power of 2 '
                       'values can improve fp16 convergence. If None, dynamic'
                       'loss scaling is used.')
    group.add_argument('--initial-loss-scale', type=float, default=2**32,
                       help='Initial loss-scale for dynamic loss scaling.')
    group.add_argument('--min-loss-scale', type=float, default=1.0,
                       help='Minimum loss scale for dynamic loss scale.')
    group.add_argument('--loss-scale-window', type=float, default=1000,
                       help='Window over which to raise/lower dynamic scale.')
    group.add_argument('--hysteresis', type=int, default=2,
                       help='hysteresis for dynamic loss scaling')
543
544
    group.add_argument('--fp32-residual-connection', action='store_true',
                       help='Move residual connections to fp32.')
545
546
547
    group.add_argument('--no-query-key-layer-scaling', action='store_false',
                       help='Do not scale Q * K^T by 1 / layer-number.',
                       dest='apply_query_key_layer_scaling')
Mohammad's avatar
Mohammad committed
548
    group.add_argument('--attention-softmax-in-fp32', action='store_true',
549
550
551
                       help='Run attention masking and softmax in fp32. '
                       'This flag is ignored unless '
                       '--no-query-key-layer-scaling is specified.')
Mohammad's avatar
Mohammad committed
552
553
    group.add_argument('--fp32-allreduce', action='store_true',
                       help='All-reduce in fp32')
554
555
556
557
    group.add_argument('--fp16-lm-cross-entropy', action='store_true',
                       help='Move the cross entropy unreduced loss calculation'
                       'for lm head to fp16.')

Mohammad's avatar
Mohammad committed
558
559
560
    return parser


Mohammad's avatar
Mohammad committed
561
def _add_distributed_args(parser):
562
563
    group = parser.add_argument_group(title='distributed')

564
565
566
567
    group.add_argument('--tensor-model-parallel-size', type=int, default=1,
                       help='Degree of tensor model parallelism.')
    group.add_argument('--pipeline-model-parallel-size', type=int, default=1,
                       help='Degree of pipeline model parallelism.')
568
569
570
    group.add_argument('--model-parallel-size', type=int, default=None,
                       help='Old model parallel argument, do not use. Use '
                       '--tensor-model-parallel-size instead.')
571
572
    group.add_argument('--num-layers-per-virtual-pipeline-stage', type=int, default=None,
                       help='Number of layers per virtual pipeline stage')
Mohammad's avatar
Mohammad committed
573
574
575
576
    group.add_argument('--distributed-backend', default='nccl',
                       choices=['nccl', 'gloo'],
                       help='Which backend to use for distributed training.')
    group.add_argument('--DDP-impl', default='local',
Mohammad's avatar
Mohammad committed
577
                       choices=['local', 'torch'],
Mohammad's avatar
Mohammad committed
578
579
                       help='which DistributedDataParallel implementation '
                       'to use.')
580
581
582
    group.add_argument('--no-scatter-gather-tensors-in-pipeline', action='store_false',
                       help='Use scatter/gather to optimize communication of tensors in pipeline',
                       dest='scatter_gather_tensors_in_pipeline')
Mohammad's avatar
Mohammad committed
583
584
    group.add_argument('--local_rank', type=int, default=None,
                       help='local rank passed from distributed launcher.')
585
    group.add_argument('--lazy-mpu-init', type=bool, required=False,
586
587
588
589
590
591
592
593
                       help='If set to True, initialize_megatron() '
                       'skips DDP initialization and returns function to '
                       'complete it instead.Also turns on '
                       '--use-cpu-initialization flag. This is for '
                       'external DDP manager.' )
    group.add_argument('--use-cpu-initialization', action='store_true',
                       default=None, help='If set, affine parallel weights '
                       'initialization uses CPU' )
Mohammad's avatar
Mohammad committed
594
595
596
    return parser


Mohammad's avatar
Mohammad committed
597
def _add_validation_args(parser):
Mohammad's avatar
Mohammad committed
598
599
600
601
602
603
604
605
606
    group = parser.add_argument_group(title='validation')

    group.add_argument('--eval-iters', type=int, default=100,
                       help='Number of iterations to run for evaluation'
                       'validation/test for.')
    group.add_argument('--eval-interval', type=int, default=1000,
                       help='Interval between running evaluation on '
                       'validation set.')

Mohammad's avatar
Mohammad committed
607
608
609
    return parser


Mohammad's avatar
Mohammad committed
610
def _add_data_args(parser):
Mohammad's avatar
Mohammad committed
611
612
    group = parser.add_argument_group(title='data and dataloader')

mohammad's avatar
mohammad committed
613
    group.add_argument('--data-path', nargs='*', default=None,
mohammad's avatar
mohammad committed
614
615
616
617
                       help='Path to the training dataset. Accepted format:'
                       '1) a single data path, 2) multiple datasets in the'
                       'form: dataset1-weight dataset1-path dataset2-weight '
                       'dataset2-path ...')
Mohammad's avatar
Mohammad committed
618
    group.add_argument('--split', type=str, default='969, 30, 1',
Mohammad's avatar
Mohammad committed
619
620
                       help='Comma-separated list of proportions for training,'
                       ' validation, and test split. For example the split '
621
622
                       '`90,5,5` will use 90%% of data for training, 5%% for '
                       'validation and 5%% for test.')
Mohammad's avatar
Mohammad committed
623
    group.add_argument('--vocab-file', type=str, default=None,
Mohammad's avatar
Mohammad committed
624
                       help='Path to the vocab file.')
Mohammad's avatar
Mohammad committed
625
626
    group.add_argument('--merge-file', type=str, default=None,
                       help='Path to the BPE merge file.')
Mohammad's avatar
Mohammad committed
627
    group.add_argument('--seq-length', type=int, default=None,
628
                       help='Maximum sequence length to process.')
629
    group.add_argument('--encoder-seq-length', type=int, default=None,
630
631
                       help='Maximum encoder sequence length to process.'
                       'This should be exclusive of --seq-length')
632
633
    group.add_argument('--decoder-seq-length', type=int, default=None,
                       help="Maximum decoder sequence length to process.")
Mostofa Patwary's avatar
Mostofa Patwary committed
634
635
    group.add_argument('--retriever-seq-length', type=int, default=256,
                       help='Maximum sequence length for the biencoder model '
Mostofa Patwary's avatar
Mostofa Patwary committed
636
                        ' for retriever')
637
638
639
    group.add_argument('--sample-rate', type=float, default=1.0,
                       help='sample rate for training data. Supposed to be 0 '
                            ' < sample_rate < 1')
Mohammad's avatar
Mohammad committed
640
641
642
643
644
645
646
647
    group.add_argument('--mask-prob', type=float, default=0.15,
                       help='Probability of replacing a token with mask.')
    group.add_argument('--short-seq-prob', type=float, default=0.1,
                       help='Probability of producing a short sequence.')
    group.add_argument('--mmap-warmup', action='store_true',
                       help='Warm up mmap files.')
    group.add_argument('--num-workers', type=int, default=2,
                       help="Dataloader number of workers.")
Mohammad's avatar
Mohammad committed
648
649
650
    group.add_argument('--tokenizer-type', type=str,
                       default=None,
                       choices=['BertWordPieceLowerCase',
Raul Puri's avatar
Raul Puri committed
651
                                'BertWordPieceCase',
Mohammad's avatar
Mohammad committed
652
653
                                'GPT2BPETokenizer'],
                       help='What type of tokenizer to use.')
654
655
656
657
658
659
660
661
662
663
    group.add_argument('--data-impl', type=str, default='infer',
                       choices=['lazy', 'cached', 'mmap', 'infer'],
                       help='Implementation of indexed datasets.')
    group.add_argument('--reset-position-ids', action='store_true',
                       help='Reset posistion ids after end-of-document token.')
    group.add_argument('--reset-attention-mask', action='store_true',
                       help='Reset self attention maske after '
                       'end-of-document token.')
    group.add_argument('--eod-mask-loss', action='store_true',
                       help='Mask loss for the end of document tokens.')
Mohammad's avatar
Mohammad committed
664

Mohammad's avatar
Mohammad committed
665
666
    return parser

Raul Puri's avatar
Raul Puri committed
667

Mohammad's avatar
Mohammad committed
668
669
def _add_autoresume_args(parser):
    group = parser.add_argument_group(title='autoresume')
Raul Puri's avatar
Raul Puri committed
670

Mohammad's avatar
Mohammad committed
671
672
673
674
675
    group.add_argument('--adlr-autoresume', action='store_true',
                       help='Enable autoresume on adlr cluster.')
    group.add_argument('--adlr-autoresume-interval', type=int, default=1000,
                       help='Intervals over which check for autoresume'
                       'termination signal')
Raul Puri's avatar
Raul Puri committed
676

Mohammad's avatar
Mohammad committed
677
    return parser
Neel Kant's avatar
Neel Kant committed
678
679


Mostofa Patwary's avatar
Mostofa Patwary committed
680
681
def _add_biencoder_args(parser):
    group = parser.add_argument_group(title='biencoder')
Neel Kant's avatar
Neel Kant committed
682
683
684

    # network size
    group.add_argument('--ict-head-size', type=int, default=None,
685
                       help='Size of block embeddings to be used in ICT and '
Mostofa Patwary's avatar
Mostofa Patwary committed
686
                        'REALM (paper default: 128)')
687
    group.add_argument('--biencoder-projection-dim', type=int, default=0,
Mostofa Patwary's avatar
Mostofa Patwary committed
688
689
                       help='Size of projection head used in biencoder (paper'
                        ' default: 128)')
690
    group.add_argument('--biencoder-shared-query-context-model', action='store_true',
Mostofa Patwary's avatar
Mostofa Patwary committed
691
692
                        help='Whether to share the parameters of the query '
                        'and context models or not')
Neel Kant's avatar
Neel Kant committed
693
694
695
696
697

    # checkpointing
    group.add_argument('--ict-load', type=str, default=None,
                       help='Directory containing an ICTBertModel checkpoint')
    group.add_argument('--bert-load', type=str, default=None,
698
699
                       help='Directory containing an BertModel checkpoint '
                       '(needed to start ICT and REALM)')
Neel Kant's avatar
Neel Kant committed
700
701
702
703
704

    # data
    group.add_argument('--titles-data-path', type=str, default=None,
                       help='Path to titles dataset used for ICT')
    group.add_argument('--query-in-block-prob', type=float, default=0.1,
705
706
                       help='Probability of keeping query in block for '
                       'ICT dataset')
Neel Kant's avatar
Neel Kant committed
707
    group.add_argument('--use-one-sent-docs', action='store_true',
Neel Kant's avatar
Neel Kant committed
708
                       help='Whether to use one sentence documents in ICT')
709
710
    group.add_argument('--evidence-data-path', type=str, default=None,
                       help='Path to Wikipedia Evidence frm DPR paper')
Neel Kant's avatar
Neel Kant committed
711

712
    # training
713
    group.add_argument('--retriever-report-topk-accuracies', nargs='+', type=int,
Mostofa Patwary's avatar
Mostofa Patwary committed
714
715
                        default=[], help="Which top-k accuracies to report "
                        "(e.g. '1 5 20')")
Mostofa Patwary's avatar
Mostofa Patwary committed
716
    group.add_argument('--retriever-score-scaling', action='store_true',
Mostofa Patwary's avatar
Mostofa Patwary committed
717
718
                       help='Whether to scale retriever scores by inverse '
                        'square root of hidden size')
719

Neel Kant's avatar
Neel Kant committed
720
721
722
    # faiss index
    group.add_argument('--faiss-use-gpu', action='store_true',
                       help='Whether create the FaissMIPSIndex on GPU')
Neel Kant's avatar
Neel Kant committed
723
    group.add_argument('--block-data-path', type=str, default=None,
Neel Kant's avatar
Neel Kant committed
724
                       help='Where to save/load BlockData to/from')
Mostofa Patwary's avatar
Mostofa Patwary committed
725
726
727
    group.add_argument('--embedding-path', type=str, default=None,
                       help='Where to save/load Open-Retrieval Embedding'
                        ' data to/from')
Neel Kant's avatar
Neel Kant committed
728
729
730

    # indexer
    group.add_argument('--indexer-batch-size', type=int, default=128,
731
732
                       help='How large of batches to use when doing indexing '
                       'jobs')
Neel Kant's avatar
Neel Kant committed
733
    group.add_argument('--indexer-log-interval', type=int, default=1000,
734
735
                       help='After how many batches should the indexer '
                       'report progress')
Neel Kant's avatar
Neel Kant committed
736
    return parser
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751


def _add_vit_args(parser):
    group = parser.add_argument_group(title="vit")

    group.add_argument('--num-classes', type=int, default=1000,
                       help='num of classes in vision classificaiton task')
    group.add_argument('--img-dim', type=int, default=224,
                       help='Image size for vision classification task')
    group.add_argument('--num-channels', type=int, default=3,
                       help='Number of channels in input image data')
    group.add_argument('--patch-dim', type=int, default=16,
                       help='patch dimension used in vit')

    return parser