arguments.py 27.3 KB
Newer Older
Raul Puri's avatar
Raul Puri committed
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
Raul Puri's avatar
Raul Puri committed
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Mohammad's avatar
Mohammad committed
16
"""Megatron arguments."""
Raul Puri's avatar
Raul Puri committed
17
18
19
20

import argparse
import os

21
import torch
22
from megatron import fused_kernels
Raul Puri's avatar
Raul Puri committed
23

24
25
def parse_args(extra_args_provider=None, defaults={},
               ignore_unknown_args=False):
Mohammad's avatar
Mohammad committed
26
    """Parse all arguments."""
27
28
    parser = argparse.ArgumentParser(description='Megatron-LM Arguments',
                                     allow_abbrev=False)
Mohammad's avatar
Mohammad committed
29

Mohammad's avatar
Mohammad committed
30
31
32
33
34
35
36
37
38
39
40
41
    # Standard arguments.
    parser = _add_network_size_args(parser)
    parser = _add_regularization_args(parser)
    parser = _add_training_args(parser)
    parser = _add_initialization_args(parser)
    parser = _add_learning_rate_args(parser)
    parser = _add_checkpointing_args(parser)
    parser = _add_mixed_precision_args(parser)
    parser = _add_distributed_args(parser)
    parser = _add_validation_args(parser)
    parser = _add_data_args(parser)
    parser = _add_autoresume_args(parser)
Neel Kant's avatar
Neel Kant committed
42
    parser = _add_realm_args(parser)
Mohammad's avatar
Mohammad committed
43
44
45
46

    # Custom arguments.
    if extra_args_provider is not None:
        parser = extra_args_provider(parser)
Mohammad's avatar
Mohammad committed
47

Mohammad's avatar
Mohammad committed
48
    # Parse.
49
50
51
52
    if ignore_unknown_args:
        args, _ = parser.parse_known_args()
    else:
        args = parser.parse_args()
Mohammad's avatar
Mohammad committed
53

Mohammad's avatar
Mohammad committed
54
55
56
    # Distributed args.
    args.rank = int(os.getenv('RANK', '0'))
    args.world_size = int(os.getenv("WORLD_SIZE", '1'))
mohammad's avatar
mohammad committed
57
    # Tensor model parallel size.
58
59
    args.tensor_model_parallel_size = min(
        args.tensor_model_parallel_size, args.world_size)
mohammad's avatar
mohammad committed
60
61
62
63
    assert args.world_size % args.tensor_model_parallel_size == 0, 'world size'\
        ' ({}) is not divisible by tensor model parallel size ({})'.format(
            args.world_size, args.tensor_model_parallel_size)
    # Pipeline model parallel size.
64
65
66
    args.pipeline_model_parallel_size = min(
        args.pipeline_model_parallel_size,
        (args.world_size // args.tensor_model_parallel_size))
67
68
    if args.pipeline_model_parallel_size > 1:
        if "ring_exchange" not in dir(torch.distributed):
mohammad's avatar
mohammad committed
69
70
71
            raise Exception('PyTorch with torch.distributed.ring_exchange '
                            'needed to run pipeline MP!')
    # Checks.
72
73
74
    model_parallel_size = args.pipeline_model_parallel_size * \
                          args.tensor_model_parallel_size
    assert args.world_size % model_parallel_size == 0, 'world size is not'\
mohammad's avatar
mohammad committed
75
76
77
        ' divisible by tensor parallel size ({}) times pipeline paralle ' \
        'size ({})'.format(args.world_size, args.tensor_model_parallel_size,
                           args.pipeline_model_parallel_size)
78
    args.data_parallel_size = args.world_size // model_parallel_size
Mohammad's avatar
Mohammad committed
79
    if args.rank == 0:
mohammad's avatar
mohammad committed
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
        print('using world size: {}, data-parallel-size: {}, '
              'tensor-model-parallel size: {}, '
              'pipeline-model-parallel size: {} '.format(
                  args.world_size, args.data_parallel_size,
                  args.tensor_model_parallel_size,
                  args.pipeline_model_parallel_size), flush=True)

    # Batch size.
    assert args.micro_batch_size is not None
    assert args.micro_batch_size > 0
    if args.global_batch_size is None:
        args.global_batch_size = args.micro_batch_size * args.data_parallel_size
        if args.rank == 0:
            print('setting global batch size to {}'.format(
                args.global_batch_size), flush=True)
    assert args.global_batch_size > 0
Mohammad's avatar
Mohammad committed
96

Mohammad's avatar
Mohammad committed
97
98
99
100
    # Fp16 loss scaling.
    args.dynamic_loss_scale = False
    if args.loss_scale is None:
        args.dynamic_loss_scale = True
Mohammad's avatar
Mohammad committed
101

102
103
104
105
106
107
108
109
    # Parameters dtype.
    args.params_dtype = torch.float
    if args.fp16:
        args.params_dtype = torch.half
    if args.rank == 0:
        print('using {} for parameters ...'.format(args.params_dtype),
              flush=True)

110
111
112
    # Consumed tokens.
    args.consumed_train_samples = 0
    args.consumed_valid_samples = 0
113

Mohammad's avatar
Mohammad committed
114
115
    # Set input defaults.
    for key in defaults:
Mohammad's avatar
Mohammad committed
116
117
118
        # For default to be valid, it should not be provided in the
        # arguments that are passed to the program. We check this by
        # ensuring the arg is set to None.
Raul Puri's avatar
Raul Puri committed
119
        if getattr(args, key) is not None:
Raul Puri's avatar
Raul Puri committed
120
            if args.rank == 0:
Raul Puri's avatar
Raul Puri committed
121
122
                print('WARNING: overriding default arguments for {key}:{v} \
                       with {key}:{v2}'.format(key=key, v=defaults[key],
Raul Puri's avatar
Raul Puri committed
123
124
                                               v2=getattr(args, key)),
                                               flush=True)
Raul Puri's avatar
Raul Puri committed
125
126
        else:
            setattr(args, key, defaults[key])
Mohammad's avatar
Mohammad committed
127

128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
    # Iteration-based training.
    if args.train_iters:
        # If we use iteration-based training, make sure the
        # sample-based options are off.
        assert args.train_samples is None, \
            'expected iteration-based training'
        assert args.lr_decay_samples is None, \
            'expected iteration-based learning rate decay'
        assert args.lr_warmup_samples == 0, \
            'expected iteration-based learnig rate warmup'
        assert args.rampup_batch_size is None, \
            'expected no batch-size rampup for iteration-based training'

    # Sample-based training.
    if args.train_samples:
        # If we use sample-based training, make sure the
        # iteration-based options are off.
        assert args.train_iters is None, \
            'expected sample-based training'
        assert args.lr_decay_iters is None, \
            'expected sample-based learning rate decay'
        assert args.lr_warmup_iters == 0, \
            'expected sample-based learnig rate warmup'

152
    # Check required arguments.
Mohammad's avatar
Mohammad committed
153
154
155
156
    required_args = ['num_layers', 'hidden_size', 'num_attention_heads',
                     'max_position_embeddings']
    for req_arg in required_args: 
        _check_arg_is_not_none(args, req_arg)
157

Mohammad's avatar
Mohammad committed
158
159
    # Checks.
    assert args.hidden_size % args.num_attention_heads == 0
Mohammad's avatar
Mohammad committed
160
161
162
163
    if args.seq_length is not None:
        assert args.max_position_embeddings >= args.seq_length
    if args.lr is not None:
        assert args.min_lr <= args.lr
Mohammad's avatar
Mohammad committed
164
165
    if args.save is not None:
        assert args.save_interval is not None
mohammad's avatar
mohammad committed
166
167
168
    # Mixed precision checks.
    if args.fp16_lm_cross_entropy:
        assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.'
mohammad's avatar
mohammad committed
169
170
171
172
173
    # Activation checkpointing.
    if args.distribute_checkpointed_activations:
        assert args.checkpoint_activations, \
            'for distribute-checkpointed-activations to work you '\
            'need to enable checkpoint-activations'
Mohammad's avatar
Mohammad committed
174

175
176
177
178
    # load scaled_upper_triang_masked_softmax_fusion kernel
    if args.scaled_upper_triang_masked_softmax_fusion:
        fused_kernels.load_scaled_upper_triang_masked_softmax_fusion_kernel()

179
180
181
182
    # load scaled_masked_softmax_fusion kernel
    if args.scaled_masked_softmax_fusion:
        fused_kernels.load_scaled_masked_softmax_fusion_kernel()

Mohammad's avatar
Mohammad committed
183
184
    _print_args(args)
    return args
Mohammad's avatar
Mohammad committed
185
186


Mohammad's avatar
Mohammad committed
187
188
189
def _print_args(args):
    """Print arguments."""
    if args.rank == 0:
mohammad's avatar
mohammad committed
190
191
        print('------------------------ arguments ------------------------',
              flush=True)
Mohammad's avatar
Mohammad committed
192
193
        str_list = []
        for arg in vars(args):
mohammad's avatar
mohammad committed
194
            dots = '.' * (48 - len(arg))
Mohammad's avatar
Mohammad committed
195
196
197
            str_list.append('  {} {} {}'.format(arg, dots, getattr(args, arg)))
        for arg in sorted(str_list, key=lambda x: x.lower()):
            print(arg, flush=True)
mohammad's avatar
mohammad committed
198
199
        print('-------------------- end of arguments ---------------------',
              flush=True)
Mohammad's avatar
Mohammad committed
200
201


202
203
204
205
def _check_arg_is_not_none(args, arg):
    assert getattr(args, arg) is not None, '{} argument is None'.format(arg)


Mohammad's avatar
Mohammad committed
206
def _add_network_size_args(parser):
Mohammad's avatar
Mohammad committed
207
    group = parser.add_argument_group(title='network size')
Mohammad's avatar
Mohammad committed
208

209
    group.add_argument('--num-layers', type=int, default=None,
Mohammad's avatar
Mohammad committed
210
                       help='Number of transformer layers.')
211
    group.add_argument('--hidden-size', type=int, default=None,
Mohammad's avatar
Mohammad committed
212
                       help='Tansformer hidden size.')
213
    group.add_argument('--num-attention-heads', type=int, default=None,
Mohammad's avatar
Mohammad committed
214
                       help='Number of transformer attention heads.')
215
    group.add_argument('--max-position-embeddings', type=int, default=None,
Mohammad's avatar
Mohammad committed
216
217
218
219
220
                       help='Maximum number of position embeddings to use. '
                       'This is the size of position embedding.')
    group.add_argument('--make-vocab-size-divisible-by', type=int, default=128,
                       help='Pad the vocab size to be divisible by this value.'
                       'This is added for computational efficieny reasons.')
Mohammad's avatar
Mohammad committed
221
222
    group.add_argument('--layernorm-epsilon', type=float, default=1e-5,
                       help='Layer norm epsilon.')
Mohammad's avatar
Mohammad committed
223
224
225
226
    group.add_argument('--apply-residual-connection-post-layernorm',
                       action='store_true',
                       help='If set, use original BERT residula connection '
                       'ordering.')
227
228
229
230
    group.add_argument('--openai-gelu', action='store_true',
                       help='Use OpenAIs GeLU implementation. This option'
                       'should not be used unless for backward compatibility'
                       'reasons.')
231
    group.add_argument('--onnx-safe', type=bool, required=False,
232
                       help='Use workarounds for known problems with Torch ONNX exporter')
Mohammad's avatar
Mohammad committed
233

Mohammad's avatar
Mohammad committed
234
235
236
    return parser


Mohammad's avatar
Mohammad committed
237
def _add_regularization_args(parser):
Mohammad's avatar
Mohammad committed
238
239
240
    group = parser.add_argument_group(title='regularization')

    group.add_argument('--attention-dropout', type=float, default=0.1,
241
                       help='Post attention dropout probability.')
Mohammad's avatar
Mohammad committed
242
243
244
245
246
247
    group.add_argument('--hidden-dropout', type=float, default=0.1,
                       help='Dropout probability for hidden state transformer.')
    group.add_argument('--weight-decay', type=float, default=0.01,
                       help='Weight decay coefficient for L2 regularization.')
    group.add_argument('--clip-grad', type=float, default=1.0,
                       help='Gradient clipping based on global L2 norm.')
248
249
250
251
252
253
254
    group.add_argument('--adam-beta1', type=float, default=0.9,
                       help='First coefficient for computing running averages of'
                       'gradient and its square')
    group.add_argument('--adam-beta2', type=float, default=0.999,
                       help='Second coefficient for computing running averages of'
                       'gradient and its square')
    group.add_argument('--adam-eps', type=float, default=1e-08,
255
                       help='Term added to the denominator to improve'
256
                       'numerical stability')
Mohammad's avatar
Mohammad committed
257
258
259

    return parser

Mohammad's avatar
Mohammad committed
260
261

def _add_training_args(parser):
Mohammad's avatar
Mohammad committed
262
263
    group = parser.add_argument_group(title='training')

264
    group.add_argument('--micro-batch-size', type=int, default=None,
Mohammad's avatar
Mohammad committed
265
266
                       help='Batch size per model instance (local batch size). '
                       'Global batch size is local batch size times data '
mohammad's avatar
mohammad committed
267
                       'parallel size times number of micro batches.')
mohammad's avatar
mohammad committed
268
    group.add_argument('--global-batch-size', type=int, default=None,
mohammad's avatar
mohammad committed
269
270
271
                       help='Training batch size. If set, it should be a '
                       'multiple of micro-batch-size times data-parallel-size. '
                       'If this value is None, then '
mohammad's avatar
mohammad committed
272
                       'use micro-batch-size * data-parallel-size as the '
mohammad's avatar
mohammad committed
273
274
                       'global batch size. This choice will result in 1 for '
                       'number of micro-batches.')
mohammad's avatar
mohammad committed
275
276
277
278
279
280
281
282
283
284
285
286
    group.add_argument('--rampup-batch-size', nargs='*', default=None,
                       help='Batch size ramp up with the following values:'
                       '  --rampup-batch-size <start batch size> '
                       '                      <batch size incerement> '
                       '                      <ramp-up samples> '
                       'For example:'
                       '   --rampup-batch-size 16 8 300000 \ '
                       '   --global-batch-size 1024'
                       'will start with global batch size 16 and over '
                       ' (1024 - 16) / 8 = 126 intervals will increase'
                       'the batch size linearly to 1024. In each interval'
                       'we will use approximately 300000 / 126 = 2380 samples.')
Mohammad's avatar
Mohammad committed
287
288
289
    group.add_argument('--checkpoint-activations', action='store_true',
                       help='Checkpoint activation to allow for training '
                       'with larger models, sequences, and batch sizes.')
290
291
292
293
    group.add_argument('--distribute-checkpointed-activations',
                       action='store_true',
                       help='If set, distribute checkpointed activations '
                       'across model parallel group.')
Mohammad's avatar
Mohammad committed
294
295
    group.add_argument('--checkpoint-num-layers', type=int, default=1,
                       help='chunk size (number of layers) for checkpointing.')
Mohammad's avatar
Mohammad committed
296
    group.add_argument('--train-iters', type=int, default=None,
Mohammad's avatar
Mohammad committed
297
                       help='Total number of iterations to train over all '
298
299
300
301
302
303
                       'training runs. Note that either train-iters or '
                       'train-samples should be provided.')
    group.add_argument('--train-samples', type=int, default=None,
                       help='Total number of samples to train over all '
                       'training runs. Note that either train-iters or '
                       'train-samples should be provided.')
Mohammad's avatar
Mohammad committed
304
305
306
307
308
309
310
    group.add_argument('--log-interval', type=int, default=100,
                       help='Report loss and timing interval.')
    group.add_argument('--exit-interval', type=int, default=None,
                       help='Exit the program after the iteration is divisible '
                       'by this value.')
    group.add_argument('--tensorboard-dir', type=str, default=None,
                       help='Write TensorBoard logs to this directory.')
311
312
313
    group.add_argument('--scaled-upper-triang-masked-softmax-fusion',
                       action='store_true',
                       help='Enable fusion of query_key_value_scaling '
314
315
316
317
318
                       'time (upper diagonal) masking and softmax.')
    group.add_argument('--scaled-masked-softmax-fusion',
                       action='store_true',
                       help='Enable fusion of query_key_value_scaling '
                       'general masking and softmax.')
319
320
321
322
    group.add_argument('--bias-gelu-fusion', action='store_true',
                        help='Enable bias and gelu fusion.')
    group.add_argument('--bias-dropout-fusion', action='store_true',
                       help='Enable bias and dropout fusion.')
Mohammad's avatar
Mohammad committed
323
324
325
326

    return parser


Mohammad's avatar
Mohammad committed
327
def _add_initialization_args(parser):
Mohammad's avatar
Mohammad committed
328
329
330
331
332
333
334
335
    group = parser.add_argument_group(title='initialization')

    group.add_argument('--seed', type=int, default=1234,
                       help='Random seed used for python, numpy, '
                       'pytorch, and cuda.')
    group.add_argument('--init-method-std', type=float, default=0.02,
                       help='Standard deviation of the zero mean normal '
                       'distribution used for weight initialization.')
Mohammad's avatar
Mohammad committed
336

Mohammad's avatar
Mohammad committed
337
338
339
    return parser


Mohammad's avatar
Mohammad committed
340
def _add_learning_rate_args(parser):
Mohammad's avatar
Mohammad committed
341
342
    group = parser.add_argument_group(title='learning rate')

Mohammad's avatar
Mohammad committed
343
    group.add_argument('--lr', type=float, default=None,
Mohammad's avatar
Mohammad committed
344
345
346
347
                       help='Initial learning rate. Depending on decay style '
                       'and initial warmup, the learing rate at each '
                       'iteration would be different.')
    group.add_argument('--lr-decay-style', type=str, default='linear',
mohammad's avatar
mohammad committed
348
                       choices=['constant', 'linear', 'cosine'],
Mohammad's avatar
Mohammad committed
349
350
351
352
                       help='Learning rate decay function.')
    group.add_argument('--lr-decay-iters', type=int, default=None,
                       help='number of iterations to decay learning rate over,'
                       ' If None defaults to `--train-iters`')
353
354
355
356
357
358
359
360
361
    group.add_argument('--lr-decay-samples', type=int, default=None,
                       help='number of samples to decay learning rate over,'
                       ' If None defaults to `--train-samples`')
    group.add_argument('--lr-warmup-iters', type=int, default=0,
                       help='number of iterations to linearly warmup '
                       'learning rate over.')
    group.add_argument('--lr-warmup-samples', type=int, default=0,
                       help='number of samples to linearly warmup '
                       'learning rate over.')
Mohammad's avatar
Mohammad committed
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
    group.add_argument('--min-lr', type=float, default=0.0,
                       help='Minumum value for learning rate. The scheduler'
                       'clip values below this threshold.')
    group.add_argument('--override-lr-scheduler', action='store_true',
                       help='Reset the values of the scheduler (learning rate,'
                       'warmup iterations, minimum learning rate, maximum '
                       'number of iterations, and decay style from input '
                       'arguments and ignore values from checkpoints. Note'
                       'that all the above values will be reset.')
    group.add_argument('--use-checkpoint-lr-scheduler', action='store_true',
                       help='Use checkpoint to set the values of the scheduler '
                       '(learning rate, warmup iterations, minimum learning '
                       'rate, maximum number of iterations, and decay style '
                       'from checkpoint and ignore input arguments.')

    return parser


Mohammad's avatar
Mohammad committed
380
def _add_checkpointing_args(parser):
Mohammad's avatar
Mohammad committed
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
    group = parser.add_argument_group(title='checkpointing')

    group.add_argument('--save', type=str, default=None,
                       help='Output directory to save checkpoints to.')
    group.add_argument('--save-interval', type=int, default=None,
                       help='Number of iterations between checkpoint saves.')
    group.add_argument('--no-save-optim', action='store_true',
                       help='Do not save current optimizer.')
    group.add_argument('--no-save-rng', action='store_true',
                       help='Do not save current rng state.')
    group.add_argument('--load', type=str, default=None,
                       help='Directory containing a model checkpoint.')
    group.add_argument('--no-load-optim', action='store_true',
                       help='Do not load optimizer when loading checkpoint.')
    group.add_argument('--no-load-rng', action='store_true',
                       help='Do not load rng state when loading checkpoint.')
    group.add_argument('--finetune', action='store_true',
                       help='Load model for finetuning. Do not load optimizer '
                       'or rng state from checkpoint and set iteration to 0. '
                       'Assumed when loading a release checkpoint.')

    return parser


Mohammad's avatar
Mohammad committed
405
def _add_mixed_precision_args(parser):
Mohammad's avatar
Mohammad committed
406
407
408
409
410
411
412
413
414
415
    group = parser.add_argument_group(title='mixed precision')

    group.add_argument('--fp16', action='store_true',
                       help='Run model in fp16 mode.')
    group.add_argument('--apply-query-key-layer-scaling', action='store_true',
                       help='Scale Q * K^T by 1 / layer-number. If this flag '
                       'is set, then it will automatically set '
                       'attention-softmax-in-fp32 to true')
    group.add_argument('--attention-softmax-in-fp32', action='store_true',
                       help='Run attention masking and softmax in fp32.')
Mohammad's avatar
Mohammad committed
416
417
    group.add_argument('--fp32-allreduce', action='store_true',
                       help='All-reduce in fp32')
Mohammad's avatar
Mohammad committed
418
419
420
421
422
423
424
425
426
427
    group.add_argument('--hysteresis', type=int, default=2,
                       help='hysteresis for dynamic loss scaling')
    group.add_argument('--loss-scale', type=float, default=None,
                       help='Static loss scaling, positive power of 2 '
                       'values can improve fp16 convergence. If None, dynamic'
                       'loss scaling is used.')
    group.add_argument('--loss-scale-window', type=float, default=1000,
                       help='Window over which to raise/lower dynamic scale.')
    group.add_argument('--min-scale', type=float, default=1,
                       help='Minimum loss scale for dynamic loss scale.')
428
429
430
431
    group.add_argument('--fp16-lm-cross-entropy', action='store_true',
                       help='Move the cross entropy unreduced loss calculation'
                       'for lm head to fp16.')

Mohammad's avatar
Mohammad committed
432
433
434
435

    return parser


Mohammad's avatar
Mohammad committed
436
def _add_distributed_args(parser):
437
438
    group = parser.add_argument_group(title='distributed')

439
440
441
442
    group.add_argument('--tensor-model-parallel-size', type=int, default=1,
                       help='Degree of tensor model parallelism.')
    group.add_argument('--pipeline-model-parallel-size', type=int, default=1,
                       help='Degree of pipeline model parallelism.')
Mohammad's avatar
Mohammad committed
443
444
445
446
    group.add_argument('--distributed-backend', default='nccl',
                       choices=['nccl', 'gloo'],
                       help='Which backend to use for distributed training.')
    group.add_argument('--DDP-impl', default='local',
Mohammad's avatar
Mohammad committed
447
                       choices=['local', 'torch'],
Mohammad's avatar
Mohammad committed
448
449
450
451
                       help='which DistributedDataParallel implementation '
                       'to use.')
    group.add_argument('--local_rank', type=int, default=None,
                       help='local rank passed from distributed launcher.')
452
453
    group.add_argument('--lazy-mpu-init', type=bool, required=False,
                       help='If set to True, initialize_megatron() skips DDP initialization'
Boris Fomitchev's avatar
Boris Fomitchev committed
454
455
                       ' and returns function to complete it instead.'
                       'Also turns on --use-cpu-initialization flag.'
456
                       'This is for external DDP manager.' )
457
458
    group.add_argument('--use-cpu-initialization', action='store_true',
                       help='If set, affine parallel weights initialization uses CPU' )
Mohammad's avatar
Mohammad committed
459
460
461
    return parser


Mohammad's avatar
Mohammad committed
462
def _add_validation_args(parser):
Mohammad's avatar
Mohammad committed
463
464
465
466
467
468
469
470
471
    group = parser.add_argument_group(title='validation')

    group.add_argument('--eval-iters', type=int, default=100,
                       help='Number of iterations to run for evaluation'
                       'validation/test for.')
    group.add_argument('--eval-interval', type=int, default=1000,
                       help='Interval between running evaluation on '
                       'validation set.')

Mohammad's avatar
Mohammad committed
472
473
474
    return parser


Mohammad's avatar
Mohammad committed
475
def _add_data_args(parser):
Mohammad's avatar
Mohammad committed
476
477
    group = parser.add_argument_group(title='data and dataloader')

mohammad's avatar
mohammad committed
478
    group.add_argument('--data-path', nargs='*', default=None,
mohammad's avatar
mohammad committed
479
480
481
482
                       help='Path to the training dataset. Accepted format:'
                       '1) a single data path, 2) multiple datasets in the'
                       'form: dataset1-weight dataset1-path dataset2-weight '
                       'dataset2-path ...')
Mohammad's avatar
Mohammad committed
483
    group.add_argument('--split', type=str, default='969, 30, 1',
Mohammad's avatar
Mohammad committed
484
485
                       help='Comma-separated list of proportions for training,'
                       ' validation, and test split. For example the split '
486
487
                       '`90,5,5` will use 90%% of data for training, 5%% for '
                       'validation and 5%% for test.')
Mohammad's avatar
Mohammad committed
488
    group.add_argument('--vocab-file', type=str, default=None,
Mohammad's avatar
Mohammad committed
489
                       help='Path to the vocab file.')
Mohammad's avatar
Mohammad committed
490
491
    group.add_argument('--merge-file', type=str, default=None,
                       help='Path to the BPE merge file.')
Mohammad's avatar
Mohammad committed
492
    group.add_argument('--seq-length', type=int, default=None,
Mohammad's avatar
Mohammad committed
493
494
495
496
497
498
499
500
501
                       help="Maximum sequence length to process.")
    group.add_argument('--mask-prob', type=float, default=0.15,
                       help='Probability of replacing a token with mask.')
    group.add_argument('--short-seq-prob', type=float, default=0.1,
                       help='Probability of producing a short sequence.')
    group.add_argument('--mmap-warmup', action='store_true',
                       help='Warm up mmap files.')
    group.add_argument('--num-workers', type=int, default=2,
                       help="Dataloader number of workers.")
Mohammad's avatar
Mohammad committed
502
503
504
    group.add_argument('--tokenizer-type', type=str,
                       default=None,
                       choices=['BertWordPieceLowerCase',
Raul Puri's avatar
Raul Puri committed
505
                                'BertWordPieceCase',
Mohammad's avatar
Mohammad committed
506
507
                                'GPT2BPETokenizer'],
                       help='What type of tokenizer to use.')
508
509
510
511
512
513
514
515
516
517
    group.add_argument('--data-impl', type=str, default='infer',
                       choices=['lazy', 'cached', 'mmap', 'infer'],
                       help='Implementation of indexed datasets.')
    group.add_argument('--reset-position-ids', action='store_true',
                       help='Reset posistion ids after end-of-document token.')
    group.add_argument('--reset-attention-mask', action='store_true',
                       help='Reset self attention maske after '
                       'end-of-document token.')
    group.add_argument('--eod-mask-loss', action='store_true',
                       help='Mask loss for the end of document tokens.')
Mohammad's avatar
Mohammad committed
518

Mohammad's avatar
Mohammad committed
519
520
    return parser

Raul Puri's avatar
Raul Puri committed
521

Mohammad's avatar
Mohammad committed
522
523
def _add_autoresume_args(parser):
    group = parser.add_argument_group(title='autoresume')
Raul Puri's avatar
Raul Puri committed
524

Mohammad's avatar
Mohammad committed
525
526
527
528
529
    group.add_argument('--adlr-autoresume', action='store_true',
                       help='Enable autoresume on adlr cluster.')
    group.add_argument('--adlr-autoresume-interval', type=int, default=1000,
                       help='Intervals over which check for autoresume'
                       'termination signal')
Raul Puri's avatar
Raul Puri committed
530

Mohammad's avatar
Mohammad committed
531
    return parser
Neel Kant's avatar
Neel Kant committed
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551


def _add_realm_args(parser):
    group = parser.add_argument_group(title='realm')

    # network size
    group.add_argument('--ict-head-size', type=int, default=None,
                       help='Size of block embeddings to be used in ICT and REALM (paper default: 128)')

    # checkpointing
    group.add_argument('--ict-load', type=str, default=None,
                       help='Directory containing an ICTBertModel checkpoint')
    group.add_argument('--bert-load', type=str, default=None,
                       help='Directory containing an BertModel checkpoint (needed to start ICT and REALM)')

    # data
    group.add_argument('--titles-data-path', type=str, default=None,
                       help='Path to titles dataset used for ICT')
    group.add_argument('--query-in-block-prob', type=float, default=0.1,
                       help='Probability of keeping query in block for ICT dataset')
Neel Kant's avatar
Neel Kant committed
552
    group.add_argument('--use-one-sent-docs', action='store_true',
Neel Kant's avatar
Neel Kant committed
553
554
                       help='Whether to use one sentence documents in ICT')

555
556
557
558
    # training
    group.add_argument('--report-topk-accuracies', nargs='+', default=[],
                       help="Which top-k accuracies to report (e.g. '1 5 20')")

Neel Kant's avatar
Neel Kant committed
559
560
561
    # faiss index
    group.add_argument('--faiss-use-gpu', action='store_true',
                       help='Whether create the FaissMIPSIndex on GPU')
Neel Kant's avatar
Neel Kant committed
562
    group.add_argument('--block-data-path', type=str, default=None,
Neel Kant's avatar
Neel Kant committed
563
                       help='Where to save/load BlockData to/from')
Neel Kant's avatar
Neel Kant committed
564
565
566
567
568
569

    # indexer
    group.add_argument('--indexer-batch-size', type=int, default=128,
                       help='How large of batches to use when doing indexing jobs')
    group.add_argument('--indexer-log-interval', type=int, default=1000,
                       help='After how many batches should the indexer report progress')
Neel Kant's avatar
Neel Kant committed
570
571
    return parser