arguments.py 51 KB
Newer Older
Raul Puri's avatar
Raul Puri committed
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
Raul Puri's avatar
Raul Puri committed
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Mohammad's avatar
Mohammad committed
16
"""Megatron arguments."""
Raul Puri's avatar
Raul Puri committed
17
18
19
20

import argparse
import os

21
import torch
Raul Puri's avatar
Raul Puri committed
22

23
def parse_args(extra_args_provider=None, ignore_unknown_args=False):
Mohammad's avatar
Mohammad committed
24
    """Parse all arguments."""
25
26
    parser = argparse.ArgumentParser(description='Megatron-LM Arguments',
                                     allow_abbrev=False)
Mohammad's avatar
Mohammad committed
27

Mohammad's avatar
Mohammad committed
28
29
30
31
32
33
34
35
36
37
38
39
    # Standard arguments.
    parser = _add_network_size_args(parser)
    parser = _add_regularization_args(parser)
    parser = _add_training_args(parser)
    parser = _add_initialization_args(parser)
    parser = _add_learning_rate_args(parser)
    parser = _add_checkpointing_args(parser)
    parser = _add_mixed_precision_args(parser)
    parser = _add_distributed_args(parser)
    parser = _add_validation_args(parser)
    parser = _add_data_args(parser)
    parser = _add_autoresume_args(parser)
Mostofa Patwary's avatar
Mostofa Patwary committed
40
    parser = _add_biencoder_args(parser)
41
    parser = _add_vision_args(parser)
42
    parser = _add_logging_args(parser)
mshoeybi's avatar
mshoeybi committed
43
    parser = _add_inference_args(parser)
Mohammad's avatar
Mohammad committed
44
45
46
47

    # Custom arguments.
    if extra_args_provider is not None:
        parser = extra_args_provider(parser)
Mohammad's avatar
Mohammad committed
48

Mohammad's avatar
Mohammad committed
49
    # Parse.
50
51
52
53
    if ignore_unknown_args:
        args, _ = parser.parse_known_args()
    else:
        args = parser.parse_args()
Mohammad's avatar
Mohammad committed
54

55
56
57
58
    # Args from environment
    args.rank = int(os.getenv('RANK', '0'))
    args.world_size = int(os.getenv("WORLD_SIZE", '1'))
        
59
60
61
    return args

def validate_args(args, defaults={}):
mohammad's avatar
mohammad committed
62
    # Tensor model parallel size.
63
64
    args.tensor_model_parallel_size = min(
        args.tensor_model_parallel_size, args.world_size)
mohammad's avatar
mohammad committed
65
66
67
68
    assert args.world_size % args.tensor_model_parallel_size == 0, 'world size'\
        ' ({}) is not divisible by tensor model parallel size ({})'.format(
            args.world_size, args.tensor_model_parallel_size)
    # Pipeline model parallel size.
69
70
71
    args.pipeline_model_parallel_size = min(
        args.pipeline_model_parallel_size,
        (args.world_size // args.tensor_model_parallel_size))
72
73
    args.transformer_pipeline_model_parallel_size = (
        args.pipeline_model_parallel_size - 1
74
        if args.standalone_embedding_stage else
75
76
        args.pipeline_model_parallel_size
    )
mohammad's avatar
mohammad committed
77
    # Checks.
78
79
80
    model_parallel_size = args.pipeline_model_parallel_size * \
                          args.tensor_model_parallel_size
    assert args.world_size % model_parallel_size == 0, 'world size is not'\
81
        ' divisible by tensor parallel size ({}) times pipeline parallel ' \
mohammad's avatar
mohammad committed
82
83
        'size ({})'.format(args.world_size, args.tensor_model_parallel_size,
                           args.pipeline_model_parallel_size)
84
    args.data_parallel_size = args.world_size // model_parallel_size
Mohammad's avatar
Mohammad committed
85
    if args.rank == 0:
mohammad's avatar
mohammad committed
86
87
88
89
90
91
        print('using world size: {}, data-parallel-size: {}, '
              'tensor-model-parallel size: {}, '
              'pipeline-model-parallel size: {} '.format(
                  args.world_size, args.data_parallel_size,
                  args.tensor_model_parallel_size,
                  args.pipeline_model_parallel_size), flush=True)
92
93
94
95
96
97
    if args.pipeline_model_parallel_size > 1:
        if args.pipeline_model_parallel_split_rank is not None:
            assert args.pipeline_model_parallel_split_rank < \
                    args.pipeline_model_parallel_size, 'split rank needs'\
                    ' to be less than pipeline model parallel size ({})'.format(
                            args.pipeline_model_parallel_size)
mohammad's avatar
mohammad committed
98

99
100
101
102
103
104
105
106
107
108
    # Deprecated arguments
    assert args.batch_size is None, '--batch-size argument is no longer ' \
        'valid, use --micro-batch-size instead'
    del args.batch_size
    assert args.warmup is None, '--warmup argument is no longer valid, use ' \
        '--lr-warmup-fraction instead'
    del args.warmup
    assert args.model_parallel_size is None, '--model-parallel-size is no ' \
        'longer valid, use --tensor-model-parallel-size instead'
    del args.model_parallel_size
Vijay Korthikanti's avatar
Vijay Korthikanti committed
109

110
    if args.checkpoint_activations:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
111
112
        args.recompute_granularity = 'full'
        args.recompute_method = 'uniform'
slym's avatar
slym committed
113
114
        if args.rank == 0:
            print('--checkpoint-activations is no longer valid, '
Vijay Korthikanti's avatar
Vijay Korthikanti committed
115
116
                  'use --recompute-granularity and --recompute-method  instead. '
                  'Defaulting to recompute-granularity=full and recompute-method=uniform.')
117
    del args.checkpoint_activations
118

Vijay Korthikanti's avatar
Vijay Korthikanti committed
119
120
121
122
    if args.recompute_activations:
        args.recompute_granularity = 'selective'
    del args.recompute_activations

Jared Casper's avatar
Jared Casper committed
123
124
125
126
127
128
129
130
131
132
133
134
135
136
    # Set input defaults.
    for key in defaults:
        # For default to be valid, it should not be provided in the
        # arguments that are passed to the program. We check this by
        # ensuring the arg is set to None.
        if getattr(args, key) is not None:
            if args.rank == 0:
                print('WARNING: overriding default arguments for {key}:{v} \
                       with {key}:{v2}'.format(key=key, v=defaults[key],
                                               v2=getattr(args, key)),
                                               flush=True)
        else:
            setattr(args, key, defaults[key])

mohammad's avatar
mohammad committed
137
138
139
140
141
142
143
144
145
    # Batch size.
    assert args.micro_batch_size is not None
    assert args.micro_batch_size > 0
    if args.global_batch_size is None:
        args.global_batch_size = args.micro_batch_size * args.data_parallel_size
        if args.rank == 0:
            print('setting global batch size to {}'.format(
                args.global_batch_size), flush=True)
    assert args.global_batch_size > 0
146
    if args.num_layers_per_virtual_pipeline_stage is not None:
147
148
149
        assert args.pipeline_model_parallel_size > 2, \
            'pipeline-model-parallel size should be greater than 2 with ' \
            'interleaved schedule'
150
151
152
153
        assert args.num_layers % args.num_layers_per_virtual_pipeline_stage == 0, \
            'number of layers is not divisible by number of layers per virtual ' \
            'pipeline stage'
        args.virtual_pipeline_model_parallel_size = \
Lawrence McAfee's avatar
Lawrence McAfee committed
154
            (args.num_layers // args.transformer_pipeline_model_parallel_size) // \
155
156
157
            args.num_layers_per_virtual_pipeline_stage
    else:
        args.virtual_pipeline_model_parallel_size = None
Mohammad's avatar
Mohammad committed
158

159
160
161
    # Parameters dtype.
    args.params_dtype = torch.float
    if args.fp16:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
162
        assert not args.bf16
163
        args.params_dtype = torch.half
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
164
165
166
    if args.bf16:
        assert not args.fp16
        args.params_dtype = torch.bfloat16
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
167
168
169
170
171
172
173
        # bfloat16 requires gradient accumulation and all-reduce to
        # be done in fp32.
        if not args.accumulate_allreduce_grads_in_fp32:
            args.accumulate_allreduce_grads_in_fp32 = True
            if args.rank == 0:
                print('accumulate and all-reduce gradients in fp32 for '
                      'bfloat16 data type.', flush=True)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
174

175
176
177
178
    if args.rank == 0:
        print('using {} for parameters ...'.format(args.params_dtype),
              flush=True)

179
180
    # If we do accumulation and all-reduces in fp32, we need to have local DDP
    # and we should make sure use-contiguous-buffers-in-local-ddp is not off.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
181
182
    if args.accumulate_allreduce_grads_in_fp32:
        assert args.DDP_impl == 'local'
183
        assert args.use_contiguous_buffers_in_local_ddp
Sangkug Lym's avatar
Sangkug Lym committed
184
185
186
187
188
189
190
191
    else:
        if args.gradient_accumulation_fusion:
            args.gradient_accumulation_fusion = False
            if args.rank == 0:
                print('Gradient accumulation fusion to linear layer weight '
                      'gradient computation is supported only with fp32 '
                      'gradient accumulation. Setting gradient_accumulation_fusion '
                      'to False', flush=True)
192

193
194
195
196
197
    # If we use the distributed optimizer, we need to have local DDP
    # and we should make sure use-contiguous-buffers-in-local-ddp is on.
    if args.use_distributed_optimizer:
        assert args.DDP_impl == 'local'
        assert args.use_contiguous_buffers_in_local_ddp
198

mshoeybi's avatar
mshoeybi committed
199
200
201
202
    # For torch DDP, we do not use contiguous buffer
    if args.DDP_impl == 'torch':
        args.use_contiguous_buffers_in_local_ddp = False

203
204
205
    if args.dataloader_type is None:
        args.dataloader_type = 'single'

206
207
208
    # Consumed tokens.
    args.consumed_train_samples = 0
    args.consumed_valid_samples = 0
209

210
211
212
213
214
215
216
217
218
    # Iteration-based training.
    if args.train_iters:
        # If we use iteration-based training, make sure the
        # sample-based options are off.
        assert args.train_samples is None, \
            'expected iteration-based training'
        assert args.lr_decay_samples is None, \
            'expected iteration-based learning rate decay'
        assert args.lr_warmup_samples == 0, \
219
            'expected iteration-based learning rate warmup'
220
221
        assert args.rampup_batch_size is None, \
            'expected no batch-size rampup for iteration-based training'
222
        if args.lr_warmup_fraction is not None:
223
            assert args.lr_warmup_iters == 0, \
224
                'can only specify one of lr-warmup-fraction and lr-warmup-iters'
225
226
227
228
229
230
231
232
233
234
235

    # Sample-based training.
    if args.train_samples:
        # If we use sample-based training, make sure the
        # iteration-based options are off.
        assert args.train_iters is None, \
            'expected sample-based training'
        assert args.lr_decay_iters is None, \
            'expected sample-based learning rate decay'
        assert args.lr_warmup_iters == 0, \
            'expected sample-based learnig rate warmup'
236
        if args.lr_warmup_fraction is not None:
237
            assert args.lr_warmup_samples == 0, \
238
239
                'can only specify one of lr-warmup-fraction ' \
                'and lr-warmup-samples'
240

241
    # Check required arguments.
Mohammad's avatar
Mohammad committed
242
243
    required_args = ['num_layers', 'hidden_size', 'num_attention_heads',
                     'max_position_embeddings']
244
    for req_arg in required_args:
Mohammad's avatar
Mohammad committed
245
        _check_arg_is_not_none(args, req_arg)
246

Mohammad's avatar
Mohammad committed
247
    # Checks.
248
249
250
251
252
253
254
255
256
257
258
259
260
    if args.ffn_hidden_size is None:
        args.ffn_hidden_size = 4 * args.hidden_size

    if args.kv_channels is None:
        assert args.hidden_size % args.num_attention_heads == 0
        args.kv_channels = args.hidden_size // args.num_attention_heads

    if args.seq_length is not None:
        assert args.encoder_seq_length is None
        args.encoder_seq_length = args.seq_length
    else:
        assert args.encoder_seq_length is not None
        args.seq_length = args.encoder_seq_length
261

Mohammad's avatar
Mohammad committed
262
263
    if args.seq_length is not None:
        assert args.max_position_embeddings >= args.seq_length
Jared Casper's avatar
Jared Casper committed
264
265
    if args.decoder_seq_length is not None:
        assert args.max_position_embeddings >= args.decoder_seq_length
Mohammad's avatar
Mohammad committed
266
267
    if args.lr is not None:
        assert args.min_lr <= args.lr
Mohammad's avatar
Mohammad committed
268
269
    if args.save is not None:
        assert args.save_interval is not None
mohammad's avatar
mohammad committed
270
271
272
    # Mixed precision checks.
    if args.fp16_lm_cross_entropy:
        assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.'
273
    if args.fp32_residual_connection:
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
274
275
        assert args.fp16 or args.bf16, \
            'residual connection in fp32 only supported when using fp16 or bf16.'
Vijay Korthikanti's avatar
Vijay Korthikanti committed
276

Vijay Korthikanti's avatar
Vijay Korthikanti committed
277
278
279
280
281
    if args.weight_decay_incr_style == 'constant':
        assert args.start_weight_decay is None
        assert args.end_weight_decay is None
        args.start_weight_decay = args.weight_decay
        args.end_weight_decay = args.weight_decay
Vijay Korthikanti's avatar
Vijay Korthikanti committed
282
    else:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
283
284
        assert args.start_weight_decay is not None
        assert args.end_weight_decay is not None
285

Sangkug Lym's avatar
Sangkug Lym committed
286
287
288
289
290
291
292
293
294
295
    TORCH_MAJOR = int(torch.__version__.split('.')[0])
    TORCH_MINOR = int(torch.__version__.split('.')[1])
    # Persistent fused layer norm.
    if TORCH_MAJOR < 1 or (TORCH_MAJOR == 1 and TORCH_MINOR < 11):
        args.no_persist_layer_norm = True
        if args.rank == 0:
            print('Persistent fused layer norm kernel is supported from '
                  'pytorch v1.11 (nvidia pytorch container paired with v1.11). '
                  'Defaulting to no_persist_layer_norm=True')

Vijay Korthikanti's avatar
Vijay Korthikanti committed
296
    # Activation recomputing.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
297
    if args.distribute_saved_activations:
mshoeybi's avatar
mshoeybi committed
298
        assert args.tensor_model_parallel_size > 1, 'can distribute ' \
Vijay Korthikanti's avatar
Vijay Korthikanti committed
299
            'recomputed activations only across tensor model ' \
mshoeybi's avatar
mshoeybi committed
300
            'parallel groups'
Vijay Korthikanti's avatar
Vijay Korthikanti committed
301
302
303
304
305
306
        assert args.recompute_granularity == 'full', \
            'distributed recompute activations is only '\
            'application to full recompute granularity'
        assert args.recompute_method is not None, \
            'for distributed recompute activations to work you '\
            'need to use a recompute method '
307
        assert TORCH_MAJOR >= 1 and TORCH_MINOR >= 10, \
Vijay Korthikanti's avatar
Vijay Korthikanti committed
308
            'distributed recompute activations are supported for pytorch ' \
309
310
            'v1.10 and above (Nvidia Pytorch container >= 21.07). Current ' \
            'pytorch version is v%s.%s.' % (TORCH_MAJOR, TORCH_MINOR)
Vijay Korthikanti's avatar
Vijay Korthikanti committed
311

Vijay Korthikanti's avatar
Vijay Korthikanti committed
312
313
314
315
    if args.recompute_granularity == 'selective':
        assert args.recompute_method is None, \
            'recompute method is not yet supported for ' \
            'selective recomputing granularity'
Vijay Korthikanti's avatar
Vijay Korthikanti committed
316
317
318
319
320
321
322

    # disable sequence parallelism when tp=1
    # to avoid change in numerics when
    # sequence_parallelism is enabled.
    if args.tensor_model_parallel_size == 1:
        args.sequence_parallel = False

Vijay Korthikanti's avatar
Vijay Korthikanti committed
323
    # disable async_tensor_model_parallel_allreduce when
Vijay Korthikanti's avatar
Vijay Korthikanti committed
324
    # model parallel memory optimization is enabled
Vijay Korthikanti's avatar
Vijay Korthikanti committed
325
326
    if args.sequence_parallel:
        args.async_tensor_model_parallel_allreduce = False
Vijay Korthikanti's avatar
Vijay Korthikanti committed
327

Mohammad's avatar
Mohammad committed
328
329
    _print_args(args)
    return args
Mohammad's avatar
Mohammad committed
330
331


Mohammad's avatar
Mohammad committed
332
333
334
def _print_args(args):
    """Print arguments."""
    if args.rank == 0:
mohammad's avatar
mohammad committed
335
336
        print('------------------------ arguments ------------------------',
              flush=True)
Mohammad's avatar
Mohammad committed
337
338
        str_list = []
        for arg in vars(args):
mohammad's avatar
mohammad committed
339
            dots = '.' * (48 - len(arg))
Mohammad's avatar
Mohammad committed
340
341
342
            str_list.append('  {} {} {}'.format(arg, dots, getattr(args, arg)))
        for arg in sorted(str_list, key=lambda x: x.lower()):
            print(arg, flush=True)
mohammad's avatar
mohammad committed
343
344
        print('-------------------- end of arguments ---------------------',
              flush=True)
Mohammad's avatar
Mohammad committed
345
346


347
348
349
350
def _check_arg_is_not_none(args, arg):
    assert getattr(args, arg) is not None, '{} argument is None'.format(arg)


mshoeybi's avatar
mshoeybi committed
351
352
353
354
355
356
357
358
359
360
361
362
def _add_inference_args(parser):
    group = parser.add_argument_group(title='inference')

    group.add_argument('--inference-batch-times-seqlen-threshold',
                       type=int, default=512,
                       help='During inference, if batch-size times '
                       'sequence-length is smaller than this threshold '
                       'then we will not use pipelining, otherwise we will.')

    return parser

    
Mohammad's avatar
Mohammad committed
363
def _add_network_size_args(parser):
Mohammad's avatar
Mohammad committed
364
    group = parser.add_argument_group(title='network size')
Mohammad's avatar
Mohammad committed
365

366
    group.add_argument('--num-layers', type=int, default=None,
Mohammad's avatar
Mohammad committed
367
                       help='Number of transformer layers.')
368
    group.add_argument('--hidden-size', type=int, default=None,
Mohammad's avatar
Mohammad committed
369
                       help='Tansformer hidden size.')
370
    group.add_argument('--ffn-hidden-size', type=int, default=None,
371
372
                       help='Transformer Feed-Forward Network hidden size. '
                       'This is set to 4*hidden-size if not provided')
373
    group.add_argument('--num-attention-heads', type=int, default=None,
Mohammad's avatar
Mohammad committed
374
                       help='Number of transformer attention heads.')
375
    group.add_argument('--kv-channels', type=int, default=None,
376
377
378
379
                       help='Projection weights dimension in multi-head '
                       'attention. This is set to '
                       '   args.hidden_size // args.num_attention_heads '
                       'if not provided.')
380
    group.add_argument('--max-position-embeddings', type=int, default=None,
Mohammad's avatar
Mohammad committed
381
382
383
384
385
                       help='Maximum number of position embeddings to use. '
                       'This is the size of position embedding.')
    group.add_argument('--make-vocab-size-divisible-by', type=int, default=128,
                       help='Pad the vocab size to be divisible by this value.'
                       'This is added for computational efficieny reasons.')
Mohammad's avatar
Mohammad committed
386
387
    group.add_argument('--layernorm-epsilon', type=float, default=1e-5,
                       help='Layer norm epsilon.')
Mohammad's avatar
Mohammad committed
388
389
390
391
    group.add_argument('--apply-residual-connection-post-layernorm',
                       action='store_true',
                       help='If set, use original BERT residula connection '
                       'ordering.')
392
393
394
395
    group.add_argument('--openai-gelu', action='store_true',
                       help='Use OpenAIs GeLU implementation. This option'
                       'should not be used unless for backward compatibility'
                       'reasons.')
396
    group.add_argument('--onnx-safe', type=bool, required=False,
397
398
                       help='Use workarounds for known problems with '
                       'Torch ONNX exporter')
399
400
401
    group.add_argument('--bert-no-binary-head', action='store_false',
                       help='Disable BERT binary head.',
                       dest='bert_binary_head')
rprenger's avatar
rprenger committed
402
403
    group.add_argument('--num-experts', type=int, default=None,
                       help='Number of Experts in Switch Transformer (None means no Switch)')
Mohammad's avatar
Mohammad committed
404
405
406
    return parser


407
408
409
410
411
def _add_logging_args(parser):
    group = parser.add_argument_group(title='logging')

    group.add_argument('--log-params-norm', action='store_true',
                       help='If set, calculate and log parameters norm.')
412
    group.add_argument('--log-num-zeros-in-grad', action='store_true',
Rewon Child's avatar
Rewon Child committed
413
                       help='If set, calculate and log the number of zeros in gradient.')
414
415
    group.add_argument('--tensorboard-log-interval', type=int, default=1,
                       help='Report to tensorboard interval.')
416
417
418
419
    group.add_argument('--tensorboard-queue-size', type=int, default=1000,
                       help='Size of the tensorboard queue for pending events '
                       'and summaries before one of the ‘add’ calls forces a '
                       'flush to disk.')
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
    group.add_argument('--log-timers-to-tensorboard', action='store_true',
                       help='If set, write timers to tensorboard.')
    group.add_argument('--log-batch-size-to-tensorboard', action='store_true',
                       help='If set, write batch-size to tensorboard.')
    group.add_argument('--no-log-learnig-rate-to-tensorboard',
                       action='store_false',
                       help='Disable learning rate logging to tensorboard.',
                       dest='log_learning_rate_to_tensorboard')
    group.add_argument('--no-log-loss-scale-to-tensorboard',
                       action='store_false',
                       help='Disable loss-scale logging to tensorboard.',
                       dest='log_loss_scale_to_tensorboard')
    group.add_argument('--log-validation-ppl-to-tensorboard',
                       action='store_true',
                       help='If set, write validation perplexity to '
                       'tensorboard.')
436
437
    group.add_argument('--log-memory-to-tensorboard',
                       action='store_true',
438
                       help='Enable memory logging to tensorboard.')
439
440
441
    group.add_argument('--log-world-size-to-tensorboard',
                       action='store_true',
                       help='Enable world size logging to tensorboard.')
442
443
444
445

    return parser


Mohammad's avatar
Mohammad committed
446
def _add_regularization_args(parser):
Mohammad's avatar
Mohammad committed
447
448
449
    group = parser.add_argument_group(title='regularization')

    group.add_argument('--attention-dropout', type=float, default=0.1,
450
                       help='Post attention dropout probability.')
Mohammad's avatar
Mohammad committed
451
452
453
454
    group.add_argument('--hidden-dropout', type=float, default=0.1,
                       help='Dropout probability for hidden state transformer.')
    group.add_argument('--weight-decay', type=float, default=0.01,
                       help='Weight decay coefficient for L2 regularization.')
Vijay Korthikanti's avatar
Vijay Korthikanti committed
455
    group.add_argument('--start-weight-decay', type=float,
456
                       help='Initial weight decay coefficient for L2 regularization.')
Vijay Korthikanti's avatar
Vijay Korthikanti committed
457
    group.add_argument('--end-weight-decay', type=float,
458
                       help='End of run weight decay coefficient for L2 regularization.')
Vijay Korthikanti's avatar
Vijay Korthikanti committed
459
    group.add_argument('--weight-decay-incr-style', type=str, default='constant',
460
461
                       choices=['constant', 'linear', 'cosine'],
                       help='Weight decay increment function.')
Mohammad's avatar
Mohammad committed
462
463
    group.add_argument('--clip-grad', type=float, default=1.0,
                       help='Gradient clipping based on global L2 norm.')
464
    group.add_argument('--adam-beta1', type=float, default=0.9,
465
466
                       help='First coefficient for computing running averages '
                       'of gradient and its square')
467
    group.add_argument('--adam-beta2', type=float, default=0.999,
468
469
                       help='Second coefficient for computing running averages '
                       'of gradient and its square')
470
    group.add_argument('--adam-eps', type=float, default=1e-08,
471
                       help='Term added to the denominator to improve'
472
                       'numerical stability')
473
474
    group.add_argument('--sgd-momentum', type=float, default=0.9,
                       help='Momentum factor for sgd')
Mohammad's avatar
Mohammad committed
475
476
477

    return parser

Mohammad's avatar
Mohammad committed
478
479

def _add_training_args(parser):
Mohammad's avatar
Mohammad committed
480
481
    group = parser.add_argument_group(title='training')

482
    group.add_argument('--micro-batch-size', type=int, default=None,
Mohammad's avatar
Mohammad committed
483
484
                       help='Batch size per model instance (local batch size). '
                       'Global batch size is local batch size times data '
mohammad's avatar
mohammad committed
485
                       'parallel size times number of micro batches.')
486
487
488
    group.add_argument('--batch-size', type=int, default=None,
                       help='Old batch size parameter, do not use. '
                       'Use --micro-batch-size instead')
mohammad's avatar
mohammad committed
489
    group.add_argument('--global-batch-size', type=int, default=None,
mohammad's avatar
mohammad committed
490
491
492
                       help='Training batch size. If set, it should be a '
                       'multiple of micro-batch-size times data-parallel-size. '
                       'If this value is None, then '
mohammad's avatar
mohammad committed
493
                       'use micro-batch-size * data-parallel-size as the '
mohammad's avatar
mohammad committed
494
495
                       'global batch size. This choice will result in 1 for '
                       'number of micro-batches.')
mohammad's avatar
mohammad committed
496
497
498
499
500
501
502
503
504
505
506
507
    group.add_argument('--rampup-batch-size', nargs='*', default=None,
                       help='Batch size ramp up with the following values:'
                       '  --rampup-batch-size <start batch size> '
                       '                      <batch size incerement> '
                       '                      <ramp-up samples> '
                       'For example:'
                       '   --rampup-batch-size 16 8 300000 \ '
                       '   --global-batch-size 1024'
                       'will start with global batch size 16 and over '
                       ' (1024 - 16) / 8 = 126 intervals will increase'
                       'the batch size linearly to 1024. In each interval'
                       'we will use approximately 300000 / 126 = 2380 samples.')
Vijay Korthikanti's avatar
Vijay Korthikanti committed
508
509
    group.add_argument('--recompute-activations', action='store_true',
                       help='recompute activation to allow for training '
Mohammad's avatar
Mohammad committed
510
                       'with larger models, sequences, and batch sizes.')
Vijay Korthikanti's avatar
Vijay Korthikanti committed
511
    group.add_argument('--recompute-granularity', type=str, default=None,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
512
                       choices=['full', 'selective'],
Vijay Korthikanti's avatar
Vijay Korthikanti committed
513
                       help='Checkpoint activations to allow for training '
Vijay Korthikanti's avatar
Vijay Korthikanti committed
514
515
                       'with larger models, sequences, and batch sizes. '
                       'It is supported at two granularities 1) full: '
Vijay Korthikanti's avatar
Vijay Korthikanti committed
516
                       'whole transformer layer is recomputed, '
Vijay Korthikanti's avatar
Vijay Korthikanti committed
517
                       '2) selective: core attention part of the transformer '
Vijay Korthikanti's avatar
Vijay Korthikanti committed
518
                       'layer is recomputed.')
Vijay Korthikanti's avatar
Vijay Korthikanti committed
519
    group.add_argument('--distribute-saved-activations',
520
                       action='store_true',
Vijay Korthikanti's avatar
Vijay Korthikanti committed
521
                       help='If set, distribute recomputed activations '
522
                       'across model parallel group.')
Vijay Korthikanti's avatar
Vijay Korthikanti committed
523
    group.add_argument('--recompute-method', type=str, default=None,
524
525
                       choices=['uniform', 'block'],
                       help='1) uniform: uniformly divide the total number of '
Vijay Korthikanti's avatar
Vijay Korthikanti committed
526
                       'Transformer layers and recompute the input activation of '
Vijay Korthikanti's avatar
Vijay Korthikanti committed
527
                       'each divided chunk at specified granularity, '
Vijay Korthikanti's avatar
Vijay Korthikanti committed
528
                       '2) recompute the input activations of only a set number of '
slym's avatar
slym committed
529
                       'individual Transformer layers per pipeline stage and do the '
Vijay Korthikanti's avatar
Vijay Korthikanti committed
530
531
532
                       'rest without any recomputing at specified granularity'
                       'default) do not apply activations recompute to any layers')
    group.add_argument('--recompute-num-layers', type=int, default=1,
533
                       help='1) uniform: the number of Transformer layers in each '
Vijay Korthikanti's avatar
Vijay Korthikanti committed
534
                       'uniformly divided recompute unit, '
535
                       '2) block: the number of individual Transformer layers '
Vijay Korthikanti's avatar
Vijay Korthikanti committed
536
                       'to recompute within each pipeline stage.')
Vijay Korthikanti's avatar
Vijay Korthikanti committed
537
538
539
540
541

    # deprecated
    group.add_argument('--checkpoint-activations', action='store_true',
                       help='Checkpoint activation to allow for training '
                       'with larger models, sequences, and batch sizes.')
Mohammad's avatar
Mohammad committed
542
    group.add_argument('--train-iters', type=int, default=None,
Mohammad's avatar
Mohammad committed
543
                       help='Total number of iterations to train over all '
544
545
546
547
548
549
                       'training runs. Note that either train-iters or '
                       'train-samples should be provided.')
    group.add_argument('--train-samples', type=int, default=None,
                       help='Total number of samples to train over all '
                       'training runs. Note that either train-iters or '
                       'train-samples should be provided.')
Mohammad's avatar
Mohammad committed
550
551
552
553
554
    group.add_argument('--log-interval', type=int, default=100,
                       help='Report loss and timing interval.')
    group.add_argument('--exit-interval', type=int, default=None,
                       help='Exit the program after the iteration is divisible '
                       'by this value.')
555
556
    group.add_argument('--exit-duration-in-mins', type=int, default=None,
                       help='Exit the program after this many minutes.')
557
558
559
    group.add_argument('--exit-signal-handler', action='store_true',
                       help='Dynamically save the checkpoint and shutdown the '
                       'training if SIGTERM is received')
Mohammad's avatar
Mohammad committed
560
561
    group.add_argument('--tensorboard-dir', type=str, default=None,
                       help='Write TensorBoard logs to this directory.')
562
    group.add_argument('--no-masked-softmax-fusion',
563
564
565
                       action='store_false',
                       help='Disable fusion of query_key_value scaling, '
                       'masking, and softmax.',
566
                       dest='masked_softmax_fusion')
567
568
569
570
571
572
    group.add_argument('--no-bias-gelu-fusion', action='store_false',
                       help='Disable bias and gelu fusion.',
                       dest='bias_gelu_fusion')
    group.add_argument('--no-bias-dropout-fusion', action='store_false',
                       help='Disable bias and dropout fusion.',
                       dest='bias_dropout_fusion')
573
574
575
    group.add_argument('--optimizer', type=str, default='adam',
                       choices=['adam', 'sgd'],
                       help='Optimizer function')
576
    group.add_argument('--dataloader-type', type=str, default=None,
Vijay Korthikanti's avatar
Vijay Korthikanti committed
577
578
                       choices=['single', 'cyclic'],
                       help='Single pass vs multiple pass data loader')
slym's avatar
slym committed
579
    group.add_argument('--no-async-tensor-model-parallel-allreduce',
Sangkug Lym's avatar
Sangkug Lym committed
580
                       action='store_false',
slym's avatar
slym committed
581
582
                       help='Disable asynchronous execution of '
                       'tensor-model-parallel all-reduce with weight '
Sangkug Lym's avatar
Sangkug Lym committed
583
584
                       'gradient compuation of a column-linear layer.',
                       dest='async_tensor_model_parallel_allreduce')
Sangkug Lym's avatar
Sangkug Lym committed
585
586
587
588
589
    group.add_argument('--no-persist-layer-norm', action='store_true',
                       help='Disable using persistent fused layer norm kernel. '
                       'This kernel supports only a set of hidden sizes. Please '
                       'check persist_ln_hidden_sizes if your hidden '
                       'size is supported.')
Vijay Korthikanti's avatar
Vijay Korthikanti committed
590
    group.add_argument('--sequence-parallel', action='store_true',
Vijay Korthikanti's avatar
Vijay Korthikanti committed
591
                       help='Enable sequence parallel optimization.')
Sangkug Lym's avatar
Sangkug Lym committed
592
593
    group.add_argument('--no-gradient-accumulation-fusion',
                       action='store_false',
594
                       help='Disable fusing gradient accumulation to weight '
Sangkug Lym's avatar
Sangkug Lym committed
595
596
                       'gradient computation of linear layers',
                       dest='gradient_accumulation_fusion')
Mohammad's avatar
Mohammad committed
597
598
599
    return parser


Mohammad's avatar
Mohammad committed
600
def _add_initialization_args(parser):
Mohammad's avatar
Mohammad committed
601
602
603
604
605
    group = parser.add_argument_group(title='initialization')

    group.add_argument('--seed', type=int, default=1234,
                       help='Random seed used for python, numpy, '
                       'pytorch, and cuda.')
606
607
608
    group.add_argument('--data-parallel-random-init', action='store_true',
                       help='Enable random initialization of params '
                       'across data parallel ranks')
Mohammad's avatar
Mohammad committed
609
610
611
    group.add_argument('--init-method-std', type=float, default=0.02,
                       help='Standard deviation of the zero mean normal '
                       'distribution used for weight initialization.')
612
613
    group.add_argument('--init-method-xavier-uniform', action='store_true',
                       help='Enable Xavier uniform parameter initialization')
Mohammad's avatar
Mohammad committed
614

Mohammad's avatar
Mohammad committed
615
616
617
    return parser


Mohammad's avatar
Mohammad committed
618
def _add_learning_rate_args(parser):
Mohammad's avatar
Mohammad committed
619
620
    group = parser.add_argument_group(title='learning rate')

Mohammad's avatar
Mohammad committed
621
    group.add_argument('--lr', type=float, default=None,
Mohammad's avatar
Mohammad committed
622
623
624
625
                       help='Initial learning rate. Depending on decay style '
                       'and initial warmup, the learing rate at each '
                       'iteration would be different.')
    group.add_argument('--lr-decay-style', type=str, default='linear',
mohammad's avatar
mohammad committed
626
                       choices=['constant', 'linear', 'cosine'],
Mohammad's avatar
Mohammad committed
627
628
629
630
                       help='Learning rate decay function.')
    group.add_argument('--lr-decay-iters', type=int, default=None,
                       help='number of iterations to decay learning rate over,'
                       ' If None defaults to `--train-iters`')
631
632
633
    group.add_argument('--lr-decay-samples', type=int, default=None,
                       help='number of samples to decay learning rate over,'
                       ' If None defaults to `--train-samples`')
634
635
636
    group.add_argument('--lr-warmup-fraction', type=float, default=None,
                       help='fraction of lr-warmup-(iters/samples) to use '
                       'for warmup (as a float)')
637
638
639
640
641
642
    group.add_argument('--lr-warmup-iters', type=int, default=0,
                       help='number of iterations to linearly warmup '
                       'learning rate over.')
    group.add_argument('--lr-warmup-samples', type=int, default=0,
                       help='number of samples to linearly warmup '
                       'learning rate over.')
643
    group.add_argument('--warmup', type=int, default=None,
644
                       help='Old lr warmup argument, do not use. Use one of the'
645
                       '--lr-warmup-* arguments above')
Mohammad's avatar
Mohammad committed
646
647
648
    group.add_argument('--min-lr', type=float, default=0.0,
                       help='Minumum value for learning rate. The scheduler'
                       'clip values below this threshold.')
649
    group.add_argument('--override-opt_param-scheduler', action='store_true',
Mohammad's avatar
Mohammad committed
650
651
652
653
654
                       help='Reset the values of the scheduler (learning rate,'
                       'warmup iterations, minimum learning rate, maximum '
                       'number of iterations, and decay style from input '
                       'arguments and ignore values from checkpoints. Note'
                       'that all the above values will be reset.')
655
    group.add_argument('--use-checkpoint-opt_param-scheduler', action='store_true',
Mohammad's avatar
Mohammad committed
656
657
658
659
660
661
662
663
                       help='Use checkpoint to set the values of the scheduler '
                       '(learning rate, warmup iterations, minimum learning '
                       'rate, maximum number of iterations, and decay style '
                       'from checkpoint and ignore input arguments.')

    return parser


Mohammad's avatar
Mohammad committed
664
def _add_checkpointing_args(parser):
Mohammad's avatar
Mohammad committed
665
666
667
668
669
670
    group = parser.add_argument_group(title='checkpointing')

    group.add_argument('--save', type=str, default=None,
                       help='Output directory to save checkpoints to.')
    group.add_argument('--save-interval', type=int, default=None,
                       help='Number of iterations between checkpoint saves.')
671
    group.add_argument('--no-save-optim', action='store_true', default=None,
Mohammad's avatar
Mohammad committed
672
                       help='Do not save current optimizer.')
673
    group.add_argument('--no-save-rng', action='store_true', default=None,
Mohammad's avatar
Mohammad committed
674
675
676
                       help='Do not save current rng state.')
    group.add_argument('--load', type=str, default=None,
                       help='Directory containing a model checkpoint.')
Jared Casper's avatar
Jared Casper committed
677
    group.add_argument('--no-load-optim', action='store_true', default=None,
Mohammad's avatar
Mohammad committed
678
                       help='Do not load optimizer when loading checkpoint.')
Jared Casper's avatar
Jared Casper committed
679
    group.add_argument('--no-load-rng', action='store_true', default=None,
Mohammad's avatar
Mohammad committed
680
681
682
683
684
                       help='Do not load rng state when loading checkpoint.')
    group.add_argument('--finetune', action='store_true',
                       help='Load model for finetuning. Do not load optimizer '
                       'or rng state from checkpoint and set iteration to 0. '
                       'Assumed when loading a release checkpoint.')
685
686
687
688
689
    group.add_argument('--no-initialization', action='store_false',
                       help='Do not perform initialization when building model, '
                       'can reduce startup time when definitely loading from a '
                       'checkpoint',
                       dest='perform_initialization')
690
691
692
    group.add_argument('--use-checkpoint-args', action='store_true',
                       help='Override any command line arguments with arguments '
                       'from the checkpoint')
Mohammad's avatar
Mohammad committed
693
694
695
696

    return parser


Mohammad's avatar
Mohammad committed
697
def _add_mixed_precision_args(parser):
Mohammad's avatar
Mohammad committed
698
699
700
701
    group = parser.add_argument_group(title='mixed precision')

    group.add_argument('--fp16', action='store_true',
                       help='Run model in fp16 mode.')
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
702
703
    group.add_argument('--bf16', action='store_true',
                       help='Run model in bfloat16 mode.')
mohammad's avatar
mohammad committed
704
705
706
707
708
709
710
711
712
713
714
715
    group.add_argument('--loss-scale', type=float, default=None,
                       help='Static loss scaling, positive power of 2 '
                       'values can improve fp16 convergence. If None, dynamic'
                       'loss scaling is used.')
    group.add_argument('--initial-loss-scale', type=float, default=2**32,
                       help='Initial loss-scale for dynamic loss scaling.')
    group.add_argument('--min-loss-scale', type=float, default=1.0,
                       help='Minimum loss scale for dynamic loss scale.')
    group.add_argument('--loss-scale-window', type=float, default=1000,
                       help='Window over which to raise/lower dynamic scale.')
    group.add_argument('--hysteresis', type=int, default=2,
                       help='hysteresis for dynamic loss scaling')
716
717
    group.add_argument('--fp32-residual-connection', action='store_true',
                       help='Move residual connections to fp32.')
718
719
720
    group.add_argument('--no-query-key-layer-scaling', action='store_false',
                       help='Do not scale Q * K^T by 1 / layer-number.',
                       dest='apply_query_key_layer_scaling')
Mohammad's avatar
Mohammad committed
721
    group.add_argument('--attention-softmax-in-fp32', action='store_true',
722
723
724
                       help='Run attention masking and softmax in fp32. '
                       'This flag is ignored unless '
                       '--no-query-key-layer-scaling is specified.')
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
725
726
727
    group.add_argument('--accumulate-allreduce-grads-in-fp32',
                       action='store_true',
                       help='Gradient accumulation and all-reduce in fp32.')
728
729
730
731
    group.add_argument('--fp16-lm-cross-entropy', action='store_true',
                       help='Move the cross entropy unreduced loss calculation'
                       'for lm head to fp16.')

Mohammad's avatar
Mohammad committed
732
733
734
    return parser


Mohammad's avatar
Mohammad committed
735
def _add_distributed_args(parser):
736
737
    group = parser.add_argument_group(title='distributed')

738
739
740
741
    group.add_argument('--tensor-model-parallel-size', type=int, default=1,
                       help='Degree of tensor model parallelism.')
    group.add_argument('--pipeline-model-parallel-size', type=int, default=1,
                       help='Degree of pipeline model parallelism.')
742
743
744
    group.add_argument('--pipeline-model-parallel-split-rank',
                       type=int, default=None,
                       help='Rank where encoder and decoder should be split.')
745
746
747
    group.add_argument('--model-parallel-size', type=int, default=None,
                       help='Old model parallel argument, do not use. Use '
                       '--tensor-model-parallel-size instead.')
748
749
    group.add_argument('--num-layers-per-virtual-pipeline-stage', type=int, default=None,
                       help='Number of layers per virtual pipeline stage')
Mohammad's avatar
Mohammad committed
750
751
752
753
    group.add_argument('--distributed-backend', default='nccl',
                       choices=['nccl', 'gloo'],
                       help='Which backend to use for distributed training.')
    group.add_argument('--DDP-impl', default='local',
Mohammad's avatar
Mohammad committed
754
                       choices=['local', 'torch'],
Mohammad's avatar
Mohammad committed
755
756
                       help='which DistributedDataParallel implementation '
                       'to use.')
757
758
759
760
    group.add_argument('--no-contiguous-buffers-in-local-ddp',
                       action='store_false', help='If set, dont use '
                       'contiguous buffer in local DDP.',
                       dest='use_contiguous_buffers_in_local_ddp')
761
762
763
    group.add_argument('--no-scatter-gather-tensors-in-pipeline', action='store_false',
                       help='Use scatter/gather to optimize communication of tensors in pipeline',
                       dest='scatter_gather_tensors_in_pipeline')
764
765
766
767
    group.add_argument('--use-ring-exchange-p2p', action='store_true',
                       default=False, help='If set, use custom-built ring exchange '
                       'for p2p communications. Note that this option will require '
                       'a custom built image that support ring-exchange p2p.')
Mohammad's avatar
Mohammad committed
768
769
    group.add_argument('--local_rank', type=int, default=None,
                       help='local rank passed from distributed launcher.')
770
    group.add_argument('--lazy-mpu-init', type=bool, required=False,
771
772
773
774
775
776
777
778
                       help='If set to True, initialize_megatron() '
                       'skips DDP initialization and returns function to '
                       'complete it instead.Also turns on '
                       '--use-cpu-initialization flag. This is for '
                       'external DDP manager.' )
    group.add_argument('--use-cpu-initialization', action='store_true',
                       default=None, help='If set, affine parallel weights '
                       'initialization uses CPU' )
Lawrence McAfee's avatar
Lawrence McAfee committed
779
    group.add_argument('--empty-unused-memory-level', default=0, type=int,
780
781
782
783
                       choices=[0, 1, 2],
                       help='Call torch.cuda.empty_cache() each iteration '
                       '(training and eval), to reduce fragmentation.'
                       '0=off, 1=moderate, 2=aggressive.')
784
    group.add_argument('--standalone-embedding-stage', action='store_true',
Lawrence McAfee's avatar
Lawrence McAfee committed
785
786
                       default=False, help='If set, *input* embedding layer '
                       'is placed on its own pipeline stage, without any '
Lawrence McAfee's avatar
Lawrence McAfee committed
787
788
                       'transformer layers. (For T5, this flag currently only '
                       'affects the encoder embedding.)')
789
790
    group.add_argument('--use-distributed-optimizer', action='store_true',
                       help='Use distributed optimizer.')
791

Mohammad's avatar
Mohammad committed
792
793
794
    return parser


Mohammad's avatar
Mohammad committed
795
def _add_validation_args(parser):
Mohammad's avatar
Mohammad committed
796
797
798
799
800
801
802
803
804
    group = parser.add_argument_group(title='validation')

    group.add_argument('--eval-iters', type=int, default=100,
                       help='Number of iterations to run for evaluation'
                       'validation/test for.')
    group.add_argument('--eval-interval', type=int, default=1000,
                       help='Interval between running evaluation on '
                       'validation set.')

Mohammad's avatar
Mohammad committed
805
806
807
    return parser


Mohammad's avatar
Mohammad committed
808
def _add_data_args(parser):
Mohammad's avatar
Mohammad committed
809
810
    group = parser.add_argument_group(title='data and dataloader')

mohammad's avatar
mohammad committed
811
    group.add_argument('--data-path', nargs='*', default=None,
mohammad's avatar
mohammad committed
812
813
814
815
                       help='Path to the training dataset. Accepted format:'
                       '1) a single data path, 2) multiple datasets in the'
                       'form: dataset1-weight dataset1-path dataset2-weight '
                       'dataset2-path ...')
Mohammad's avatar
Mohammad committed
816
    group.add_argument('--split', type=str, default='969, 30, 1',
Mohammad's avatar
Mohammad committed
817
818
                       help='Comma-separated list of proportions for training,'
                       ' validation, and test split. For example the split '
819
820
                       '`90,5,5` will use 90%% of data for training, 5%% for '
                       'validation and 5%% for test.')
Mohammad's avatar
Mohammad committed
821
    group.add_argument('--vocab-file', type=str, default=None,
Mohammad's avatar
Mohammad committed
822
                       help='Path to the vocab file.')
Mohammad's avatar
Mohammad committed
823
824
    group.add_argument('--merge-file', type=str, default=None,
                       help='Path to the BPE merge file.')
825
826
827
    group.add_argument('--vocab-extra-ids', type=int, default=0,
                       help='Number of additional vocabulary tokens. '
                            'They are used for span masking in the T5 model')
Mohammad's avatar
Mohammad committed
828
    group.add_argument('--seq-length', type=int, default=None,
829
                       help='Maximum sequence length to process.')
830
    group.add_argument('--encoder-seq-length', type=int, default=None,
831
832
                       help='Maximum encoder sequence length to process.'
                       'This should be exclusive of --seq-length')
833
834
    group.add_argument('--decoder-seq-length', type=int, default=None,
                       help="Maximum decoder sequence length to process.")
Mostofa Patwary's avatar
Mostofa Patwary committed
835
836
    group.add_argument('--retriever-seq-length', type=int, default=256,
                       help='Maximum sequence length for the biencoder model '
Mostofa Patwary's avatar
Mostofa Patwary committed
837
                        ' for retriever')
838
839
840
    group.add_argument('--sample-rate', type=float, default=1.0,
                       help='sample rate for training data. Supposed to be 0 '
                            ' < sample_rate < 1')
Mohammad's avatar
Mohammad committed
841
842
843
844
845
846
847
848
    group.add_argument('--mask-prob', type=float, default=0.15,
                       help='Probability of replacing a token with mask.')
    group.add_argument('--short-seq-prob', type=float, default=0.1,
                       help='Probability of producing a short sequence.')
    group.add_argument('--mmap-warmup', action='store_true',
                       help='Warm up mmap files.')
    group.add_argument('--num-workers', type=int, default=2,
                       help="Dataloader number of workers.")
Mohammad's avatar
Mohammad committed
849
850
851
    group.add_argument('--tokenizer-type', type=str,
                       default=None,
                       choices=['BertWordPieceLowerCase',
Raul Puri's avatar
Raul Puri committed
852
                                'BertWordPieceCase',
Mohammad's avatar
Mohammad committed
853
854
                                'GPT2BPETokenizer'],
                       help='What type of tokenizer to use.')
855
856
857
858
859
860
861
862
863
864
    group.add_argument('--data-impl', type=str, default='infer',
                       choices=['lazy', 'cached', 'mmap', 'infer'],
                       help='Implementation of indexed datasets.')
    group.add_argument('--reset-position-ids', action='store_true',
                       help='Reset posistion ids after end-of-document token.')
    group.add_argument('--reset-attention-mask', action='store_true',
                       help='Reset self attention maske after '
                       'end-of-document token.')
    group.add_argument('--eod-mask-loss', action='store_true',
                       help='Mask loss for the end of document tokens.')
Mohammad's avatar
Mohammad committed
865

Mohammad's avatar
Mohammad committed
866
867
    return parser

Raul Puri's avatar
Raul Puri committed
868

Mohammad's avatar
Mohammad committed
869
870
def _add_autoresume_args(parser):
    group = parser.add_argument_group(title='autoresume')
Raul Puri's avatar
Raul Puri committed
871

Mohammad's avatar
Mohammad committed
872
873
874
875
876
    group.add_argument('--adlr-autoresume', action='store_true',
                       help='Enable autoresume on adlr cluster.')
    group.add_argument('--adlr-autoresume-interval', type=int, default=1000,
                       help='Intervals over which check for autoresume'
                       'termination signal')
Raul Puri's avatar
Raul Puri committed
877

Mohammad's avatar
Mohammad committed
878
    return parser
Neel Kant's avatar
Neel Kant committed
879
880


Mostofa Patwary's avatar
Mostofa Patwary committed
881
882
def _add_biencoder_args(parser):
    group = parser.add_argument_group(title='biencoder')
Neel Kant's avatar
Neel Kant committed
883
884
885

    # network size
    group.add_argument('--ict-head-size', type=int, default=None,
886
                       help='Size of block embeddings to be used in ICT and '
Mostofa Patwary's avatar
Mostofa Patwary committed
887
                        'REALM (paper default: 128)')
888
    group.add_argument('--biencoder-projection-dim', type=int, default=0,
Mostofa Patwary's avatar
Mostofa Patwary committed
889
890
                       help='Size of projection head used in biencoder (paper'
                        ' default: 128)')
891
    group.add_argument('--biencoder-shared-query-context-model', action='store_true',
Mostofa Patwary's avatar
Mostofa Patwary committed
892
893
                        help='Whether to share the parameters of the query '
                        'and context models or not')
Neel Kant's avatar
Neel Kant committed
894
895
896
897
898

    # checkpointing
    group.add_argument('--ict-load', type=str, default=None,
                       help='Directory containing an ICTBertModel checkpoint')
    group.add_argument('--bert-load', type=str, default=None,
899
900
                       help='Directory containing an BertModel checkpoint '
                       '(needed to start ICT and REALM)')
Neel Kant's avatar
Neel Kant committed
901
902
903
904
905

    # data
    group.add_argument('--titles-data-path', type=str, default=None,
                       help='Path to titles dataset used for ICT')
    group.add_argument('--query-in-block-prob', type=float, default=0.1,
906
907
                       help='Probability of keeping query in block for '
                       'ICT dataset')
Neel Kant's avatar
Neel Kant committed
908
    group.add_argument('--use-one-sent-docs', action='store_true',
Neel Kant's avatar
Neel Kant committed
909
                       help='Whether to use one sentence documents in ICT')
910
911
    group.add_argument('--evidence-data-path', type=str, default=None,
                       help='Path to Wikipedia Evidence frm DPR paper')
Neel Kant's avatar
Neel Kant committed
912

913
    # training
914
    group.add_argument('--retriever-report-topk-accuracies', nargs='+', type=int,
Mostofa Patwary's avatar
Mostofa Patwary committed
915
916
                        default=[], help="Which top-k accuracies to report "
                        "(e.g. '1 5 20')")
Mostofa Patwary's avatar
Mostofa Patwary committed
917
    group.add_argument('--retriever-score-scaling', action='store_true',
Mostofa Patwary's avatar
Mostofa Patwary committed
918
919
                       help='Whether to scale retriever scores by inverse '
                        'square root of hidden size')
920

Neel Kant's avatar
Neel Kant committed
921
    # faiss index
Neel Kant's avatar
Neel Kant committed
922
    group.add_argument('--block-data-path', type=str, default=None,
Neel Kant's avatar
Neel Kant committed
923
                       help='Where to save/load BlockData to/from')
Mostofa Patwary's avatar
Mostofa Patwary committed
924
925
926
    group.add_argument('--embedding-path', type=str, default=None,
                       help='Where to save/load Open-Retrieval Embedding'
                        ' data to/from')
Neel Kant's avatar
Neel Kant committed
927
928
929

    # indexer
    group.add_argument('--indexer-batch-size', type=int, default=128,
930
931
                       help='How large of batches to use when doing indexing '
                       'jobs')
Neel Kant's avatar
Neel Kant committed
932
    group.add_argument('--indexer-log-interval', type=int, default=1000,
933
934
                       help='After how many batches should the indexer '
                       'report progress')
Neel Kant's avatar
Neel Kant committed
935
    return parser
936
937


938
939
def _add_vision_args(parser):
    group = parser.add_argument_group(title="vision")
940

941
    # general vision arguements
942
943
    group.add_argument('--num-classes', type=int, default=1000,
                       help='num of classes in vision classificaiton task')
944
945
946
947
    group.add_argument('--img-h', type=int, default=224,
                       help='Image height for vision classification task')
    group.add_argument('--img-w', type=int, default=224,
                       help='Image height for vision classification task')
948
949
950
    group.add_argument('--num-channels', type=int, default=3,
                       help='Number of channels in input image data')
    group.add_argument('--patch-dim', type=int, default=16,
951
                       help='patch dimension')
952
953
954
955
956
957
958
    group.add_argument('--classes-fraction', type=float, default=1.0,
                       help='training with fraction of classes.')
    group.add_argument('--data-per-class-fraction', type=float, default=1.0,
                       help='training with fraction of data per class.')
    group.add_argument('--no-data-sharding', action='store_false',
                       help='Disable data sharding.',
                       dest='data_sharding')
959
960
961
962
    group.add_argument('--head-lr-mult', type=float, default=1.0,
                       help='learning rate multiplier for head during finetuning')

    # pretraining type and backbone selection`
Vijay Korthikanti's avatar
Vijay Korthikanti committed
963
964
    group.add_argument('--vision-pretraining', action='store_true',
                       help='flag to indicate vision pretraining')
965
    group.add_argument('--vision-pretraining-type', type=str, default='classify',
Vijay Korthikanti's avatar
Vijay Korthikanti committed
966
                       choices=['classify', 'inpaint', 'dino'],
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
                       help='pretraining objectives')
    group.add_argument('--vision-backbone-type', type=str, default='vit',
                       choices=['vit', 'mit', 'swin'],
                       help='backbone types types')
    group.add_argument('--swin-backbone-type', type=str, default='tiny',
                       choices=['tiny', 'base', 'h3'],
                       help='pretraining objectives')
    
    # inpainting arguments
    group.add_argument('--mask-type', type=str, default='random',
                       choices=['random', 'row'],
                       help='mask types')
    group.add_argument('--mask-factor', type=float, default=1.0,
                       help='mask size scaling parameter')
 
    # dino arguments
    group.add_argument('--iter-per-epoch', type=int, default=1250,
                       help='iterations per epoch')
    group.add_argument('--dino-local-img-size', type=int, default=96,
                       help='Image size for vision classification task')
    group.add_argument('--dino-local-crops-number', type=int, default=10,
                       help='Number of local crops')
    group.add_argument('--dino-head-hidden-size', type=int, default=2048,
                       help='Hidden dimension size in dino head')
    group.add_argument('--dino-bottleneck-size', type=int, default=256,
                       help='Bottle neck dimension in dino head ')
    group.add_argument('--dino-freeze-last-layer', type=float, default=1,
                       help='Freezing last layer weights')
    group.add_argument('--dino-norm-last-layer', action='store_true',
                       help='Disable Norm in last layer.')
    group.add_argument('--dino-warmup-teacher-temp', type=float, default=0.04,
                       help='warump teacher temperature')
    group.add_argument('--dino-teacher-temp', type=float, default=0.07,
                       help='teacher temperature')
    group.add_argument('--dino-warmup-teacher-temp-epochs', type=int, default=30,
                       help='warmup teacher temperaure epochs')
1003
1004

    return parser