# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. """Megatron arguments.""" import argparse import os import sys import torch from megatron_mini.model.module import PositionEmbeddingType def parse_args(extra_args_provider=None, ignore_unknown_args=False): """Parse all arguments.""" parser = argparse.ArgumentParser(description='Megatron-LM Arguments', allow_abbrev=False) # Standard arguments. parser = _add_network_size_args(parser) parser = _add_training_args(parser) parser = _add_mixed_precision_args(parser) parser = _add_distributed_args(parser) parser = _add_inference_args(parser) # Custom arguments. if extra_args_provider is not None: parser = extra_args_provider(parser) # Parse. if ignore_unknown_args: args, _ = parser.parse_known_args() else: args = parser.parse_args() # Args from environment args.rank = int(os.getenv('RANK', '0')) args.world_size = int(os.getenv("WORLD_SIZE", '1')) return args def validate_args(args, defaults={}, aix_config=None): # Tensor model parallel size. args.tensor_model_parallel_size = min( args.tensor_model_parallel_size, args.world_size) assert args.world_size % args.tensor_model_parallel_size == 0, 'world size'\ ' ({}) is not divisible by tensor model parallel size ({})'.format( args.world_size, args.tensor_model_parallel_size) # Pipeline model parallel size. args.pipeline_model_parallel_size = min( args.pipeline_model_parallel_size, (args.world_size // args.tensor_model_parallel_size)) args.transformer_pipeline_model_parallel_size = (args.pipeline_model_parallel_size ) # Checks. model_parallel_size = args.pipeline_model_parallel_size * \ args.tensor_model_parallel_size assert args.world_size % model_parallel_size == 0, 'world size is not'\ ' divisible by tensor parallel size ({}) times pipeline parallel ' \ 'size ({})'.format(args.world_size, args.tensor_model_parallel_size, args.pipeline_model_parallel_size) args.data_parallel_size = args.world_size // model_parallel_size if args.rank == 0: print('using world size: {}, data-parallel-size: {}, ' 'tensor-model-parallel size: {}, ' 'pipeline-model-parallel size: {} '.format( args.world_size, args.data_parallel_size, args.tensor_model_parallel_size, args.pipeline_model_parallel_size), flush=True) if args.pipeline_model_parallel_size > 1: if args.pipeline_model_parallel_split_rank is not None: assert args.pipeline_model_parallel_split_rank < \ args.pipeline_model_parallel_size, 'split rank needs'\ ' to be less than pipeline model parallel size ({})'.format( args.pipeline_model_parallel_size) # Deprecated arguments assert args.model_parallel_size is None, '--model-parallel-size is no ' \ 'longer valid, use --tensor-model-parallel-size instead' del args.model_parallel_size # Set input defaults. for key in defaults: # For default to be valid, it should not be provided in the # arguments that are passed to the program. We check this by # ensuring the arg is set to None. if getattr(args, key) is not None: if args.rank == 0: print('WARNING: overriding default arguments for {key}:{v} \ with {key}:{v2}'.format(key=key, v=defaults[key], v2=getattr(args, key)), flush=True) else: setattr(args, key, defaults[key]) if aix_config is not None: # Set args for load Aix model for key in aix_config: if getattr(args, key) is not None: if args.rank == 0: print( "Overriding arguments for \"{key}:{v2}\" with AixConfig \"{key}:{v}\"".format( key=key, v=aix_config[key], v2=getattr(args, key) ), flush=True, file=sys.stderr ) setattr(args, key, aix_config[key]) else: if args.rank == 0: print( "Setting arguments with AixConfig \"{key}:{v}\"".format( key=key, v=aix_config[key] ), flush=True, file=sys.stderr ) setattr(args, key, aix_config[key]) # Batch size. assert args.micro_batch_size is not None assert args.micro_batch_size > 0 # Parameters dtype. args.params_dtype = torch.float if args.fp16: assert not args.bf16 args.params_dtype = torch.half if args.bf16: assert not args.fp16 args.params_dtype = torch.bfloat16 if args.rank == 0: print('using {} for parameters ...'.format(args.params_dtype), flush=True) # Consumed tokens. args.consumed_train_samples = 0 args.consumed_valid_samples = 0 # Support for variable sequence lengths across batches/microbatches. # set it if the dataloader supports generation of variable sequence lengths # across batches/microbatches. Due to additional communication overhead # during pipeline parallelism, it should not be set if sequence length # is constant during training. args.variable_seq_lengths = False if args.num_layers is not None: assert args.encoder_num_layers is None, \ 'cannot have both num-layers and encoder-num-layers specified' args.encoder_num_layers = args.num_layers else: assert args.encoder_num_layers is not None, \ 'either num-layers or encoder-num-layers should be specified' args.num_layers = args.encoder_num_layers # Check required arguments. required_args = ['num_layers', 'hidden_size', 'num_attention_heads'] for req_arg in required_args: _check_arg_is_not_none(args, req_arg) # Checks. if args.ffn_hidden_size is None: args.ffn_hidden_size = 4 * args.hidden_size if args.kv_channels is None: assert args.hidden_size % args.num_attention_heads == 0 args.kv_channels = args.hidden_size // args.num_attention_heads if args.seq_length is not None: assert args.max_position_embeddings >= args.seq_length TORCH_MAJOR = int(torch.__version__.split('.')[0]) TORCH_MINOR = int(torch.__version__.split('.')[1]) # Persistent fused layer norm. if TORCH_MAJOR < 1 or (TORCH_MAJOR == 1 and TORCH_MINOR < 11): args.no_persist_layer_norm = True if args.rank == 0: print('Persistent fused layer norm kernel is supported from ' 'pytorch v1.11 (nvidia pytorch container paired with v1.11). ' 'Defaulting to no_persist_layer_norm=True') if args.fp16: assert args.transformer_impl == 'local', \ 'transformer-engine not yet approved for fp16 training and inference' # disable sequence parallelism when tp=1 # to avoid change in numerics when # sequence_parallelism is enabled. if args.tensor_model_parallel_size == 1: args.sequence_parallel = False args.async_tensor_model_parallel_allreduce = False # disable async_tensor_model_parallel_allreduce when # model parallel memory optimization is enabled if args.sequence_parallel: args.async_tensor_model_parallel_allreduce = False args.no_gradient_accumulation_fusion = True if os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1": if args.sequence_parallel: raise RuntimeError( "Using sequence parallelism requires setting the environment variable " "CUDA_DEVICE_MAX_CONNECTIONS to 1") if args.async_tensor_model_parallel_allreduce: raise RuntimeError( "Using async gradient all reduce requires setting the environment " "variable CUDA_DEVICE_MAX_CONNECTIONS to 1") _print_args(args) return args def _print_args(args): """Print arguments.""" if args.rank == 0: print('------------------------ arguments ------------------------', flush=True) str_list = [] for arg in vars(args): dots = '.' * (48 - len(arg)) str_list.append(' {} {} {}'.format(arg, dots, getattr(args, arg))) for arg in sorted(str_list, key=lambda x: x.lower()): print(arg, flush=True) print('-------------------- end of arguments ---------------------', flush=True) def _check_arg_is_not_none(args, arg): assert getattr(args, arg) is not None, '{} argument is None'.format(arg) def _add_inference_args(parser): group = parser.add_argument_group(title='inference') group.add_argument('--inference-batch-times-seqlen-threshold', type=int, default=512, help='During inference, if batch-size times ' 'sequence-length is smaller than this threshold ' 'then we will not use pipelining, otherwise we will.') group.add_argument('--max-tokens-to-oom', type=int, default=12000, help='Maximum number of tokens during inference' 'tokens here is # in prompt + # to generate' 'Allows us to throw an error before OOM crashes server') return parser def _add_network_size_args(parser): group = parser.add_argument_group(title='network size') group.add_argument('--num-layers', type=int, default=None, help='Number of transformer layers.') group.add_argument('--encoder-num-layers', type=int, default=None, help='Number of encoder transformer layers.') group.add_argument('--decoder-num-layers', type=int, default=None, help='Number of decoder transformer layers.') group.add_argument('--hidden-size', type=int, default=None, help='Tansformer hidden size.') group.add_argument('--no-query-key-layer-scaling', action='store_false', help='Do not scale Q * K^T by 1 / layer-number.', dest='apply_query_key_layer_scaling') group.add_argument('--ffn-hidden-size', type=int, default=None, help='Transformer Feed-Forward Network hidden size. ' 'This is set to 4*hidden-size if not provided') group.add_argument('--num-attention-heads', type=int, default=None, help='Number of transformer attention heads.') group.add_argument('--kv-channels', type=int, default=None, help='Projection weights dimension in multi-head ' 'attention. This is set to ' ' args.hidden_size // args.num_attention_heads ' 'if not provided.') group.add_argument('--num-kv-heads', type=int, default=8, help='Number of transformer attention heads in MultiQuery\' s Key&Value.') group.add_argument('--position-embedding-type', type=lambda x: PositionEmbeddingType[x], choices=list(PositionEmbeddingType), default=PositionEmbeddingType.absolute, help='Define position embedding type ("absolute" | "rotary" | "alibi"). "absolute" by default.' ) group.add_argument('--kv-lru-capacity', type=int, default=0, help='Maximum number of lru cache.') group.add_argument('--rope-theta', type=int, default=10000, help='Maximum number of theat in cos and sin. ' 'This is the size of position embedding.') group.add_argument('--rope-linear-scaling-factor', type=int, default=1, help='Maximum number of theat in cos and sin. ' 'This is the size of position embedding.') group.add_argument('--inner-hidden-dim', type=int, default=None, help='Projection weights dimension in Gated linear unit. This is set to ' 'int(2 * (args.hidden_size * ffn_expand_rate) / 3), ' 'if not provided.') group.add_argument('--seq-length', type=int, default=1024, help='Maximum sequence length to process.') group.add_argument( "--dist-timeout", type=int, default=60*24*4, help="Timeout for Pytorch Distributed backend (in minutes).", ) parser.add_argument( '--attention-head-type', type=str, default='multiquery', choices=['multihead', 'multiquery', 'groupedquery']) group.add_argument('--max-position-embeddings', type=int, default=1024, help='Maximum number of position embeddings to use. ' 'This is the size of position embedding.') group.add_argument('--make-vocab-size-divisible-by', type=int, default=128, help='Pad the vocab size to be divisible by this value.' 'This is added for computational efficieny reasons.') group.add_argument('--layernorm-epsilon', type=float, default=1e-5, help='Layer norm epsilon.') group.add_argument('--apply-residual-connection-post-layernorm', action='store_true', help='If set, use original BERT residula connection ' 'ordering.') group.add_argument('--openai-gelu', action='store_true', help='Use OpenAIs GeLU implementation. This option' 'should not be used unless for backward compatibility' 'reasons.') group.add_argument('--onnx-safe', type=bool, required=False, help='Use workarounds for known problems with ' 'Torch ONNX exporter') group.add_argument('--bert-no-binary-head', action='store_false', help='Disable BERT binary head.', dest='bert_binary_head') group.add_argument('--num-experts', type=int, default=None, help='Number of Experts in Switch Transformer (None means no Switch)') group.add_argument('--is-extend-seq', action='store_true', help='Enable sequence parallel optimization.') return parser def _add_training_args(parser): group = parser.add_argument_group(title='training') group.add_argument('--micro-batch-size', type=int, default=None, help='Batch size per model instance (local batch size). ' 'Global batch size is local batch size times data ' 'parallel size times number of micro batches.') group.add_argument('--no-masked-softmax-fusion', action='store_false', help='Disable fusion of query_key_value scaling, ' 'masking, and softmax.', dest='masked_softmax_fusion') group.add_argument('--use-flash-attn', action='store_true', help='use FlashAttention implementation of attention. ' 'https://arxiv.org/abs/2205.14135') return parser def _add_mixed_precision_args(parser): group = parser.add_argument_group(title='mixed precision') group.add_argument('--fp16', action='store_true', help='Run model in fp16 mode.') group.add_argument('--bf16', action='store_true', help='Run model in bfloat16 mode.') return parser def _add_distributed_args(parser): group = parser.add_argument_group(title='distributed') group.add_argument('--tensor-model-parallel-size', type=int, default=1, help='Degree of tensor model parallelism.') group.add_argument('--pipeline-model-parallel-size', type=int, default=1, help='Degree of pipeline model parallelism.') group.add_argument('--pipeline-model-parallel-split-rank', type=int, default=None, help='Rank where encoder and decoder should be split.') group.add_argument('--model-parallel-size', type=int, default=None, help='Old model parallel argument, do not use. Use ' '--tensor-model-parallel-size instead.') group.add_argument('--distributed-backend', default='nccl', choices=['nccl', 'gloo'], help='Which backend to use for distributed training.') group.add_argument('--local_rank', type=int, default=None, help='local rank passed from distributed launcher.') group.add_argument('--use-cpu-initialization', action='store_true', default=None, help='If set, affine parallel weights ' 'initialization uses CPU' ) return parser