Commit d3dd8642 authored by Rayyyyy's avatar Rayyyyy
Browse files

First add

parents
Pipeline #1259 failed with stages
in 0 seconds
#!/bin/bash
#split checkpoint along the tensor
LOAD_CHECKPOINT_PATH=$1
#<Specify the loaded ckpt path>
SAVE_CHECKPOINT_PATH=$2
#<Specify the stored ckpt path>
TOKENIZER_MODEL_PATH=$3
#<Specify tokenizer model path>
export CUDA_DEVICE_MAX_CONNECTIONS=1
if [ ! -d $SAVE_CHECKPOINT_PATH ]; then
mkdir $SAVE_CHECKPOINT_PATH
fi
python tools/split_tp_partitions.py \
--tokenizer-model-path $TOKENIZER_MODEL_PATH \
--tensor-model-parallel-size 1 \
--pipeline-model-parallel-size 8 \
--target-pipeline-model-parallel-size 8 \
--target-tensor-model-parallel-size 4 \
--pipeline-generate-layer 0,1,2,3,4,5,6,7 \
--tokenizer-type YuanTokenizer \
--num-layers 24 \
--target-num-layers 24 \
--hidden-size 2048 \
--num-attention-heads 16 \
--kv-channels 256 \
--seq-length 4096 \
--max-position-embeddings 4096 \
--use-lf-gate \
--lf-conv2d-group 1 \
--lf-conv2d-num-pad 1 \
--position-embedding-type rope \
--flash-attn-drop 0.1 \
--fim-rate 0.5 \
--fim-spm-rate 0.5 \
--attention-dropout 0 \
--hidden-dropout 0 \
--norm-dtype RMSNorm \
--disable-bias-linear \
--reset-position-ids \
--use-flash-attn \
--swiglu \
--adam-beta1 0.9 \
--adam-beta2 0.95 \
--bf16 \
--num-attention-router-heads 4096 \
--rotary-percent 0.5 \
--use-attention-router \
--no-masked-softmax-fusion \
--use-fp32-router \
--num-experts 32 \
--moe-router-load-balancing-type none \
--moe-router-topk 2 \
--moe-grouped-gemm \
--save-interval 1 \
--recompute-method block \
--recompute-granularity full \
--recompute-num-layers 1 \
--load $LOAD_CHECKPOINT_PATH \
--save $SAVE_CHECKPOINT_PATH \
--micro-batch-size 1 \
--global-batch-size 1152 \
--use-distributed-optimizer \
--lr 0.00009 \
--train-iters 63578 \
--lr-decay-iters 63578 \
--lr-decay-style cosine \
--min-lr 1.0e-5 \
--weight-decay 1e-1 \
--no-load-optim \
--use-cpu-initialization \
--process-checkpoint \
--data-impl mmap --DDP-impl local
du -sh $SAVE_CHECKPOINT_PATH
#!/bin/bash
#split checkpoint along the tensor
LOAD_CHECKPOINT_PATH=$1
#<Specify the loaded ckpt path>
SAVE_CHECKPOINT_PATH=$2
#<Specify the stored ckpt path>
TOKENIZER_MODEL_PATH=$3
#<Specify tokenizer model path>
export CUDA_DEVICE_MAX_CONNECTIONS=1
if [ ! -d $SAVE_CHECKPOINT_PATH ]; then
mkdir $SAVE_CHECKPOINT_PATH
fi
python tools/split_tp_partitions.py \
--tokenizer-model-path $TOKENIZER_MODEL_PATH \
--tensor-model-parallel-size 1 \
--pipeline-model-parallel-size 16 \
--target-pipeline-model-parallel-size 16 \
--target-tensor-model-parallel-size 4 \
--pipeline-model-parallel-method block \
--pipeline-model-parallel-blocks 2,2,2,2,2,3,3,3,3,3,3,3,3,3,3,2 \
--pipeline-generate-layer 0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 \
--tokenizer-type YuanTokenizer \
--num-layers 42 \
--target-num-layers 42 \
--hidden-size 8192 \
--num-attention-heads 64 \
--seq-length 4096 \
--max-position-embeddings 4096 \
--use-lf-gate \
--lf-conv2d-group 1 \
--lf-conv2d-num-pad 1 \
--position-embedding-type rope \
--flash-attn-drop 0.1 \
--fim-rate 0.5 \
--fim-spm-rate 0.5 \
--attention-dropout 0 \
--hidden-dropout 0 \
--norm-dtype RMSNorm \
--disable-bias-linear \
--reset-position-ids \
--use-flash-attn \
--swiglu \
--adam-beta1 0.9 \
--adam-beta2 0.95 \
--bf16 \
--save-interval 1 \
--recompute-method block \
--recompute-granularity full \
--recompute-num-layers 1 \
--load $LOAD_CHECKPOINT_PATH \
--save $SAVE_CHECKPOINT_PATH \
--micro-batch-size 1 \
--global-batch-size 1152 \
--use-distributed-optimizer \
--lr 0.00009 \
--train-iters 63578 \
--lr-decay-iters 63578 \
--lr-decay-style cosine \
--min-lr 1.0e-5 \
--weight-decay 1e-1 \
--no-load-optim \
--use-cpu-initialization \
--process-checkpoint \
--data-impl mmap --DDP-impl local
du -sh $SAVE_CHECKPOINT_PATH
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import torch
from .global_vars import get_args, get_retro_args
from .global_vars import get_current_global_batch_size
from .global_vars import get_num_microbatches
from .global_vars import get_signal_handler
from .global_vars import update_num_microbatches
from .global_vars import get_tokenizer
from .global_vars import get_tensorboard_writer
from .global_vars import get_adlr_autoresume
from .global_vars import get_timers
from .initialize import initialize_megatron
from .utils import (print_rank_0,
is_last_rank,
print_rank_last)
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Megatron arguments."""
import argparse
import dataclasses
import json
import os
import torch
import types
import torch.nn.functional as F
from megatron.global_vars import set_retro_args, get_retro_args
from tools.retro.utils import get_args_path as get_retro_args_path
from megatron.core.transformer import TransformerConfig
def parse_args(extra_args_provider=None, ignore_unknown_args=False):
"""Parse all arguments."""
parser = argparse.ArgumentParser(description='Megatron-LM Arguments',
allow_abbrev=False)
# Standard arguments.
parser = _add_network_size_args(parser)
parser = _add_regularization_args(parser)
parser = _add_training_args(parser)
parser = _add_initialization_args(parser)
parser = _add_learning_rate_args(parser)
parser = _add_checkpointing_args(parser)
parser = _add_mixed_precision_args(parser)
parser = _add_distributed_args(parser)
parser = _add_validation_args(parser)
parser = _add_data_args(parser)
parser = _add_autoresume_args(parser)
parser = _add_biencoder_args(parser)
parser = _add_vision_args(parser)
parser = _add_moe_args(parser)
parser = _add_logging_args(parser)
parser = _add_inference_args(parser)
parser = _add_transformer_engine_args(parser)
parser = _add_retro_args(parser)
# Custom arguments.
if extra_args_provider is not None:
parser = extra_args_provider(parser)
# Parse.
if ignore_unknown_args:
args, _ = parser.parse_known_args()
else:
args = parser.parse_args()
# Args from environment
args.rank = int(os.getenv('RANK', '0'))
args.world_size = int(os.getenv("WORLD_SIZE", '1'))
return args
def validate_args(args, defaults={}):
# Tensor model parallel size.
args.tensor_model_parallel_size = min(
args.tensor_model_parallel_size, args.world_size)
assert args.world_size % args.tensor_model_parallel_size == 0, 'world size'\
' ({}) is not divisible by tensor model parallel size ({})'.format(
args.world_size, args.tensor_model_parallel_size)
# Pipeline model parallel size.
args.pipeline_model_parallel_size = min(
args.pipeline_model_parallel_size,
(args.world_size // args.tensor_model_parallel_size))
args.transformer_pipeline_model_parallel_size = (
args.pipeline_model_parallel_size - 1
if args.standalone_embedding_stage else
args.pipeline_model_parallel_size
)
# Checks.
model_parallel_size = args.pipeline_model_parallel_size * \
args.tensor_model_parallel_size
assert args.world_size % model_parallel_size == 0, 'world size is not'\
' divisible by tensor parallel size ({}) times pipeline parallel ' \
'size ({})'.format(args.world_size, args.tensor_model_parallel_size,
args.pipeline_model_parallel_size)
assert args.expert_model_parallel_size == 1, 'expert model parallel sizeb must be 1,The current version does not support expert parallel methods'
args.data_parallel_size = args.world_size // model_parallel_size
if args.rank == 0:
print('using world size: {}, data-parallel-size: {}, '
'tensor-model-parallel size: {}, '
'pipeline-model-parallel size: {} '.format(
args.world_size, args.data_parallel_size,
args.tensor_model_parallel_size,
args.pipeline_model_parallel_size), flush=True)
if args.pipeline_model_parallel_size > 1:
if args.pipeline_model_parallel_split_rank is not None:
assert args.pipeline_model_parallel_split_rank < \
args.pipeline_model_parallel_size, 'split rank needs'\
' to be less than pipeline model parallel size ({})'.format(
args.pipeline_model_parallel_size)
# Deprecated arguments
assert args.batch_size is None, '--batch-size argument is no longer ' \
'valid, use --micro-batch-size instead'
del args.batch_size
assert args.warmup is None, '--warmup argument is no longer valid, use ' \
'--lr-warmup-fraction instead'
del args.warmup
assert args.model_parallel_size is None, '--model-parallel-size is no ' \
'longer valid, use --tensor-model-parallel-size instead'
del args.model_parallel_size
if args.checkpoint_activations:
if args.rank == 0:
print('--checkpoint-activations is no longer valid, use --recompute-activations, '
'or, for more control, --recompute-granularity and --recompute-method.')
exit()
del args.checkpoint_activations
if args.recompute_activations:
args.recompute_granularity = 'selective'
if args.pipeline_model_parallel_method == 'block':
assert args.num_layers_per_virtual_pipeline_stage == None, 'pipeline interleavinig not support pipeline_model_parallel_method=block now'
args.pipeline_blocks = [int(i) for i in args.pipeline_model_parallel_blocks.split(',')]
assert sum(args.pipeline_blocks) == args.num_layers, 'sum of args.pipeline_model_parallel_blocks must eq args.num_layers'
assert len(args.pipeline_blocks) == args.pipeline_model_parallel_size, 'num of args.pipeline_model_parallel_blocks must eq args.pipeline_model_parallel_size'
elif not args.standalone_embedding_stage:
assert args.num_layers % args.pipeline_model_parallel_size == 0
args.pipeline_blocks = [args.num_layers // args.pipeline_model_parallel_size] * args.pipeline_model_parallel_size
else:
assert args.num_layers % (args.pipeline_model_parallel_size - 1) == 0
args.pipeline_blocks = [0] + [args.num_layers // (args.pipeline_model_parallel_size - 1)] * (args.pipeline_model_parallel_size - 1)
del args.recompute_activations
# Set input defaults.
for key in defaults:
# For default to be valid, it should not be provided in the
# arguments that are passed to the program. We check this by
# ensuring the arg is set to None.
if getattr(args, key) is not None:
if args.rank == 0:
print('WARNING: overriding default arguments for {key}:{v} \
with {key}:{v2}'.format(key=key, v=defaults[key],
v2=getattr(args, key)),
flush=True)
else:
setattr(args, key, defaults[key])
# Batch size.
assert args.micro_batch_size is not None
assert args.micro_batch_size > 0
if args.global_batch_size is None:
args.global_batch_size = args.micro_batch_size * args.data_parallel_size
if args.rank == 0:
print('setting global batch size to {}'.format(
args.global_batch_size), flush=True)
assert args.global_batch_size > 0
if args.num_layers_per_virtual_pipeline_stage is not None:
assert args.pipeline_model_parallel_size > 2, \
'pipeline-model-parallel size should be greater than 2 with ' \
'interleaved schedule'
assert args.num_layers % args.num_layers_per_virtual_pipeline_stage == 0, \
'number of layers is not divisible by number of layers per virtual ' \
'pipeline stage'
args.virtual_pipeline_model_parallel_size = \
(args.num_layers // args.transformer_pipeline_model_parallel_size) // \
args.num_layers_per_virtual_pipeline_stage
else:
args.virtual_pipeline_model_parallel_size = None
# Parameters dtype.
args.params_dtype = torch.float
if args.fp16:
assert not args.bf16
args.params_dtype = torch.half
if args.bf16:
assert not args.fp16
args.params_dtype = torch.bfloat16
# bfloat16 requires gradient accumulation and all-reduce to
# be done in fp32.
if not args.accumulate_allreduce_grads_in_fp32:
args.accumulate_allreduce_grads_in_fp32 = True
if args.rank == 0:
print('accumulate and all-reduce gradients in fp32 for '
'bfloat16 data type.', flush=True)
if args.rank == 0:
print('using {} for parameters ...'.format(args.params_dtype),
flush=True)
# If we do accumulation and all-reduces in fp32, we need to have local DDP
# and we should make sure use-contiguous-buffers-in-local-ddp is not off.
if args.accumulate_allreduce_grads_in_fp32:
assert args.DDP_impl == 'local'
assert args.use_contiguous_buffers_in_local_ddp
# If we use the distributed optimizer, we need to have local DDP
# and we should make sure use-contiguous-buffers-in-local-ddp is on.
if args.use_distributed_optimizer:
assert args.DDP_impl == 'local'
assert args.use_contiguous_buffers_in_local_ddp
# For torch DDP, we do not use contiguous buffer
if args.DDP_impl == 'torch':
args.use_contiguous_buffers_in_local_ddp = False
if args.dataloader_type is None:
args.dataloader_type = 'single'
# Consumed tokens.
args.consumed_train_samples = 0
args.consumed_valid_samples = 0
# Support for variable sequence lengths across batches/microbatches.
# set it if the dataloader supports generation of variable sequence lengths
# across batches/microbatches. Due to additional communication overhead
# during pipeline parallelism, it should not be set if sequence length
# is constant during training.
args.variable_seq_lengths = False
# Iteration-based training.
if args.train_iters:
# If we use iteration-based training, make sure the
# sample-based options are off.
assert args.train_samples is None, \
'expected iteration-based training'
assert args.lr_decay_samples is None, \
'expected iteration-based learning rate decay'
assert args.lr_warmup_samples == 0, \
'expected iteration-based learning rate warmup'
assert args.rampup_batch_size is None, \
'expected no batch-size rampup for iteration-based training'
if args.lr_warmup_fraction is not None:
assert args.lr_warmup_iters == 0, \
'can only specify one of lr-warmup-fraction and lr-warmup-iters'
# Sample-based training.
if args.train_samples:
# If we use sample-based training, make sure the
# iteration-based options are off.
assert args.train_iters is None, \
'expected sample-based training'
assert args.lr_decay_iters is None, \
'expected sample-based learning rate decay'
assert args.lr_warmup_iters == 0, \
'expected sample-based learnig rate warmup'
if args.lr_warmup_fraction is not None:
assert args.lr_warmup_samples == 0, \
'can only specify one of lr-warmup-fraction ' \
'and lr-warmup-samples'
if args.num_layers is not None:
assert args.encoder_num_layers is None, \
'cannot have both num-layers and encoder-num-layers specified'
args.encoder_num_layers = args.num_layers
else:
assert args.encoder_num_layers is not None, \
'either num-layers or encoder-num-layers should be specified'
args.num_layers = args.encoder_num_layers
# Check required arguments.
required_args = ['num_layers', 'hidden_size', 'num_attention_heads',
'max_position_embeddings']
for req_arg in required_args:
_check_arg_is_not_none(args, req_arg)
# Checks.
if args.ffn_hidden_size is None:
args.ffn_hidden_size = 4 * args.hidden_size
if args.kv_channels is None:
assert args.hidden_size % args.num_attention_heads == 0
args.kv_channels = args.hidden_size // args.num_attention_heads
if args.seq_length is not None:
assert args.encoder_seq_length is None
args.encoder_seq_length = args.seq_length
else:
assert args.encoder_seq_length is not None
args.seq_length = args.encoder_seq_length
if args.seq_length is not None:
assert args.max_position_embeddings >= args.seq_length
if args.decoder_seq_length is not None:
assert args.max_position_embeddings >= args.decoder_seq_length
if args.lr is not None:
assert args.min_lr <= args.lr
if args.save is not None:
assert args.save_interval is not None
# Mixed precision checks.
if args.fp16_lm_cross_entropy:
assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.'
if args.fp32_residual_connection:
assert args.fp16 or args.bf16, \
'residual connection in fp32 only supported when using fp16 or bf16.'
if args.moe_grouped_gemm:
assert args.bf16, 'Currently GroupedGEMM for MoE only supports bf16 dtype.'
dc = torch.cuda.get_device_capability()
assert dc[0] >= 8, "Unsupported compute capability for GroupedGEMM kernels."
if args.weight_decay_incr_style == 'constant':
assert args.start_weight_decay is None
assert args.end_weight_decay is None
args.start_weight_decay = args.weight_decay
args.end_weight_decay = args.weight_decay
else:
assert args.start_weight_decay is not None
assert args.end_weight_decay is not None
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
# Persistent fused layer norm.
if TORCH_MAJOR < 1 or (TORCH_MAJOR == 1 and TORCH_MINOR < 11):
args.no_persist_layer_norm = True
if args.rank == 0:
print('Persistent fused layer norm kernel is supported from '
'pytorch v1.11 (nvidia pytorch container paired with v1.11). '
'Defaulting to no_persist_layer_norm=True')
# Activation recomputing.
if args.distribute_saved_activations:
assert args.tensor_model_parallel_size > 1, 'can distribute ' \
'recomputed activations only across tensor model ' \
'parallel groups'
assert args.recompute_granularity == 'full', \
'distributed recompute activations is only '\
'application to full recompute granularity'
assert args.recompute_method is not None, \
'for distributed recompute activations to work you '\
'need to use a recompute method '
assert TORCH_MAJOR >= 1 and TORCH_MINOR >= 10, \
'distributed recompute activations are supported for pytorch ' \
'v1.10 and above (Nvidia Pytorch container >= 21.07). Current ' \
'pytorch version is v%s.%s.' % (TORCH_MAJOR, TORCH_MINOR)
# Tranformer-Engine/FP8 related checking
if args.fp8_e4m3 or args.fp8_hybrid:
assert args.transformer_impl == 'transformer_engine', \
'transformer-engine required for fp8 training and inference'
assert not (args.fp8_e4m3 and args.fp8_hybrid), \
'cannot train with both fp8 e4m3 and hybrid formatting'
if args.recompute_granularity == 'selective':
assert args.recompute_method is None, \
'recompute method is not yet supported for ' \
'selective recomputing granularity'
# disable sequence parallelism when tp=1
# to avoid change in numerics when
# sequence_parallelism is enabled.
if args.tensor_model_parallel_size == 1:
args.sequence_parallel = False
# disable async_tensor_model_parallel_allreduce when
# model parallel memory optimization is enabled
if args.sequence_parallel:
args.async_tensor_model_parallel_allreduce = False
if os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1":
if args.sequence_parallel:
raise RuntimeError(
"Using sequence parallelism requires setting the environment variable "
"CUDA_DEVICE_MAX_CONNECTIONS to 1")
if args.async_tensor_model_parallel_allreduce:
raise RuntimeError(
"Using async gradient all reduce requires setting the environment "
"variable CUDA_DEVICE_MAX_CONNECTIONS to 1")
# Disable bias gelu fusion if we are disabling bias altogether
if not args.add_bias_linear:
args.bias_gelu_fusion = False
# Retro checks.
if args.retro_add_retriever:
# Sequence parallelism unsupported.
assert not args.sequence_parallel, \
"retro currently does not support sequence parallelism."
# Pipeline parallelism unsupported.
assert args.pipeline_model_parallel_size == 1, \
"retro currently does not support pipeline parallelism."
# Load retro args.
retro_args_path = get_retro_args_path(args.retro_workdir)
assert os.path.exists(retro_args_path), "retro workdir missing args.json"
with open(retro_args_path) as f:
retro_args = types.SimpleNamespace(**json.load(f))
retro_args.retro_return_doc_ids = args.retro_return_doc_ids
retro_args.retro_gpt_retrieved_length = \
args.retro_num_retrieved_chunks * \
retro_args.retro_gpt_chunk_length
set_retro_args(retro_args)
# Legacy RoPE arguments
if args.use_rotary_position_embeddings:
args.position_embedding_type = 'rope'
# Would just need to add 'NoPE' as a position_embedding_type to support this, but for now
# don't allow it to keep things simple
if not args.add_position_embedding and args.position_embedding_type != 'rope':
raise RuntimeError('--no-position-embedding is deprecated, use --position-embedding-type')
# Print arguments.
_print_args("arguments", args)
retro_args = get_retro_args()
if retro_args and args != retro_args:
_print_args("retro arguments", types.SimpleNamespace(**{k:v for k,v in vars(retro_args).items() if k.startswith("retro")}, rank=args.rank))
return args
def _print_args(title, args):
"""Print arguments."""
if args.rank == 0:
print(f'------------------------ {title} ------------------------',
flush=True)
str_list = []
for arg in vars(args):
dots = '.' * (48 - len(arg))
str_list.append(' {} {} {}'.format(arg, dots, getattr(args, arg)))
for arg in sorted(str_list, key=lambda x: x.lower()):
print(arg, flush=True)
print(f'-------------------- end of {title} ---------------------',
flush=True)
def _check_arg_is_not_none(args, arg):
assert getattr(args, arg) is not None, '{} argument is None'.format(arg)
def core_transformer_config_from_args(args):
# Translate args to core transformer configuration
kw_args = {}
for f in dataclasses.fields(TransformerConfig):
if hasattr(args, f.name):
kw_args[f.name] = getattr(args, f.name)
kw_args['persist_layer_norm'] = not args.no_persist_layer_norm
kw_args['layernorm_zero_centered_gamma'] = args.apply_layernorm_1p
kw_args['deallocate_pipeline_outputs'] = True
kw_args['pipeline_dtype'] = args.params_dtype
kw_args['batch_p2p_comm'] = not args.overlap_p2p_comm
kw_args['num_moe_experts'] = args.num_experts
if args.swiglu:
kw_args['activation_func'] = F.silu
kw_args['gated_linear_unit'] = True
kw_args['bias_activation_fusion'] = args.bias_swiglu_fusion
else:
kw_args['bias_activation_fusion'] = args.bias_gelu_fusion
if args.init_method_xavier_uniform:
kw_args['init_method'] = torch.nn.init.xavier_uniform_
kw_args['scaled_init_method'] = torch.nn.init.xavier_uniform_
if args.group_query_attention:
kw_args['num_query_groups'] = args.num_query_groups
else:
kw_args['num_query_groups'] = None
if args.num_experts != None and args.num_shared_experts != None:
assert args.num_shared_experts < args.num_experts, 'args.num_shared_experts must less than args.num_experts'
assert args.seq_length == args.num_attention_router_heads, 'args.seq_length and args.num_attention_router_heads must be equal in the current version'
return TransformerConfig(**kw_args)
def _add_transformer_engine_args(parser):
group = parser.add_argument_group(title='Transformer-Engine')
group.add_argument('--fp8-e4m3', action='store_true',
help='E4M3 TransformerLayer', dest='fp8_e4m3')
group.add_argument('--fp8-hybrid', action='store_true',
help='Hybrid FP8 TransformerLayer', dest='fp8_hybrid')
group.add_argument('--no-fp8-wgrad', action='store_false',
help='Execute wgrad in higher precision even for FP8 runs', dest='fp8_wgrad')
group.add_argument('--fp8-margin', type=int, default=0,
help='Scaling margin for fp8', dest='fp8_margin')
group.add_argument('--fp8-interval', type=int, default=1,
help='Scaling update interval for fp8', dest='fp8_interval')
group.add_argument('--transformer-impl', default='local',
choices=['local', 'transformer_engine'],
help='Which Transformer implementation to use.',
dest='transformer_impl')
group.add_argument('--fp8-amax-history-len', type=int, default=1,
help='Number of steps for which amax history is recorded per tensor',
dest='fp8_amax_history_len')
group.add_argument('--fp8-amax-compute-algo', default='most_recent',
choices=['most_recent', 'max'],
help='Algorithm for computing amax from history',
dest='fp8_amax_compute_algo')
return parser
def _add_inference_args(parser):
group = parser.add_argument_group(title='inference')
group.add_argument('--inference-batch-times-seqlen-threshold',
type=int, default=512,
help='During inference, if batch-size times '
'sequence-length is smaller than this threshold '
'then we will not use pipelining, otherwise we will.')
group.add_argument('--max-tokens-to-oom',
type=int, default=12000,
help='Maximum number of tokens during inference'
'tokens here is # in prompt + # to generate'
'Allows us to throw an error before OOM crashes server')
group.add_argument('--output-bert-embeddings', action='store_true',
help='Output Bert embeddings (via mean pooling) from '
'model, rather than its binary head output or entire '
'hidden batch.')
group.add_argument('--bert-embedder-type', default="megatron",
choices=["megatron", "huggingface"],
help='Select either Megatron or Huggingface as the '
'Bert embedder.')
group.add_argument('--inference-server', action='store_true',
help='whether to use inference-server.')
group.add_argument('--repetition-penalty',
type=float,
default=1.0,
help='Repetition_penalty.')
return parser
def _add_retro_args(parser):
group = parser.add_argument_group(title='retro')
group.add_argument('--retro-workdir', default=None,
help='Retro working directory, which contains the '
'preprocessed data for for pretraining. This directory '
'is built during preprocessing (see '
'tools/retro/README.md), and contains subdirectories '
'for the chunk database and pretraining neighbors.')
group.add_argument('--retro-add-retriever',
action='store_true', default=False,
help='Add a retriever to the transformer, for use in '
'pretraining a Retro model.')
group.add_argument('--retro-cyclic-train-iters', type=int, default=None,
help='Set number of training iterations for cyclic '
'Retro training.')
group.add_argument('--retro-encoder-layers', type=int, default=2,
help='Number of layers to use for the retrieval '
'encoder.')
group.add_argument('--retro-encoder-hidden-dropout',
type=float, default=0.1, help='Hidden dropout for '
'retrieval encoder.')
group.add_argument('--retro-encoder-attention-dropout',
type=float, default=0.1, help='Attention dropout for '
'retrieval encoder.')
group.add_argument("--retro-num-neighbors", type=int, default=2,
help='Number of neighbors to retrieve during '
'pretraining.')
group.add_argument("--retro-num-retrieved-chunks", type=int, default=2,
help='Number of chunks to retrieve from the retrieval '
'database.')
group.add_argument("--retro-return-doc-ids", action="store_true",
help="Turn this on when preprocessing retro data.")
# Enforce argument naming convention.
for action in group._group_actions:
prefix = action.dest.split("_")[0]
assert prefix == "retro", \
"Retro args must be prefixed with '--retro-*', for consistent " \
"styling. Please fix '%s'." % ", ".join(action.option_strings)
return parser
def _add_network_size_args(parser):
group = parser.add_argument_group(title='network size')
group.add_argument('--num-layers', type=int, default=None,
help='Number of transformer layers.')
group.add_argument('--encoder-num-layers', type=int, default=None,
help='Number of encoder transformer layers.')
group.add_argument('--decoder-num-layers', type=int, default=None,
help='Number of decoder transformer layers.')
group.add_argument('--hidden-size', type=int, default=None,
help='Tansformer hidden size.')
group.add_argument('--ffn-hidden-size', type=int, default=None,
help='Transformer Feed-Forward Network hidden size. '
'This is set to 4*hidden-size if not provided')
group.add_argument('--num-attention-heads', type=int, default=None,
help='Number of transformer attention heads.')
group.add_argument('--kv-channels', type=int, default=None,
help='Projection weights dimension in multi-head '
'attention. This is set to '
' args.hidden_size // args.num_attention_heads '
'if not provided.')
group.add_argument('--group-query-attention', action='store_true',
help='Use group-query attention.')
group.add_argument('--num-query-groups', type=int, default=1)
group.add_argument('--max-position-embeddings', type=int, default=None,
help='Maximum number of position embeddings to use. '
'This is the size of position embedding.')
group.add_argument('--position-embedding-type', type=str, default='learned_absolute',
choices=['learned_absolute', 'rope'],
help='Position embedding type.')
group.add_argument('--use-rotary-position-embeddings', action='store_true',
help='Use rotary positional embeddings or not. '
'Deprecated: use --position-embedding-type')
group.add_argument('--rotary-base', type=int, default=10000,
help='base number for rope')
group.add_argument('--rotary-percent', type=float, default=1.0,
help='Percent of rotary dimension to use, default 100%')
group.add_argument('--no-position-embedding',
action='store_false',
help='Disable position embedding. Deprecated: use --position-embedding-type',
dest='add_position_embedding')
group.add_argument('--make-vocab-size-divisible-by', type=int, default=128,
help='Pad the vocab size to be divisible by this value.'
'This is added for computational efficieny reasons.')
group.add_argument('--layernorm-epsilon', type=float, default=1e-5,
help='Layer norm epsilon.')
group.add_argument('--apply-layernorm-1p', action='store_true',
help='Adjust LayerNorm weights such that they are centered '
'around zero. This improves numerical stability.')
group.add_argument('--apply-residual-connection-post-layernorm',
action='store_true',
help='If set, use original BERT residula connection '
'ordering.')
group.add_argument('--openai-gelu', action='store_true',
help='Use OpenAIs GeLU implementation. This option'
'should not be used unless for backward compatibility'
'reasons.')
group.add_argument('--squared-relu', action='store_true',
help='Use squared relu activation instead of default gelu')
group.add_argument('--swiglu', action='store_true',
help='Use gated linear units and SiLU activation instead of default gelu')
group.add_argument('--onnx-safe', type=bool, required=False,
help='Use workarounds for known problems with '
'Torch ONNX exporter')
group.add_argument('--bert-no-binary-head', action='store_false',
help='Disable BERT binary head.',
dest='bert_binary_head')
group.add_argument('--untie-embeddings-and-output-weights', action='store_true',
help='Untie embeddings and output weights.'),
group.add_argument('--embedding-weights-in-fp32', action='store_true',
help='Cast word embedding weights to fp32 before embedding fwd.'),
group.add_argument('--no-embedding-dropout', action='store_false',
help='Disable embedding dropout.')
group.add_argument('--use-lf-gate', action='store_true',
help='whether to use Localized Filtering.')
group.add_argument('--lf-conv2d-group', type=int, default=1,
help='the number of Localized Filtering conv2d group.')
group.add_argument('--lf-conv2d-num-pad', type=int, default=1, choices=[0, 1],
help='the number of lf conv2d pad, training pad is 1, eval pad is 0')
return parser
def _add_logging_args(parser):
group = parser.add_argument_group(title='logging')
group.add_argument('--log-params-norm', action='store_true',
help='If set, calculate and log parameters norm.')
group.add_argument('--log-num-zeros-in-grad', action='store_true',
help='If set, calculate and log the number of zeros in gradient.')
group.add_argument('--timing-log-level', type=int,
default=0, choices=range(0,3),
help='Granularity level to measure and report timing. '
' 0: report only iteration time and make sure timing '
' does not introduce extra overhead.'
' 1: report timing for operations that are executed '
' very limited times (basically once) during '
' each iteration (such as gradient all-reduce) '
' 2: report timing for operations that migh be '
' executed numerous times during each iteration. '
'Note that setting the level to 1 or 2 might '
'cause increase in iteration time.')
group.add_argument('--no-barrier-with-level-1-timing', action='store_false',
help='If not set, use barrier with level 1 time '
'measurements. Note that this is up to the user '
'to make sure calling barrier with their timers '
'will not result in hangs. This can happen if for '
'example the user adds a level 1 timer that is not '
'called by all ranks.',
dest='barrier_with_L1_time')
group.add_argument('--timing-log-option', type=str, default='minmax',
choices=['max', 'minmax', 'all'],
help='Options for logging timing:'
' max: report the max timing across all ranks'
' minmax: report min and max timings across all ranks'
' all: report timings of all ranks.')
group.add_argument('--tensorboard-log-interval', type=int, default=1,
help='Report to tensorboard interval.')
group.add_argument('--tensorboard-queue-size', type=int, default=1000,
help='Size of the tensorboard queue for pending events '
'and summaries before one of the ‘add’ calls forces a '
'flush to disk.')
group.add_argument('--log-timers-to-tensorboard', action='store_true',
help='If set, write timers to tensorboard.')
group.add_argument('--log-batch-size-to-tensorboard', action='store_true',
help='If set, write batch-size to tensorboard.')
group.add_argument('--no-log-learnig-rate-to-tensorboard',
action='store_false',
help='Disable learning rate logging to tensorboard.',
dest='log_learning_rate_to_tensorboard')
group.add_argument('--no-log-loss-scale-to-tensorboard',
action='store_false',
help='Disable loss-scale logging to tensorboard.',
dest='log_loss_scale_to_tensorboard')
group.add_argument('--log-validation-ppl-to-tensorboard',
action='store_true',
help='If set, write validation perplexity to '
'tensorboard.')
group.add_argument('--log-memory-to-tensorboard',
action='store_true',
help='Enable memory logging to tensorboard.')
group.add_argument('--log-world-size-to-tensorboard',
action='store_true',
help='Enable world size logging to tensorboard.')
return parser
def _add_regularization_args(parser):
group = parser.add_argument_group(title='regularization')
group.add_argument('--attention-dropout', type=float, default=0.1,
help='Post attention dropout probability.')
group.add_argument('--flash-attn-drop', type=float, default=0.1,
help='flash_attn dropout probability.')
group.add_argument('--hidden-dropout', type=float, default=0.1,
help='Dropout probability for hidden state transformer.')
group.add_argument('--weight-decay', type=float, default=0.01,
help='Weight decay coefficient for L2 regularization.')
group.add_argument('--start-weight-decay', type=float,
help='Initial weight decay coefficient for L2 regularization.')
group.add_argument('--end-weight-decay', type=float,
help='End of run weight decay coefficient for L2 regularization.')
group.add_argument('--weight-decay-incr-style', type=str, default='constant',
choices=['constant', 'linear', 'cosine'],
help='Weight decay increment function.')
group.add_argument('--clip-grad', type=float, default=1.0,
help='Gradient clipping based on global L2 norm.')
group.add_argument('--adam-beta1', type=float, default=0.9,
help='First coefficient for computing running averages '
'of gradient and its square')
group.add_argument('--adam-beta2', type=float, default=0.999,
help='Second coefficient for computing running averages '
'of gradient and its square')
group.add_argument('--adam-eps', type=float, default=1e-08,
help='Term added to the denominator to improve'
'numerical stability')
group.add_argument('--sgd-momentum', type=float, default=0.9,
help='Momentum factor for sgd')
return parser
def _add_training_args(parser):
group = parser.add_argument_group(title='training')
group.add_argument('--micro-batch-size', type=int, default=None,
help='Batch size per model instance (local batch size). '
'Global batch size is local batch size times data '
'parallel size times number of micro batches.')
group.add_argument('--batch-size', type=int, default=None,
help='Old batch size parameter, do not use. '
'Use --micro-batch-size instead')
group.add_argument('--global-batch-size', type=int, default=None,
help='Training batch size. If set, it should be a '
'multiple of micro-batch-size times data-parallel-size. '
'If this value is None, then '
'use micro-batch-size * data-parallel-size as the '
'global batch size. This choice will result in 1 for '
'number of micro-batches.')
group.add_argument('--rampup-batch-size', nargs='*', default=None,
help='Batch size ramp up with the following values:'
' --rampup-batch-size <start batch size> '
' <batch size incerement> '
' <ramp-up samples> '
'For example:'
' --rampup-batch-size 16 8 300000 \ '
' --global-batch-size 1024'
'will start with global batch size 16 and over '
' (1024 - 16) / 8 = 126 intervals will increase'
'the batch size linearly to 1024. In each interval'
'we will use approximately 300000 / 126 = 2380 samples.')
group.add_argument('--recompute-activations', action='store_true',
help='recompute activation to allow for training '
'with larger models, sequences, and batch sizes.')
group.add_argument('--recompute-granularity', type=str, default=None,
choices=['full', 'selective'],
help='Checkpoint activations to allow for training '
'with larger models, sequences, and batch sizes. '
'It is supported at two granularities 1) full: '
'whole transformer layer is recomputed, '
'2) selective: core attention part of the transformer '
'layer is recomputed.')
group.add_argument('--distribute-saved-activations',
action='store_true',
help='If set, distribute recomputed activations '
'across model parallel group.')
group.add_argument('--pipeline-model-parallel-method', type=str, default='uniform',
choices=['uniform', 'block'],
help='1) uniform: uniformly divide the total number of '
'Transformer layers to each pipeline stage, '
'2) custom split transformers to differnt pipeline stages')
group.add_argument('--pipeline-model-parallel-blocks', type=str, default=None,
help='The number of transformer layers specified by the user for each pipeline stage,'
'Separate each number with a comma')
group.add_argument('--recompute-method', type=str, default=None,
choices=['uniform', 'block'],
help='1) uniform: uniformly divide the total number of '
'Transformer layers and recompute the input activation of '
'each divided chunk at specified granularity, '
'2) recompute the input activations of only a set number of '
'individual Transformer layers per pipeline stage and do the '
'rest without any recomputing at specified granularity'
'default) do not apply activations recompute to any layers')
group.add_argument('--recompute-num-layers', type=int, default=1,
help='1) uniform: the number of Transformer layers in each '
'uniformly divided recompute unit, '
'2) block: the number of individual Transformer layers '
'to recompute within each pipeline stage.')
group.add_argument('--profile', action='store_true',
help='Enable nsys profiling. When using this option, nsys '
'options should be specified in commandline. An example '
'nsys commandline is `nsys profile -s none -t nvtx,cuda '
'-o <path/to/output_file> --force-overwrite true '
'--capture-range=cudaProfilerApi '
'--capture-range-end=stop`.')
group.add_argument('--profile-step-start', type=int, default=10,
help='Gloable step to start profiling.')
group.add_argument('--profile-step-end', type=int, default=12,
help='Gloable step to stop profiling.')
group.add_argument('--profile-ranks', nargs='+', type=int, default=[0],
help='Global ranks to profile.')
# deprecated
group.add_argument('--checkpoint-activations', action='store_true',
help='Checkpoint activation to allow for training '
'with larger models, sequences, and batch sizes.')
group.add_argument('--train-iters', type=int, default=None,
help='Total number of iterations to train over all '
'training runs. Note that either train-iters or '
'train-samples should be provided.')
group.add_argument('--train-samples', type=int, default=None,
help='Total number of samples to train over all '
'training runs. Note that either train-iters or '
'train-samples should be provided.')
group.add_argument('--log-interval', type=int, default=100,
help='Report loss and timing interval.')
group.add_argument('--exit-interval', type=int, default=None,
help='Exit the program after the iteration is divisible '
'by this value.')
group.add_argument('--exit-duration-in-mins', type=int, default=None,
help='Exit the program after this many minutes.')
group.add_argument('--exit-signal-handler', action='store_true',
help='Dynamically save the checkpoint and shutdown the '
'training if SIGTERM is received')
group.add_argument('--tensorboard-dir', type=str, default=None,
help='Write TensorBoard logs to this directory.')
group.add_argument('--no-masked-softmax-fusion',
action='store_false',
help='Disable fusion of query_key_value scaling, '
'masking, and softmax.',
dest='masked_softmax_fusion')
group.add_argument('--no-bias-gelu-fusion', action='store_false',
help='Disable bias and gelu fusion.',
dest='bias_gelu_fusion')
group.add_argument('--no-bias-swiglu-fusion', action='store_false',
help='Disable bias and swiglu fusion, the fusion is '
'available only when using megatron-core.',
dest='bias_swiglu_fusion')
group.add_argument('--no-bias-dropout-fusion', action='store_false',
help='Disable bias and dropout fusion.',
dest='bias_dropout_fusion')
group.add_argument('--use-flash-attn', action='store_true',
help='use FlashAttention implementation of attention. '
'https://arxiv.org/abs/2205.14135')
group.add_argument('--disable-bias-linear', action='store_false',
help='Disable bias in the linear layers',
dest='add_bias_linear')
group.add_argument('--optimizer', type=str, default='adam',
choices=['adam', 'sgd'],
help='Optimizer function')
group.add_argument('--dataloader-type', type=str, default=None,
choices=['single', 'cyclic'],
help='Single pass vs multiple pass data loader')
group.add_argument('--no-async-tensor-model-parallel-allreduce',
action='store_false',
help='Disable asynchronous execution of '
'tensor-model-parallel all-reduce with weight '
'gradient compuation of a column-linear layer.',
dest='async_tensor_model_parallel_allreduce')
group.add_argument('--no-persist-layer-norm', action='store_true',
help='Disable using persistent fused layer norm kernel. '
'This kernel supports only a set of hidden sizes. Please '
'check persist_ln_hidden_sizes if your hidden '
'size is supported.')
group.add_argument('--sequence-parallel', action='store_true',
help='Enable sequence parallel optimization.')
group.add_argument('--no-gradient-accumulation-fusion',
action='store_false',
help='Disable fusing gradient accumulation to weight '
'gradient computation of linear layers',
dest='gradient_accumulation_fusion')
return parser
def _add_initialization_args(parser):
group = parser.add_argument_group(title='initialization')
group.add_argument('--seed', type=int, default=1234,
help='Random seed used for python, numpy, '
'pytorch, and cuda.')
group.add_argument('--data-parallel-random-init', action='store_true',
help='Enable random initialization of params '
'across data parallel ranks')
group.add_argument('--init-method-std', type=float, default=0.02,
help='Standard deviation of the zero mean normal '
'distribution used for weight initialization.')
group.add_argument('--init-method-xavier-uniform', action='store_true',
help='Enable Xavier uniform parameter initialization')
return parser
def _add_learning_rate_args(parser):
group = parser.add_argument_group(title='learning rate')
group.add_argument('--lr', type=float, default=None,
help='Initial learning rate. Depending on decay style '
'and initial warmup, the learing rate at each '
'iteration would be different.')
group.add_argument('--lr-decay-style', type=str, default='linear',
choices=['constant', 'linear', 'cosine', 'inverse-square-root'],
help='Learning rate decay function.')
group.add_argument('--lr-decay-iters', type=int, default=None,
help='number of iterations to decay learning rate over,'
' If None defaults to `--train-iters`')
group.add_argument('--lr-decay-samples', type=int, default=None,
help='number of samples to decay learning rate over,'
' If None defaults to `--train-samples`')
group.add_argument('--lr-warmup-fraction', type=float, default=None,
help='fraction of lr-warmup-(iters/samples) to use '
'for warmup (as a float)')
group.add_argument('--lr-warmup-iters', type=int, default=0,
help='number of iterations to linearly warmup '
'learning rate over.')
group.add_argument('--lr-warmup-samples', type=int, default=0,
help='number of samples to linearly warmup '
'learning rate over.')
group.add_argument('--warmup', type=int, default=None,
help='Old lr warmup argument, do not use. Use one of the'
'--lr-warmup-* arguments above')
group.add_argument('--min-lr', type=float, default=0.0,
help='Minumum value for learning rate. The scheduler'
'clip values below this threshold.')
group.add_argument('--override-opt-param-scheduler', action='store_true',
help='Reset the values of the scheduler (learning rate,'
'warmup iterations, minimum learning rate, maximum '
'number of iterations, and decay style from input '
'arguments and ignore values from checkpoints. Note'
'that all the above values will be reset.')
group.add_argument('--use-checkpoint-opt_param-scheduler', action='store_true',
help='Use checkpoint to set the values of the scheduler '
'(learning rate, warmup iterations, minimum learning '
'rate, maximum number of iterations, and decay style '
'from checkpoint and ignore input arguments.')
group.add_argument('--process-checkpoint', action='store_true', default=None,
help='This parameter sets device=None when processing checkpoint.')
group.add_argument('--lr-2nd-period-scaler', type=float, default=1,
help='scale learning rate in 2nd period')
return parser
def _add_checkpointing_args(parser):
group = parser.add_argument_group(title='checkpointing')
group.add_argument('--save', type=str, default=None,
help='Output directory to save checkpoints to.')
group.add_argument('--save-interval', type=int, default=None,
help='Number of iterations between checkpoint saves.')
group.add_argument('--no-save-optim', action='store_true', default=None,
help='Do not save current optimizer.')
group.add_argument('--no-save-rng', action='store_true', default=None,
help='Do not save current rng state.')
group.add_argument('--load', type=str, default=None,
help='Directory containing a model checkpoint.')
group.add_argument('--no-load-optim', action='store_true', default=None,
help='Do not load optimizer when loading checkpoint.')
group.add_argument('--no-load-args', action='store_true', default=None,
help='Do not load args when loading checkpoint.')
group.add_argument('--no-load-rng', action='store_true', default=None,
help='Do not load rng state when loading checkpoint.')
group.add_argument('--train-reset', action='store_true', default=None,
help='set itertion to 0, dataset consumed samples to 0.')
group.add_argument('--finetune', action='store_true',
help='Load model for finetuning. Do not load optimizer '
'or rng state from checkpoint and set iteration to 0. '
'Assumed when loading a release checkpoint.')
group.add_argument('--no-initialization', action='store_false',
help='Do not perform initialization when building model, '
'can reduce startup time when definitely loading from a '
'checkpoint',
dest='perform_initialization')
group.add_argument('--use-checkpoint-args', action='store_true',
help='Override any command line arguments with arguments '
'from the checkpoint')
group.add_argument('--exit-on-missing-checkpoint', action='store_true',
help="If '--load' is set, but checkpoint is not found "
"(e.g., path typo), then exit instead of random "
"initialization.")
return parser
def _add_mixed_precision_args(parser):
group = parser.add_argument_group(title='mixed precision')
group.add_argument('--fp16', action='store_true',
help='Run model in fp16 mode.')
group.add_argument('--bf16', action='store_true',
help='Run model in bfloat16 mode.')
group.add_argument('--loss-scale', type=float, default=None,
help='Static loss scaling, positive power of 2 '
'values can improve fp16 convergence. If None, dynamic'
'loss scaling is used.')
group.add_argument('--initial-loss-scale', type=float, default=2**32,
help='Initial loss-scale for dynamic loss scaling.')
group.add_argument('--min-loss-scale', type=float, default=1.0,
help='Minimum loss scale for dynamic loss scale.')
group.add_argument('--loss-scale-window', type=float, default=1000,
help='Window over which to raise/lower dynamic scale.')
group.add_argument('--hysteresis', type=int, default=2,
help='hysteresis for dynamic loss scaling')
group.add_argument('--fp32-residual-connection', action='store_true',
help='Move residual connections to fp32.')
group.add_argument('--no-query-key-layer-scaling', action='store_false',
help='Do not scale Q * K^T by 1 / layer-number.',
dest='apply_query_key_layer_scaling')
group.add_argument('--attention-softmax-in-fp32', action='store_true',
help='Run attention masking and softmax in fp32. '
'This flag is ignored unless '
'--no-query-key-layer-scaling is specified.')
group.add_argument('--accumulate-allreduce-grads-in-fp32',
action='store_true',
help='Gradient accumulation and all-reduce in fp32.')
group.add_argument('--fp16-lm-cross-entropy', action='store_true',
help='Move the cross entropy unreduced loss calculation'
'for lm head to fp16.')
return parser
def _add_distributed_args(parser):
group = parser.add_argument_group(title='distributed')
group.add_argument('--tensor-model-parallel-size', type=int, default=1,
help='Degree of tensor model parallelism.')
group.add_argument('--pipeline-model-parallel-size', type=int, default=1,
help='Degree of pipeline model parallelism.')
group.add_argument('--pipeline-model-parallel-split-rank',
type=int, default=None,
help='Rank where encoder and decoder should be split.')
group.add_argument('--model-parallel-size', type=int, default=None,
help='Old model parallel argument, do not use. Use '
'--tensor-model-parallel-size instead.')
group.add_argument('--num-layers-per-virtual-pipeline-stage', type=int, default=None,
help='Number of layers per virtual pipeline stage')
group.add_argument('--overlap-p2p-communication',
action='store_true',
help='overlap pipeline parallel communication with forward and backward chunks',
dest='overlap_p2p_comm')
group.add_argument('--distributed-backend', default='nccl',
choices=['nccl', 'gloo'],
help='Which backend to use for distributed training.')
group.add_argument('--distributed-timeout-minutes', type=int, default=10,
help='Timeout minutes for torch.distributed.')
group.add_argument('--DDP-impl', default='local',
choices=['local', 'torch'],
help='which DistributedDataParallel implementation '
'to use.')
group.add_argument('--no-contiguous-buffers-in-local-ddp',
action='store_false', help='If set, dont use '
'contiguous buffer in local DDP.',
dest='use_contiguous_buffers_in_local_ddp')
group.add_argument('--no-scatter-gather-tensors-in-pipeline', action='store_false',
help='Use scatter/gather to optimize communication of tensors in pipeline',
dest='scatter_gather_tensors_in_pipeline')
group.add_argument('--use-ring-exchange-p2p', action='store_true',
default=False, help='If set, use custom-built ring exchange '
'for p2p communications. Note that this option will require '
'a custom built image that support ring-exchange p2p.')
group.add_argument('--local_rank', type=int, default=None,
help='local rank passed from distributed launcher.')
group.add_argument('--lazy-mpu-init', type=bool, required=False,
help='If set to True, initialize_megatron() '
'skips DDP initialization and returns function to '
'complete it instead.Also turns on '
'--use-cpu-initialization flag. This is for '
'external DDP manager.' )
group.add_argument('--use-cpu-initialization', action='store_true',
default=None, help='If set, affine parallel weights '
'initialization uses CPU' )
group.add_argument('--empty-unused-memory-level', default=0, type=int,
choices=[0, 1, 2],
help='Call torch.cuda.empty_cache() each iteration '
'(training and eval), to reduce fragmentation.'
'0=off, 1=moderate, 2=aggressive.')
group.add_argument('--standalone-embedding-stage', action='store_true',
default=False, help='If set, *input* embedding layer '
'is placed on its own pipeline stage, without any '
'transformer layers. (For T5, this flag currently only '
'affects the encoder embedding.)')
group.add_argument('--use-distributed-optimizer', action='store_true',
help='Use distributed optimizer.')
return parser
def _add_validation_args(parser):
group = parser.add_argument_group(title='validation')
group.add_argument('--eval-iters', type=int, default=100,
help='Number of iterations to run for evaluation'
'validation/test for.')
group.add_argument('--eval-interval', type=int, default=1000,
help='Interval between running evaluation on '
'validation set.')
group.add_argument('--skip-train', action='store_true',
default=False, help='If set, bypass the training loop, '
'optionally do evaluation for validation/test, and exit.')
return parser
def _add_data_args(parser):
group = parser.add_argument_group(title='data and dataloader')
group.add_argument('--data-path', nargs='*', default=None,
help='Path to the training dataset. Accepted format:'
'1) a single data path, 2) multiple datasets in the'
'form: dataset1-weight dataset1-path dataset2-weight '
'dataset2-path ... It is used with --split when a '
'single dataset used for all three: train, valid '
'and test. It is exclusive to the other '
'--*-data-path args')
group.add_argument('--split', type=str, default='969, 30, 1',
help='Comma-separated list of proportions for training,'
' validation, and test split. For example the split '
'`90,5,5` will use 90%% of data for training, 5%% for '
'validation and 5%% for test.')
group.add_argument('--train-data-path', nargs='*', default=None,
help='Path to the training dataset. Accepted format:'
'1) a single data path, 2) multiple datasets in the'
'form: dataset1-weight dataset1-path dataset2-weight '
'dataset2-path ...')
group.add_argument('--valid-data-path', nargs='*', default=None,
help='Path to the validation dataset. Accepted format:'
'1) a single data path, 2) multiple datasets in the'
'form: dataset1-weight dataset1-path dataset2-weight '
'dataset2-path ...')
group.add_argument('--test-data-path', nargs='*', default=None,
help='Path to the test dataset. Accepted format:'
'1) a single data path, 2) multiple datasets in the'
'form: dataset1-weight dataset1-path dataset2-weight '
'dataset2-path ...')
group.add_argument('--data-cache-path', default=None,
help='Path to a directory to hold cached index files.')
group.add_argument('--vocab-size', type=int, default=None,
help='Size of vocab before EOD or padding.')
group.add_argument('--vocab-file', type=str, default=None,
help='Path to the vocab file.')
group.add_argument('--merge-file', type=str, default=None,
help='Path to the BPE merge file.')
group.add_argument('--vocab-extra-ids', type=int, default=0,
help='Number of additional vocabulary tokens. '
'They are used for span masking in the T5 model')
group.add_argument('--seq-length', type=int, default=None,
help='Maximum sequence length to process.')
group.add_argument('--encoder-seq-length', type=int, default=None,
help='Maximum encoder sequence length to process.'
'This should be exclusive of --seq-length')
group.add_argument('--decoder-seq-length', type=int, default=None,
help="Maximum decoder sequence length to process.")
group.add_argument('--retriever-seq-length', type=int, default=256,
help='Maximum sequence length for the biencoder model '
'for retriever')
group.add_argument('--sample-rate', type=float, default=1.0,
help='sample rate for training data. Supposed to be 0 '
' < sample_rate < 1')
group.add_argument('--mask-prob', type=float, default=0.15,
help='Probability of replacing a token with mask.')
group.add_argument('--fim-rate', type=float, default=0.0,
help='Fill in the middle rate.')
group.add_argument('--fim-spm-rate', type=float, default=0.5,
help='Fill in the middle mode: SPM, the rate of SPM')
group.add_argument('--norm-dtype', type=str,
default='RMSNorm',
choices=['LayerNorm',
'RMSNorm'],
help='the type of norm')
group.add_argument('--short-seq-prob', type=float, default=0.1,
help='Probability of producing a short sequence.')
group.add_argument('--mmap-warmup', action='store_true',
help='Warm up mmap files.')
group.add_argument('--num-workers', type=int, default=2,
help="Dataloader number of workers.")
group.add_argument('--tokenizer-type', type=str,
default=None,
choices=['BertWordPieceLowerCase',
'BertWordPieceCase',
'GPT2BPETokenizer',
'SentencePieceTokenizer',
'GPTSentencePieceTokenizer',
'EncDecTokenizer',
'YuanTokenizer',
'NullTokenizer'],
help='What type of tokenizer to use.')
group.add_argument('--tokenizer-model', type=str, default=None,
help='Sentencepiece tokenizer model.')
group.add_argument('--tokenizer-model-path', type=str, default=None,
help='the path of tokenizer model.')
group.add_argument('--data-impl', type=str, default='infer',
choices=['mmap', 'infer'],
help='Implementation of indexed datasets.')
group.add_argument('--reset-position-ids', action='store_true',
help='Reset posistion ids after end-of-document token.')
group.add_argument('--reset-attention-mask', action='store_true',
help='Reset self attention maske after '
'end-of-document token.')
group.add_argument('--eod-mask-loss', action='store_true',
help='Mask loss for the end of document tokens.')
group.add_argument('--sft-stage', action='store_true',
help='Whether it is in the SFT stage.')
return parser
def _add_autoresume_args(parser):
group = parser.add_argument_group(title='autoresume')
group.add_argument('--adlr-autoresume', action='store_true',
help='Enable autoresume on adlr cluster.')
group.add_argument('--adlr-autoresume-interval', type=int, default=1000,
help='Intervals over which check for autoresume'
'termination signal')
return parser
def _add_biencoder_args(parser):
group = parser.add_argument_group(title='biencoder')
# network size
group.add_argument('--ict-head-size', type=int, default=None,
help='Size of block embeddings to be used in ICT and '
'REALM (paper default: 128)')
group.add_argument('--biencoder-projection-dim', type=int, default=0,
help='Size of projection head used in biencoder (paper'
' default: 128)')
group.add_argument('--biencoder-shared-query-context-model', action='store_true',
help='Whether to share the parameters of the query '
'and context models or not')
# checkpointing
group.add_argument('--ict-load', type=str, default=None,
help='Directory containing an ICTBertModel checkpoint')
group.add_argument('--bert-load', type=str, default=None,
help='Directory containing an BertModel checkpoint '
'(needed to start ICT and REALM)')
# data
group.add_argument('--titles-data-path', type=str, default=None,
help='Path to titles dataset used for ICT')
group.add_argument('--query-in-block-prob', type=float, default=0.1,
help='Probability of keeping query in block for '
'ICT dataset')
group.add_argument('--use-one-sent-docs', action='store_true',
help='Whether to use one sentence documents in ICT')
group.add_argument('--evidence-data-path', type=str, default=None,
help='Path to Wikipedia Evidence frm DPR paper')
# training
group.add_argument('--retriever-report-topk-accuracies', nargs='+', type=int,
default=[], help="Which top-k accuracies to report "
"(e.g. '1 5 20')")
group.add_argument('--retriever-score-scaling', action='store_true',
help='Whether to scale retriever scores by inverse '
'square root of hidden size')
# faiss index
group.add_argument('--block-data-path', type=str, default=None,
help='Where to save/load BlockData to/from')
group.add_argument('--embedding-path', type=str, default=None,
help='Where to save/load Open-Retrieval Embedding'
' data to/from')
# indexer
group.add_argument('--indexer-batch-size', type=int, default=128,
help='How large of batches to use when doing indexing '
'jobs')
group.add_argument('--indexer-log-interval', type=int, default=1000,
help='After how many batches should the indexer '
'report progress')
return parser
def _add_vision_args(parser):
group = parser.add_argument_group(title="vision")
# general vision arguements
group.add_argument('--num-classes', type=int, default=1000,
help='num of classes in vision classificaiton task')
group.add_argument('--img-h', type=int, default=224,
help='Image height for vision classification task')
group.add_argument('--img-w', type=int, default=224,
help='Image height for vision classification task')
group.add_argument('--num-channels', type=int, default=3,
help='Number of channels in input image data')
group.add_argument('--patch-dim', type=int, default=16,
help='patch dimension')
group.add_argument('--classes-fraction', type=float, default=1.0,
help='training with fraction of classes.')
group.add_argument('--data-per-class-fraction', type=float, default=1.0,
help='training with fraction of data per class.')
group.add_argument('--no-data-sharding', action='store_false',
help='Disable data sharding.',
dest='data_sharding')
group.add_argument('--head-lr-mult', type=float, default=1.0,
help='learning rate multiplier for head during finetuning')
# pretraining type and backbone selection`
group.add_argument('--vision-pretraining', action='store_true',
help='flag to indicate vision pretraining')
group.add_argument('--vision-pretraining-type', type=str, default='classify',
choices=['classify', 'inpaint', 'dino'],
help='pretraining objectives')
group.add_argument('--vision-backbone-type', type=str, default='vit',
choices=['vit', 'mit', 'swin'],
help='backbone types types')
group.add_argument('--swin-backbone-type', type=str, default='tiny',
choices=['tiny', 'base', 'h3'],
help='pretraining objectives')
# inpainting arguments
group.add_argument('--mask-type', type=str, default='random',
choices=['random', 'row'],
help='mask types')
group.add_argument('--mask-factor', type=float, default=1.0,
help='mask size scaling parameter')
# dino arguments
group.add_argument('--iter-per-epoch', type=int, default=1250,
help='iterations per epoch')
group.add_argument('--dino-local-img-size', type=int, default=96,
help='Image size for vision classification task')
group.add_argument('--dino-local-crops-number', type=int, default=10,
help='Number of local crops')
group.add_argument('--dino-head-hidden-size', type=int, default=2048,
help='Hidden dimension size in dino head')
group.add_argument('--dino-bottleneck-size', type=int, default=256,
help='Bottle neck dimension in dino head ')
group.add_argument('--dino-freeze-last-layer', type=float, default=1,
help='Freezing last layer weights')
group.add_argument('--dino-norm-last-layer', action='store_true',
help='Disable Norm in last layer.')
group.add_argument('--dino-warmup-teacher-temp', type=float, default=0.04,
help='warump teacher temperature')
group.add_argument('--dino-teacher-temp', type=float, default=0.07,
help='teacher temperature')
group.add_argument('--dino-warmup-teacher-temp-epochs', type=int, default=30,
help='warmup teacher temperaure epochs')
return parser
def _add_moe_args(parser):
group = parser.add_argument_group(title="moe")
group.add_argument('--expert-model-parallel-size', type=int, default=1,
help='Degree of expert model parallelism.The current version does not support expert parallel methods')
group.add_argument('--num-attention-router-heads', type=int, default=4096,
help='The number of headers in the attention router method.It is currently consistent with seq length')
group.add_argument('--num-experts', type=int, default=None,
help='Number of Experts in MoE (None means no MoE)')
group.add_argument('--num-shared-experts', type=int, default=None,
help='The number of shared experts in the moe method')
group.add_argument('--isolate-shared-experts', action='store_true',
help='Use isolate Shared Experts in moe.Using this option divides shared experts into more fine-grained experts.ffn_hidden_size = ffn_hidden_size / num-shared-experts')
group.add_argument('--use-fp32-router', action='store_true',
help='Use fp32 in attention_router way')
group.add_argument('--use-attention-router', action='store_true',
help='use attention-router way')
group.add_argument('--moe-router-load-balancing-type', type=str,
choices=['aux_loss', 'sinkhorn', "none"],
default='aux_loss',
help='Determines the load balancing strategy for the router. "aux_loss" corresponds to the load balancing loss used in GShard and SwitchTransformer, "sinkhorn" corresponds to the balancing algorithm used in S-BASE, and "none" implies no load balancing. The default is "aux_loss".')
group.add_argument('--moe-router-topk', type=int, default=2,
help='Number of experts to route to for each token. The default is 2.')
group.add_argument('--moe-grouped-gemm', action='store_true',
help='When there are multiple experts per rank, compress multiple local (potentially small) gemms in a single kernel launch to improve the utilization and performance by leveraging the Grouped GEMM feature introduced since CUTLASS 2.8 (https://github.com/fanshiqing/grouped_gemm).')
group.add_argument('--moe-aux-loss-coeff', type=float, default=0.0,
help='Scaling coefficient for the aux loss: a starting value of 1e-2 is recommended.')
group.add_argument('--moe-z-loss-coeff', type=float, default=None,
help='Scaling coefficient for the z-loss: a starting value of 1e-3 is recommended.')
group.add_argument('--moe-input-jitter-eps', type=float, default=None,
help='Add noise to the input tensor by applying jitter with a specified epsilon value.')
group.add_argument('--moe-token-dropping', action='store_true',
help='This feature involves selectively dropping and padding tokens for each expert to achieve a specified capacity, similar to GShard, Switch-Transformer, and DeepSpeed-MoE. Note: Currently unsupported.')
return parser
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
"""Input/output checkpointing."""
import os
import random
import sys
import numpy as np
import torch
from megatron import update_num_microbatches
from megatron.core import mpu, tensor_parallel
from .global_vars import get_args
from .utils import (unwrap_model,
print_rank_0)
_CHECKPOINT_VERSION = None
def set_checkpoint_version(value):
global _CHECKPOINT_VERSION
if _CHECKPOINT_VERSION is not None:
assert _CHECKPOINT_VERSION == value, \
"checkpoint versions do not match"
_CHECKPOINT_VERSION = value
def get_checkpoint_version():
global _CHECKPOINT_VERSION
return _CHECKPOINT_VERSION
def check_checkpoint_args(checkpoint_args):
"""Ensure fixed arguments for a model are the same for the input
arguments and the one retrieved from checkpoint."""
args = get_args()
def _compare(arg_name, old_arg_name=None, default=None):
if old_arg_name is not None:
ckpt_arg_name = old_arg_name
else:
ckpt_arg_name = arg_name
if default is not None:
checkpoint_value = getattr(checkpoint_args, ckpt_arg_name, default)
else:
checkpoint_value = getattr(checkpoint_args, ckpt_arg_name)
args_value = getattr(args, arg_name)
error_message = '{} value from checkpoint ({}) is not equal to the ' \
'input argument value ({}).'.format(
arg_name, checkpoint_value, args_value)
assert checkpoint_value == args_value, error_message
_compare('num_layers')
_compare('hidden_size')
_compare('num_attention_heads')
_compare('add_position_embedding', default=True)
if args.vocab_file:
_compare('max_position_embeddings')
_compare('make_vocab_size_divisible_by')
_compare('padded_vocab_size')
_compare('tokenizer_type')
if args.data_parallel_random_init:
_compare('data_parallel_random_init')
if get_checkpoint_version() < 3.0:
_compare('tensor_model_parallel_size',
old_arg_name='model_parallel_size')
if get_checkpoint_version() >= 3.0:
_compare('tensor_model_parallel_size')
_compare('pipeline_model_parallel_size')
def ensure_directory_exists(filename):
"""Build filename's path if it does not already exists."""
dirname = os.path.dirname(filename)
os.makedirs(dirname, exist_ok = True)
def get_checkpoint_name(checkpoints_path, iteration, release=False,
pipeline_parallel=None,
tensor_rank=None, pipeline_rank=None):
"""Determine the directory name for this rank's checkpoint."""
if release:
directory = 'release'
else:
directory = 'iter_{:07d}'.format(iteration)
# Use both the tensor and pipeline MP rank.
if pipeline_parallel is None:
pipeline_parallel = (mpu.get_pipeline_model_parallel_world_size() > 1)
if tensor_rank is None:
tensor_rank = mpu.get_tensor_model_parallel_rank()
if pipeline_rank is None:
pipeline_rank = mpu.get_pipeline_model_parallel_rank()
# Use both the tensor and pipeline MP rank. If using the distributed
# optimizer, then the optimizer's path must additionally include the
# data parallel rank.
if not pipeline_parallel:
common_path = os.path.join(checkpoints_path, directory,
f'mp_rank_{tensor_rank:02d}')
else:
common_path = os.path.join(checkpoints_path, directory,
f'mp_rank_{tensor_rank:02d}_{pipeline_rank:03d}')
return os.path.join(common_path, "model_optim_rng.pt")
def get_distributed_optimizer_checkpoint_name(model_checkpoint_name):
return os.path.join(os.path.dirname(model_checkpoint_name),
"distrib_optim.pt")
def find_checkpoint_rank_0(checkpoints_path, iteration, release=False):
"""Finds the checkpoint for rank 0 without knowing if we are using
pipeline parallelism or not.
Since the checkpoint naming scheme changes if pipeline parallelism
is present, we need to look for both naming schemes if we don't
know if the checkpoint has pipeline parallelism.
"""
# Look for checkpoint with no pipelining
filename = get_checkpoint_name(checkpoints_path, iteration, release,
pipeline_parallel=False,
tensor_rank=0, pipeline_rank=0)
if os.path.isfile(filename):
return filename
# Look for checkpoint with pipelining
filename = get_checkpoint_name(checkpoints_path, iteration, release,
pipeline_parallel=True,
tensor_rank=0, pipeline_rank=0)
if os.path.isfile(filename):
return filename
return None, None
def get_checkpoint_tracker_filename(checkpoints_path):
"""Tracker file rescords the latest chckpoint during
training to restart from."""
return os.path.join(checkpoints_path, 'latest_checkpointed_iteration.txt')
def read_metadata(tracker_filename):
# Read the tracker file and either set the iteration or
# mark it as a release checkpoint.
iteration = 0
release = False
with open(tracker_filename, 'r') as f:
metastring = f.read().strip()
try:
iteration = int(metastring)
except ValueError:
release = metastring == 'release'
if not release:
print_rank_0('ERROR: Invalid metadata file {}. Exiting'.format(
tracker_filename))
sys.exit()
assert iteration > 0 or release, 'error parsing metadata file {}'.format(
tracker_filename)
# Get the max iteration retrieved across the ranks.
if torch.distributed.is_initialized():
iters_cuda = torch.cuda.LongTensor([iteration])
torch.distributed.all_reduce(iters_cuda, op=torch.distributed.ReduceOp.MAX)
max_iter = iters_cuda[0].item()
# We should now have all the same iteration.
# If not, print a warning and chose the maximum
# iteration across all ranks.
if iteration != max_iter:
rank = torch.distributed.get_rank()
print('WARNING: on rank {} found iteration {} in the '
'metadata while max iteration across the ranks '
'is {}, replacing it with max iteration.'.format(
rank, iteration, max_iter), flush=True)
else:
# When loading a checkpoint outside of training (for example,
# when editing it), we might not have torch distributed
# initialized, in this case, just assume we have the latest
max_iter = iteration
return max_iter, release
def get_rng_state():
""" collect rng state across data parallel ranks """
args = get_args()
rng_state = {
'random_rng_state': random.getstate(),
'np_rng_state': np.random.get_state(),
'torch_rng_state': torch.get_rng_state(),
'cuda_rng_state': torch.cuda.get_rng_state(),
'rng_tracker_states': tensor_parallel.get_cuda_rng_tracker().get_states()}
rng_state_list = None
if torch.distributed.is_initialized() and \
mpu.get_data_parallel_world_size() > 1 and \
args.data_parallel_random_init:
rng_state_list = \
[None for i in range(mpu.get_data_parallel_world_size())]
torch.distributed.all_gather_object(
rng_state_list,
rng_state,
group=mpu.get_data_parallel_group())
else:
rng_state_list = [rng_state]
return rng_state_list
def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
"""Save a model checkpoint."""
args = get_args()
# Only rank zero of the data parallel writes to the disk.
model = unwrap_model(model)
print_rank_0('saving checkpoint at iteration {:7d} to {}'.format(
iteration, args.save))
# Collect rng state across data parallel ranks.
rng_state = get_rng_state()
# Checkpoint name.
checkpoint_name = get_checkpoint_name(args.save, iteration)
# Save distributed optimizer's custom parameter state.
if args.use_distributed_optimizer:
optim_checkpoint_name = \
get_distributed_optimizer_checkpoint_name(checkpoint_name)
ensure_directory_exists(optim_checkpoint_name)
optimizer.save_parameter_state(optim_checkpoint_name)
# Collect args, model, RNG.
if not torch.distributed.is_initialized() \
or mpu.get_data_parallel_rank() == 0:
# Arguments, iteration, and model.
state_dict = {}
state_dict['args'] = args
state_dict['checkpoint_version'] = 3.0
state_dict['iteration'] = iteration
if len(model) == 1:
state_dict['model'] = model[0].state_dict_for_save_checkpoint()
else:
for i in range(len(model)):
mpu.set_virtual_pipeline_model_parallel_rank(i)
state_dict['model%d' % i] = \
model[i].state_dict_for_save_checkpoint()
# Optimizer stuff.
if not args.no_save_optim:
if optimizer is not None:
state_dict['optimizer'] = optimizer.state_dict()
if opt_param_scheduler is not None:
state_dict['opt_param_scheduler'] = \
opt_param_scheduler.state_dict()
# RNG states.
if not args.no_save_rng:
state_dict["rng_state"] = rng_state
# Save.
ensure_directory_exists(checkpoint_name)
torch.save(state_dict, checkpoint_name)
# Wait so everyone is done (necessary)
if torch.distributed.is_initialized():
torch.distributed.barrier()
print_rank_0(' successfully saved checkpoint at iteration {:7d} to {}' \
.format(iteration, args.save))
# And update the latest iteration
if not torch.distributed.is_initialized() \
or torch.distributed.get_rank() == 0:
tracker_filename = get_checkpoint_tracker_filename(args.save)
with open(tracker_filename, 'w') as f:
f.write(str(iteration))
# Wait so everyone is done (not necessary)
if torch.distributed.is_initialized():
torch.distributed.barrier()
def _transpose_first_dim(t, num_splits, num_splits_first, model):
input_shape = t.size()
# We use a self_attention module but the values extracted aren't
# specific to self attention so should work for cross attention as well
while hasattr(model, 'module'):
model = model.module
attention_module = model.language_model.encoder.layers[0].self_attention
hidden_size_per_attention_head = attention_module.hidden_size_per_attention_head
num_attention_heads_per_partition = attention_module.num_attention_heads_per_partition
if num_splits_first:
"""[num_splits * np * hn, h]
-->(view) [num_splits, np, hn, h]
-->(tranpose) [np, num_splits, hn, h]
-->(view) [np * num_splits * hn, h] """
intermediate_shape = \
(num_splits, num_attention_heads_per_partition,
hidden_size_per_attention_head) + input_shape[1:]
t = t.view(*intermediate_shape)
t = t.transpose(0, 1).contiguous()
else:
"""[np * hn * num_splits, h]
-->(view) [np, hn, num_splits, h]
-->(tranpose) [np, num_splits, hn, h]
-->(view) [np * num_splits * hn, h] """
intermediate_shape = \
(num_attention_heads_per_partition,
hidden_size_per_attention_head, num_splits) +\
input_shape[1:]
t = t.view(*intermediate_shape)
t = t.transpose(1, 2).contiguous()
t = t.view(*input_shape)
return t
def fix_query_key_value_ordering(model, checkpoint_version):
"""Fix up query/key/value matrix ordering if checkpoint
version is smaller than 2.0
"""
if checkpoint_version < 2.0:
if isinstance(model, list):
assert len(model)==1
model = model[0]
for name, param in model.named_parameters():
if name.endswith(('.query_key_value.weight', '.query_key_value.bias')):
if checkpoint_version == 0:
fixed_param = _transpose_first_dim(param.data, 3, True, model)
elif checkpoint_version == 1.0:
fixed_param = _transpose_first_dim(param.data, 3, False, model)
else:
print_rank_0(f"Invalid checkpoint version {checkpoint_version}.")
sys.exit()
param.data.copy_(fixed_param)
if name.endswith(('.key_value.weight', '.key_value.bias')):
if checkpoint_version == 0:
fixed_param = _transpose_first_dim(param.data, 2, True, model)
elif checkpoint_version == 1.0:
fixed_param = _transpose_first_dim(param.data, 2, False, model)
else:
print_rank_0(f"Invalid checkpoint version {checkpoint_version}.")
sys.exit()
param.data.copy_(fixed_param)
print_rank_0(" succesfully fixed query-key-values ordering for"
" checkpoint version {}".format(checkpoint_version))
def _load_base_checkpoint(load_dir, rank0=False):
""" Load the base state_dict from the given directory
If rank0 is true, just loads rank 0 checkpoint, ignoring arguments.
"""
# Read the tracker file and set the iteration.
tracker_filename = get_checkpoint_tracker_filename(load_dir)
# If no tracker file, return nothing
if not os.path.isfile(tracker_filename):
if not rank0:
print_rank_0('WARNING: could not find the metadata file {} '.format(
tracker_filename))
print_rank_0(' will not load any checkpoints and will start from '
'random')
return None, "", False
# Otherwise, read the tracker file and either set the iteration or
# mark it as a release checkpoint.
iteration, release = read_metadata(tracker_filename)
# Checkpoint.
if rank0:
checkpoint_name = find_checkpoint_rank_0(load_dir, iteration, release)
else:
checkpoint_name = get_checkpoint_name(load_dir, iteration, release)
if release:
print_rank_0(f' loading release checkpoint from {load_dir}')
else:
print_rank_0(f' loading checkpoint from {load_dir} at iteration {iteration}')
# Load the checkpoint.
try:
state_dict = torch.load(checkpoint_name, map_location='cpu')
except ModuleNotFoundError:
from megatron.fp16_deprecated import loss_scaler
# For backward compatibility.
if not rank0:
print_rank_0(' > deserializing using the old code structure ...')
sys.modules['fp16.loss_scaler'] = sys.modules[
'megatron.fp16_deprecated.loss_scaler']
sys.modules['megatron.fp16.loss_scaler'] = sys.modules[
'megatron.fp16_deprecated.loss_scaler']
state_dict = torch.load(checkpoint_name, map_location='cpu')
sys.modules.pop('fp16.loss_scaler', None)
sys.modules.pop('megatron.fp16.loss_scaler', None)
except BaseException as e:
print_rank_0('could not load the checkpoint')
print_rank_0(e)
sys.exit()
return state_dict, checkpoint_name, release
def load_args_from_checkpoint(args, load_arg='load'):
"""Set required arguments from the checkpoint specified in the
arguments.
Will overwrite arguments that have a non-None default value, but
will leave any arguments that default to None as set.
Returns the same args NameSpace with the new values added/updated.
If no checkpoint is specified in args, or if the checkpoint is
there but invalid, the arguments will not be modified
"""
load_dir = getattr(args, load_arg)
if load_dir is None:
print_rank_0('No load directory specified, using provided arguments.')
return args
state_dict, checkpoint_name, release = _load_base_checkpoint(load_dir, rank0=True)
# Args.
if not state_dict:
print_rank_0('Checkpoint not found to provide arguments, using provided arguments.')
return args
if 'args' not in state_dict:
print_rank_0('Checkpoint provided does not have arguments saved, using provided arguments.')
return args
checkpoint_args = state_dict['args']
checkpoint_version = state_dict.get('checkpoint_version', 0)
args.iteration = state_dict['iteration']
# One-off conversion for foundation models
if hasattr(checkpoint_args, 'disable_bias_linear'):
setattr(checkpoint_args, 'add_bias_linear', not getattr(checkpoint_args, 'disable_bias_linear'))
def _set_arg(arg_name, old_arg_name=None, force=False):
if not force and getattr(args, arg_name, None) is not None:
return
if old_arg_name is not None:
checkpoint_value = getattr(checkpoint_args, old_arg_name, None)
else:
checkpoint_value = getattr(checkpoint_args, arg_name, None)
if checkpoint_value is not None:
print_rank_0(f"Setting {arg_name} to {checkpoint_value} from checkpoint")
setattr(args, arg_name, checkpoint_value)
else:
print_rank_0(f"Checkpoint did not provide arguments {arg_name}")
_set_arg('num_layers')
_set_arg('hidden_size')
_set_arg('ffn_hidden_size')
_set_arg('seq_length')
_set_arg('num_attention_heads')
_set_arg('kv_channels')
_set_arg('max_position_embeddings')
_set_arg('position_embedding_type', force=True)
_set_arg('add_position_embedding', force=True)
_set_arg('use_rotary_position_embeddings', force=True)
_set_arg('rotary_percent', force=True)
_set_arg('add_bias_linear', force=True)
_set_arg('swiglu', force=True)
_set_arg('untie_embeddings_and_output_weights', force=True)
_set_arg('apply_layernorm_1p', force=True)
_set_arg('tokenizer_type')
_set_arg('padded_vocab_size')
if checkpoint_version < 3.0:
_set_arg('tensor_model_parallel_size',
'model_parallel_size')
else:
_set_arg('tensor_model_parallel_size', force=True)
_set_arg('pipeline_model_parallel_size', force=True)
_set_arg('virtual_pipeline_model_parallel_size', force=True)
_set_arg('num_layers_per_virtual_pipeline_stage')
return args, checkpoint_args
def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', strict=True):
"""Load a model checkpoint and return the iteration.
strict (bool): whether to strictly enforce that the keys in
:attr:`state_dict` of the checkpoint match the names of
parameters and buffers in model.
"""
args = get_args()
load_dir = getattr(args, load_arg)
model = unwrap_model(model)
state_dict, checkpoint_name, release = _load_base_checkpoint(load_dir, rank0=False)
# Checkpoint not loaded.
if state_dict is None:
# Conditionally exit at this point.
if args.exit_on_missing_checkpoint:
print_rank_0(">> '--exit-on-missing-checkpoint' set ... exiting. <<")
torch.distributed.barrier()
sys.exit()
# Iteration defaults to 0.
return 0
# Set checkpoint version.
set_checkpoint_version(state_dict.get('checkpoint_version', 0))
# Set iteration.
if args.finetune or release:
iteration = 0
else:
try:
iteration = state_dict['iteration']
except KeyError:
try: # Backward compatible with older checkpoints
iteration = state_dict['total_iters']
except KeyError:
print_rank_0('A metadata file exists but unable to load '
'iteration from checkpoint {}, exiting'.format(
checkpoint_name))
sys.exit()
# Check arguments.
assert args.consumed_train_samples == 0
assert args.consumed_valid_samples == 0
if 'args' in state_dict and not args.finetune and not args.no_load_args:
checkpoint_args = state_dict['args']
check_checkpoint_args(checkpoint_args)
args.consumed_train_samples = getattr(checkpoint_args,
'consumed_train_samples', 0)
update_num_microbatches(consumed_samples=args.consumed_train_samples)
args.consumed_valid_samples = getattr(checkpoint_args,
'consumed_valid_samples', 0)
else:
print_rank_0('could not find arguments in the checkpoint ...')
# Model.
if len(model) == 1:
model[0].load_state_dict(state_dict['model'], strict=strict)
else:
for i in range(len(model)):
mpu.set_virtual_pipeline_model_parallel_rank(i)
model[i].load_state_dict(state_dict['model%d' % i], strict=strict)
# Fix up query/key/value matrix ordering if needed.
checkpoint_version = get_checkpoint_version()
print_rank_0(f' checkpoint version {checkpoint_version}')
fix_query_key_value_ordering(model, checkpoint_version)
# Optimizer.
if not release and not args.finetune and not args.no_load_optim:
try:
# Load state dict.
if optimizer is not None:
optimizer.load_state_dict(state_dict['optimizer'])
# Load distributed optimizer's custom parameter state.
if args.use_distributed_optimizer:
tracker_filename = get_checkpoint_tracker_filename(load_dir)
iteration, release = read_metadata(tracker_filename)
model_checkpoint_name = \
get_checkpoint_name(load_dir, iteration, release)
optim_checkpoint_name = \
get_distributed_optimizer_checkpoint_name(
model_checkpoint_name)
optimizer.load_parameter_state(optim_checkpoint_name)
# Load scheduler.
if opt_param_scheduler is not None:
if 'lr_scheduler' in state_dict: # backward compatbility
opt_param_scheduler.load_state_dict(state_dict['lr_scheduler'])
else:
opt_param_scheduler.load_state_dict(state_dict['opt_param_scheduler'])
except KeyError:
print_rank_0('Unable to load optimizer from checkpoint {}. '
'Specify --no-load-optim or --finetune to prevent '
'attempting to load the optimizer state, '
'exiting ...'.format(checkpoint_name))
sys.exit()
else:
if (args.fp16 or args.bf16) and optimizer is not None:
optimizer.reload_model_params()
# rng states.
if not release and not args.finetune and not args.no_load_rng:
try:
if 'rng_state' in state_dict:
# access rng_state for data parallel rank
if args.data_parallel_random_init:
rng_state = state_dict['rng_state'][mpu.get_data_parallel_rank()]
else:
rng_state = state_dict['rng_state'][0]
random.setstate(rng_state['random_rng_state'])
np.random.set_state(rng_state['np_rng_state'])
torch.set_rng_state(rng_state['torch_rng_state'])
torch.cuda.set_rng_state(rng_state['cuda_rng_state'])
# Check for empty states array
if not rng_state['rng_tracker_states']:
raise KeyError
tensor_parallel.get_cuda_rng_tracker().set_states(
rng_state['rng_tracker_states'])
else: # backward compatability
random.setstate(state_dict['random_rng_state'])
np.random.set_state(state_dict['np_rng_state'])
torch.set_rng_state(state_dict['torch_rng_state'])
torch.cuda.set_rng_state(state_dict['cuda_rng_state'])
# Check for empty states array
if not state_dict['rng_tracker_states']:
raise KeyError
tensor_parallel.get_cuda_rng_tracker().set_states(
state_dict['rng_tracker_states'])
except KeyError:
print_rank_0('Unable to load rng state from checkpoint {}. '
'Specify --no-load-rng or --finetune to prevent '
'attempting to load the rng state, '
'exiting ...'.format(checkpoint_name))
sys.exit()
if args.train_reset:
iteration = 0
args.consumed_train_samples = 0
args.consumed_valid_samples = 0
# Some utilities want to load a checkpoint without distributed being initialized
if torch.distributed.is_initialized():
torch.distributed.barrier()
print_rank_0(f' successfully loaded checkpoint from {args.load} '
f'at iteration {iteration}')
return iteration
def load_biencoder_checkpoint(model, only_query_model=False,
only_context_model=False, custom_load_path=None):
"""
selectively load retrieval models for indexing/retrieving
from saved checkpoints
"""
args = get_args()
model = unwrap_model(model)
load_path = custom_load_path if custom_load_path is not None else args.load
tracker_filename = get_checkpoint_tracker_filename(load_path)
with open(tracker_filename, 'r') as f:
iteration = int(f.read().strip())
checkpoint_name = get_checkpoint_name(load_path, iteration,
args.use_distributed_optimizer,
release=False)
if mpu.get_data_parallel_rank() == 0:
print('global rank {} is loading checkpoint {}'.format(
torch.distributed.get_rank(), checkpoint_name))
state_dict = torch.load(checkpoint_name, map_location='cpu')
ret_state_dict = state_dict['model']
if only_query_model:
ret_state_dict.pop('context_model')
if only_context_model:
ret_state_dict.pop('query_model')
assert len(model) == 1
model[0].load_state_dict(ret_state_dict)
torch.distributed.barrier()
if mpu.get_data_parallel_rank() == 0:
print(' successfully loaded {}'.format(checkpoint_name))
return model
Megatron Core is a library for efficient and scalable training of transformer based models.
import megatron.core.parallel_state
import megatron.core.tensor_parallel
import megatron.core.utils
from .model_parallel_config import ModelParallelConfig
# Alias parallel_state as mpu, its legacy name
mpu = parallel_state
__all__ = [
"parallel_state",
"tensor_parallel",
"utils",
"ModelParallelConfig"
]
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import enum
class ModelType(enum.Enum):
encoder_or_decoder = 1
encoder_and_decoder = 2
retro_encoder = 3
retro_decoder = 4
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import torch
from typing import Tuple, Optional
def _bias_dropout_add_func(x, bias, residual, prob, training):
# type: (Tensor, Optional[Tensor], Tensor, float, bool) -> Tensor
# NOTE: Previously, the argument `bias` used to be passed as
# `bias.expand_as(residual)` when the `bias_dropout_func` is called from the
# transformer layer but broadcasting should automatically take care of that.
# Also, looking at broadcasting semantics, `expand_as` and broadcasting
# seem to be identical performance-wise (both just change the view).
if bias is not None:
x = x + bias
out = torch.nn.functional.dropout(x, p=prob, training=training)
out = residual + out
return out
def get_bias_dropout_add(training, fused):
def unfused_bias_dropout_add(x_with_bias, residual, prob):
x, bias = x_with_bias # unpack
return _bias_dropout_add_func(x, bias, residual, prob, training)
@torch.jit.script
def bias_dropout_add_fused_train(
x_with_bias: Tuple[torch.Tensor, Optional[torch.Tensor]],
residual: torch.Tensor,
prob: float
) -> torch.Tensor:
x, bias = x_with_bias # unpack
return _bias_dropout_add_func(x, bias, residual, prob, True)
@torch.jit.script
def bias_dropout_add_fused_inference(
x_with_bias: Tuple[torch.Tensor, Optional[torch.Tensor]],
residual: torch.Tensor,
prob: float
) -> torch.Tensor:
x, bias = x_with_bias # unpack
return _bias_dropout_add_func(x, bias, residual, prob, False)
if fused:
# jit scripting for a nn.module (with dropout) is not
# triggering the fusion kernel. For now, we use two
# different nn.functional routines to account for varying
# dropout semantics during training and inference phases.
if training:
return bias_dropout_add_fused_train
else:
return bias_dropout_add_fused_inference
else:
return unfused_bias_dropout_add
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import torch
###### BIAS GELU FUSION/ NO AUTOGRAD ################
# 1/sqrt(2*pi)-> 0.3989423
# 1/sqrt(2) -> 0.70710678
# sqrt(2/pi) -> 0.79788456
# this function is tanh approximation of gelu
# actual gelu is:
# x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
@torch.jit.script
def bias_gelu(bias, y):
x = bias + y
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
# gradient of tanh approximation of gelu
# gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
@torch.jit.script
def bias_gelu_back(g, bias, y):
x = bias + y
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
return ff*g
class GeLUFunction(torch.autograd.Function):
@staticmethod
# bias is an optional argument
def forward(ctx, input, bias):
ctx.save_for_backward(input, bias)
return bias_gelu(bias, input)
@staticmethod
def backward(ctx, grad_output):
input, bias = ctx.saved_tensors
tmp = bias_gelu_back(grad_output, bias, input)
return tmp, tmp
bias_gelu_impl = GeLUFunction.apply
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import numbers
import torch
from torch.nn.parameter import Parameter
from torch.nn import init
import importlib
from megatron.core.utils import make_viewless_tensor
try:
from apex.contrib.layer_norm.layer_norm import FastLayerNormFN
HAVE_PERSIST_LAYER_NORM = True
except:
HAVE_PERSIST_LAYER_NORM = False
try:
from apex.normalization.fused_layer_norm import FusedLayerNormAffineFunction
HAVE_FUSED_LAYER_NORM = True
except:
HAVE_FUSED_LAYER_NORM = False
class FusedLayerNorm(torch.nn.Module):
def __init__(self, hidden_size, eps=1e-5,
persist_layer_norm=True,
sequence_parallel=False,
zero_centered_gamma=False):
super().__init__()
self.zero_centered_gamma = zero_centered_gamma
# List of hiddens sizes supported in the persistent layer norm kernel
# If the hidden size is not supported, fall back to the non-persistent
# kernel.
persist_ln_hidden_sizes = [1024, 1536, 2048, 2304, 3072, 3840, 4096,
5120, 6144, 8192, 10240, 12288, 12800, 15360, 16384, 18432, 20480,
24576, 25600, 30720, 32768, 40960, 49152, 65536]
if hidden_size not in persist_ln_hidden_sizes or not HAVE_PERSIST_LAYER_NORM:
persist_layer_norm = False
if not persist_layer_norm and not HAVE_FUSED_LAYER_NORM:
# TODO: Add pytorch only layer norm
raise ValueError(f'Apex must currently be installed to use megatron core.')
if isinstance(hidden_size, numbers.Integral):
hidden_size = (hidden_size,)
self.hidden_size = torch.Size(hidden_size)
self.eps = eps
self.weight = Parameter(torch.Tensor(*hidden_size))
self.bias = Parameter(torch.Tensor(*hidden_size))
self.reset_parameters()
self.persist_layer_norm = persist_layer_norm
self.sequence_parallel = sequence_parallel
# set sequence parallelism flag on weight and bias parameters
setattr(self.weight, 'sequence_parallel', self.sequence_parallel)
setattr(self.bias, 'sequence_parallel', self.sequence_parallel)
def reset_parameters(self):
if self.zero_centered_gamma:
init.zeros_(self.weight)
init.zeros_(self.bias)
else:
init.ones_(self.weight)
init.zeros_(self.bias)
def forward(self, input):
weight = self.weight + 1 if self.zero_centered_gamma else self.weight
if self.persist_layer_norm:
output = FastLayerNormFN.apply(input, weight, self.bias, self.eps)
# Apex's fast layer norm function outputs a 'view' tensor (i.e., has
# a populated '_base' field). This will result in schedule.py's
# deallocate_output_tensor() throwing an error, so a viewless tensor is
# created to prevent this.
output = make_viewless_tensor(inp = output,
requires_grad = input.requires_grad,
keep_graph = True)
else:
output = FusedLayerNormAffineFunction.apply(input, weight, self.bias, self.hidden_size, self.eps)
return output
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import torch
import torch.nn as nn
from megatron.core.transformer.enums import AttnMaskType
class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
"""
Fused operation which performs following three operations in sequence
1. Scale the tensor.
2. Apply upper triangular mask (typically used in gpt models).
3. Perform softmax.
"""
@staticmethod
def forward(ctx, inputs, scale):
import scaled_upper_triang_masked_softmax_cuda
scale_t = torch.tensor([scale])
softmax_results = scaled_upper_triang_masked_softmax_cuda.forward(
inputs, scale_t[0]
)
ctx.save_for_backward(softmax_results, scale_t)
return softmax_results
@staticmethod
def backward(ctx, output_grads):
import scaled_upper_triang_masked_softmax_cuda
softmax_results, scale_t = ctx.saved_tensors
input_grads = scaled_upper_triang_masked_softmax_cuda.backward(
output_grads, softmax_results, scale_t[0]
)
return input_grads, None
class ScaledMaskedSoftmax(torch.autograd.Function):
"""
Fused operation which performs following three operations in sequence
1. Scale the tensor.
2. Apply the mask.
3. Perform softmax.
"""
@staticmethod
def forward(ctx, inputs, mask, scale):
import scaled_masked_softmax_cuda
scale_t = torch.tensor([scale])
softmax_results = scaled_masked_softmax_cuda.forward(inputs, mask, scale_t[0])
ctx.save_for_backward(softmax_results, scale_t)
return softmax_results
@staticmethod
def backward(ctx, output_grads):
import scaled_masked_softmax_cuda
softmax_results, scale_t = ctx.saved_tensors
input_grads = scaled_masked_softmax_cuda.backward(
output_grads, softmax_results, scale_t[0]
)
return input_grads, None, None
class ScaledSoftmax(torch.autograd.Function):
"""
Fused operation which performs following two operations in sequence
1. Scale the tensor.
2. Perform softmax.
"""
@staticmethod
def forward(ctx, inputs, scale):
import scaled_softmax_cuda
scale_t = torch.tensor([scale])
softmax_results = scaled_softmax_cuda.forward(
inputs, scale_t[0]
)
ctx.save_for_backward(softmax_results, scale_t)
return softmax_results
@staticmethod
def backward(ctx, output_grads):
import scaled_softmax_cuda
softmax_results, scale_t = ctx.saved_tensors
input_grads = scaled_softmax_cuda.backward(
output_grads, softmax_results, scale_t[0]
)
return input_grads, None, None
class FusedScaleMaskSoftmax(nn.Module):
"""
fused operation: scaling + mask + softmax
Arguments:
input_in_fp16: flag to indicate if input in fp16 data format.
input_in_bf16: flag to indicate if input in bf16 data format.
attn_mask_type: attention mask type (pad or causal)
scaled_masked_softmax_fusion: flag to indicate user want to use softmax fusion
mask_func: mask function to be applied.
softmax_in_fp32: if true, softmax in performed at fp32 precision.
scale: scaling factor used in input tensor scaling.
"""
def __init__(
self,
input_in_fp16,
input_in_bf16,
attn_mask_type,
scaled_masked_softmax_fusion,
mask_func,
softmax_in_fp32,
scale,
):
super(FusedScaleMaskSoftmax, self).__init__()
self.input_in_fp16 = input_in_fp16
self.input_in_bf16 = input_in_bf16
assert not (
self.input_in_fp16 and self.input_in_bf16
), "both fp16 and bf16 flags cannot be active at the same time."
self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16
self.attn_mask_type = attn_mask_type
self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
self.mask_func = mask_func
self.softmax_in_fp32 = softmax_in_fp32
self.scale = scale
assert (
self.scale is None or softmax_in_fp32
), "softmax should be in fp32 when scaled"
def forward(self, input, mask):
# [b, np, sq, sk]
assert input.dim() == 4
if self.is_kernel_available(mask, *input.size()):
return self.forward_fused_softmax(input, mask)
else:
return self.forward_torch_softmax(input, mask)
def is_kernel_available(self, mask, b, np, sq, sk):
attn_batches = b * np
if (
self.scaled_masked_softmax_fusion # user want to fuse
and self.input_in_float16 # input must be fp16
and 16 < sk <= 4096 # sk must be 16 ~ 2048
and sq % 4 == 0 # sq must be divisor of 4
and sk % 4 == 0 # sk must be divisor of 4
and attn_batches % 4 == 0 # np * b must be divisor of 4
):
if 0 <= sk <= 4096:
batch_per_block = self.get_batch_per_block(sq, sk, b, np)
if self.attn_mask_type == AttnMaskType.causal:
if attn_batches % batch_per_block == 0:
return True
else:
if sq % batch_per_block == 0:
return True
return False
def forward_fused_softmax(self, input, mask):
b, np, sq, sk = input.size()
scale = self.scale if self.scale is not None else 1.0
if self.attn_mask_type == AttnMaskType.causal:
assert sq == sk, "causal mask is only for self attention"
# input is 3D tensor (attn_batches, sq, sk)
input = input.view(-1, sq, sk)
probs = ScaledUpperTriangMaskedSoftmax.apply(input, scale)
return probs.view(b, np, sq, sk)
else:
# input is 4D tensor (b, np, sq, sk)
if mask is not None:
return ScaledMaskedSoftmax.apply(input, mask, scale)
else:
return ScaledSoftmax.apply(input, scale)
def forward_torch_softmax(self, input, mask):
if self.input_in_float16 and self.softmax_in_fp32:
input = input.float()
if self.scale is not None:
input = input * self.scale
mask_output = self.mask_func(input, mask) if mask is not None else input
probs = torch.nn.Softmax(dim=-1)(mask_output)
if self.input_in_float16 and self.softmax_in_fp32:
if self.input_in_fp16:
probs = probs.half()
else:
probs = probs.bfloat16()
return probs
@staticmethod
def get_batch_per_block(sq, sk, b, np):
import scaled_masked_softmax_cuda
return scaled_masked_softmax_cuda.get_batch_per_block(sq, sk, b, np)
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from dataclasses import dataclass
from typing import Callable
import torch
@dataclass
class ModelParallelConfig:
"""Base configuration for Megatron Core
Model Parallelism
-----------------
tensor_model_parallel_size (int): Intra-layer model parallelism. Splits tensors across GPU ranks. Defaults to 1.
pipeline_model_parallel_size (int): Inter-layer model parallelism. Splits transformer layers across GPU
ranks. Defaults to 1.
virtual_pipeline_model_parallel_size (int): Interleaved pipeline parallelism is used to improve performance by
reducing the pipeline bubble. Considers a transformer block as a list of smaller transformer (virtual) blocks.
The number of virtual blocks per pipeline model parallel rank is the virtual model parallel size. See Efficient
Large-Scale Language Model Training on GPU Clusters Using Megatron-LM: https://arxiv.org/pdf/2104.04473.pdf for
more details. Defaults to None.
sequence_parallel (bool): Makes tensor parallelism more memory efficient for LLMs (20B+) by
parallelizing layer norms and dropout sequentially. See Reducing Activation Recomputation in Large Transformer
Models: https://arxiv.org/abs/2205.05198 for more details. Defaults to False.
Initialization
--------------
perform_initialization (bool, default=True): If true, weights are initialized. This option can be useful when you
know you are going to load values from a checkpoint.
use_cpu_initialization: (bool, default=False): When set to False, we initialize the weights directly on the GPU.
Transferring weights from CPU to GPU can take a significant amount of time for large models. Defaults to False.
Training
--------
fp16 (bool): If true, train with fp16 mixed precision training. Defaults to False.
bf16 (bool): If true, train with bf16 mixed precision training. Defaults to False.
params_dtype (torch.dtype): dtype used when intializing the weights. Defaults to torch.float32
timers (optional, default=None): TODO
Optimizations
-------------
gradient_accumulation_fusion (bool): If true, fuses weight gradient accumulation to GEMMs. Requires the custom CUDA
extension fused_weight_gradient_mlp_cuda module. To use gradient_accumulation_fusion you must install APEX with
--cpp_ext and --cuda_ext. For example: "pip install --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext\"
". Note that the extension requires CUDA>=11. Otherwise, you must turn off gradient accumulation fusion.
Defaults to False.
async_tensor_model_parallel_allreduce (bool, default=True): If true, enables asynchronous execution of
tensor-model-parallel all-reduce with weight gradient compuation of a column-linear layer. Defaults to False.
Pipeline Parallelism
--------------------
pipeline_dtype (required): dtype used in p2p communication, usually params_dtype
grad_scale_func (optional, default=None): If using loss scaling, this function should take the loss and return the
scaled loss. If None, no function is called on the loss.
enable_autocast (bool): If true runs the forward step function inside torch.autocast context. Default is False.
autocast_dtype (torch.dtype): dtype to pass to torch.amp.autocast when enabled. Default is pipeline_dtype.
variable_seq_lengths (bool, default=False): Support for variable sequence lengths across microbatches. Setting this
communicates the size of tensors during pipeline parallelism communication, because of this extra overhead it
should only be set if the sequence length varies by microbatch within a global batch.
num_microbatches_with_partial_activation_checkpoints (int, default=None): If int, set the number of microbatches
where not all of the layers will be checkpointed and recomputed. The rest of the microbatches within the window
of maximum outstanding microbatches will recompute all layers (either full recompute or selective recompute). If
None, the checkpoint and recompute will be left up to the forward_step function.
overlap_p2p_comm (bool, optional, default=False): When True some of the peer to peer communication for pipeline
parallelism will overlap with computation. Must be False if batch_p2p_comm is true.
batch_p2p_comm (bool, default=True): Use batch_isend_irecv instead of individual isend/irecv calls. Must be False
if overlap_p2p_comm is True.
batch_p2p_sync (bool, default=True): When using batch_isend_irecv, do a cuda.device.synchronize afterward to work
around a bug in older version of PyTorch.
use_ring_exchange_p2p (bool, default = False): Use custom ring_exchange kernel instead of
torch.distributed.batch_isend_irecv(). Requires custom built torch with torch.distributed.ring_exchange.
deallocate_pipeline_outputs (optional, default=False): If True, output data is deallocated after the tensor is sent
to the next pipeline stage. Helps with saving memory, does nothing when pipeline parallel is not used.
no_sync_func (optional): Function that creates a context that suppresses asynchronous data-parallel
communication. If the model is an instance of torch.nn.DistributedDataParallel, the default is to use
torch.nn.DistributedDataParallel.no_sync.
grad_sync_func (optional): Function that launches asynchronous gradient reductions (e.g. distributed optimizer
gradient reduce-scatters). The function should take one argument: an iterable of parameters whose gradients are
to be synchronized.
param_sync_func (optional): Function that launches asynchronous parameter synchronizations (e.g. distributed
optimizer parameter all-gathers). The function should take one argument: an iterable of parameters to be
synchronized.
"""
# Model parallelism
tensor_model_parallel_size: int = 1
pipeline_model_parallel_size: int = 1
virtual_pipeline_model_parallel_size: int = None
sequence_parallel: bool = False
# Initialization
perform_initialization: bool = True
use_cpu_initialization: bool = False
# Training
fp16: bool = False
bf16: bool = False
params_dtype: torch.dtype = torch.float32
timers: Callable = None
# Optimizations
gradient_accumulation_fusion: bool = False
async_tensor_model_parallel_allreduce: bool = False
# Pipeline Parallel
pipeline_dtype: torch.dtype = None
grad_scale_func: Callable = None
enable_autocast: bool = False
autocast_dtype: torch.dtype = None
variable_seq_lengths: bool = False
num_microbatches_with_partial_activation_checkpoints: int = None
overlap_p2p_comm: bool = False
batch_p2p_comm: bool = True
batch_p2p_sync: bool = True
use_ring_exchange_p2p: bool = False
deallocate_pipeline_outputs: bool = False
no_sync_func: Callable = None
grad_sync_func: Callable = None
param_sync_func: Callable = None
def __post_init__(self):
""" Python dataclass method that is used to modify attributes after initialization.
See https://docs.python.org/3/library/dataclasses.html#post-init-processing for more details.
"""
if self.sequence_parallel:
if self.tensor_model_parallel_size <= 1:
raise ValueError("Can not use sequence paralllelism without tensor parallelism")
if self.async_tensor_model_parallel_allreduce:
# sequence_parallelism already does this async
self.async_tensor_model_parallel_allreduce = False
if self.pipeline_model_parallel_size > 1:
if self.pipeline_dtype is None:
raise ValueError("When using pipeline parallelism, pipeline_dtype must be specified")
if self.autocast_dtype is None:
self.autocast_dtype = self.params_dtype
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import importlib.util
import torch
import pdb
from megatron import get_args
from torch import einsum, nn
__all__ = ['RotaryEmbedding', 'apply_rotary_pos_emb']
class RotaryEmbedding(nn.Module):
def __init__(self, dim, base=10000):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
def forward(self, max_seq_len, offset=0):
self.inv_freq = self.inv_freq.to(torch.float32)
seq = torch.arange(max_seq_len, device=self.inv_freq.device) + offset
freqs = einsum('i , j -> i j', seq.type_as(self.inv_freq), self.inv_freq)
# first part even vector components, second part odd vector components,
# 2 * dim in dimension size
emb = torch.cat((freqs, freqs), dim=-1)
# emb [seq_length, .., dim]
return emb[:, None, None, :]
def _rotate_half(x):
"""
change sign so the last dimension becomes [-odd, +even]
"""
x1, x2 = torch.chunk(x, 2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(t, freqs, position_ids):
args = get_args()
"""
input tensor t is of shape [seq_length, ..., dim]
rotary positional embeding tensor freqs is of shape [seq_length, ..., dim]
check https://kexue.fm/archives/8265 for detailed formulas
"""
rot_dim = freqs.shape[-1]
if position_ids.shape[1] > 1:
freqs = freqs[position_ids]
freqs = freqs.view(t.shape[1],freqs.shape[1],freqs.shape[2],freqs.shape[4]).transpose(0,1)
# ideally t_pass is empty so rotary pos embedding is applied to all tensor t
t, t_pass = t[..., :rot_dim], t[..., rot_dim:]
# first part is cosine component
# second part is sine component, need to change signs with _rotate_half method
t = (t * freqs.cos()) + (_rotate_half(t) * freqs.sin())
if args.bf16 and t.dtype != torch.bfloat16:
t = t.bfloat16()
return torch.cat((t, t_pass), dim=-1)
from .gpt_model import GPTModel
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment