arguments.py 62.7 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
Raul Puri's avatar
Raul Puri committed
2

Mohammad's avatar
Mohammad committed
3
"""Megatron arguments."""
Raul Puri's avatar
Raul Puri committed
4
5

import argparse
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
6
import json
Raul Puri's avatar
Raul Puri committed
7
import os
8
import torch
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
9
10
11
12
13
import types

from megatron.global_vars import set_retro_args, get_retro_args
from tools.retro.utils import get_args_path as get_retro_args_path

Raul Puri's avatar
Raul Puri committed
14

15
def parse_args(extra_args_provider=None, ignore_unknown_args=False):
Mohammad's avatar
Mohammad committed
16
    """Parse all arguments."""
17
18
    parser = argparse.ArgumentParser(description='Megatron-LM Arguments',
                                     allow_abbrev=False)
Mohammad's avatar
Mohammad committed
19

Mohammad's avatar
Mohammad committed
20
21
22
23
24
25
26
27
28
29
30
31
    # Standard arguments.
    parser = _add_network_size_args(parser)
    parser = _add_regularization_args(parser)
    parser = _add_training_args(parser)
    parser = _add_initialization_args(parser)
    parser = _add_learning_rate_args(parser)
    parser = _add_checkpointing_args(parser)
    parser = _add_mixed_precision_args(parser)
    parser = _add_distributed_args(parser)
    parser = _add_validation_args(parser)
    parser = _add_data_args(parser)
    parser = _add_autoresume_args(parser)
Mostofa Patwary's avatar
Mostofa Patwary committed
32
    parser = _add_biencoder_args(parser)
33
    parser = _add_vision_args(parser)
34
    parser = _add_logging_args(parser)
mshoeybi's avatar
mshoeybi committed
35
    parser = _add_inference_args(parser)
36
    parser = _add_transformer_engine_args(parser)
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
37
    parser = _add_retro_args(parser)
Mohammad's avatar
Mohammad committed
38
39
40
41

    # Custom arguments.
    if extra_args_provider is not None:
        parser = extra_args_provider(parser)
Mohammad's avatar
Mohammad committed
42

Mohammad's avatar
Mohammad committed
43
    # Parse.
44
45
46
47
    if ignore_unknown_args:
        args, _ = parser.parse_known_args()
    else:
        args = parser.parse_args()
Mohammad's avatar
Mohammad committed
48

49
50
51
52
    # Args from environment
    args.rank = int(os.getenv('RANK', '0'))
    args.world_size = int(os.getenv("WORLD_SIZE", '1'))
        
53
54
55
    return args

def validate_args(args, defaults={}):
mohammad's avatar
mohammad committed
56
    # Tensor model parallel size.
57
58
    args.tensor_model_parallel_size = min(
        args.tensor_model_parallel_size, args.world_size)
mohammad's avatar
mohammad committed
59
60
61
62
    assert args.world_size % args.tensor_model_parallel_size == 0, 'world size'\
        ' ({}) is not divisible by tensor model parallel size ({})'.format(
            args.world_size, args.tensor_model_parallel_size)
    # Pipeline model parallel size.
63
64
65
    args.pipeline_model_parallel_size = min(
        args.pipeline_model_parallel_size,
        (args.world_size // args.tensor_model_parallel_size))
66
67
    args.transformer_pipeline_model_parallel_size = (
        args.pipeline_model_parallel_size - 1
68
        if args.standalone_embedding_stage else
69
70
        args.pipeline_model_parallel_size
    )
mohammad's avatar
mohammad committed
71
    # Checks.
72
73
74
    model_parallel_size = args.pipeline_model_parallel_size * \
                          args.tensor_model_parallel_size
    assert args.world_size % model_parallel_size == 0, 'world size is not'\
75
        ' divisible by tensor parallel size ({}) times pipeline parallel ' \
mohammad's avatar
mohammad committed
76
77
        'size ({})'.format(args.world_size, args.tensor_model_parallel_size,
                           args.pipeline_model_parallel_size)
78
    args.data_parallel_size = args.world_size // model_parallel_size
Mohammad's avatar
Mohammad committed
79
    if args.rank == 0:
mohammad's avatar
mohammad committed
80
81
82
83
84
85
        print('using world size: {}, data-parallel-size: {}, '
              'tensor-model-parallel size: {}, '
              'pipeline-model-parallel size: {} '.format(
                  args.world_size, args.data_parallel_size,
                  args.tensor_model_parallel_size,
                  args.pipeline_model_parallel_size), flush=True)
86
87
88
89
90
91
    if args.pipeline_model_parallel_size > 1:
        if args.pipeline_model_parallel_split_rank is not None:
            assert args.pipeline_model_parallel_split_rank < \
                    args.pipeline_model_parallel_size, 'split rank needs'\
                    ' to be less than pipeline model parallel size ({})'.format(
                            args.pipeline_model_parallel_size)
mohammad's avatar
mohammad committed
92

93
94
95
96
97
98
99
100
101
102
    # Deprecated arguments
    assert args.batch_size is None, '--batch-size argument is no longer ' \
        'valid, use --micro-batch-size instead'
    del args.batch_size
    assert args.warmup is None, '--warmup argument is no longer valid, use ' \
        '--lr-warmup-fraction instead'
    del args.warmup
    assert args.model_parallel_size is None, '--model-parallel-size is no ' \
        'longer valid, use --tensor-model-parallel-size instead'
    del args.model_parallel_size
Vijay Korthikanti's avatar
Vijay Korthikanti committed
103

104
    if args.checkpoint_activations:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
105
106
        args.recompute_granularity = 'full'
        args.recompute_method = 'uniform'
slym's avatar
slym committed
107
108
        if args.rank == 0:
            print('--checkpoint-activations is no longer valid, '
Vijay Korthikanti's avatar
Vijay Korthikanti committed
109
110
                  'use --recompute-granularity and --recompute-method  instead. '
                  'Defaulting to recompute-granularity=full and recompute-method=uniform.')
111
    del args.checkpoint_activations
112

Vijay Korthikanti's avatar
Vijay Korthikanti committed
113
114
115
116
    if args.recompute_activations:
        args.recompute_granularity = 'selective'
    del args.recompute_activations

Jared Casper's avatar
Jared Casper committed
117
118
119
120
121
122
123
124
125
126
127
128
129
130
    # Set input defaults.
    for key in defaults:
        # For default to be valid, it should not be provided in the
        # arguments that are passed to the program. We check this by
        # ensuring the arg is set to None.
        if getattr(args, key) is not None:
            if args.rank == 0:
                print('WARNING: overriding default arguments for {key}:{v} \
                       with {key}:{v2}'.format(key=key, v=defaults[key],
                                               v2=getattr(args, key)),
                                               flush=True)
        else:
            setattr(args, key, defaults[key])

mohammad's avatar
mohammad committed
131
132
133
134
135
136
137
138
139
    # Batch size.
    assert args.micro_batch_size is not None
    assert args.micro_batch_size > 0
    if args.global_batch_size is None:
        args.global_batch_size = args.micro_batch_size * args.data_parallel_size
        if args.rank == 0:
            print('setting global batch size to {}'.format(
                args.global_batch_size), flush=True)
    assert args.global_batch_size > 0
140
    if args.num_layers_per_virtual_pipeline_stage is not None:
141
142
143
        assert args.pipeline_model_parallel_size > 2, \
            'pipeline-model-parallel size should be greater than 2 with ' \
            'interleaved schedule'
144
145
146
147
        assert args.num_layers % args.num_layers_per_virtual_pipeline_stage == 0, \
            'number of layers is not divisible by number of layers per virtual ' \
            'pipeline stage'
        args.virtual_pipeline_model_parallel_size = \
Lawrence McAfee's avatar
Lawrence McAfee committed
148
            (args.num_layers // args.transformer_pipeline_model_parallel_size) // \
149
150
151
            args.num_layers_per_virtual_pipeline_stage
    else:
        args.virtual_pipeline_model_parallel_size = None
Mohammad's avatar
Mohammad committed
152

153
154
155
    # Parameters dtype.
    args.params_dtype = torch.float
    if args.fp16:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
156
        assert not args.bf16
157
        args.params_dtype = torch.half
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
158
159
160
    if args.bf16:
        assert not args.fp16
        args.params_dtype = torch.bfloat16
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
161
162
163
164
165
166
167
        # bfloat16 requires gradient accumulation and all-reduce to
        # be done in fp32.
        if not args.accumulate_allreduce_grads_in_fp32:
            args.accumulate_allreduce_grads_in_fp32 = True
            if args.rank == 0:
                print('accumulate and all-reduce gradients in fp32 for '
                      'bfloat16 data type.', flush=True)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
168

169
170
171
172
    if args.rank == 0:
        print('using {} for parameters ...'.format(args.params_dtype),
              flush=True)

173
174
    # If we do accumulation and all-reduces in fp32, we need to have local DDP
    # and we should make sure use-contiguous-buffers-in-local-ddp is not off.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
175
176
    if args.accumulate_allreduce_grads_in_fp32:
        assert args.DDP_impl == 'local'
177
        assert args.use_contiguous_buffers_in_local_ddp
178

179
180
181
182
183
    # If we use the distributed optimizer, we need to have local DDP
    # and we should make sure use-contiguous-buffers-in-local-ddp is on.
    if args.use_distributed_optimizer:
        assert args.DDP_impl == 'local'
        assert args.use_contiguous_buffers_in_local_ddp
184

mshoeybi's avatar
mshoeybi committed
185
186
187
188
    # For torch DDP, we do not use contiguous buffer
    if args.DDP_impl == 'torch':
        args.use_contiguous_buffers_in_local_ddp = False

189
190
191
    if args.dataloader_type is None:
        args.dataloader_type = 'single'

192
193
194
    # Consumed tokens.
    args.consumed_train_samples = 0
    args.consumed_valid_samples = 0
195

196
197
198
199
200
201
202
    # Support for variable sequence lengths across batches/microbatches.
    # set it if the dataloader supports generation of variable sequence lengths
    # across batches/microbatches. Due to additional communication overhead
    # during pipeline parallelism, it should not be set if sequence length
    # is constant during training.
    args.variable_seq_lengths = False

203
204
205
206
207
208
209
210
211
    # Iteration-based training.
    if args.train_iters:
        # If we use iteration-based training, make sure the
        # sample-based options are off.
        assert args.train_samples is None, \
            'expected iteration-based training'
        assert args.lr_decay_samples is None, \
            'expected iteration-based learning rate decay'
        assert args.lr_warmup_samples == 0, \
212
            'expected iteration-based learning rate warmup'
213
214
        assert args.rampup_batch_size is None, \
            'expected no batch-size rampup for iteration-based training'
215
        if args.lr_warmup_fraction is not None:
216
            assert args.lr_warmup_iters == 0, \
217
                'can only specify one of lr-warmup-fraction and lr-warmup-iters'
218
219
220
221
222
223
224
225
226
227
228

    # Sample-based training.
    if args.train_samples:
        # If we use sample-based training, make sure the
        # iteration-based options are off.
        assert args.train_iters is None, \
            'expected sample-based training'
        assert args.lr_decay_iters is None, \
            'expected sample-based learning rate decay'
        assert args.lr_warmup_iters == 0, \
            'expected sample-based learnig rate warmup'
229
        if args.lr_warmup_fraction is not None:
230
            assert args.lr_warmup_samples == 0, \
231
232
                'can only specify one of lr-warmup-fraction ' \
                'and lr-warmup-samples'
233

234
    if args.num_layers is not None:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
235
236
        assert args.encoder_num_layers is None, \
            'cannot have both num-layers and encoder-num-layers specified'
237
238
        args.encoder_num_layers = args.num_layers
    else:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
239
240
        assert args.encoder_num_layers is not None, \
            'either num-layers or encoder-num-layers should be specified'
241
242
        args.num_layers = args.encoder_num_layers

243
    # Check required arguments.
Mohammad's avatar
Mohammad committed
244
245
    required_args = ['num_layers', 'hidden_size', 'num_attention_heads',
                     'max_position_embeddings']
246
    for req_arg in required_args:
Mohammad's avatar
Mohammad committed
247
        _check_arg_is_not_none(args, req_arg)
248

Mohammad's avatar
Mohammad committed
249
    # Checks.
250
251
252
    if args.ffn_hidden_size is None:
        args.ffn_hidden_size = 4 * args.hidden_size

253
254
255
256
257
258
259
260
    if args.swiglu:
        # reduce the dimnesion for MLP since projections happens on
        # two linear layers. this keeps the number of paramters in
        # the same ballpark as the counterpart with 4*h size
        # we keep it a multiple of 64, which means the actual tensor size
        # will be a multiple of 64 / tp_size
        args.ffn_hidden_size = int((4 * args.hidden_size * 2 / 3) / 64) * 64

261
262
263
264
265
266
267
268
269
270
    if args.kv_channels is None:
        assert args.hidden_size % args.num_attention_heads == 0
        args.kv_channels = args.hidden_size // args.num_attention_heads

    if args.seq_length is not None:
        assert args.encoder_seq_length is None
        args.encoder_seq_length = args.seq_length
    else:
        assert args.encoder_seq_length is not None
        args.seq_length = args.encoder_seq_length
271

Mohammad's avatar
Mohammad committed
272
273
    if args.seq_length is not None:
        assert args.max_position_embeddings >= args.seq_length
Jared Casper's avatar
Jared Casper committed
274
275
    if args.decoder_seq_length is not None:
        assert args.max_position_embeddings >= args.decoder_seq_length
Mohammad's avatar
Mohammad committed
276
277
    if args.lr is not None:
        assert args.min_lr <= args.lr
Mohammad's avatar
Mohammad committed
278
279
    if args.save is not None:
        assert args.save_interval is not None
mohammad's avatar
mohammad committed
280
281
282
    # Mixed precision checks.
    if args.fp16_lm_cross_entropy:
        assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.'
283
    if args.fp32_residual_connection:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
284
285
        assert args.fp16 or args.bf16, \
            'residual connection in fp32 only supported when using fp16 or bf16.'
Vijay Korthikanti's avatar
Vijay Korthikanti committed
286

Vijay Korthikanti's avatar
Vijay Korthikanti committed
287
288
289
290
291
    if args.weight_decay_incr_style == 'constant':
        assert args.start_weight_decay is None
        assert args.end_weight_decay is None
        args.start_weight_decay = args.weight_decay
        args.end_weight_decay = args.weight_decay
Vijay Korthikanti's avatar
Vijay Korthikanti committed
292
    else:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
293
294
        assert args.start_weight_decay is not None
        assert args.end_weight_decay is not None
295

Sangkug Lym's avatar
Sangkug Lym committed
296
297
298
299
300
301
302
303
304
305
    TORCH_MAJOR = int(torch.__version__.split('.')[0])
    TORCH_MINOR = int(torch.__version__.split('.')[1])
    # Persistent fused layer norm.
    if TORCH_MAJOR < 1 or (TORCH_MAJOR == 1 and TORCH_MINOR < 11):
        args.no_persist_layer_norm = True
        if args.rank == 0:
            print('Persistent fused layer norm kernel is supported from '
                  'pytorch v1.11 (nvidia pytorch container paired with v1.11). '
                  'Defaulting to no_persist_layer_norm=True')

Vijay Korthikanti's avatar
Vijay Korthikanti committed
306
    # Activation recomputing.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
307
    if args.distribute_saved_activations:
mshoeybi's avatar
mshoeybi committed
308
        assert args.tensor_model_parallel_size > 1, 'can distribute ' \
Vijay Korthikanti's avatar
Vijay Korthikanti committed
309
            'recomputed activations only across tensor model ' \
mshoeybi's avatar
mshoeybi committed
310
            'parallel groups'
Vijay Korthikanti's avatar
Vijay Korthikanti committed
311
312
313
314
315
316
        assert args.recompute_granularity == 'full', \
            'distributed recompute activations is only '\
            'application to full recompute granularity'
        assert args.recompute_method is not None, \
            'for distributed recompute activations to work you '\
            'need to use a recompute method '
317
        assert TORCH_MAJOR >= 1 and TORCH_MINOR >= 10, \
Vijay Korthikanti's avatar
Vijay Korthikanti committed
318
            'distributed recompute activations are supported for pytorch ' \
319
320
            'v1.10 and above (Nvidia Pytorch container >= 21.07). Current ' \
            'pytorch version is v%s.%s.' % (TORCH_MAJOR, TORCH_MINOR)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
321

322
323
324
325
326
327
328
329
330
331
332
333
    # Tranformer-Engine/FP8 related checking
    if args.fp8_e4m3 or args.fp8_hybrid:
        assert args.transformer_impl == 'transformer_engine', \
            'transformer-engine required for fp8 training and inference'

    assert not (args.fp8_e4m3 and args.fp8_hybrid), \
        'cannot train with both fp8 e4m3 and hybrid formatting'

    if args.fp16:
        assert args.transformer_impl == 'local', \
            'transformer-engine not yet approved for fp16 training and inference'

Vijay Korthikanti's avatar
Vijay Korthikanti committed
334
335
336
337
    if args.recompute_granularity == 'selective':
        assert args.recompute_method is None, \
            'recompute method is not yet supported for ' \
            'selective recomputing granularity'
Vijay Korthikanti's avatar
Vijay Korthikanti committed
338
339
340
341
342
343
344

    # disable sequence parallelism when tp=1
    # to avoid change in numerics when
    # sequence_parallelism is enabled.
    if args.tensor_model_parallel_size == 1:
        args.sequence_parallel = False

Vijay Korthikanti's avatar
Vijay Korthikanti committed
345
    # disable async_tensor_model_parallel_allreduce when
Vijay Korthikanti's avatar
Vijay Korthikanti committed
346
    # model parallel memory optimization is enabled
Vijay Korthikanti's avatar
Vijay Korthikanti committed
347
348
    if args.sequence_parallel:
        args.async_tensor_model_parallel_allreduce = False
Vijay Korthikanti's avatar
Vijay Korthikanti committed
349

350
351
352
353
354
355
356
357
358
359
    if os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1":
        if args.sequence_parallel:
            raise RuntimeError(
                "Using sequence parallelism requires setting the environment variable "
                "CUDA_DEVICE_MAX_CONNECTIONS to 1")
        if args.async_tensor_model_parallel_allreduce:
            raise RuntimeError(
                "Using async gradient all reduce requires setting the environment "
                "variable CUDA_DEVICE_MAX_CONNECTIONS to 1")

360
361
362
363
    # Disable bias gelu fusion if we are disabling bias altogether
    if not args.add_bias_linear:
        args.bias_gelu_fusion = False

Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
    # Load retro args.
    if args.retro_workdir:
        retro_args_path = get_retro_args_path(args.retro_workdir)
        if os.path.exists(retro_args_path):
            with open(retro_args_path) as f:
                retro_args = types.SimpleNamespace(**json.load(f))
                retro_args.retro_return_doc_ids = args.retro_return_doc_ids
                retro_args.retro_gpt_retrieved_length = \
                    args.retro_num_retrieved_chunks * \
                    retro_args.retro_gpt_chunk_length
                set_retro_args(retro_args)

    # Print arguments.
    _print_args("arguments", args)
    retro_args = get_retro_args()
    if retro_args and args != retro_args:
        _print_args("retro arguments", types.SimpleNamespace(**{k:v for k,v in vars(retro_args).items() if k.startswith("retro")}, rank=args.rank))
381

Mohammad's avatar
Mohammad committed
382
    return args
Mohammad's avatar
Mohammad committed
383
384


Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
385
def _print_args(title, args):
Mohammad's avatar
Mohammad committed
386
387
    """Print arguments."""
    if args.rank == 0:
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
388
        print(f'------------------------ {title} ------------------------',
mohammad's avatar
mohammad committed
389
              flush=True)
Mohammad's avatar
Mohammad committed
390
391
        str_list = []
        for arg in vars(args):
mohammad's avatar
mohammad committed
392
            dots = '.' * (48 - len(arg))
Mohammad's avatar
Mohammad committed
393
394
395
            str_list.append('  {} {} {}'.format(arg, dots, getattr(args, arg)))
        for arg in sorted(str_list, key=lambda x: x.lower()):
            print(arg, flush=True)
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
396
        print(f'-------------------- end of {title} ---------------------',
mohammad's avatar
mohammad committed
397
              flush=True)
Mohammad's avatar
Mohammad committed
398
399


400
401
402
403
def _check_arg_is_not_none(args, arg):
    assert getattr(args, arg) is not None, '{} argument is None'.format(arg)


404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
def _add_transformer_engine_args(parser):
    group = parser.add_argument_group(title='Transformer-Engine')

    group.add_argument('--fp8-e4m3', action='store_true',
                        help='E4M3 TransformerLayer', dest='fp8_e4m3')
    group.add_argument('--fp8-hybrid', action='store_true',
                        help='Hybrid FP8 TransformerLayer', dest='fp8_hybrid')
    group.add_argument('--no-fp8-wgrad', action='store_false',
                        help='Execute wgrad in higher precision even for FP8 runs', dest='fp8_wgrad')
    group.add_argument('--fp8-margin', type=int, default=0,
                        help='Scaling margin for fp8', dest='fp8_margin')
    group.add_argument('--fp8-interval', type=int, default=1,
                        help='Scaling update interval for fp8', dest='fp8_interval')
    group.add_argument('--transformer-impl', default='local',
                       choices=['local', 'transformer_engine'],
                       help='Which Transformer implementation to use.',
                       dest='transformer_impl')
    group.add_argument('--fp8-amax-history-len', type=int, default=1,
                        help='Number of steps for which amax history is recorded per tensor',
                        dest='fp8_amax_history_len')
    group.add_argument('--fp8-amax-compute-algo', default='most_recent',
                       choices=['most_recent', 'max'],
                       help='Algorithm for computing amax from history',
                       dest='fp8_amax_compute_algo')

    return parser

mshoeybi's avatar
mshoeybi committed
431
432
433
434
435
436
437
438
def _add_inference_args(parser):
    group = parser.add_argument_group(title='inference')

    group.add_argument('--inference-batch-times-seqlen-threshold',
                       type=int, default=512,
                       help='During inference, if batch-size times '
                       'sequence-length is smaller than this threshold '
                       'then we will not use pipelining, otherwise we will.')
439
440
441
442
443
    group.add_argument('--max-tokens-to-oom',
                       type=int, default=12000,
                       help='Maximum number of tokens during inference'
                       'tokens here is # in prompt + # to generate'
                       'Allows us to throw an error before OOM crashes server')
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
444
445
446
447
448
449
450
451
452
    group.add_argument('--output-bert-embeddings', action='store_true',
                       help='Output Bert embeddings (via mean pooling) from '
                       'model, rather than its binary head output or entire '
                       'hidden batch.')
    group.add_argument('--bert-embedder-type', default="megatron",
                       choices=["megatron", "huggingface"],
                       help='Select either Megatron or Huggingface as the '
                       'Bert embedder.')

mshoeybi's avatar
mshoeybi committed
453
454
    return parser

Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499

def _add_retro_args(parser):
    group = parser.add_argument_group(title='retro')

    group.add_argument('--retro-workdir', default=None,
                       help='Retro working directory, which contains the '
                       'preprocessed data for for pretraining. This directory '
                       'is built during preprocessing (see '
                       'tools/retro/README.md), and contains subdirectories '
                       'for the chunk database and pretraining neighbors.')
    group.add_argument('--retro-add-retriever',
                       action='store_true', default=False,
                       help='Add a retriever to the transformer, for use in '
                       'pretraining a Retro model.')
    group.add_argument('--retro-cyclic-train-iters', type=int, default=None,
                       help='Set number of training iterations for cyclic '
                       'Retro training.')
    group.add_argument('--retro-encoder-layers', type=int, default=2,
                       help='Number of layers to use for the retrieval '
                       'encoder.')
    group.add_argument('--retro-encoder-hidden-dropout',
                       type=float, default=0.1, help='Hidden dropout for '
                       'retrieval encoder.')
    group.add_argument('--retro-encoder-attention-dropout',
                       type=float, default=0.1, help='Attention dropout for '
                       'retrieval encoder.')
    group.add_argument("--retro-num-neighbors", type=int, default=2,
                       help='Number of neighbors to retrieve during '
                       'pretraining.')
    group.add_argument("--retro-num-retrieved-chunks", type=int, default=2,
                       help='Number of chunks to retrieve from the retrieval '
                       'database.')
    group.add_argument("--retro-return-doc-ids", action="store_true",
                       help="Turn this on when preprocessing retro data.")

    # Enforce argument naming convention.
    for action in group._group_actions:
        prefix = action.dest.split("_")[0]
        assert prefix == "retro", \
            "Retro args must be prefixed with '--retro-*', for consistent " \
            "styling. Please fix '%s'." % ", ".join(action.option_strings)

    return parser


Mohammad's avatar
Mohammad committed
500
def _add_network_size_args(parser):
Mohammad's avatar
Mohammad committed
501
    group = parser.add_argument_group(title='network size')
Mohammad's avatar
Mohammad committed
502

503
    group.add_argument('--num-layers', type=int, default=None,
Mohammad's avatar
Mohammad committed
504
                       help='Number of transformer layers.')
505
506
507
508
    group.add_argument('--encoder-num-layers', type=int, default=None,
                       help='Number of encoder transformer layers.')
    group.add_argument('--decoder-num-layers', type=int, default=None,
                       help='Number of decoder transformer layers.')
509
    group.add_argument('--hidden-size', type=int, default=None,
Mohammad's avatar
Mohammad committed
510
                       help='Tansformer hidden size.')
511
    group.add_argument('--ffn-hidden-size', type=int, default=None,
512
513
                       help='Transformer Feed-Forward Network hidden size. '
                       'This is set to 4*hidden-size if not provided')
514
    group.add_argument('--num-attention-heads', type=int, default=None,
Mohammad's avatar
Mohammad committed
515
                       help='Number of transformer attention heads.')
516
    group.add_argument('--kv-channels', type=int, default=None,
517
518
519
520
                       help='Projection weights dimension in multi-head '
                       'attention. This is set to '
                       '   args.hidden_size // args.num_attention_heads '
                       'if not provided.')
521
    group.add_argument('--max-position-embeddings', type=int, default=None,
Mohammad's avatar
Mohammad committed
522
523
524
525
526
                       help='Maximum number of position embeddings to use. '
                       'This is the size of position embedding.')
    group.add_argument('--make-vocab-size-divisible-by', type=int, default=128,
                       help='Pad the vocab size to be divisible by this value.'
                       'This is added for computational efficieny reasons.')
Mohammad's avatar
Mohammad committed
527
528
    group.add_argument('--layernorm-epsilon', type=float, default=1e-5,
                       help='Layer norm epsilon.')
Mostofa Patwary's avatar
Mostofa Patwary committed
529
    group.add_argument('--apply-layernorm-1p', action='store_true',
530
531
                       help='Adjust LayerNorm weights such that they are centered '
                       'around zero. This improves numerical stability.')
Mohammad's avatar
Mohammad committed
532
533
534
535
    group.add_argument('--apply-residual-connection-post-layernorm',
                       action='store_true',
                       help='If set, use original BERT residula connection '
                       'ordering.')
536
537
538
539
    group.add_argument('--openai-gelu', action='store_true',
                       help='Use OpenAIs GeLU implementation. This option'
                       'should not be used unless for backward compatibility'
                       'reasons.')
540
541
542
543
    group.add_argument('--squared-relu', action='store_true',
                       help='Use squared relu activation instead of default gelu')
    group.add_argument('--swiglu', action='store_true',
                       help='Use gated linear units and SiLU activation instead of default gelu')
544
    group.add_argument('--onnx-safe', type=bool, required=False,
545
546
                       help='Use workarounds for known problems with '
                       'Torch ONNX exporter')
547
548
549
    group.add_argument('--bert-no-binary-head', action='store_false',
                       help='Disable BERT binary head.',
                       dest='bert_binary_head')
rprenger's avatar
rprenger committed
550
551
    group.add_argument('--num-experts', type=int, default=None,
                       help='Number of Experts in Switch Transformer (None means no Switch)')
Mohammad's avatar
Mohammad committed
552
553
554
    return parser


555
556
557
558
559
def _add_logging_args(parser):
    group = parser.add_argument_group(title='logging')

    group.add_argument('--log-params-norm', action='store_true',
                       help='If set, calculate and log parameters norm.')
560
    group.add_argument('--log-num-zeros-in-grad', action='store_true',
Rewon Child's avatar
Rewon Child committed
561
                       help='If set, calculate and log the number of zeros in gradient.')
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
    group.add_argument('--timing-log-level', type=int,
                       default=0, choices=range(0,3),
                       help='Granularity level to measure and report timing. '
                       '   0: report only iteration time and make sure timing '
                       '      does not introduce extra overhead.'
                       '   1: report timing for operations that are executed '
                       '      very limited times (basically once) during '
                       '      each iteration (such as gradient all-reduce) '
                       '   2: report timing for operations that migh be '
                       '      executed numerous times during each iteration. '
                       'Note that setting the level to 1 or 2 might '
                       'cause increase in iteration time.')
    group.add_argument('--no-barrier-with-level-1-timing', action='store_false',
                       help='If not set, use barrier with level 1 time '
                       'measurements. Note that this is up to the user '
                       'to make sure calling barrier with their timers '
                       'will not result in hangs. This can happen if for '
                       'example the user adds a level 1 timer that is not '
                       'called by all ranks.',
                       dest='barrier_with_L1_time')
    group.add_argument('--timing-log-option', type=str, default='minmax',
                       choices=['max', 'minmax', 'all'],
                       help='Options for logging timing:'
                       '  max: report the max timing across all ranks'
                       '  minmax: report min and max timings across all ranks'
                       '  all: report timings of all ranks.')
588
589
    group.add_argument('--tensorboard-log-interval', type=int, default=1,
                       help='Report to tensorboard interval.')
590
591
592
593
    group.add_argument('--tensorboard-queue-size', type=int, default=1000,
                       help='Size of the tensorboard queue for pending events '
                       'and summaries before one of the ‘add’ calls forces a '
                       'flush to disk.')
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
    group.add_argument('--log-timers-to-tensorboard', action='store_true',
                       help='If set, write timers to tensorboard.')
    group.add_argument('--log-batch-size-to-tensorboard', action='store_true',
                       help='If set, write batch-size to tensorboard.')
    group.add_argument('--no-log-learnig-rate-to-tensorboard',
                       action='store_false',
                       help='Disable learning rate logging to tensorboard.',
                       dest='log_learning_rate_to_tensorboard')
    group.add_argument('--no-log-loss-scale-to-tensorboard',
                       action='store_false',
                       help='Disable loss-scale logging to tensorboard.',
                       dest='log_loss_scale_to_tensorboard')
    group.add_argument('--log-validation-ppl-to-tensorboard',
                       action='store_true',
                       help='If set, write validation perplexity to '
                       'tensorboard.')
610
611
    group.add_argument('--log-memory-to-tensorboard',
                       action='store_true',
612
                       help='Enable memory logging to tensorboard.')
613
614
615
    group.add_argument('--log-world-size-to-tensorboard',
                       action='store_true',
                       help='Enable world size logging to tensorboard.')
616
617
618
619

    return parser


Mohammad's avatar
Mohammad committed
620
def _add_regularization_args(parser):
Mohammad's avatar
Mohammad committed
621
622
623
    group = parser.add_argument_group(title='regularization')

    group.add_argument('--attention-dropout', type=float, default=0.1,
624
                       help='Post attention dropout probability.')
Mohammad's avatar
Mohammad committed
625
626
627
628
    group.add_argument('--hidden-dropout', type=float, default=0.1,
                       help='Dropout probability for hidden state transformer.')
    group.add_argument('--weight-decay', type=float, default=0.01,
                       help='Weight decay coefficient for L2 regularization.')
Vijay Korthikanti's avatar
Vijay Korthikanti committed
629
    group.add_argument('--start-weight-decay', type=float,
630
                       help='Initial weight decay coefficient for L2 regularization.')
Vijay Korthikanti's avatar
Vijay Korthikanti committed
631
    group.add_argument('--end-weight-decay', type=float,
632
                       help='End of run weight decay coefficient for L2 regularization.')
Vijay Korthikanti's avatar
Vijay Korthikanti committed
633
    group.add_argument('--weight-decay-incr-style', type=str, default='constant',
634
635
                       choices=['constant', 'linear', 'cosine'],
                       help='Weight decay increment function.')
Mohammad's avatar
Mohammad committed
636
637
    group.add_argument('--clip-grad', type=float, default=1.0,
                       help='Gradient clipping based on global L2 norm.')
638
    group.add_argument('--adam-beta1', type=float, default=0.9,
639
640
                       help='First coefficient for computing running averages '
                       'of gradient and its square')
641
    group.add_argument('--adam-beta2', type=float, default=0.999,
642
643
                       help='Second coefficient for computing running averages '
                       'of gradient and its square')
644
    group.add_argument('--adam-eps', type=float, default=1e-08,
645
                       help='Term added to the denominator to improve'
646
                       'numerical stability')
647
648
    group.add_argument('--sgd-momentum', type=float, default=0.9,
                       help='Momentum factor for sgd')
Mohammad's avatar
Mohammad committed
649
650
651

    return parser

Mohammad's avatar
Mohammad committed
652
653

def _add_training_args(parser):
Mohammad's avatar
Mohammad committed
654
655
    group = parser.add_argument_group(title='training')

656
    group.add_argument('--micro-batch-size', type=int, default=None,
Mohammad's avatar
Mohammad committed
657
658
                       help='Batch size per model instance (local batch size). '
                       'Global batch size is local batch size times data '
mohammad's avatar
mohammad committed
659
                       'parallel size times number of micro batches.')
660
661
662
    group.add_argument('--batch-size', type=int, default=None,
                       help='Old batch size parameter, do not use. '
                       'Use --micro-batch-size instead')
mohammad's avatar
mohammad committed
663
    group.add_argument('--global-batch-size', type=int, default=None,
mohammad's avatar
mohammad committed
664
665
666
                       help='Training batch size. If set, it should be a '
                       'multiple of micro-batch-size times data-parallel-size. '
                       'If this value is None, then '
mohammad's avatar
mohammad committed
667
                       'use micro-batch-size * data-parallel-size as the '
mohammad's avatar
mohammad committed
668
669
                       'global batch size. This choice will result in 1 for '
                       'number of micro-batches.')
mohammad's avatar
mohammad committed
670
671
672
673
674
675
676
677
678
679
680
681
    group.add_argument('--rampup-batch-size', nargs='*', default=None,
                       help='Batch size ramp up with the following values:'
                       '  --rampup-batch-size <start batch size> '
                       '                      <batch size incerement> '
                       '                      <ramp-up samples> '
                       'For example:'
                       '   --rampup-batch-size 16 8 300000 \ '
                       '   --global-batch-size 1024'
                       'will start with global batch size 16 and over '
                       ' (1024 - 16) / 8 = 126 intervals will increase'
                       'the batch size linearly to 1024. In each interval'
                       'we will use approximately 300000 / 126 = 2380 samples.')
Vijay Korthikanti's avatar
Vijay Korthikanti committed
682
683
    group.add_argument('--recompute-activations', action='store_true',
                       help='recompute activation to allow for training '
Mohammad's avatar
Mohammad committed
684
                       'with larger models, sequences, and batch sizes.')
Vijay Korthikanti's avatar
Vijay Korthikanti committed
685
    group.add_argument('--recompute-granularity', type=str, default=None,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
686
                       choices=['full', 'selective'],
Vijay Korthikanti's avatar
Vijay Korthikanti committed
687
                       help='Checkpoint activations to allow for training '
Vijay Korthikanti's avatar
Vijay Korthikanti committed
688
689
                       'with larger models, sequences, and batch sizes. '
                       'It is supported at two granularities 1) full: '
Vijay Korthikanti's avatar
Vijay Korthikanti committed
690
                       'whole transformer layer is recomputed, '
Vijay Korthikanti's avatar
Vijay Korthikanti committed
691
                       '2) selective: core attention part of the transformer '
Vijay Korthikanti's avatar
Vijay Korthikanti committed
692
                       'layer is recomputed.')
Vijay Korthikanti's avatar
Vijay Korthikanti committed
693
    group.add_argument('--distribute-saved-activations',
694
                       action='store_true',
Vijay Korthikanti's avatar
Vijay Korthikanti committed
695
                       help='If set, distribute recomputed activations '
696
                       'across model parallel group.')
Vijay Korthikanti's avatar
Vijay Korthikanti committed
697
    group.add_argument('--recompute-method', type=str, default=None,
698
699
                       choices=['uniform', 'block'],
                       help='1) uniform: uniformly divide the total number of '
Vijay Korthikanti's avatar
Vijay Korthikanti committed
700
                       'Transformer layers and recompute the input activation of '
Vijay Korthikanti's avatar
Vijay Korthikanti committed
701
                       'each divided chunk at specified granularity, '
Vijay Korthikanti's avatar
Vijay Korthikanti committed
702
                       '2) recompute the input activations of only a set number of '
slym's avatar
slym committed
703
                       'individual Transformer layers per pipeline stage and do the '
Vijay Korthikanti's avatar
Vijay Korthikanti committed
704
705
706
                       'rest without any recomputing at specified granularity'
                       'default) do not apply activations recompute to any layers')
    group.add_argument('--recompute-num-layers', type=int, default=1,
707
                       help='1) uniform: the number of Transformer layers in each '
Vijay Korthikanti's avatar
Vijay Korthikanti committed
708
                       'uniformly divided recompute unit, '
709
                       '2) block: the number of individual Transformer layers '
Vijay Korthikanti's avatar
Vijay Korthikanti committed
710
                       'to recompute within each pipeline stage.')
Vijay Korthikanti's avatar
Vijay Korthikanti committed
711
712
713
714
715

    # deprecated
    group.add_argument('--checkpoint-activations', action='store_true',
                       help='Checkpoint activation to allow for training '
                       'with larger models, sequences, and batch sizes.')
Mohammad's avatar
Mohammad committed
716
    group.add_argument('--train-iters', type=int, default=None,
Mohammad's avatar
Mohammad committed
717
                       help='Total number of iterations to train over all '
718
719
720
721
722
723
                       'training runs. Note that either train-iters or '
                       'train-samples should be provided.')
    group.add_argument('--train-samples', type=int, default=None,
                       help='Total number of samples to train over all '
                       'training runs. Note that either train-iters or '
                       'train-samples should be provided.')
Mohammad's avatar
Mohammad committed
724
725
726
727
728
    group.add_argument('--log-interval', type=int, default=100,
                       help='Report loss and timing interval.')
    group.add_argument('--exit-interval', type=int, default=None,
                       help='Exit the program after the iteration is divisible '
                       'by this value.')
729
730
    group.add_argument('--exit-duration-in-mins', type=int, default=None,
                       help='Exit the program after this many minutes.')
731
732
733
    group.add_argument('--exit-signal-handler', action='store_true',
                       help='Dynamically save the checkpoint and shutdown the '
                       'training if SIGTERM is received')
Mohammad's avatar
Mohammad committed
734
735
    group.add_argument('--tensorboard-dir', type=str, default=None,
                       help='Write TensorBoard logs to this directory.')
736
    group.add_argument('--no-masked-softmax-fusion',
737
738
739
                       action='store_false',
                       help='Disable fusion of query_key_value scaling, '
                       'masking, and softmax.',
740
                       dest='masked_softmax_fusion')
741
742
743
744
745
746
    group.add_argument('--no-bias-gelu-fusion', action='store_false',
                       help='Disable bias and gelu fusion.',
                       dest='bias_gelu_fusion')
    group.add_argument('--no-bias-dropout-fusion', action='store_false',
                       help='Disable bias and dropout fusion.',
                       dest='bias_dropout_fusion')
747
748
749
    group.add_argument('--use-flash-attn', action='store_true',
                       help='use FlashAttention implementation of attention. '
                       'https://arxiv.org/abs/2205.14135')
750
751
752
    group.add_argument('--disable-bias-linear', action='store_false',
                       help='Disable bias in the linear layers',
                       dest='add_bias_linear')
753
754
755
    group.add_argument('--optimizer', type=str, default='adam',
                       choices=['adam', 'sgd'],
                       help='Optimizer function')
756
    group.add_argument('--dataloader-type', type=str, default=None,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
757
758
                       choices=['single', 'cyclic'],
                       help='Single pass vs multiple pass data loader')
slym's avatar
slym committed
759
    group.add_argument('--no-async-tensor-model-parallel-allreduce',
Sangkug Lym's avatar
Sangkug Lym committed
760
                       action='store_false',
slym's avatar
slym committed
761
762
                       help='Disable asynchronous execution of '
                       'tensor-model-parallel all-reduce with weight '
Sangkug Lym's avatar
Sangkug Lym committed
763
764
                       'gradient compuation of a column-linear layer.',
                       dest='async_tensor_model_parallel_allreduce')
Sangkug Lym's avatar
Sangkug Lym committed
765
766
767
768
769
    group.add_argument('--no-persist-layer-norm', action='store_true',
                       help='Disable using persistent fused layer norm kernel. '
                       'This kernel supports only a set of hidden sizes. Please '
                       'check persist_ln_hidden_sizes if your hidden '
                       'size is supported.')
Vijay Korthikanti's avatar
Vijay Korthikanti committed
770
    group.add_argument('--sequence-parallel', action='store_true',
Vijay Korthikanti's avatar
Vijay Korthikanti committed
771
                       help='Enable sequence parallel optimization.')
Sangkug Lym's avatar
Sangkug Lym committed
772
773
    group.add_argument('--no-gradient-accumulation-fusion',
                       action='store_false',
774
                       help='Disable fusing gradient accumulation to weight '
Sangkug Lym's avatar
Sangkug Lym committed
775
776
                       'gradient computation of linear layers',
                       dest='gradient_accumulation_fusion')
Mohammad's avatar
Mohammad committed
777
778
779
    return parser


Mohammad's avatar
Mohammad committed
780
def _add_initialization_args(parser):
Mohammad's avatar
Mohammad committed
781
782
783
784
785
    group = parser.add_argument_group(title='initialization')

    group.add_argument('--seed', type=int, default=1234,
                       help='Random seed used for python, numpy, '
                       'pytorch, and cuda.')
786
787
788
    group.add_argument('--data-parallel-random-init', action='store_true',
                       help='Enable random initialization of params '
                       'across data parallel ranks')
Mohammad's avatar
Mohammad committed
789
790
791
    group.add_argument('--init-method-std', type=float, default=0.02,
                       help='Standard deviation of the zero mean normal '
                       'distribution used for weight initialization.')
792
793
    group.add_argument('--init-method-xavier-uniform', action='store_true',
                       help='Enable Xavier uniform parameter initialization')
Mohammad's avatar
Mohammad committed
794

Mohammad's avatar
Mohammad committed
795
796
797
    return parser


Mohammad's avatar
Mohammad committed
798
def _add_learning_rate_args(parser):
Mohammad's avatar
Mohammad committed
799
800
    group = parser.add_argument_group(title='learning rate')

Mohammad's avatar
Mohammad committed
801
    group.add_argument('--lr', type=float, default=None,
Mohammad's avatar
Mohammad committed
802
803
804
805
                       help='Initial learning rate. Depending on decay style '
                       'and initial warmup, the learing rate at each '
                       'iteration would be different.')
    group.add_argument('--lr-decay-style', type=str, default='linear',
806
                       choices=['constant', 'linear', 'cosine', 'inverse-square-root'],
Mohammad's avatar
Mohammad committed
807
808
809
810
                       help='Learning rate decay function.')
    group.add_argument('--lr-decay-iters', type=int, default=None,
                       help='number of iterations to decay learning rate over,'
                       ' If None defaults to `--train-iters`')
811
812
813
    group.add_argument('--lr-decay-samples', type=int, default=None,
                       help='number of samples to decay learning rate over,'
                       ' If None defaults to `--train-samples`')
814
815
816
    group.add_argument('--lr-warmup-fraction', type=float, default=None,
                       help='fraction of lr-warmup-(iters/samples) to use '
                       'for warmup (as a float)')
817
818
819
820
821
822
    group.add_argument('--lr-warmup-iters', type=int, default=0,
                       help='number of iterations to linearly warmup '
                       'learning rate over.')
    group.add_argument('--lr-warmup-samples', type=int, default=0,
                       help='number of samples to linearly warmup '
                       'learning rate over.')
823
    group.add_argument('--warmup', type=int, default=None,
824
                       help='Old lr warmup argument, do not use. Use one of the'
825
                       '--lr-warmup-* arguments above')
Mohammad's avatar
Mohammad committed
826
827
828
    group.add_argument('--min-lr', type=float, default=0.0,
                       help='Minumum value for learning rate. The scheduler'
                       'clip values below this threshold.')
829
    group.add_argument('--override-opt_param-scheduler', action='store_true',
Mohammad's avatar
Mohammad committed
830
831
832
833
834
                       help='Reset the values of the scheduler (learning rate,'
                       'warmup iterations, minimum learning rate, maximum '
                       'number of iterations, and decay style from input '
                       'arguments and ignore values from checkpoints. Note'
                       'that all the above values will be reset.')
835
    group.add_argument('--use-checkpoint-opt_param-scheduler', action='store_true',
Mohammad's avatar
Mohammad committed
836
837
838
839
840
841
842
843
                       help='Use checkpoint to set the values of the scheduler '
                       '(learning rate, warmup iterations, minimum learning '
                       'rate, maximum number of iterations, and decay style '
                       'from checkpoint and ignore input arguments.')

    return parser


Mohammad's avatar
Mohammad committed
844
def _add_checkpointing_args(parser):
Mohammad's avatar
Mohammad committed
845
846
847
848
849
850
    group = parser.add_argument_group(title='checkpointing')

    group.add_argument('--save', type=str, default=None,
                       help='Output directory to save checkpoints to.')
    group.add_argument('--save-interval', type=int, default=None,
                       help='Number of iterations between checkpoint saves.')
851
    group.add_argument('--no-save-optim', action='store_true', default=None,
Mohammad's avatar
Mohammad committed
852
                       help='Do not save current optimizer.')
853
    group.add_argument('--no-save-rng', action='store_true', default=None,
Mohammad's avatar
Mohammad committed
854
855
856
                       help='Do not save current rng state.')
    group.add_argument('--load', type=str, default=None,
                       help='Directory containing a model checkpoint.')
Jared Casper's avatar
Jared Casper committed
857
    group.add_argument('--no-load-optim', action='store_true', default=None,
Mohammad's avatar
Mohammad committed
858
                       help='Do not load optimizer when loading checkpoint.')
Jared Casper's avatar
Jared Casper committed
859
    group.add_argument('--no-load-rng', action='store_true', default=None,
Mohammad's avatar
Mohammad committed
860
861
862
863
864
                       help='Do not load rng state when loading checkpoint.')
    group.add_argument('--finetune', action='store_true',
                       help='Load model for finetuning. Do not load optimizer '
                       'or rng state from checkpoint and set iteration to 0. '
                       'Assumed when loading a release checkpoint.')
865
866
867
868
869
    group.add_argument('--no-initialization', action='store_false',
                       help='Do not perform initialization when building model, '
                       'can reduce startup time when definitely loading from a '
                       'checkpoint',
                       dest='perform_initialization')
870
871
872
    group.add_argument('--use-checkpoint-args', action='store_true',
                       help='Override any command line arguments with arguments '
                       'from the checkpoint')
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
873
874
875
876
    group.add_argument('--exit-on-missing-checkpoint', action='store_true',
                       help="If '--load' is set, but checkpoint is not found "
                       "(e.g., path typo), then exit instead of random "
                       "initialization.")
Mohammad's avatar
Mohammad committed
877
878
879
880

    return parser


Mohammad's avatar
Mohammad committed
881
def _add_mixed_precision_args(parser):
Mohammad's avatar
Mohammad committed
882
883
884
885
    group = parser.add_argument_group(title='mixed precision')

    group.add_argument('--fp16', action='store_true',
                       help='Run model in fp16 mode.')
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
886
887
    group.add_argument('--bf16', action='store_true',
                       help='Run model in bfloat16 mode.')
mohammad's avatar
mohammad committed
888
889
890
891
892
893
894
895
896
897
898
899
    group.add_argument('--loss-scale', type=float, default=None,
                       help='Static loss scaling, positive power of 2 '
                       'values can improve fp16 convergence. If None, dynamic'
                       'loss scaling is used.')
    group.add_argument('--initial-loss-scale', type=float, default=2**32,
                       help='Initial loss-scale for dynamic loss scaling.')
    group.add_argument('--min-loss-scale', type=float, default=1.0,
                       help='Minimum loss scale for dynamic loss scale.')
    group.add_argument('--loss-scale-window', type=float, default=1000,
                       help='Window over which to raise/lower dynamic scale.')
    group.add_argument('--hysteresis', type=int, default=2,
                       help='hysteresis for dynamic loss scaling')
900
901
    group.add_argument('--fp32-residual-connection', action='store_true',
                       help='Move residual connections to fp32.')
902
903
904
    group.add_argument('--no-query-key-layer-scaling', action='store_false',
                       help='Do not scale Q * K^T by 1 / layer-number.',
                       dest='apply_query_key_layer_scaling')
Mohammad's avatar
Mohammad committed
905
    group.add_argument('--attention-softmax-in-fp32', action='store_true',
906
907
908
                       help='Run attention masking and softmax in fp32. '
                       'This flag is ignored unless '
                       '--no-query-key-layer-scaling is specified.')
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
909
910
911
    group.add_argument('--accumulate-allreduce-grads-in-fp32',
                       action='store_true',
                       help='Gradient accumulation and all-reduce in fp32.')
912
913
914
915
    group.add_argument('--fp16-lm-cross-entropy', action='store_true',
                       help='Move the cross entropy unreduced loss calculation'
                       'for lm head to fp16.')

Mohammad's avatar
Mohammad committed
916
917
918
    return parser


Mohammad's avatar
Mohammad committed
919
def _add_distributed_args(parser):
920
921
    group = parser.add_argument_group(title='distributed')

922
923
924
925
    group.add_argument('--tensor-model-parallel-size', type=int, default=1,
                       help='Degree of tensor model parallelism.')
    group.add_argument('--pipeline-model-parallel-size', type=int, default=1,
                       help='Degree of pipeline model parallelism.')
926
927
928
    group.add_argument('--pipeline-model-parallel-split-rank',
                       type=int, default=None,
                       help='Rank where encoder and decoder should be split.')
929
930
931
    group.add_argument('--model-parallel-size', type=int, default=None,
                       help='Old model parallel argument, do not use. Use '
                       '--tensor-model-parallel-size instead.')
932
933
    group.add_argument('--num-layers-per-virtual-pipeline-stage', type=int, default=None,
                       help='Number of layers per virtual pipeline stage')
Mohammad's avatar
Mohammad committed
934
935
936
    group.add_argument('--distributed-backend', default='nccl',
                       choices=['nccl', 'gloo'],
                       help='Which backend to use for distributed training.')
Lawrence McAfee's avatar
Retro  
Lawrence McAfee committed
937
938
    group.add_argument('--distributed-timeout-minutes', type=int, default=10,
                       help='Timeout minutes for torch.distributed.')
Mohammad's avatar
Mohammad committed
939
    group.add_argument('--DDP-impl', default='local',
Mohammad's avatar
Mohammad committed
940
                       choices=['local', 'torch'],
Mohammad's avatar
Mohammad committed
941
942
                       help='which DistributedDataParallel implementation '
                       'to use.')
943
944
945
946
    group.add_argument('--no-contiguous-buffers-in-local-ddp',
                       action='store_false', help='If set, dont use '
                       'contiguous buffer in local DDP.',
                       dest='use_contiguous_buffers_in_local_ddp')
947
948
949
    group.add_argument('--no-scatter-gather-tensors-in-pipeline', action='store_false',
                       help='Use scatter/gather to optimize communication of tensors in pipeline',
                       dest='scatter_gather_tensors_in_pipeline')
950
951
952
953
    group.add_argument('--use-ring-exchange-p2p', action='store_true',
                       default=False, help='If set, use custom-built ring exchange '
                       'for p2p communications. Note that this option will require '
                       'a custom built image that support ring-exchange p2p.')
Mohammad's avatar
Mohammad committed
954
955
    group.add_argument('--local_rank', type=int, default=None,
                       help='local rank passed from distributed launcher.')
956
    group.add_argument('--lazy-mpu-init', type=bool, required=False,
957
958
959
960
961
962
963
964
                       help='If set to True, initialize_megatron() '
                       'skips DDP initialization and returns function to '
                       'complete it instead.Also turns on '
                       '--use-cpu-initialization flag. This is for '
                       'external DDP manager.' )
    group.add_argument('--use-cpu-initialization', action='store_true',
                       default=None, help='If set, affine parallel weights '
                       'initialization uses CPU' )
Lawrence McAfee's avatar
Lawrence McAfee committed
965
    group.add_argument('--empty-unused-memory-level', default=0, type=int,
966
967
968
969
                       choices=[0, 1, 2],
                       help='Call torch.cuda.empty_cache() each iteration '
                       '(training and eval), to reduce fragmentation.'
                       '0=off, 1=moderate, 2=aggressive.')
970
    group.add_argument('--standalone-embedding-stage', action='store_true',
Lawrence McAfee's avatar
Lawrence McAfee committed
971
972
                       default=False, help='If set, *input* embedding layer '
                       'is placed on its own pipeline stage, without any '
Lawrence McAfee's avatar
Lawrence McAfee committed
973
974
                       'transformer layers. (For T5, this flag currently only '
                       'affects the encoder embedding.)')
975
976
    group.add_argument('--use-distributed-optimizer', action='store_true',
                       help='Use distributed optimizer.')
977

Mohammad's avatar
Mohammad committed
978
979
980
    return parser


Mohammad's avatar
Mohammad committed
981
def _add_validation_args(parser):
Mohammad's avatar
Mohammad committed
982
983
984
985
986
987
988
989
990
    group = parser.add_argument_group(title='validation')

    group.add_argument('--eval-iters', type=int, default=100,
                       help='Number of iterations to run for evaluation'
                       'validation/test for.')
    group.add_argument('--eval-interval', type=int, default=1000,
                       help='Interval between running evaluation on '
                       'validation set.')

Mohammad's avatar
Mohammad committed
991
992
993
    return parser


Mohammad's avatar
Mohammad committed
994
def _add_data_args(parser):
Mohammad's avatar
Mohammad committed
995
996
    group = parser.add_argument_group(title='data and dataloader')

mohammad's avatar
mohammad committed
997
    group.add_argument('--data-path', nargs='*', default=None,
mohammad's avatar
mohammad committed
998
999
1000
                       help='Path to the training dataset. Accepted format:'
                       '1) a single data path, 2) multiple datasets in the'
                       'form: dataset1-weight dataset1-path dataset2-weight '
1001
1002
1003
1004
                       'dataset2-path ... It is used with --split when a '
                       'single dataset used for all three: train, valid '
                       'and test. It is exclusive to the other '
                       '--*-data-path args')
Mohammad's avatar
Mohammad committed
1005
    group.add_argument('--split', type=str, default='969, 30, 1',
Mohammad's avatar
Mohammad committed
1006
1007
                       help='Comma-separated list of proportions for training,'
                       ' validation, and test split. For example the split '
1008
1009
                       '`90,5,5` will use 90%% of data for training, 5%% for '
                       'validation and 5%% for test.')
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
    group.add_argument('--train-data-path', nargs='*', default=None,
                       help='Path to the training dataset. Accepted format:'
                       '1) a single data path, 2) multiple datasets in the'
                       'form: dataset1-weight dataset1-path dataset2-weight '
                       'dataset2-path ...')
    group.add_argument('--valid-data-path', nargs='*', default=None,
                       help='Path to the validation dataset. Accepted format:'
                       '1) a single data path, 2) multiple datasets in the'
                       'form: dataset1-weight dataset1-path dataset2-weight '
                       'dataset2-path ...')
    group.add_argument('--test-data-path', nargs='*', default=None,
                       help='Path to the test dataset. Accepted format:'
                       '1) a single data path, 2) multiple datasets in the'
                       'form: dataset1-weight dataset1-path dataset2-weight '
                       'dataset2-path ...')
1025

Mohammad's avatar
Mohammad committed
1026
    group.add_argument('--vocab-file', type=str, default=None,
Mohammad's avatar
Mohammad committed
1027
                       help='Path to the vocab file.')
Mohammad's avatar
Mohammad committed
1028
1029
    group.add_argument('--merge-file', type=str, default=None,
                       help='Path to the BPE merge file.')
1030
1031
1032
    group.add_argument('--vocab-extra-ids', type=int, default=0,
                       help='Number of additional vocabulary tokens. '
                            'They are used for span masking in the T5 model')
Mohammad's avatar
Mohammad committed
1033
    group.add_argument('--seq-length', type=int, default=None,
1034
                       help='Maximum sequence length to process.')
1035
    group.add_argument('--encoder-seq-length', type=int, default=None,
1036
1037
                       help='Maximum encoder sequence length to process.'
                       'This should be exclusive of --seq-length')
1038
1039
    group.add_argument('--decoder-seq-length', type=int, default=None,
                       help="Maximum decoder sequence length to process.")
Mostofa Patwary's avatar
Mostofa Patwary committed
1040
1041
    group.add_argument('--retriever-seq-length', type=int, default=256,
                       help='Maximum sequence length for the biencoder model '
1042
                       'for retriever')
1043
1044
1045
    group.add_argument('--sample-rate', type=float, default=1.0,
                       help='sample rate for training data. Supposed to be 0 '
                            ' < sample_rate < 1')
Mohammad's avatar
Mohammad committed
1046
1047
1048
1049
1050
1051
1052
1053
    group.add_argument('--mask-prob', type=float, default=0.15,
                       help='Probability of replacing a token with mask.')
    group.add_argument('--short-seq-prob', type=float, default=0.1,
                       help='Probability of producing a short sequence.')
    group.add_argument('--mmap-warmup', action='store_true',
                       help='Warm up mmap files.')
    group.add_argument('--num-workers', type=int, default=2,
                       help="Dataloader number of workers.")
Mohammad's avatar
Mohammad committed
1054
1055
1056
    group.add_argument('--tokenizer-type', type=str,
                       default=None,
                       choices=['BertWordPieceLowerCase',
Raul Puri's avatar
Raul Puri committed
1057
                                'BertWordPieceCase',
1058
1059
                                'GPT2BPETokenizer',
                                'SentencePieceTokenizer'],
Mohammad's avatar
Mohammad committed
1060
                       help='What type of tokenizer to use.')
1061
    group.add_argument('--tokenizer-model', type=str, default=None,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1062
                       help='Sentencepiece tokenizer model.')
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
    group.add_argument('--data-impl', type=str, default='infer',
                       choices=['lazy', 'cached', 'mmap', 'infer'],
                       help='Implementation of indexed datasets.')
    group.add_argument('--reset-position-ids', action='store_true',
                       help='Reset posistion ids after end-of-document token.')
    group.add_argument('--reset-attention-mask', action='store_true',
                       help='Reset self attention maske after '
                       'end-of-document token.')
    group.add_argument('--eod-mask-loss', action='store_true',
                       help='Mask loss for the end of document tokens.')
Mohammad's avatar
Mohammad committed
1073

Mohammad's avatar
Mohammad committed
1074
1075
    return parser

Raul Puri's avatar
Raul Puri committed
1076

Mohammad's avatar
Mohammad committed
1077
1078
def _add_autoresume_args(parser):
    group = parser.add_argument_group(title='autoresume')
Raul Puri's avatar
Raul Puri committed
1079

Mohammad's avatar
Mohammad committed
1080
1081
1082
1083
1084
    group.add_argument('--adlr-autoresume', action='store_true',
                       help='Enable autoresume on adlr cluster.')
    group.add_argument('--adlr-autoresume-interval', type=int, default=1000,
                       help='Intervals over which check for autoresume'
                       'termination signal')
Raul Puri's avatar
Raul Puri committed
1085

Mohammad's avatar
Mohammad committed
1086
    return parser
Neel Kant's avatar
Neel Kant committed
1087
1088


Mostofa Patwary's avatar
Mostofa Patwary committed
1089
1090
def _add_biencoder_args(parser):
    group = parser.add_argument_group(title='biencoder')
Neel Kant's avatar
Neel Kant committed
1091
1092
1093

    # network size
    group.add_argument('--ict-head-size', type=int, default=None,
1094
                       help='Size of block embeddings to be used in ICT and '
Mostofa Patwary's avatar
Mostofa Patwary committed
1095
                        'REALM (paper default: 128)')
1096
    group.add_argument('--biencoder-projection-dim', type=int, default=0,
Mostofa Patwary's avatar
Mostofa Patwary committed
1097
1098
                       help='Size of projection head used in biencoder (paper'
                        ' default: 128)')
1099
    group.add_argument('--biencoder-shared-query-context-model', action='store_true',
Mostofa Patwary's avatar
Mostofa Patwary committed
1100
1101
                        help='Whether to share the parameters of the query '
                        'and context models or not')
Neel Kant's avatar
Neel Kant committed
1102
1103
1104
1105
1106

    # checkpointing
    group.add_argument('--ict-load', type=str, default=None,
                       help='Directory containing an ICTBertModel checkpoint')
    group.add_argument('--bert-load', type=str, default=None,
1107
1108
                       help='Directory containing an BertModel checkpoint '
                       '(needed to start ICT and REALM)')
Neel Kant's avatar
Neel Kant committed
1109
1110
1111
1112
1113

    # data
    group.add_argument('--titles-data-path', type=str, default=None,
                       help='Path to titles dataset used for ICT')
    group.add_argument('--query-in-block-prob', type=float, default=0.1,
1114
1115
                       help='Probability of keeping query in block for '
                       'ICT dataset')
Neel Kant's avatar
Neel Kant committed
1116
    group.add_argument('--use-one-sent-docs', action='store_true',
Neel Kant's avatar
Neel Kant committed
1117
                       help='Whether to use one sentence documents in ICT')
1118
1119
    group.add_argument('--evidence-data-path', type=str, default=None,
                       help='Path to Wikipedia Evidence frm DPR paper')
Neel Kant's avatar
Neel Kant committed
1120

1121
    # training
1122
    group.add_argument('--retriever-report-topk-accuracies', nargs='+', type=int,
Mostofa Patwary's avatar
Mostofa Patwary committed
1123
1124
                        default=[], help="Which top-k accuracies to report "
                        "(e.g. '1 5 20')")
Mostofa Patwary's avatar
Mostofa Patwary committed
1125
    group.add_argument('--retriever-score-scaling', action='store_true',
Mostofa Patwary's avatar
Mostofa Patwary committed
1126
1127
                       help='Whether to scale retriever scores by inverse '
                        'square root of hidden size')
1128

Neel Kant's avatar
Neel Kant committed
1129
    # faiss index
Neel Kant's avatar
Neel Kant committed
1130
    group.add_argument('--block-data-path', type=str, default=None,
Neel Kant's avatar
Neel Kant committed
1131
                       help='Where to save/load BlockData to/from')
Mostofa Patwary's avatar
Mostofa Patwary committed
1132
1133
1134
    group.add_argument('--embedding-path', type=str, default=None,
                       help='Where to save/load Open-Retrieval Embedding'
                        ' data to/from')
Neel Kant's avatar
Neel Kant committed
1135
1136
1137

    # indexer
    group.add_argument('--indexer-batch-size', type=int, default=128,
1138
1139
                       help='How large of batches to use when doing indexing '
                       'jobs')
Neel Kant's avatar
Neel Kant committed
1140
    group.add_argument('--indexer-log-interval', type=int, default=1000,
1141
1142
                       help='After how many batches should the indexer '
                       'report progress')
Neel Kant's avatar
Neel Kant committed
1143
    return parser
1144
1145


1146
1147
def _add_vision_args(parser):
    group = parser.add_argument_group(title="vision")
1148

1149
    # general vision arguements
1150
1151
    group.add_argument('--num-classes', type=int, default=1000,
                       help='num of classes in vision classificaiton task')
1152
1153
1154
1155
    group.add_argument('--img-h', type=int, default=224,
                       help='Image height for vision classification task')
    group.add_argument('--img-w', type=int, default=224,
                       help='Image height for vision classification task')
1156
1157
1158
    group.add_argument('--num-channels', type=int, default=3,
                       help='Number of channels in input image data')
    group.add_argument('--patch-dim', type=int, default=16,
1159
                       help='patch dimension')
1160
1161
1162
1163
1164
1165
1166
    group.add_argument('--classes-fraction', type=float, default=1.0,
                       help='training with fraction of classes.')
    group.add_argument('--data-per-class-fraction', type=float, default=1.0,
                       help='training with fraction of data per class.')
    group.add_argument('--no-data-sharding', action='store_false',
                       help='Disable data sharding.',
                       dest='data_sharding')
1167
1168
1169
1170
    group.add_argument('--head-lr-mult', type=float, default=1.0,
                       help='learning rate multiplier for head during finetuning')

    # pretraining type and backbone selection`
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1171
1172
    group.add_argument('--vision-pretraining', action='store_true',
                       help='flag to indicate vision pretraining')
1173
    group.add_argument('--vision-pretraining-type', type=str, default='classify',
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1174
                       choices=['classify', 'inpaint', 'dino'],
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
                       help='pretraining objectives')
    group.add_argument('--vision-backbone-type', type=str, default='vit',
                       choices=['vit', 'mit', 'swin'],
                       help='backbone types types')
    group.add_argument('--swin-backbone-type', type=str, default='tiny',
                       choices=['tiny', 'base', 'h3'],
                       help='pretraining objectives')
    
    # inpainting arguments
    group.add_argument('--mask-type', type=str, default='random',
                       choices=['random', 'row'],
                       help='mask types')
    group.add_argument('--mask-factor', type=float, default=1.0,
                       help='mask size scaling parameter')
 
    # dino arguments
    group.add_argument('--iter-per-epoch', type=int, default=1250,
                       help='iterations per epoch')
    group.add_argument('--dino-local-img-size', type=int, default=96,
                       help='Image size for vision classification task')
    group.add_argument('--dino-local-crops-number', type=int, default=10,
                       help='Number of local crops')
    group.add_argument('--dino-head-hidden-size', type=int, default=2048,
                       help='Hidden dimension size in dino head')
    group.add_argument('--dino-bottleneck-size', type=int, default=256,
                       help='Bottle neck dimension in dino head ')
    group.add_argument('--dino-freeze-last-layer', type=float, default=1,
                       help='Freezing last layer weights')
    group.add_argument('--dino-norm-last-layer', action='store_true',
                       help='Disable Norm in last layer.')
    group.add_argument('--dino-warmup-teacher-temp', type=float, default=0.04,
                       help='warump teacher temperature')
    group.add_argument('--dino-teacher-temp', type=float, default=0.07,
                       help='teacher temperature')
    group.add_argument('--dino-warmup-teacher-temp-epochs', type=int, default=30,
                       help='warmup teacher temperaure epochs')
1211
1212

    return parser