Commit 2c63b5cd authored by wangxj's avatar wangxj
Browse files

升级0.12版本

parent c271aaae
Pipeline #2451 passed with stage
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
......@@ -18,11 +18,15 @@ from megatron.core.models.retro.utils import (
get_config_path as get_retro_config_path,
get_gpt_data_dir as get_retro_data_dir,
)
from megatron.core.rerun_state_machine import RerunStateMachine
from megatron.core.transformer import TransformerConfig, MLATransformerConfig
from megatron.core.transformer.enums import AttnBackend
from megatron.core.utils import is_torch_min_version
from megatron.core.utils import (
is_torch_min_version,
get_torch_version,
)
from megatron.training.activations import squared_relu
from megatron.training.utils import update_use_dist_ckpt
from megatron.training.utils import update_use_dist_ckpt, get_device_arch_version
def parse_args(extra_args_provider=None, ignore_unknown_args=False):
......@@ -187,21 +191,28 @@ def moe_freq_type(x):
def validate_args(args, defaults={}):
# Temporary
assert args.non_persistent_ckpt_type in ['global', None], \
'Currently only global checkpoints are supported'
assert args.non_persistent_ckpt_type in ['global', 'local', None], \
'Currently only global and local checkpoints are supported'
if args.non_persistent_ckpt_type == 'local':
try:
from nvidia_resiliency_ext.checkpointing.local.ckpt_managers.local_manager import \
LocalCheckpointManager
except ModuleNotFoundError as e:
raise RuntimeError('nvidia_resiliency_ext is required for local checkpointing') from e
# Load saved args from Retro (if applicable).
load_retro_args(args)
# Set args.use_dist_ckpt from args.ckpt_format.
if args.use_legacy_models:
assert args.ckpt_format == "torch", \
"legacy model format only supports the 'torch' checkpoint format."
update_use_dist_ckpt(args)
if args.encoder_pipeline_model_parallel_size == 0 and args.num_experts == 0:
assert args.encoder_tensor_model_parallel_size == args.tensor_model_parallel_size, "If non-MOE encoder shares first decoder pipeline rank it must have the same TP as the decoder."
if args.encoder_tensor_model_parallel_size > 0:
assert args.encoder_pipeline_model_parallel_size > 0, "encoder_pipeline_model_parallel_size must be defined."
assert args.num_attention_heads % args.encoder_tensor_model_parallel_size == 0
assert args.encoder_tensor_model_parallel_size <= args.tensor_model_parallel_size, "We do not support encoders with more TP than the decoder."
......@@ -220,12 +231,8 @@ def validate_args(args, defaults={}):
if args.attention_backend == AttnBackend.local:
assert args.spec[0] == 'local' , '--attention-backend local is only supported with --spec local'
# Pipeline model parallel size.
args.transformer_pipeline_model_parallel_size = (
args.pipeline_model_parallel_size - 1
if args.standalone_embedding_stage else
args.pipeline_model_parallel_size
)
# Pipeline model parallel size.
args.transformer_pipeline_model_parallel_size = args.pipeline_model_parallel_size
args.data_parallel_size = args.world_size // total_model_size
......@@ -329,13 +336,12 @@ def validate_args(args, defaults={}):
print('setting global batch size to {}'.format(
args.global_batch_size), flush=True)
assert args.global_batch_size > 0
if args.decoder_first_pipeline_num_layers is None and args.decoder_last_pipeline_num_layers is None:
# Divisibility check not applicable for T5 models which specify encoder_num_layers
# and decoder_num_layers.
if args.num_layers is not None:
assert args.num_layers % args.transformer_pipeline_model_parallel_size == 0, \
'Number of layers should be divisible by the pipeline-model-parallel size'
if args.num_layers_per_virtual_pipeline_stage is not None:
# Uneven virtual pipeline parallelism
assert args.num_layers_per_virtual_pipeline_stage is None or args.num_virtual_stages_per_pipeline_rank is None, \
'--num-layers-per-virtual-pipeline-stage and --num-virtual-stages-per-pipeline-rank cannot be set at the same time'
if args.num_layers_per_virtual_pipeline_stage is not None or args.num_virtual_stages_per_pipeline_rank is not None:
if args.overlap_p2p_comm:
assert args.pipeline_model_parallel_size > 1, \
'When interleaved schedule is used, pipeline-model-parallel size '\
......@@ -345,15 +351,28 @@ def validate_args(args, defaults={}):
'When interleaved schedule is used and p2p communication overlap is disabled, '\
'pipeline-model-parallel size should be greater than 2 to avoid having multiple '\
'p2p sends and recvs between same 2 ranks per communication batch'
assert args.num_layers is not None
# Double check divisibility check here since check above is if guarded.
assert args.num_layers % args.transformer_pipeline_model_parallel_size == 0, \
'Number of layers should be divisible by the pipeline-model-parallel size'
num_layers_per_pipeline_stage = args.num_layers // args.transformer_pipeline_model_parallel_size
assert num_layers_per_pipeline_stage % args.num_layers_per_virtual_pipeline_stage == 0, \
'Number of layers per pipeline stage must be divisible by number of layers per virtual pipeline stage'
args.virtual_pipeline_model_parallel_size = num_layers_per_pipeline_stage // \
args.num_layers_per_virtual_pipeline_stage
if args.num_virtual_stages_per_pipeline_rank is None:
assert args.decoder_first_pipeline_num_layers is None and args.decoder_last_pipeline_num_layers is None, \
'please use --num-virtual-stages-per-pipeline-rank to specify virtual pipeline parallel degree when enable uneven pipeline parallelism'
num_layers = args.num_layers
if args.account_for_embedding_in_pipeline_split:
num_layers += 1
if args.account_for_loss_in_pipeline_split:
num_layers += 1
assert num_layers % args.transformer_pipeline_model_parallel_size == 0, \
'number of layers of the model must be divisible pipeline model parallel size'
num_layers_per_pipeline_stage = num_layers // args.transformer_pipeline_model_parallel_size
assert num_layers_per_pipeline_stage % args.num_layers_per_virtual_pipeline_stage == 0, \
'number of layers per pipeline stage must be divisible number of layers per virtual pipeline stage'
args.virtual_pipeline_model_parallel_size = num_layers_per_pipeline_stage // \
args.num_layers_per_virtual_pipeline_stage
else:
args.virtual_pipeline_model_parallel_size = args.num_virtual_stages_per_pipeline_rank
else:
args.virtual_pipeline_model_parallel_size = None
# Overlap P2P communication is disabled if not using the interleaved schedule.
......@@ -364,6 +383,30 @@ def validate_args(args, defaults={}):
print('WARNING: Setting args.overlap_p2p_comm and args.align_param_gather to False '
'since non-interleaved schedule does not support overlapping p2p communication '
'and aligned param AG')
if args.decoder_first_pipeline_num_layers is None and args.decoder_last_pipeline_num_layers is None:
# Divisibility check not applicable for T5 models which specify encoder_num_layers
# and decoder_num_layers.
if args.num_layers is not None:
num_layers = args.num_layers
if args.account_for_embedding_in_pipeline_split:
num_layers += 1
if args.account_for_loss_in_pipeline_split:
num_layers += 1
assert num_layers % args.transformer_pipeline_model_parallel_size == 0, \
'Number of layers should be divisible by the pipeline-model-parallel size'
if args.rank == 0:
print(f"Number of virtual stages per pipeline stage: {args.virtual_pipeline_model_parallel_size}")
if args.data_parallel_sharding_strategy == "optim_grads_params":
args.overlap_param_gather = True
args.overlap_grad_reduce = True
if args.data_parallel_sharding_strategy == "optim_grads":
args.overlap_grad_reduce = True
if args.overlap_param_gather:
assert args.use_distributed_optimizer, \
......@@ -373,8 +416,8 @@ def validate_args(args, defaults={}):
assert not args.use_legacy_models, \
'--overlap-param-gather only supported with MCore models'
if getattr(args, "use_torch_fsdp2", False):
assert get_torch_version() >= PkgVersion("2.4"), \
if args.use_torch_fsdp2:
assert is_torch_min_version("2.4.0"), \
'FSDP2 requires PyTorch >= 2.4.0 with FSDP 2 support.'
assert args.pipeline_model_parallel_size == 1, \
'--use-torch-fsdp2 is not supported with pipeline parallelism'
......@@ -401,10 +444,33 @@ def validate_args(args, defaults={}):
assert not args.use_dist_ckpt, \
'--overlap-param-gather-with-optimizer-step not supported with distributed checkpointing yet'
dtype_map = {
'fp32': torch.float32, 'bf16': torch.bfloat16, 'fp16': torch.float16, 'fp8': torch.uint8,
}
map_dtype = lambda d: d if isinstance(d, torch.dtype) else dtype_map[d]
args.main_grads_dtype = map_dtype(args.main_grads_dtype)
args.main_params_dtype = map_dtype(args.main_params_dtype)
args.exp_avg_dtype = map_dtype(args.exp_avg_dtype)
args.exp_avg_sq_dtype = map_dtype(args.exp_avg_sq_dtype)
if args.fp8_param_gather:
assert args.use_distributed_optimizer, \
'--fp8-param-gather only supported with distributed optimizer'
if args.use_custom_fsdp:
assert args.use_distributed_optimizer, \
'--use-custom-fsdp only supported with distributed optimizer'
if args.data_parallel_sharding_strategy in ["optim_grads_params", "optim_grads"]:
warnings.warn('Please make sure your TransformerEngine support FSDP + gradient accumulation fusion')
assert args.gradient_accumulation_fusion is False, \
"optim_grads_params optim_grads are not supported with gradient accumulation fusion"
if args.data_parallel_sharding_strategy == "optim_grads_params":
assert args.check_weight_hash_across_dp_replicas_interval is None, \
'check_weight_hash_across_dp_replicas_interval is not supported with optim_grads_params'
# Parameters dtype.
args.params_dtype = torch.float
if args.fp16:
......@@ -422,7 +488,13 @@ def validate_args(args, defaults={}):
args.params_dtype = torch.bfloat16
# bfloat16 requires gradient accumulation and all-reduce to
# be done in fp32.
if not args.accumulate_allreduce_grads_in_fp32:
if args.accumulate_allreduce_grads_in_fp32:
assert args.main_grads_dtype == torch.float32, \
"--main-grads-dtype can only be fp32 when --accumulate-allreduce-grads-in-fp32 is set"
if args.grad_reduce_in_bf16:
args.accumulate_allreduce_grads_in_fp32 = False
elif not args.accumulate_allreduce_grads_in_fp32 and args.main_grads_dtype == torch.float32:
args.accumulate_allreduce_grads_in_fp32 = True
if args.rank == 0:
print('accumulate and all-reduce gradients in fp32 for '
......@@ -525,7 +597,9 @@ def validate_args(args, defaults={}):
args.seq_length = args.encoder_seq_length
if args.seq_length is not None:
assert args.max_position_embeddings >= args.seq_length
assert args.max_position_embeddings >= args.seq_length, \
f"max_position_embeddings ({args.max_position_embeddings}) must be greater than " \
f"or equal to seq_length ({args.seq_length})."
if args.decoder_seq_length is not None:
assert args.max_position_embeddings >= args.decoder_seq_length
if args.lr is not None:
......@@ -597,7 +671,7 @@ def validate_args(args, defaults={}):
# model parallel memory optimization is enabled
if args.sequence_parallel:
args.async_tensor_model_parallel_allreduce = False
if getattr(args, "use_torch_fsdp2", False):
if args.use_torch_fsdp2:
warnings.warn(
"Using sequence parallelism with FSDP2 together. Try not to using them "
"together since they require different CUDA_MAX_CONNECTIONS settings "
......@@ -605,13 +679,14 @@ def validate_args(args, defaults={}):
"environment variable CUDA_DEVICE_MAX_CONNECTIONS to 1 while FSDP2 "
"requires not setting CUDA_DEVICE_MAX_CONNECTIONS=1 for better parallelization.")
if os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1":
if os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1" and get_device_arch_version() < 10:
# CUDA_DEVICE_MAX_CONNECTIONS requirement no longer exists since the Blackwell architecture
if args.sequence_parallel:
raise RuntimeError(
warnings.warn(
"Using sequence parallelism requires setting the environment variable "
"CUDA_DEVICE_MAX_CONNECTIONS to 1")
if args.async_tensor_model_parallel_allreduce:
raise RuntimeError(
warnings.warn(
"Using async gradient all reduce requires setting the environment "
"variable CUDA_DEVICE_MAX_CONNECTIONS to 1")
......@@ -642,9 +717,6 @@ def validate_args(args, defaults={}):
assert not args.use_legacy_models, \
'--decoupled-lr and --decoupled-min-lr is not supported in legacy models.'
# FlashAttention
args.use_flash_attn = args.use_flash_attn_cutlass or args.use_flash_attn_triton
# Legacy RoPE arguments
if args.use_rotary_position_embeddings:
args.position_embedding_type = 'rope'
......@@ -660,15 +732,21 @@ def validate_args(args, defaults={}):
if not args.add_position_embedding and args.position_embedding_type != 'rope':
raise RuntimeError('--no-position-embedding is deprecated, use --position-embedding-type')
# Relative position embeddings arguments
if args.position_embedding_type == 'relative':
assert (
args.transformer_impl == "transformer_engine"
), 'Local transformer implementation currently does not support attention bias-based position embeddings.'
# MoE Spec check
if args.num_experts == 0:
args.num_experts = None
if args.num_experts is not None:
assert args.spec is None, "Model Spec must be None when using MoEs"
if args.moe_ffn_hidden_size is None:
args.moe_ffn_hidden_size = args.ffn_hidden_size
# Context parallel
if args.context_parallel_size > 1:
assert not args.use_legacy_models, "Context parallelism is not supported in legacy models."
......@@ -691,10 +769,6 @@ def validate_args(args, defaults={}):
any([args.train_data_path, args.valid_data_path, args.test_data_path]) \
<= 1, "A single data source must be provided in training mode, else None"
if args.use_tp_pp_dp_mapping:
assert args.context_parallel_size * args.expert_model_parallel_size <= 1, \
"context_parallel and expert_model_parallel can't be used with tp-pp-dp mapping."
# Deterministic mode
if args.deterministic_mode:
assert not args.use_flash_attn, "Flash attention can not be used in deterministic mode."
......@@ -710,6 +784,21 @@ def validate_args(args, defaults={}):
if args.apply_query_key_layer_scaling:
args.attention_softmax_in_fp32 = True
if args.result_rejected_tracker_filename is not None:
# Append to passed-in args.iterations_to_skip.
iterations_to_skip_from_file = RerunStateMachine.get_skipped_iterations_from_tracker_file(
args.result_rejected_tracker_filename
)
args.iterations_to_skip.extend(iterations_to_skip_from_file)
# Make sure all functionality that requires Gloo process groups is disabled.
if not args.enable_gloo_process_groups:
if args.use_distributed_optimizer:
# If using distributed optimizer, must use distributed checkpointing.
# Legacy checkpointing uses Gloo process groups to collect full distributed
# optimizer state in the CPU memory of DP rank 0.
assert args.use_dist_ckpt
# Checkpointing
if args.ckpt_fully_parallel_save_deprecated and args.rank == 0:
print('--ckpt-fully-parallel-save flag is deprecated and has no effect.'
......@@ -745,6 +834,31 @@ def validate_args(args, defaults={}):
args.no_load_rng = True
print('Warning: disabling --no-load-rng for upcycling.')
# Optimizer CPU offload check
if args.optimizer_cpu_offload:
assert args.use_precision_aware_optimizer, (
"The optimizer cpu offload must be used in conjunction with `--use-precision-aware-optimizer`, "
"as the hybrid device optimizer reuses the code path of this flag."
)
# MoE loss and include embedding and loss layer check
if args.num_experts is not None:
if args.moe_router_load_balancing_type != "none" or args.moe_z_loss_coeff is not None:
assert not args.account_for_embedding_in_pipeline_split, \
"Cannot support load balancing loss and z loss with --account-for-embedding-in-pipeline-split"
assert not args.account_for_loss_in_pipeline_split, \
"Cannot support load balancing loss and z loss with --account-for-loss-in-pipeline-split"
if args.non_persistent_ckpt_type == "local":
assert args.non_persistent_local_ckpt_dir is not None, "Tried to use local checkpointing without specifying --local-ckpt-dir!"
if args.replication:
assert args.replication_jump is not None, "--replication requires the value of --replication-jump!"
assert args.non_persistent_ckpt_type == "local", f"--replication requires args.non_persistent_ckpt_type == 'local', but got: {args.non_persistent_ckpt_type}"
elif args.replication_jump:
print("Warning: --replication-jump was specified despite not using replication. Ignoring.")
args.replication_jump = None
# Print arguments.
_print_args("arguments", args)
......@@ -791,8 +905,8 @@ def core_transformer_config_from_args(args, config_class=None):
kw_args['batch_p2p_comm'] = not args.overlap_p2p_comm
kw_args['num_moe_experts'] = args.num_experts
kw_args['rotary_interleaved'] = args.rotary_interleaved
kw_args['first_pipeline_num_layers']= args.decoder_first_pipeline_num_layers
kw_args['last_pipeline_num_layers']= args.decoder_last_pipeline_num_layers
kw_args['num_layers_in_first_pipeline_stage']= args.decoder_first_pipeline_num_layers
kw_args['num_layers_in_last_pipeline_stage']= args.decoder_last_pipeline_num_layers
if args.swiglu:
kw_args['activation_func'] = F.silu
kw_args['gated_linear_unit'] = True
......@@ -847,6 +961,11 @@ def _add_transformer_engine_args(parser):
group.add_argument('--fp8-param-gather', action='store_true',
help='Keep the compute param in fp8 (do not use any other intermediate '
'dtype) and perform the param all-gather in fp8.')
group.add_argument('--te-rng-tracker', action='store_true', default=False,
help='Use the Transformer Engine version of the random number generator. '
'Required for CUDA graphs support.')
group.add_argument('--inference-rng-tracker', action='store_true', default=False,
help='Use a random number generator configured for inference.')
return parser
def _add_inference_args(parser):
......@@ -873,8 +992,15 @@ def _add_inference_args(parser):
'Bert embedder.')
group.add_argument('--flash-decode', default=False, action="store_true",
help='Whether to use the flash decoding kernel.')
group.add_argument('--enable-cuda-graph', default=False, action="store_true",
help='Use CUDA graph capture and replay.')
group.add_argument("--cuda-graph-warmup-steps", type=int, default=3,
help="Number of CUDA graph warmup steps")
group.add_argument('--inference-max-requests', type=int, default=8,
help='Maximum number of requests for inference.',
dest='inference_max_batch_size')
group.add_argument('--inference-max-seq-length', type=int, default=2560,
help='Maximum sequence length allocated for prefill during inference.',
help='Maximum sequence length expected for inference (prefill + decode).',
dest='inference_max_seq_length')
return parser
......@@ -957,8 +1083,12 @@ def _add_network_size_args(parser):
help='Maximum number of position embeddings to use. '
'This is the size of position embedding.')
group.add_argument('--position-embedding-type', type=str, default='learned_absolute',
choices=['learned_absolute', 'rope', 'none'],
help='Position embedding type.')
choices=['learned_absolute', 'rope', 'relative', 'none'],
help='Position embedding type.')
group.add_argument('--relative-attention-num-buckets', type=int, default=32,
help='Number of buckets for relative position embeddings.')
group.add_argument('--relative-attention-max-distance', type=int, default=128,
help='Maximum distance for relative position embeddings calculation.')
group.add_argument('--use-rotary-position-embeddings', action='store_true',
help='Use rotary positional embeddings or not. '
'Deprecated: use --position-embedding-type')
......@@ -971,7 +1101,9 @@ def _add_network_size_args(parser):
group.add_argument('--rotary-seq-len-interpolation-factor', type=int, default=None,
help='Sequence length interpolation factor for rotary embeddings.')
group.add_argument('--use-rope-scaling', action='store_true',
help='Apply rope scaling as used in llama3.1')
help='Apply rope scaling as used in llama3.x')
group.add_argument('--rope-scaling-factor', type=float, default=8.0,
help='Rope scaling factor in llama3.x models')
group.add_argument('--no-position-embedding',
action='store_false',
help='Disable position embedding. Deprecated: use --position-embedding-type',
......@@ -1059,6 +1191,9 @@ def _add_ft_package_args(parser):
group.add_argument('--enable-ft-package', action='store_true',
help='If set, Fault Tolerance package is enabled. '
'Note: This feature is for Nvidia internal use only.')
group.add_argument('--calc-ft-timeouts', action='store_true',
help='If set, FT package will try to automatically compute the timeouts. '
'Note: This feature is for Nvidia internal use only.')
return parser
......@@ -1227,6 +1362,9 @@ def _add_training_args(parser):
group.add_argument('--check-for-spiky-loss', action='store_true',
help='Check for spiky loss',
dest='check_for_spiky_loss')
group.add_argument('--check-for-large-grads', action='store_true',
help='Check for unexpectedly large grads',
dest='check_for_large_grads')
group.add_argument('--distribute-saved-activations',
action='store_true',
help='If set, distribute recomputed activations '
......@@ -1259,17 +1397,19 @@ def _add_training_args(parser):
help='Global step to start profiling.')
group.add_argument('--profile-step-end', type=int, default=12,
help='Global step to stop profiling.')
group.add_argument('--iterations-to-skip', nargs='+', type=int, default=[],
help='List of iterations to skip, empty by default.')
group.add_argument('--result-rejected-tracker-filename', type=str, default=None,
help='Optional name of file tracking `result_rejected` events.')
group.add_argument('--disable-gloo-process-groups', action='store_false',
dest='enable_gloo_process_groups',
help='Disables creation and usage of Gloo process groups.')
group.add_argument('--use-pytorch-profiler', action='store_true',
help='Use the built-in pytorch profiler. '
'Useful if you wish to view profiles in tensorboard.',
dest='use_pytorch_profiler')
group.add_argument('--use-hip-profiler', action='store_true',
help='Use HIP PROFILER',
dest='use_hip_profiler')
group.add_argument('--profile-ranks', nargs='+', type=int, default=[0],
help='Global ranks to profile.')
group.add_argument('--profile-dir', type=str, default="./",
help='profile dir to save.')
group.add_argument('--record-memory-history', action="store_true", default=False,
help='Record memory history in last rank.')
group.add_argument('--memory-snapshot-path', type=str, default="snapshot.pickle",
......@@ -1363,11 +1503,9 @@ def _add_training_args(parser):
group.add_argument('--cross-entropy-loss-fusion', action='store_true',
help='Enabled fusion of cross entropy loss calculation.',
dest='cross_entropy_loss_fusion')
group.add_argument('--use-flash-attn-cutlass', action='store_true',
group.add_argument('--use-flash-attn', action='store_true',
help='use FlashAttention implementation of attention. '
'https://arxiv.org/abs/2205.14135')
group.add_argument('--use-flash-attn-triton', action='store_true',
help='use FlashAttention implementation of attention using Triton.')
group.add_argument('--disable-bias-linear', action='store_false',
help='Disable bias in the linear layers',
dest='add_bias_linear')
......@@ -1377,6 +1515,18 @@ def _add_training_args(parser):
group.add_argument('--optimizer', type=str, default='adam',
choices=['adam', 'sgd'],
help='Optimizer function')
group.add_argument('--optimizer-cpu-offload', action='store_true',
help='Offload optimizer state to CPU')
group.add_argument('--optimizer-offload-fraction', type=float, default=1.0,
help='Ratio of optimizer state to offload to CPU')
group.add_argument('--use-torch-optimizer-for-cpu-offload', action='store_true',
help="Use torch.optim.Optimizer instead of Megatron's optimizer in optimizer cpu offload mode.")
group.add_argument('--overlap-cpu-optimizer-d2h-h2d', action='store_true', default=False,
help='Overlap CPU optimizer step, gradients D2H and updated parameters H2D.')
group.add_argument('--no-pin-cpu-grads', action='store_false', dest='pin_cpu_grads',
help='Disable pinning of CPU memory for gradients.')
group.add_argument('--no-pin-cpu-params', action='store_false', dest='pin_cpu_params',
help='Disable pinning of CPU memory for parameters.')
group.add_argument('--dataloader-type', type=str, default=None,
choices=['single', 'cyclic', 'external'],
help='Single pass vs multiple pass data loader')
......@@ -1425,6 +1575,10 @@ def _add_training_args(parser):
group.add_argument('--disable-tp-comm-split-rs', action='store_false',
help='Disables the Reduce-Scatter overlap with fprop GEMM.',
dest='tp_comm_split_rs')
group.add_argument('--pipeline-model-parallel-comm-backend', type=str, default=None,
choices=['nccl', 'ucc'],
help='Select a communicator backend for pipeline parallel communication. '
'If None, the default backend will be used.')
return parser
......@@ -1549,8 +1703,7 @@ def _add_checkpointing_args(parser):
choices=['global', 'local', 'in_memory', None],
help='Type of non-persistent model checkpoints. '
'"global" - Saved as a standard checkpoint (e.g., on Lustre) with old checkpoints being removed. '
'"local" - [TBD] Each rank saves a portion of the checkpoint locally (e.g., on SSD/ramdisk). '
'"in_memory" - [TBD] A special kind of local checkpoint that avoids serialization. '
'"local" - Each rank saves a portion of the checkpoint locally (e.g., on SSD/ramdisk). '
'None - No non-persistent checkpointing (default option).')
group.add_argument('--non-persistent-global-ckpt-dir', type=str, default=None,
help='Directory containing global non-persistent model checkpoints.')
......@@ -1586,6 +1739,9 @@ def _add_checkpointing_args(parser):
group.add_argument('--use-dist-ckpt', action='store_true',
dest='use_dist_ckpt_deprecated',
help='Deprecated: see --ckpt-format.')
group.add_argument('--use-persistent-ckpt-worker', action='store_true',
help='Enables a persitent checkpoint worker for async save')
group.add_argument('--auto-detect-ckpt-format', action='store_true',
help='Determine if the checkpoint format is in legacy or distributed format.'
' If False, expects distributed checkpoint iff args.ckpt_format != "torch".'
......@@ -1641,6 +1797,8 @@ def _add_mixed_precision_args(parser):
help='Run model in fp16 mode.')
group.add_argument('--bf16', action='store_true',
help='Run model in bfloat16 mode.')
group.add_argument('--grad-reduce-in-bf16', action='store_true',
help='Reduce gradients in bfloat16.')
group.add_argument('--loss-scale', type=float, default=None,
help='Static loss scaling, positive power of 2 '
'values can improve fp16 convergence. If None, dynamic'
......@@ -1699,6 +1857,8 @@ def _add_distributed_args(parser):
'--tensor-model-parallel-size instead.')
group.add_argument('--num-layers-per-virtual-pipeline-stage', type=int, default=None,
help='Number of layers per virtual pipeline stage')
group.add_argument('--num-virtual-stages-per-pipeline-rank', type=int, default=None,
help='Number of virtual pipeline stages per pipeline parallelism rank')
group.add_argument('--microbatch-group-size-per-virtual-pipeline-stage', type=int, default=None,
help='Number of contiguous microbatches per virtual pipeline stage',
dest='microbatch_group_size_per_vp_stage')
......@@ -1726,8 +1886,15 @@ def _add_distributed_args(parser):
help='If not set, all PP stages will launch gradient reduces simultaneously. '
'Otherwise, each PP stage will independently launch as needed.',
dest='align_grad_reduce')
group.add_argument('--ddp-num-buckets', type=int, default=None,
help='Number of buckets for data-parallel communication')
group.add_argument('--ddp-bucket-size', type=int, default=None,
help='Bucket size for data-parallel communication')
group.add_argument('--ddp-pad-buckets-for-high-nccl-busbw', action='store_true',
default=False, help='If set, make sure the bucket size is divisible by a large power '
'of 2 (2^16) to ensure NCCL collectives have high bus bandwidth at large DP counts, '
'since NCCL message size (which for ring algorithms is bucket_size / dp_size) '
'apparently needs to be divisible by a power of 2 for high busbw.')
group.add_argument('--ddp-average-in-collective', action='store_true',
default=False, help='If set, average directly in data-parallel communication collective.')
group.add_argument('--overlap-param-gather', action='store_true',
......@@ -1745,21 +1912,33 @@ def _add_distributed_args(parser):
default=False, help='If set, use custom-built ring exchange '
'for p2p communications. Note that this option will require '
'a custom built image that support ring-exchange p2p.')
group.add_argument('--local-rank', type=int, default=int(os.getenv('LOCAL_RANK', '0')),
help='local rank passed from distributed launcher.')
group.add_argument('--lazy-mpu-init', type=bool, required=False,
help='If set to True, initialize_megatron() '
'skips DDP initialization and returns function to '
'complete it instead.Also turns on '
'complete it instead. Also turns on '
'--use-cpu-initialization flag. This is for '
'external DDP manager.' )
group.add_argument('--standalone-embedding-stage', action='store_true',
default=False, help='If set, *input* embedding layer '
'is placed on its own pipeline stage, without any '
'transformer layers. (For T5, this flag currently only '
'affects the encoder embedding.)')
group.add_argument('--account-for-embedding-in-pipeline-split', action='store_true',
default=False, help='If set, *input* embedding layer will be treated as a standard transformer'
'layer in the context of partition and placement for pipeline parallelism.')
group.add_argument('--account-for-loss-in-pipeline-split', action='store_true',
default=False, help='If set, loss layer will be treated as a standard transformer'
'layer in the context of partition and placement for pipeline parallelism.')
group.add_argument('--use-distributed-optimizer', action='store_true',
help='Use distributed optimizer.')
group.add_argument('--use-custom-fsdp', action='store_true',
help='Use the Megatron FSDP code path in DDP.')
group.add_argument('--init-model-with-meta-device', action='store_true')
group.add_argument('--data-parallel-sharding-strategy', type=str, default='no_shard',
choices=['no_shard', 'optim', 'optim_grads', 'optim_grads_params'],
help='Sharding strategy of data parallelism.')
group.add_argument('--no-gradient-reduce-div-fusion', action='store_false', dest='gradient_reduce_div_fusion',
help='If not set, fuse the division in gradient reduce.')
group.add_argument('--suggested-communication-unit-size', type=int, default=400_000_000,
help='When batch communication is needed across multiple buckets, '
'this environment variable guides the size of communication unit size.')
group.add_argument('--keep-fp8-transpose-cache-when-using-custom-fsdp', action='store_true',
help='If set, keep the fp8 transpose cache when using custom FSDP.')
group.add_argument('--num-distributed-optimizer-instances', type=int, default=1,
help='Number of Distributed Optimizer copies across Data Parallel domain.')
group.add_argument('--use-torch-fsdp2', action='store_true',
......@@ -1786,10 +1965,21 @@ def _add_distributed_args(parser):
'setting `min_ctas`, `max_ctas`, and `cga_cluster_size`.')
group.add_argument('--use-tp-pp-dp-mapping', action='store_true', default=False,
help='If set, distributed ranks initialize order is changed '
'from tp-dp-pp to tp-pp-dp. Make sure EP and CP aren\'t used '
'with this option enabled')
'from tp-cp-ep-dp-pp to tp-cp-ep-pp-dp.')
group.add_argument('--replication', action='store_true', default=False,
help="If set, replication of local checkpoints is enabled. "
"Needs to be enabled on all ranks.")
group.add_argument('--replication-jump', default=None, type=int,
help="Specifies `J`, the spacing between ranks storing replicas of a given rank's data. "
"Replicas for rank `n` may be on ranks `n+J`, `n+2J`, ..., or `n-J`, `n-2J`, etc. "
"This flag has an effect only if --replication is used. "
"and must be consistent across all ranks.")
group.add_argument('--replication-factor', default=2, type=int,
help="Number of machines storing the replica of a given rank's data.")
group.add_argument('--rank', default=-1, type=int,
help='node rank for distributed training')
group.add_argument('--local-rank', type=int, default=int(os.getenv('LOCAL_RANK', '0')),
help='local rank passed from distributed launcher.')
group.add_argument('--world-size', type=int, default=8,
help='number of nodes for distributed training')
group.add_argument('--dist-url',
......@@ -1834,8 +2024,6 @@ def _add_tokenizer_args(parser):
'GPTSentencePieceTokenizer',
'HuggingFaceTokenizer',
'Llama2Tokenizer',
'Llama3Tokenizer',
'QwenTokenizer',
'TikTokenizer',
'MultimodalTokenizer',
'NullTokenizer'],
......@@ -1862,11 +2050,6 @@ def _add_data_args(parser):
'(3) a list of prefixes e.g. prefix1 prefix2. '
'For (3), weights are inferred from the lengths of the contributing datasets. '
'This argument is exclusive to the other independent --*-data-path arguments.')
group.add_argument('--renormalize-blend-weights', action='store_true',
help='Renormalize the blend weights to account for the mid-level dataset '
'oversampling done to ensure fulfillment of the requested number of '
'samples. Use this option if prompted. Defaults to False for backward '
'comparability in the data sample order.')
group.add_argument('--split', type=str, default=None,
help='Comma-separated list of proportions for training,'
' validation, and test split. For example the split '
......@@ -2078,6 +2261,7 @@ def _add_vision_args(parser):
def _add_moe_args(parser):
group = parser.add_argument_group(title="moe")
# General arguments
group.add_argument('--expert-model-parallel-size', type=int, default=1,
help='Degree of expert model parallelism.')
group.add_argument('--expert-tensor-parallel-size', type=int, default=None,
......@@ -2103,16 +2287,39 @@ def _add_moe_args(parser):
help='Enable overlapping between shared expert computations and dispatcher communications. '
'Without this, the shared epxerts execute after the routed experts. '
'Only effective when moe-shared-expert-intermediate-size is set.')
group.add_argument('--moe-grouped-gemm', action='store_true',
help='When there are multiple experts per rank, launch multiple local GEMM kernels in multiple streams to improve the utilization and performance with GroupedLinear in TransformerEngine.')
# Router arguments
group.add_argument('--moe-router-load-balancing-type', type=str,
choices=['aux_loss', 'sinkhorn', 'none'],
choices=['aux_loss', 'seq_aux_loss', 'sinkhorn', 'none'],
default='aux_loss',
help='Determines the load balancing strategy for the router. "aux_loss" corresponds to the load balancing loss used in GShard and SwitchTransformer, "sinkhorn" corresponds to the balancing algorithm used in S-BASE, and "none" implies no load balancing. The default is "aux_loss".')
help='Determines the load balancing strategy for the router. "aux_loss" corresponds to the load balancing loss used in GShard and SwitchTransformer; "seq_aux_loss" corresponds to the load balancing loss used in DeepSeekV2, which computes the loss for each individual sample; "sinkhorn" corresponds to the balancing algorithm used in S-BASE, and "none" implies no load balancing. The default is "aux_loss".')
group.add_argument('--moe-router-score-function', type=str,
choices=['softmax', 'sigmoid'],
default='softmax',
help='Score function for MoE TopK routing. Can be "softmax" or "sigmoid".')
group.add_argument('--moe-router-topk', type=int, default=2,
help='Number of experts to route to for each token. The default is 2.')
group.add_argument('--moe-router-pre-softmax', action='store_true',
help='Enable pre-softmax routing for MoE, which means softmax is before the top-k selection. By default, softmax is done after top-k.')
group.add_argument('--moe-grouped-gemm', action='store_true',
help='When there are multiple experts per rank, launch multiple local GEMM kernels in multiple streams to improve the utilization and performance with GroupedLinear in TransformerEngine.')
group.add_argument('--moe-router-num-groups', type=int, default=None,
help='Number of groups to divide experts into for group-limited routing. When using group-limited routing: 1) Experts are divided into equal-sized groups, 2) For each token, a subset of groups are selected based on routing scores (sum of top-2 expert scores within each group), 3) From these selected groups, moe_router_topk experts are chosen.'
'Two common use cases: 1) Device-limited routing: Set equal to expert parallel size (EP) to limit each token to experts on a subset of devices (See DeepSeek-V2: https://arxiv.org/pdf/2405.04434) 2) Node-limited routing: Set equal to number of nodes in EP group to limit each token to experts on a subset of nodes (See DeepSeek-V3: https://arxiv.org/pdf/2412.19437)')
group.add_argument('--moe-router-group-topk', type=int, default=None,
help='Number of selected groups for group-limited routing.')
group.add_argument('--moe-router-topk-scaling-factor', type=float, default=None,
help='Scaling factor for routing score in top-k selection, only works when --moe-router-pre-softmax enabled. Defaults to None, which means no scaling.')
group.add_argument('--moe-router-enable-expert-bias', action='store_true',
help='TopK routing with dynamic expert bias in the aux-loss-free load balancing strategy. '
'The routing decision is based on the sum of the routing scores and the expert bias. '
'See https://arxiv.org/abs/2408.15664 for details.')
group.add_argument('--moe-router-bias-update-rate', type=float, default=1e-3,
help='Expert bias update rate in the aux-loss-free load balancing strategy. '
'The expert bias is updated based on the number of assigned tokens to each expert in a global batch, '
'where the bias is increased for the experts with less assigned tokens and decreased for the experts with more assigned tokens. '
'The default value 1e-3 is same as that used in DeepSeekV3.')
group.add_argument('--moe-use-legacy-grouped-gemm', action='store_true',
help='Use legacy GroupedMLP rather than TEGroupedMLP. Note: The legacy one will be deprecated soon.')
group.add_argument('--moe-aux-loss-coeff', type=float, default=0.0,
help='Scaling coefficient for the aux loss: a starting value of 1e-2 is recommended.')
group.add_argument('--moe-z-loss-coeff', type=float, default=None,
......@@ -2120,9 +2327,11 @@ def _add_moe_args(parser):
group.add_argument('--moe-input-jitter-eps', type=float, default=None,
help='Add noise to the input tensor by applying jitter with a specified epsilon value.')
group.add_argument('--moe-token-dispatcher-type', type=str,
choices=['allgather', 'alltoall', 'alltoall_seq'],
choices=['allgather', 'alltoall', 'flex', 'alltoall_seq'],
default='allgather',
help="The type of token dispatcher to use. The default is 'allgather'. Options are 'allgather', 'alltoall' and 'alltoall_seq'. We recommend using 'alltoall' when applying expert parallelism. For more information, please refer to the documentation in core/moe/README.")
group.add_argument('--moe-enable-deepep', action='store_true',
help='[Experimental] Enable DeepSeek/DeepEP for efficient token dispatching and combine in MoE models. Only works with flex token dispatcher by setting --moe-token-dispatcher-type=flex.')
group.add_argument('--moe-per-layer-logging', action='store_true',
help='Enable per-layer logging for MoE, currently supports auxiliary loss and z loss.')
# Token dropping arguments
......@@ -2139,6 +2348,8 @@ def _add_moe_args(parser):
group.add_argument('--moe-use-upcycling', action='store_true',
help='Load a checkpoint of a dense model, convert it into an MoE model, and save the converted model to the path specified by --save. '
'Upcycling is implemented on the top of distributed checkpointing, so it supports parallel modes different from the dense model.')
group.add_argument('--moe-permute-fusion', action='store_true',
help='Fuse token rearrangement ops during token dispatching.')
return parser
......@@ -2156,6 +2367,10 @@ def _add_mla_args(parser):
help="Dimension of the head in the V projection.")
group.add_argument('--rotary-scaling-factor', type=float, default=1.0,
help="Rotary scaling factor for the rotary embeddings.")
group.add_argument('--mscale', type=float, default=1.0,
help="Mscale for YaRN RoPE in multi-latent attention.")
group.add_argument('--mscale-all-dim', type=float, default=1.0,
help="Mscale all dimensions for YaRN RoPE in multi-latent attention.")
return parser
......@@ -2185,4 +2400,18 @@ def _add_experimental_args(parser):
'the overidden pattern')
group.add_argument('--yaml-cfg', type=str, default=None,
help = 'Config file to add additional arguments')
# Args of precision-aware optimizer
group.add_argument('--use-precision-aware-optimizer', action='store_true',
help='Use the precision-aware optimizer in TransformerEngine, which allows '
'setting the main params and optimizer states to lower precision, such as '
'fp16 and fp8.')
group.add_argument('--main-grads-dtype', default='fp32', choices=['fp32', 'bf16'],
help='Dtype of main grads when enabling precision-aware-optimizer')
group.add_argument('--main-params-dtype', default='fp32', choices=['fp32', 'fp16'],
help='Dtype of main params when enabling precision-aware-optimizer')
group.add_argument('--exp-avg-dtype', default='fp32', choices=['fp32', 'fp16', 'fp8'],
help='Dtype of exp_avg when enabling precision-aware-optimizer')
group.add_argument('--exp-avg-sq-dtype', default='fp32', choices=['fp32', 'fp16', 'fp8'],
help='Dtype of exp_avg_sq when enabling precision-aware-optimizer')
return parser
......@@ -13,11 +13,19 @@ from megatron.training.utils import print_rank_0
logger = logging.getLogger(__name__)
# Singleton manager of async calls
# The default is `TemporalAsyncCaller`
_async_calls_queue = AsyncCallsQueue()
def init_persistent_async_worker():
global _async_calls_queue
# Recreate the async_calls_queue for persistent worker
# This duplicate step is for backward compatiblity
_async_calls_queue = AsyncCallsQueue(persistent=True)
def schedule_async_save(async_request: AsyncRequest):
""" Schedule the async save request.
"""Schedule the async save request.
Args:
async_request (AsyncRequest): the async save request.
......@@ -25,19 +33,33 @@ def schedule_async_save(async_request: AsyncRequest):
_async_calls_queue.schedule_async_request(async_request)
def maybe_finalize_async_save(blocking: bool = False):
""" Finalizes active async save calls.
def maybe_finalize_async_save(blocking: bool = False, terminate=False):
"""Finalizes active async save calls.
Args:
blocking (bool, optional): if True, will wait until all active requests
are done. Otherwise, finalizes only the async request that already
finished. Defaults to False.
terminate (bool, optional): if True, the asynchronous queue will
be closed as the last action of this function.
"""
args = get_args()
if not args.async_save:
return
if blocking and _async_calls_queue.get_num_unfinalized_calls() > 0:
if blocking and not is_empty_async_queue():
print_rank_0('Unfinalized async checkpoint saves. Finalizing them synchronously now.')
_async_calls_queue.maybe_finalize_async_calls(blocking)
_async_calls_queue.maybe_finalize_async_calls(blocking, no_dist=False)
if terminate:
_async_calls_queue.close()
def is_empty_async_queue() -> bool:
"""Check if async calls queue is empty. This result is consistent across ranks.
Returns:
bool: True if there is any ongoing async call.
"""
return _async_calls_queue.get_num_unfinalized_calls() == 0
......@@ -20,21 +20,20 @@ import torch
from megatron.core import mpu, tensor_parallel, dist_checkpointing
from megatron.core.dist_checkpointing.mapping import ShardedObject
from megatron.core.dist_checkpointing.serialization import get_default_load_sharded_strategy
from megatron.core.dist_checkpointing.state_dict_transformation import (
prepare_state_dict_for_save,
recreate_state_dict_after_load,
)
from megatron.core.dist_checkpointing.strategies.fully_parallel import \
FullyParallelSaveStrategyWrapper, FullyParallelLoadStrategyWrapper
from megatron.core.num_microbatches_calculator import update_num_microbatches
from megatron.core.utils import is_float8tensor
from megatron.core.fp8_utils import is_float8tensor
from megatron.core.rerun_state_machine import get_rerun_state_machine
from .async_utils import schedule_async_save
from .async_utils import schedule_async_save, is_empty_async_queue
from .global_vars import get_args, get_one_logger
from .utils import unwrap_model, print_rank_0, append_to_progress_log, is_last_rank
from ..core.dist_checkpointing.serialization import \
get_default_save_sharded_strategy
from .one_logger_utils import on_save_checkpoint_start, on_save_checkpoint_success
from . import wandb_utils
from . import ft_integration
# [ModelOpt]: Import
try:
......@@ -305,7 +304,7 @@ class CheckpointType(Enum):
def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floating_point_operations_so_far,
checkpointing_context=None, pipeline_rank=None, expert_rank=None, tensor_rank=None, pipeline_parallel=None, expert_parallel=None, non_persistent_ckpt=False,
train_data_iterator=None, ft_client=None, preprocess_common_state_dict_fn = None):
train_data_iterator=None, preprocess_common_state_dict_fn = None):
"""Save a model, optimizer and optionally dataloader checkpoint.
Checkpointing context is used to persist some checkpointing state
......@@ -315,8 +314,7 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floati
the checkpoint will be saved with special functionality for removing old checkpoints.
There are several types of non-persistent checkpoints:
"global" - Saved as a standard checkpoint (e.g., on Lustre) with old checkpoints being removed.
"local" - [TBD] Each rank saves a portion of the checkpoint locally (e.g., on SSD/ramdisk).
"in_memory" - [TBD] A special kind of local checkpoint that avoids serialization.
"local" - Each rank saves a portion of the checkpoint locally (e.g., on SSD/ramdisk).
Dataloader checkpoint is only saved if the dataloader supports it. Currently this applies only
to the Megatron Energon dataloader (multimodal) and not the built-in Megatron dataloader (text-only).
......@@ -324,9 +322,15 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floati
start_ckpt = time()
args = get_args()
if args.async_save and not is_empty_async_queue():
print_rank_0('WARNING: Starting a checkpoint save before previous has finished. Consider increasing the checkpoint interval.')
# Prepare E2E metrics at start of save checkpoint
productive_metrics = on_save_checkpoint_start(args.async_save)
# Monitor for the checkpointing timeout (no-op if FT is not enabled)
ft_integration.on_checkpointing_start()
# Only rank zero of the data parallel writes to the disk.
model = unwrap_model(model)
......@@ -347,7 +351,6 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floati
save_dir, leave_ckpt_num=1, do_async=args.async_save
)
elif args.non_persistent_ckpt_type == 'local':
raise RuntimeError('LocalCheckpointManagers are not yet integrated')
ckpt_type = CheckpointType.LOCAL
save_dir = checkpointing_context['local_checkpoint_manager'].local_ckpt_dir
else:
......@@ -361,13 +364,19 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floati
# Collect rng state across data parallel ranks.
rng_state = get_rng_state(ckpt_type != CheckpointType.LEGACY)
# Collect rerun state across all ranks
rerun_state_machine = get_rerun_state_machine()
rerun_state = rerun_state_machine.state_dict(
data_iterator=train_data_iterator, use_dist_ckpt=ckpt_type != CheckpointType.LEGACY
)
# Checkpoint name.
return_base_dir = (ckpt_type != CheckpointType.LEGACY)
checkpoint_name = get_checkpoint_name(save_dir, iteration, release=False, pipeline_parallel=pipeline_parallel,
tensor_rank=tensor_rank, pipeline_rank=pipeline_rank, expert_parallel=expert_parallel, expert_rank=expert_rank, return_base_dir=return_base_dir)
# Save dataloader state if the dataloader supports it (currently only Megatron Energon).
save_dataloader_state(train_data_iterator, iteration, getattr(args, "dataloader_save", None))
maybe_save_dataloader_state(train_data_iterator, iteration, getattr(args, "dataloader_save", None))
# Save distributed optimizer's custom parameter state.
if (
......@@ -379,7 +388,8 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floati
optim_checkpoint_name = \
get_distributed_optimizer_checkpoint_name(checkpoint_name)
ensure_directory_exists(optim_checkpoint_name)
optimizer.save_parameter_state(optim_checkpoint_name)
if not optimizer.is_stub_optimizer:
optimizer.save_parameter_state(optim_checkpoint_name)
async_save_request = None
if args.async_save:
......@@ -409,11 +419,9 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floati
use_dist_ckpt=ckpt_type != CheckpointType.LEGACY,
iteration=iteration,
optim_sd_kwargs=optim_sd_kwargs,
train_data_iterator=train_data_iterator,
rerun_state=rerun_state,
)
if args.enable_ft_package and ft_client is not None:
state_dict["ft_state"] = ft_client.state_dict()
state_dict['num_floating_point_operations_so_far'] = num_floating_point_operations_so_far
if ckpt_type == CheckpointType.GLOBAL:
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
......@@ -428,6 +436,14 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floati
save_strategy = get_default_save_sharded_strategy(args.ckpt_format)
if args.ckpt_assume_constant_structure and args.ckpt_format == 'torch_dist':
save_strategy.use_cached_ckpt_structure = args.ckpt_assume_constant_structure
if checkpointing_context is not None and 'load_strategy' in checkpointing_context:
cached_global_metadata = getattr(checkpointing_context['load_strategy'], 'cached_global_metadata', None)
if cached_global_metadata is not None:
logger.debug("Plugging in the read metadata from the load strategy...")
save_strategy.cached_global_metadata = cached_global_metadata
else:
logger.debug("Failed to plug in the read metadata from the load strategy...")
if args.ckpt_fully_parallel_save:
save_strategy = FullyParallelSaveStrategyWrapper(save_strategy, mpu.get_data_parallel_group(with_context_parallel=True),
args.ckpt_assume_constant_structure)
......@@ -446,26 +462,44 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floati
else:
# [ModelOpt]: Inject modelopt_state into state_dict
if has_nvidia_modelopt:
save_modelopt_state(model, state_dict)
if ckpt_type == CheckpointType.LOCAL:
print_rank_0('WARNING: Local checkpointing does not support nvidia_modelopt.')
else:
save_modelopt_state(model, state_dict)
end_ckpt = time()
logger.debug(f"rank: {rank}, takes {end_ckpt - start_ckpt} to prepare state dict for ckpt ")
if ckpt_type == CheckpointType.LOCAL:
state_dict_for_save = prepare_state_dict_for_save(
state_dict, algo=args.non_persistent_local_ckpt_algo
try:
from megatron.core.dist_checkpointing.tensor_aware_state_dict import MCoreTensorAwareStateDict
except ModuleNotFoundError:
raise RuntimeError("The 'nvidia_resiliency_ext' module is required for local "
"checkpointing but was not found. Please ensure it is installed.")
algo = args.non_persistent_local_ckpt_algo
cached_metadata = None
if args.ckpt_assume_constant_structure and 'local_checkpoint_cache' in checkpointing_context:
cached_metadata = checkpointing_context['local_checkpoint_cache']
state_dict_for_save, cacheable_metadata = MCoreTensorAwareStateDict.from_state_dict(
state_dict, algo=algo, cached_metadata=cached_metadata,
parallelization_group=mpu.get_data_parallel_group(with_context_parallel=True)
)
async_save_request = checkpointing_context['local_checkpoint_manager'].save(
state_dict_for_save, iteration, is_async=bool(args.async_save)
)
checkpointing_context['local_checkpoint_cache'] = cacheable_metadata
else:
assert ckpt_type == CheckpointType.LEGACY
# Save.
ensure_directory_exists(checkpoint_name)
torch.save(state_dict, checkpoint_name)
start_misc = time()
if not args.async_save:
assert async_save_request is None
# Wait so everyone is done (necessary)
if torch.distributed.is_initialized():
torch.distributed.barrier()
if ckpt_type != CheckpointType.LOCAL:
if not args.async_save:
assert async_save_request is None
# Wait so everyone is done (necessary)
if torch.distributed.is_initialized():
torch.distributed.barrier()
# And update the latest iteration
if not torch.distributed.is_initialized() \
......@@ -507,6 +541,17 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floati
else:
onelogger_finalize_fn()
# Additional callback for wandb (last rank)
if not torch.distributed.is_initialized() \
or is_last_rank():
def wandb_finalize_fn():
wandb_utils.on_save_checkpoint_success(checkpoint_name, get_checkpoint_tracker_filename(save_dir), save_dir, iteration)
if args.async_save:
assert async_save_request is not None
async_save_request.add_finalize_fn(wandb_finalize_fn)
else:
wandb_finalize_fn()
if args.async_save:
schedule_async_save(async_save_request)
print_rank_0(' scheduled an async checkpoint save at iteration {:7d} to {}' \
......@@ -519,6 +564,7 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floati
end_misc = time()
logger.debug(f"rank: {rank}, takes {end_misc - start_misc} to finalize ckpt save ")
ft_integration.on_checkpointing_end(is_async_finalization=False)
def cleanup_old_non_persistent_checkpoint(save_dir, leave_ckpt_num=1, do_async=False):
if torch.distributed.is_initialized() and torch.distributed.get_rank() != 0:
......@@ -543,7 +589,7 @@ def cleanup_old_non_persistent_checkpoint(save_dir, leave_ckpt_num=1, do_async=F
remove_iter_ckpts(rm_iter_ckpts)
def save_dataloader_state(train_iterator, iteration, dataloader_save_path):
def maybe_save_dataloader_state(train_iterator, iteration, dataloader_save_path):
"""Saves dataloader state if the dataloader supports it.
Currently, this is only used by Megatron Energon dataloader (multimodal) to store its state at a
......@@ -558,13 +604,13 @@ def save_dataloader_state(train_iterator, iteration, dataloader_save_path):
iteration (int): Current iteration.
dataloader_save_path (str): Path where the dataloader state is saved.
"""
# If no dataloader or saving path is provided, then exit early.
if train_iterator is None or dataloader_save_path is None:
# If no dataloader or saving path is provided, exit early, otherwise, raise an error.
if train_iterator is None or dataloader_save_path is None or dataloader_save_path == "":
return
# If dataloader doesn't support saving state, exit early.
if not hasattr(train_iterator, "save_state"):
return
# If dataloader doesn't support saving state, raise an error.
if not hasattr(train_iterator.iterable, "save_state"):
raise RuntimeError(f"Could not find a save_state for the train_iterator of type {type(train_iterator)}")
# Save dataloader state for each data parallel rank only once.
first_rank = mpu.is_pipeline_first_stage(ignore_virtual=True) and mpu.get_tensor_model_parallel_rank() == 0
......@@ -573,7 +619,7 @@ def save_dataloader_state(train_iterator, iteration, dataloader_save_path):
dp_rank = mpu.get_data_parallel_rank()
print(f"saving dataloader checkpoint at iteration {iteration} to {dataloader_save_path}")
train_dataloader_state_dict = train_iterator.save_state()
train_dataloader_state_dict = train_iterator.iterable.save_state()
data_state_save_path = get_checkpoint_name(
dataloader_save_path, iteration,
basename=f'train_dataloader_dprank{dp_rank:03d}.pt'
......@@ -593,7 +639,7 @@ def save_dataloader_state(train_iterator, iteration, dataloader_save_path):
def generate_state_dict(args, model, optimizer, opt_param_scheduler,
rng_state, use_dist_ckpt=False, iteration=None,
optim_sd_kwargs=None, train_data_iterator=None):
optim_sd_kwargs=None, rerun_state=None):
# Arguments, iteration, and model.
state_dict = {}
state_dict['args'] = args
......@@ -614,7 +660,7 @@ def generate_state_dict(args, model, optimizer, opt_param_scheduler,
model[i].state_dict_for_save_checkpoint())
# Optimizer stuff.
if not args.no_save_optim:
if optimizer is not None:
if optimizer is not None and not optimizer.is_stub_optimizer:
state_dict['optimizer'] = (optimizer.sharded_state_dict(state_dict, **(optim_sd_kwargs or {}))
if use_dist_ckpt else
optimizer.state_dict())
......@@ -623,10 +669,7 @@ def generate_state_dict(args, model, optimizer, opt_param_scheduler,
opt_param_scheduler.state_dict()
# Rerun state
rerun_state_machine = get_rerun_state_machine()
state_dict['rerun_state_machine'] = rerun_state_machine.get_checkpoint_state(
train_data_iterator
)
state_dict['rerun_state_machine'] = rerun_state
# RNG states.
if not args.no_save_rng:
......@@ -719,8 +762,7 @@ def _get_non_persistent_iteration(non_persistent_global_dir, args, checkpointing
print_rank_0(' will not load any non-persistent checkpoint')
return iteration
elif args.non_persistent_ckpt_type == "local":
raise RuntimeError('LocalCheckpointManagers are not yet integrated')
return checkpointing_context['local_checkpoint_manager'].get_latest_checkpoint_iteration()
return checkpointing_context['local_checkpoint_manager'].find_latest()
else:
assert False, 'Please use local or global non-persistent checkpoints' \
f'(got: {args.non_persistent_ckpt_type})'
......@@ -744,17 +786,17 @@ def _load_non_persistent_base_checkpoint(
f'Loading from a non-persistent checkpoint (non-persistent iter {non_persistent_iteration})'
)
return _load_global_dist_base_checkpoint(
non_persistent_global_dir, args, rank0, sharded_state_dict, non_persistent_iteration, False
non_persistent_global_dir, args, rank0, sharded_state_dict, non_persistent_iteration, False,
checkpointing_context=checkpointing_context
)
elif args.non_persistent_ckpt_type == "local":
raise RuntimeError('LocalCheckpointManagers are not yet integrated')
intermediate_state_dict, checkpoint_name = checkpointing_context[
'local_checkpoint_manager'
].load()
state_dict = recreate_state_dict_after_load(
state_dict = intermediate_state_dict.to_state_dict(
sharded_state_dict,
intermediate_state_dict,
algo=args.non_persistent_local_ckpt_algo,
parallelization_group = mpu.get_data_parallel_group(with_context_parallel=True)
)
return state_dict, checkpoint_name, False, CheckpointType.LOCAL
else:
......@@ -763,7 +805,7 @@ def _load_non_persistent_base_checkpoint(
def _load_global_dist_base_checkpoint(
load_dir, args, rank0, sharded_state_dict, iteration, release
load_dir, args, rank0, sharded_state_dict, iteration, release, checkpointing_context=None
):
""" Load the base state_dict from the given directory containing the global distributed checkpoint """
if rank0:
......@@ -787,6 +829,8 @@ def _load_global_dist_base_checkpoint(
load_strategy = FullyParallelLoadStrategyWrapper(
load_strategy, mpu.get_data_parallel_group(with_context_parallel=True)
)
if checkpointing_context is not None:
checkpointing_context["load_strategy"] = load_strategy
state_dict = dist_checkpointing.load(sharded_state_dict, checkpoint_name, load_strategy, strict=args.dist_ckpt_strictness)
return state_dict, checkpoint_name, release, CheckpointType.GLOBAL
......@@ -860,7 +904,7 @@ def _load_base_checkpoint(
# Handle global distributed checkpoint
if is_dist_ckpt:
return _load_global_dist_base_checkpoint(
load_dir, args, rank0, sharded_state_dict, iteration, release
load_dir, args, rank0, sharded_state_dict, iteration, release, checkpointing_context=checkpointing_context
)
# Handle global legacy checkpoint
if rank0:
......@@ -1035,7 +1079,7 @@ def fix_fp8_params_lose_precision_when_loading_dist_ckpt(state_dict):
def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', strict=True,
ft_client=None, checkpointing_context=None, skip_load_to_model_and_opt=False):
checkpointing_context=None, skip_load_to_model_and_opt=False):
"""Load a model checkpoint and return the iteration.
strict (bool): whether to strictly enforce that the keys in
:attr:`state_dict` of the checkpoint match the names of
......@@ -1059,7 +1103,8 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
raise FileNotFoundError("No checkpoint found in load directory or pretrained directory")
args.finetune = True
model = unwrap_model(model)
ddp_model = model
model = unwrap_model(ddp_model)
load_kwargs = {}
is_dist_ckpt = False
......@@ -1074,11 +1119,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
rank0=True,
checkpointing_context=checkpointing_context,
)
if args.enable_ft_package and ft_client is not None and state_dict is not None:
if 'ft_state' in state_dict:
ft_client.load_state_dict(state_dict['ft_state'])
else:
print_rank_0("ft_state is not present in state_dict")
is_dist_ckpt = (
ckpt_type == CheckpointType.LOCAL
or dist_checkpointing.check_is_distributed_checkpoint(checkpoint_name)
......@@ -1136,6 +1177,32 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
gen_sd_optim = None
gen_sd_opt_param_scheduler = None
# Determine if rerun state will be loaded
if (
ckpt_tp_pp == run_tp_pp
and not release
and not args.finetune
and 'rerun_state_machine' in state_dict
):
rerun_state_machine = get_rerun_state_machine()
gen_sd_rerun_state = rerun_state_machine.state_dict(
data_iterator=None, use_dist_ckpt=True
)
else:
gen_sd_rerun_state = None
if ckpt_tp_pp != run_tp_pp:
print_rank_0("{}: Rerun state will be ignored".format(mismatch_msg))
# [ModelOpt]: IMPORTANT! Restoring modelopt_state (sharded or not) must be performed
# after the model instance has been created and before _load_base_checkpoint is called.
if has_nvidia_modelopt:
if ckpt_type == CheckpointType.LOCAL:
print_rank_0('WARNING: Local checkpointing does not support nvidia_modelopt.')
elif ckpt_type == CheckpointType.GLOBAL:
restore_modelopt_state(model, state_dict)
else:
restore_sharded_modelopt_state(model, checkpoint_name)
# [ModelOpt]: Initial loading from non-resume sharded checkpoint to a Distillation Model
# will result in key mismatch with loss modules potentially containing parameters, since
# it requires generating a state_dict before loading. Here we hide those modules if present.
......@@ -1145,9 +1212,9 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
stack.enter_context(m.hide_loss_modules())
load_kwargs['sharded_state_dict'] = generate_state_dict(
args, model, gen_sd_optim, gen_sd_opt_param_scheduler, gen_sd_rng_state,
use_dist_ckpt=True, optim_sd_kwargs=optim_sd_kwargs, train_data_iterator=None
use_dist_ckpt=True, optim_sd_kwargs=optim_sd_kwargs, rerun_state=gen_sd_rerun_state
)
# When "--fp8-param-gather" is disabled, this function doesn't modify anything.
fix_fp8_params_lose_precision_when_loading_dist_ckpt(load_kwargs['sharded_state_dict'])
......@@ -1156,12 +1223,6 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
**load_kwargs
)
if args.enable_ft_package and ft_client is not None and state_dict is not None:
if 'ft_state' in state_dict:
ft_client.load_state_dict(state_dict['ft_state'])
else:
print_rank_0("ft_state is not present in state_dict")
# Checkpoint not loaded.
if state_dict is None:
# Iteration and num_floating_point_operations_so_far default to 0.
......@@ -1202,24 +1263,15 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
else:
print_rank_0('could not find arguments in the checkpoint ...')
# [ModelOpt]: loading modelopt_state (sharded or not)
if has_nvidia_modelopt:
if ckpt_type == CheckpointType.LOCAL:
raise NotImplementedError('Local checkpointing does not support model opt')
if not args.use_dist_ckpt:
restore_modelopt_state(model, state_dict)
else:
restore_sharded_modelopt_state(model, checkpoint_name)
# Model.
strict = False if args.retro_add_retriever else strict
if not skip_load_to_model_and_opt:
if len(model) == 1:
model[0].load_state_dict(state_dict['model'], strict=strict)
if len(ddp_model) == 1:
ddp_model[0].load_state_dict(state_dict['model'], strict=strict)
else:
for i in range(len(model)):
for i in range(len(ddp_model)):
mpu.set_virtual_pipeline_model_parallel_rank(i)
model[i].load_state_dict(state_dict['model%d' % i], strict=strict)
ddp_model[i].load_state_dict(state_dict['model%d' % i], strict=strict)
# Fix up query/key/value matrix ordering if needed.
checkpoint_version = get_checkpoint_version()
......@@ -1230,7 +1282,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
if not release and not args.finetune and not args.no_load_optim:
try:
# Load state dict.
if not skip_load_to_model_and_opt and optimizer is not None:
if not skip_load_to_model_and_opt and optimizer is not None and not optimizer.is_stub_optimizer:
optimizer.load_state_dict(state_dict['optimizer'])
# Load distributed optimizer's custom parameter state.
......@@ -1268,7 +1320,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
# rerun state
try:
if 'rerun_state_machine' in state_dict:
get_rerun_state_machine().set_checkpoint_state(state_dict['rerun_state_machine'])
get_rerun_state_machine().load_state_dict(state_dict['rerun_state_machine'])
except Exception as e:
print(f"Unable to restore RerunMachine from checkpoint: {e}")
sys.exit()
......@@ -1317,7 +1369,18 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
f'p {mpu.get_pipeline_model_parallel_rank() + 1}/{mpu.get_pipeline_model_parallel_world_size()} ] '
f'at iteration {iteration}')
# Additional callback for wandb (last rank)
if not torch.distributed.is_initialized() \
or is_last_rank():
wandb_utils.on_load_checkpoint_success(checkpoint_name, load_dir)
torch.cuda.empty_cache()
if iteration > 0:
# Notify FT that a checkpoint was loaded.
is_local_chkpt = (ckpt_type == CheckpointType.LOCAL)
ft_integration.on_checkpoint_loaded(is_local_chkpt=is_local_chkpt)
return iteration, num_floating_point_operations_so_far
......
File mode changed from 100755 to 100644
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
"""
FT Package Integration
Fault Tolerance (FT) package integration for Megatron-LM, using the FT section-based API.
This file is part of the integration process for the FT package, a custom heartbeat-based
system developed by NVIDIA. The FT package monitors the ranks to detect hangs, gracefully
terminates the workload, and respawns it from the last checkpoints. It includes an auto
config feature that automatically sets up timeouts based on the observed time of iterations.
The FT package is included in "nvidia-resiliency-ext"
(https://github.com/NVIDIA/nvidia-resiliency-ext).
Note: This tool is an internal NVIDIA tool and is not open source. This file does not
contain the FT package itself but supports its integration.
NOTE: The workload must be run using the `ft_launcher` tool provided by `nvidia-resiliency-ext.`
NOTE: Calls to the public API of this module are no-ops if FT is not initialized
(`ft_integration.setup` was not called).
NOTE: Default distributed process group should be initialized before calling `ft_integration.setup`
The "setup" FT section is opened during FT initialization and closed before the first training or
eval iteration. Training and evaluation steps are wrapped in the "step" section, but only after a
few warmup iterations. This is because the initial iterations may be slower, and we want the "step"
timeout to be short. These warmup steps, which are not wrapped in the "step" section, will fall into
the out-of-section area. All checkpoint-saving-related operations (including asynchronous
checkpointing finalization) are wrapped in the "checkpointing" section.
If timeout calculation is enabled (--calc-ft-timeouts),
FT timeouts are updated after each checkpoint and at the end of the run.
Updated values are based on observed intervals.
`ft_launcher` command example:
```
ft_launcher \
--rdzv_backend=c10d --rdzv_endpoint=${MASTER_ADDR}:${MASTER_PORT} \
--nnodes=${NUM_NODES} --nproc-per-node=${NUM_GPUS_PER_NODE} \
--ft-param-rank_section_timeouts=setup:600,step:180,checkpointing:420 \
--ft-param-rank_out_of_section_timeout=300 \
train_script_with_ft.py
```
"""
import types
from enum import Enum, auto
import argparse
import json
import os
import random
import signal
import sys
import threading
import time
from typing import Any, Optional
import torch
from . import global_vars
from .utils import is_rank0, print_rank_0
class StateMachineActions(Enum):
NONE = auto()
SAVE_CHECKPOINT = auto()
TRAIN_HEARTBEAT = auto()
EVAL_HEARTBEAT = auto()
UPDATE_TIMEOUT = auto()
_GLOBAL_RANK_MONITOR_CLIENT = None
_ft_state_path = None
_is_persistent_chkpt_loaded = False
_is_async_chkpt_enabled = False
_is_calculating_timeouts = False
_is_setup_section_open = False
_seen_checkpoints_cnt = 0
_seen_tr_iters_cnt = 0
_curr_eval_iter_idx = 0
_NUM_WARMUP_ITERS = 1
_MIN_ITERS_FOR_STEP_TIMEOUT_UPDATE = 16
class _TrainingStateMachine:
"""
This class encapsulates logic for determining when:
- FT timeouts can be updated (`.can_update_timeouts` property)
`on_ ...` methods update the state and should be called from the corresponding places.
def get_rank_monitor_client() -> Optional[Any]:
"""Returns the underlying fault tolerance client instance
Returns:
RankMonitorClient: rank monitor client instance, or None if FT was not initialized
"""
return _GLOBAL_RANK_MONITOR_CLIENT
MIN_ITERS_FOR_TIMEOUT_UPDATE = 2
def __init__(self):
self.num_tr_iters_total = 0
self.num_tr_iter_at_last_save = None
self.seen_checkpointing = False
self.timeouts_updated = False
def on_save_checkpoint(self):
self.num_tr_iter_at_last_save = self.num_tr_iters_total
def on_train_heartbeat(self):
self.num_tr_iters_total += 1
if not self.seen_checkpointing and self.num_tr_iter_at_last_save is not None:
# detect mid-epoch checkpointing that makes hearbeat interval longer
iters_pre_save = self.num_tr_iter_at_last_save
iters_post_save = self.num_tr_iters_total - self.num_tr_iter_at_last_save
self.seen_checkpointing = iters_pre_save > 0 and iters_post_save > 0
def on_eval_heartbeat(self):
pass
def on_timeouts_updated(self):
self.timeouts_updated = True
@property
def can_update_timeouts(self) -> bool:
"""
Returns True if new timeouts can be computed.
`.on_timeouts_updated()` resets this property back to False.
"""
if self.timeouts_updated:
# timeouts are updated at most once per training run
return False
if self.num_tr_iters_total < self.MIN_ITERS_FOR_TIMEOUT_UPDATE:
# need a few training iters
return False
# check if there was checkoint saving
# this makes heartbeat iterval longer than usual.
return self.seen_checkpointing
def perform_action(self, action: StateMachineActions):
if action == StateMachineActions.TRAIN_HEARTBEAT:
self.on_train_heartbeat()
elif action == StateMachineActions.SAVE_CHECKPOINT:
self.on_save_checkpoint()
elif action == StateMachineActions.EVAL_HEARTBEAT:
self.on_eval_heartbeat()
elif action == StateMachineActions.UPDATE_TIMEOUT:
self.on_timeouts_updated()
assert not self.can_update_timeouts
# No action for StateMachineActions.NONE
def setup(args: argparse.Namespace) -> None:
"""Initialize fault tolerance
_GLOBAL_RANK_MONITOR_CLIENT = None
_GLOBAL_STATE_MACHINE = _TrainingStateMachine()
Args:
args (argparse.Namespace): parsed Megatron-LM command line arguments
def _set_rank_monitor_client():
Raises:
ValueError: if invalid config is provided
"""
from nvidia_resiliency_ext.fault_tolerance import RankMonitorClient
print_rank_0(f"FT: initializing...")
checkpoint_dir = args.save
if not checkpoint_dir:
raise ValueError("checkpointing save dir must be set to enable fault tolerance")
if is_rank0() and not os.path.exists(checkpoint_dir):
# MLM checkpoint dir will be needed for saving FT state.
# it can happen before the checkpointing, so create it in advance
os.makedirs(checkpoint_dir, exist_ok=True)
cli = RankMonitorClient()
global _GLOBAL_RANK_MONITOR_CLIENT
global_vars._ensure_var_is_not_initialized(_GLOBAL_RANK_MONITOR_CLIENT, 'rank monitor client')
_GLOBAL_RANK_MONITOR_CLIENT = cli
def get_rank_monitor_client(action=StateMachineActions.NONE):
global _GLOBAL_RANK_MONITOR_CLIENT, _GLOBAL_STATE_MACHINE
if _GLOBAL_RANK_MONITOR_CLIENT is None:
try:
_set_rank_monitor_client()
except ImportError:
_GLOBAL_RANK_MONITOR_CLIENT = None
_GLOBAL_STATE_MACHINE.perform_action(action)
return _GLOBAL_RANK_MONITOR_CLIENT
global _ft_state_path
_ft_state_path = os.path.join(checkpoint_dir, "ft_state.json")
global _is_async_chkpt_enabled
_is_async_chkpt_enabled = args.async_save
global _is_calculating_timeouts
_is_calculating_timeouts = args.calc_ft_timeouts
cli.init_workload_monitoring()
_load_state_if_exists()
print_rank_0(f"FT: initialized. Timeouts={cli.section_timeouts}")
cli.start_section("setup")
global _is_setup_section_open
_is_setup_section_open = True
def on_training_step_start() -> None:
"""Should be called before each training step"""
rmon_cli = get_rank_monitor_client()
if rmon_cli is not None:
global _is_setup_section_open
if _is_setup_section_open:
rmon_cli.end_section("setup")
_is_setup_section_open = False
if _seen_tr_iters_cnt >= _NUM_WARMUP_ITERS:
rmon_cli.start_section("step")
# reset eval step index. we started training, so evaluation is done
global _curr_eval_iter_idx
_curr_eval_iter_idx = 0
def on_training_step_end() -> None:
"""Should be called after each training step"""
rmon_cli = get_rank_monitor_client()
if rmon_cli is not None:
global _seen_tr_iters_cnt
if _seen_tr_iters_cnt >= _NUM_WARMUP_ITERS:
rmon_cli.end_section("step")
_seen_tr_iters_cnt += 1
def on_eval_step_start() -> None:
"""Should be called before each validation step"""
rmon_cli = get_rank_monitor_client()
if rmon_cli is not None:
global _is_setup_section_open
if _is_setup_section_open:
# setup section can be open if there were no training iters before evaluation
rmon_cli.end_section("setup")
_is_setup_section_open = False
if _curr_eval_iter_idx >= _NUM_WARMUP_ITERS:
rmon_cli.start_section("step")
def on_eval_step_end() -> None:
"""Should be called after each validation step"""
rmon_cli = get_rank_monitor_client()
if rmon_cli is not None:
global _curr_eval_iter_idx
if _curr_eval_iter_idx >= _NUM_WARMUP_ITERS:
rmon_cli.end_section("step")
_curr_eval_iter_idx += 1
def on_checkpointing_start() -> None:
"""Should be called before each checkpoint-saving-related operation."""
rmon_cli = get_rank_monitor_client()
if rmon_cli is not None:
rmon_cli.start_section("checkpointing")
def on_checkpointing_end(is_async_finalization: bool) -> None:
"""Should be called after each checkpoint-saving-related operation.
Args:
is_async_finalization (bool): true if called after an async checkpointing finalization
"""
rmon_cli = get_rank_monitor_client()
if rmon_cli is not None:
rmon_cli.end_section("checkpointing")
# async checkpointing finalization is called before each training iter, it can be no-op.
# let's try to update the timeouts only on the `save_checkpoint`
if not is_async_finalization:
global _seen_checkpoints_cnt
_seen_checkpoints_cnt += 1
_maybe_update_timeouts()
def on_checkpoint_loaded(is_local_chkpt: bool) -> None:
"""Should be called after a checkpoint was loaded
Args:
is_local_chkpt (bool): true if it was a local checkpoint, false if global
"""
# checkpoint can be loaded during "setup"
# check if persistent checkpoint was loaded,
# in-memory checkpoint reading can be very fast,
# so we could underestimate the "setup" timeout
global _is_persistent_chkpt_loaded
_is_persistent_chkpt_loaded = not is_local_chkpt
def shutdown() -> None:
"""Shutdowns fault folerance, updates the FT timeouts if possible"""
global _GLOBAL_RANK_MONITOR_CLIENT
rmon_cli = get_rank_monitor_client()
if rmon_cli is not None:
print_rank_0("FT: closing...")
_maybe_update_timeouts(is_closing_ft=True)
rmon_cli.shutdown_workload_monitoring()
print_rank_0("FT: closed.")
_GLOBAL_RANK_MONITOR_CLIENT = None
def _load_state_if_exists():
rmon_cli = get_rank_monitor_client()
if os.path.exists(_ft_state_path):
with open(_ft_state_path, "r") as f:
ft_state = json.load(f)
rmon_cli.load_state_dict(ft_state)
print_rank_0(f"FT: loaded timeouts from {_ft_state_path}. {rmon_cli.section_timeouts}")
def _update_timeouts(selected_sections, calc_out_of_section):
print_rank_0(
f"FT: updating timeouts for: {selected_sections} "
+ f"update out-of-section: {calc_out_of_section} ..."
)
rmon_cli = get_rank_monitor_client()
rmon_cli.calculate_and_set_section_timeouts(
selected_sections=selected_sections, calc_out_of_section=calc_out_of_section
)
if is_rank0():
ft_state = rmon_cli.state_dict()
with open(_ft_state_path, "w") as f:
json.dump(ft_state, f)
print_rank_0(f"FT: updated timeouts saved to {_ft_state_path}. {rmon_cli.section_timeouts}")
def _maybe_update_timeouts(is_closing_ft=False):
rmon_cli = get_rank_monitor_client()
if rmon_cli is None:
return
if not _is_calculating_timeouts:
return
# Decide which section timeouts can be updated
sections_to_update = []
if _is_persistent_chkpt_loaded:
sections_to_update.append("setup")
else:
print_rank_0(
"FT: can't update the setup section timeout until persistent checkpoint is loaded"
)
if _seen_tr_iters_cnt >= _MIN_ITERS_FOR_STEP_TIMEOUT_UPDATE:
sections_to_update.append("step")
else:
print_rank_0("FT: need to see more training iterations to update the step section timeout")
if _seen_checkpoints_cnt > 0:
if not _is_async_chkpt_enabled:
sections_to_update.append("checkpointing")
else:
# There can be too much checkpointing section time variability
# across runs with the async checkpointing, e.g. in some runs all checkpointing
# work can be parallelized (=short checkpointing sections) while in others we can
# hit a costly finalization.
print_rank_0(
"FT: can't update the checkpointing section timeout with async checkpointing"
)
else:
print_rank_0("FT: checkpointing section is not updated until a checkpoint was saved")
update_out_of_section = False
if is_closing_ft:
# with async checkpointing, "checkpointing" section is not updated,
# but still we want to see some checkpointing to ensure that is was a complete run
if {'setup', 'step'}.issubset(sections_to_update) and _seen_checkpoints_cnt > 0:
update_out_of_section = True
else:
print_rank_0(
"FT: the out-of-section timeout won't be updated until all FT sections were seen"
)
else:
print_rank_0("FT: the out-of-section timeout won't be updated as the FT is not closing yet")
if sections_to_update or update_out_of_section:
_update_timeouts(
selected_sections=sections_to_update, calc_out_of_section=update_out_of_section
)
def maybe_setup_simulated_fault() -> None:
"""Sets a simulated fault, based on `FT_SIM_FAULT_DESC` env variable.
Simulated fault description format:
rank_hung|rank_killed;rank_to_fail|"";base_delay
NOTE: This if for FT testing only
"""
simulated_fault_desc = os.environ.get('FT_SIM_FAULT_DESC', None)
if not simulated_fault_desc:
return
fault_type: Any # silence mypy
rank_to_fail: Any # silence mypy
base_delay: Any # silence mypy
fault_type, rank_to_fail, base_delay = simulated_fault_desc.split(';')
fault_type = fault_type.strip()
rank_to_fail = rank_to_fail.strip()
rank_to_fail = int(rank_to_fail) if rank_to_fail else None
base_delay = float(base_delay.strip())
rng = random.Random()
print_rank_0(
f"FT: Initializing simulated fault: {fault_type},"
+ f"rank to fail: {rank_to_fail}, base delay: {base_delay}"
)
# rank that simulates a fault can be explicitly specified in the `rank_to_fail` field
# if not specified, it just picks a random rank
rank = torch.distributed.get_rank()
rand_rank = rng.randint(0, torch.distributed.get_world_size() - 1)
rank_to_fail = rank_to_fail if rank_to_fail is not None else rand_rank
rank_to_fail = torch.tensor([rank_to_fail], device=torch.cuda.current_device())
torch.distributed.broadcast(rank_to_fail, 0)
rank_to_fail = int(rank_to_fail.item())
if rank != rank_to_fail:
# this rank is not going to simulate a fault, nothing more to do
return
if fault_type == 'random':
fault_type = rng.choice(['rank_killed', 'rank_hung'])
if fault_type == 'rank_killed':
target_pid = os.getpid()
elif fault_type == 'rank_hung':
target_pid = os.getpid()
else:
raise Exception(f"Unknown fault type {fault_type} expected one of: rank_killed, rank_hung.")
# add some randomness to the delay
delay = base_delay + 0.2 * rng.random() * base_delay
print_rank_0(f"FT: Selected fault={fault_type}; target rank={rank_to_fail}; delay={delay}")
def __fault_thread():
time.sleep(delay)
for of in [sys.stdout, sys.stderr]:
print(
f"\n####\nFT: Simulating fault: {fault_type}; rank to fail: {rank_to_fail}\n####\n",
file=of,
flush=True,
)
if fault_type == 'rank_hung':
os.kill(target_pid, signal.SIGSTOP)
else:
os.kill(target_pid, signal.SIGKILL)
def can_update_timeouts():
global _GLOBAL_STATE_MACHINE
return _GLOBAL_STATE_MACHINE.can_update_timeouts
fault_sim_thread = threading.Thread(target=__fault_thread)
fault_sim_thread.daemon = True
fault_sim_thread.start()
File mode changed from 100755 to 100644
......@@ -2,29 +2,34 @@
"""Megatron initialization."""
import logging
import random
import os
import random
import time
import warnings
from datetime import timedelta
import numpy as np
import torch
from datetime import timedelta
from megatron.legacy import fused_kernels
from megatron.training import get_adlr_autoresume
from megatron.training import get_args
from megatron.training import get_tensorboard_writer
from megatron.core import mpu, tensor_parallel
from megatron.core.rerun_state_machine import initialize_rerun_state_machine, RerunErrorInjector, RerunDiagnostic, RerunMode
from megatron.training.arguments import parse_args, validate_args
from megatron.training.yaml_arguments import validate_yaml
from megatron.training.checkpointing import load_args_from_checkpoint
from megatron.training.global_vars import set_global_variables
from megatron.core.fusions.fused_bias_dropout import bias_dropout_add_fused_train
from megatron.core.fusions.fused_bias_gelu import bias_gelu
from megatron.core.fusions.fused_bias_swiglu import bias_swiglu
from megatron.core.parallel_state import create_group
from megatron.core.rerun_state_machine import (
RerunDiagnostic,
RerunErrorInjector,
RerunMode,
initialize_rerun_state_machine,
)
from megatron.core.utils import get_te_version, is_te_min_version, is_torch_min_version
from megatron.legacy import fused_kernels
from megatron.training import get_adlr_autoresume, get_args, get_tensorboard_writer
from megatron.training.arguments import parse_args, validate_args
from megatron.training.async_utils import init_persistent_async_worker
from megatron.training.checkpointing import load_args_from_checkpoint
from megatron.training.global_vars import set_global_variables
from megatron.training.yaml_arguments import validate_yaml
logger = logging.getLogger(__name__)
......@@ -36,7 +41,7 @@ def initialize_megatron(
allow_no_cuda=False,
skip_mpu_initialization=False,
get_embedding_ranks=None,
get_position_embedding_ranks=None
get_position_embedding_ranks=None,
):
"""Set global variables, initialize distributed, and
set autoresume and random seeds.
......@@ -61,14 +66,21 @@ def initialize_megatron(
if args.use_checkpoint_args or args_defaults.get("use_checkpoint_args", False):
assert args.load is not None, "--use-checkpoint-args requires --load argument"
assert args.non_persistent_ckpt_type != "local", (
"--use-checkpoint-args is not supported with --non_persistent_ckpt_type=local. "
"Two-stage checkpoint loading is not implemented, and all arguments must be defined "
"before initializing LocalCheckpointManager."
)
load_args_from_checkpoint(args)
if args.async_save and args.use_persistent_ckpt_worker:
init_persistent_async_worker()
if args.yaml_cfg is not None:
args = validate_yaml(args, args_defaults)
else:
validate_args(args, args_defaults)
# set global args, build tokenizer, and set adlr-autoresume,
# tensorboard-writer, and timers.
set_global_variables(args)
......@@ -78,10 +90,8 @@ def initialize_megatron(
# init rerun state
def state_save_func():
return {
'rng_tracker_states': tensor_parallel.get_cuda_rng_tracker().get_states()
}
return {'rng_tracker_states': tensor_parallel.get_cuda_rng_tracker().get_states()}
def state_restore_func(state_dict):
if state_dict['rng_tracker_states']:
tensor_parallel.get_cuda_rng_tracker().set_states(state_dict['rng_tracker_states'])
......@@ -95,6 +105,7 @@ def initialize_megatron(
error_injection_rate=args.error_injection_rate,
error_injection_type=RerunDiagnostic(args.error_injection_type),
),
result_rejected_tracker_filename=args.result_rejected_tracker_filename,
)
# torch.distributed initialization
......@@ -106,7 +117,12 @@ def initialize_megatron(
# Random seeds for reproducibility.
if args.rank == 0:
print("> setting random seeds to {} ...".format(args.seed))
_set_random_seed(args.seed, args.data_parallel_random_init)
_set_random_seed(
args.seed,
args.data_parallel_random_init,
args.te_rng_tracker,
args.inference_rng_tracker,
)
if skip_mpu_initialization:
return None
......@@ -133,8 +149,8 @@ def initialize_megatron(
_compile_dependencies()
if args.tp_comm_overlap:
#TODO: Should this be activated with just decoder-tp-comm-overlap too?
_initialize_tp_communicators()
# TODO: Should this be activated with just decoder-tp-comm-overlap too?
_initialize_tp_communicators()
# No continuation function
return None
......@@ -172,17 +188,10 @@ def _compile_dependencies():
# Constraints on sequence length and attn_batch_size to enable warp based
# optimization and upper triangular optimization (for causal mask)
custom_kernel_constraint = (
seq_len > 16
and seq_len <= 16384
and seq_len % 4 == 0
and attn_batch_size % 4 == 0
seq_len > 16 and seq_len <= 16384 and seq_len % 4 == 0 and attn_batch_size % 4 == 0
)
# Print a warning.
if not (
(args.fp16 or args.bf16)
and custom_kernel_constraint
and args.masked_softmax_fusion
):
if not ((args.fp16 or args.bf16) and custom_kernel_constraint and args.masked_softmax_fusion):
if args.rank == 0:
print(
"WARNING: constraints for invoking optimized"
......@@ -192,14 +201,14 @@ def _compile_dependencies():
)
# Always build on rank zero first.
if torch.distributed.get_rank() == 0:
start_time = time.time()
print("> compiling and loading fused kernels ...", flush=True)
#fused_kernels.load(args)
torch.distributed.barrier()
else:
torch.distributed.barrier()
#fused_kernels.load(args)
# if torch.distributed.get_rank() == 0:
# start_time = time.time()
# print("> compiling and loading fused kernels ...", flush=True)
# fused_kernels.load(args)
# torch.distributed.barrier()
# else:
# torch.distributed.barrier()
# fused_kernels.load(args)
# Simple barrier to make sure all ranks have passed the
# compilation phase successfully before moving on to the
# rest of the program. We think this might ensure that
......@@ -212,48 +221,65 @@ def _compile_dependencies():
flush=True,
)
def _initialize_tp_communicators():
""" initializing the communicators with user buffers for high-performance tensor-model-parallel
communication overlap """
"""initializing the communicators with user buffers for high-performance tensor-model-parallel
communication overlap"""
try:
import yaml
import transformer_engine
from transformer_engine.pytorch import module as te_module
import transformer_engine
import yaml
from transformer_engine.pytorch import module as te_module
except ImportError:
raise RuntimeError("Tensor Parallel Communication/GEMM Overlap optimization needs 'yaml' and "
"'transformer_engine' packages")
raise RuntimeError(
"Tensor Parallel Communication/GEMM Overlap optimization needs 'yaml' and "
"'transformer_engine' packages"
)
args = get_args()
if args.tp_comm_overlap_cfg is not None:
with open(args.tp_comm_overlap_cfg,"r") as stream:
ub_cfgs = yaml.safe_load(stream)
with open(args.tp_comm_overlap_cfg, "r") as stream:
ub_cfgs = yaml.safe_load(stream)
else:
ub_cfgs = {}
ub_cfgs = {}
if getattr(args, 'decoder_tp_comm_overlap', False):
input_shape = [(args.decoder_seq_length * args.micro_batch_size) // args.context_parallel_size , args.hidden_size]
input_shape = [
(args.decoder_seq_length * args.micro_batch_size) // args.context_parallel_size,
args.hidden_size,
]
else:
input_shape = [(args.seq_length * args.micro_batch_size) // args.context_parallel_size , args.hidden_size]
input_shape = [
(args.seq_length * args.micro_batch_size) // args.context_parallel_size,
args.hidden_size,
]
if is_te_min_version("1.9.0"):
# The process group with the target bootstrap backend is created in Transformer Engine.
te_module.base.initialize_ub(shape = input_shape, tp_size = args.tensor_model_parallel_size,
use_fp8 = (args.fp8 is not None) , ub_cfgs = ub_cfgs,
bootstrap_backend = args.tp_comm_bootstrap_backend)
te_module.base.initialize_ub(
shape=input_shape,
tp_size=args.tensor_model_parallel_size,
use_fp8=(args.fp8 is not None),
ub_cfgs=ub_cfgs,
bootstrap_backend=args.tp_comm_bootstrap_backend,
)
else:
if args.tp_comm_bootstrap_backend != 'mpi':
warnings.warn(
f"Transformer Engine v{get_te_version()} supports only MPI bootstrap backend."
)
# Create a MPI process group to help with TP communication overlap bootstrap.
torch.distributed.new_group(backend='mpi')
te_module.base.initialize_ub(shape = input_shape, tp_size = args.tensor_model_parallel_size,
use_fp8 = (args.fp8 is not None) , ub_cfgs = ub_cfgs)
create_group(backend='mpi', group_desc='TP_BOOTSTRAP_GROUP_MPI')
te_module.base.initialize_ub(
shape=input_shape,
tp_size=args.tensor_model_parallel_size,
use_fp8=(args.fp8 is not None),
ub_cfgs=ub_cfgs,
)
def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks):
"""Initialize torch.distributed and core model parallel."""
......@@ -264,14 +290,14 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks):
if args.rank == 0:
print(
"torch distributed is already initialized, "
"skipping initialization ...",
"torch distributed is already initialized, " "skipping initialization ...",
flush=True,
)
args.rank = torch.distributed.get_rank()
args.world_size = torch.distributed.get_world_size()
else:
if args.rank == 0:
print("> initializing torch distributed ...", flush=True)
# Manually set the device ids.
......@@ -283,7 +309,7 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks):
# Call the init process
init_process_group_kwargs = {
'backend' : args.distributed_backend,
'backend': args.distributed_backend,
'world_size': args.world_size,
'rank': args.rank,
'init_method': args.dist_url,
......@@ -303,6 +329,7 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks):
args.pipeline_model_parallel_size,
args.virtual_pipeline_model_parallel_size,
args.pipeline_model_parallel_split_rank,
pipeline_model_parallel_comm_backend=args.pipeline_model_parallel_comm_backend,
context_parallel_size=args.context_parallel_size,
hierarchical_context_parallel_sizes=args.hierarchical_context_parallel_sizes,
expert_model_parallel_size=args.expert_model_parallel_size,
......@@ -310,11 +337,12 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks):
expert_tensor_parallel_size=args.expert_tensor_parallel_size,
distributed_timeout_minutes=args.distributed_timeout_minutes,
nccl_communicator_config_path=args.nccl_communicator_config_path,
order='tp-cp-ep-dp-pp' if not args.use_tp_pp_dp_mapping else 'tp-pp-dp',
order='tp-cp-ep-dp-pp' if not args.use_tp_pp_dp_mapping else 'tp-cp-ep-pp-dp',
encoder_tensor_model_parallel_size=args.encoder_tensor_model_parallel_size,
encoder_pipeline_model_parallel_size=args.encoder_pipeline_model_parallel_size,
get_embedding_ranks=get_embedding_ranks,
get_position_embedding_ranks=get_position_embedding_ranks,
create_gloo_process_groups=args.enable_gloo_process_groups,
)
if args.rank == 0:
print(
......@@ -336,7 +364,9 @@ def _init_autoresume():
torch.distributed.barrier()
def _set_random_seed(seed_, data_parallel_random_init=False):
def _set_random_seed(
seed_, data_parallel_random_init=False, te_rng_tracker=False, inference_rng_tracker=False
):
"""Set random seed for reproducability."""
if seed_ is not None and seed_ > 0:
# Ensure that different pipeline MP stages get different seeds.
......@@ -348,7 +378,9 @@ def _set_random_seed(seed_, data_parallel_random_init=False):
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.device_count() > 0:
tensor_parallel.model_parallel_cuda_manual_seed(seed)
tensor_parallel.model_parallel_cuda_manual_seed(
seed, te_rng_tracker, inference_rng_tracker
)
else:
raise ValueError("Seed ({}) should be a positive integer.".format(seed_))
......@@ -374,7 +406,7 @@ def set_jit_fusion_options():
torch._C._jit_override_can_fuse_on_cpu(False)
torch._C._jit_override_can_fuse_on_gpu(False)
torch._C._jit_set_texpr_fuser_enabled(False)
torch._C._jit_set_nvfuser_enabled(False)#(True)
torch._C._jit_set_nvfuser_enabled(True)
torch._C._debug_set_autodiff_subgraph_inlining(False)
else:
# legacy pytorch fuser
......@@ -398,9 +430,7 @@ def _warmup_jit_function():
# Warmup fused bias+gelu
bias = torch.rand(
args.ffn_hidden_size // args.tensor_model_parallel_size,
dtype=dtype,
device="cuda",
args.ffn_hidden_size // args.tensor_model_parallel_size, dtype=dtype, device="cuda"
)
input = torch.rand(
(
......@@ -437,15 +467,11 @@ def _warmup_jit_function():
dtype=dtype,
device="cuda",
)
bias = torch.rand((args.hidden_size), dtype=dtype, device="cuda").expand_as(
residual
)
bias = torch.rand((args.hidden_size), dtype=dtype, device="cuda").expand_as(residual)
dropout_rate = 0.1
# Warmup JIT fusions with the input grad_enable state of both forward
# prop and recomputation
for input_grad, bias_grad, residual_grad in zip(
[False, True], [True, True], [True, True]
):
for input_grad, bias_grad, residual_grad in zip([False, True], [True, True], [True, True]):
input.requires_grad = input_grad
bias.requires_grad = bias_grad
residual.requires_grad = residual_grad
......@@ -456,7 +482,7 @@ def _warmup_jit_function():
def setup_logging() -> None:
""" Sets the default logging level based on cmdline args and env vars.
"""Sets the default logging level based on cmdline args and env vars.
Precedence:
1. Command line argument `--logging-level`
......
File mode changed from 100755 to 100644
......@@ -3,6 +3,8 @@ import time, os
from .global_vars import get_one_logger, get_args
_one_logger_utils_version = "1.0.0-mlm"
def get_timestamp_in_ms():
"""Helper function to get timestamp in ms
......@@ -86,7 +88,7 @@ def _produce_e2e_metrics(log_throughput=False, throughput=None):
# Unpack and assign local vars
base_metrics = one_logger.store_get('get_e2e_base_metrics')()
(iteration, train_duration, eval_duration, eval_iterations,
total_flops, num_floating_point_operations_so_far,
total_flops_since_current_train_start, num_floating_point_operations_so_far,
consumed_train_samples, world_size, seq_length) = base_metrics.values()
iteration_start = one_logger.store_get('iteration_start')
......@@ -125,7 +127,7 @@ def _produce_e2e_metrics(log_throughput=False, throughput=None):
if log_throughput:
if train_duration:
train_throughput_per_gpu = total_flops / (train_duration * 10**12 * world_size)
train_throughput_per_gpu = total_flops_since_current_train_start / (train_duration * 10**12 * world_size)
else:
train_throughput_per_gpu = 0.0
......@@ -136,7 +138,7 @@ def _produce_e2e_metrics(log_throughput=False, throughput=None):
throughput_metrics = {
'train_tflop_end': float(num_floating_point_operations_so_far) / (10**12),
'train_tflop': float(total_flops) / (10**12),
'train_tflop': float(total_flops_since_current_train_start) / (10**12),
'train_throughput_per_gpu': train_throughput_per_gpu,
'train_throughput_per_gpu_max': train_throughput_per_gpu_max,
}
......@@ -234,7 +236,7 @@ def on_save_checkpoint_start(async_save):
# Unpack and assign local vars
base_metrics = one_logger.store_get('get_e2e_base_metrics')()
(iteration, train_duration, eval_duration, eval_iterations,
total_flops, num_floating_point_operations_so_far,
total_flops_since_current_train_start, num_floating_point_operations_so_far,
consumed_train_samples, world_size, seq_length) = base_metrics.values()
save_checkpoint_count = one_logger.store_get('save_checkpoint_count') + 1
......@@ -289,6 +291,7 @@ def on_pretrain_start():
'app_run_type': 'training',
'summary_data_schema_version': '1.0.0',
'app_metrics_feature_tags': 'full',
'one_logger_utils_version': _one_logger_utils_version,
})
def track_config_flags(train_iters, skip_train, do_train, do_valid, do_test,
......
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
......@@ -39,6 +39,11 @@ nvlm_yi_34b_template = "{{- bos_token }}{% for message in messages %}{{'<|im_sta
qwen2p0_custom_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
# Note: this is the same template as https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct/blob/main/tokenizer_config.json#L2053
# but we removed the forced system message.
llama3p1_chat_template = """{{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- set date_string = \"26 Jul 2024\" %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = none %}\n{%- endif %}\n\n{%- if system_message is not none %}{#- System message + builtin tools #}\n{{- \"<|start_header_id|>system<|end_header_id|>\\n\\n\" }}\n{%- if builtin_tools is defined or tools is not none %}\n {{- \"Environment: ipython\\n\" }}\n{%- endif %}\n{%- if builtin_tools is defined %}\n {{- \"Tools: \" + builtin_tools | reject('equalto', 'code_interpreter') | join(\", \") + \"\\n\\n\"}}\n{%- endif %}{%- if tools is not none and not tools_in_user_message %}\n {{- \"You have access to the following functions. To call a function, please respond with JSON for a function call.\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n{%- endif %}\n{{- system_message }}\n{{- \"<|eot_id|>\" }}\n\n{%-endif %}{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception(\"Cannot put tools in the first user message when there's no first user message!\") }}\n{%- endif %}\n {{- '<|start_header_id|>user<|end_header_id|>\\n\\n' -}}\n {{- \"Given the following functions, please respond with a JSON for a function call \" }}\n {{- \"with its proper arguments that best answers the given prompt.\\n\\n\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n {{- first_user_message + \"<|eot_id|>\"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' }}\n {%- elif 'tool_calls' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception(\"This model only supports single tool-calls at once!\") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {%- if builtin_tools is defined and tool_call.name in builtin_tools %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- \"<|python_tag|>\" + tool_call.name + \".call(\" }}\n {%- for arg_name, arg_val in tool_call.arguments | items %}\n {{- arg_name + '=\"' + arg_val + '\"' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \")\" }}\n {%- else %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- '{\"name\": \"' + tool_call.name + '\", ' }}\n {{- '\"parameters\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- \"}\" }}\n {%- endif %}\n {%- if builtin_tools is defined %}\n {#- This means we're in ipython mode #}\n {{- \"<|eom_id|>\" }}\n {%- else %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n {%- elif message.role == \"tool\" or message.role == \"ipython\" %}\n {{- \"<|start_header_id|>ipython<|end_header_id|>\\n\\n\" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}\n{%- endif %}\n"""
@dataclass
class PromptConfig:
......@@ -104,6 +109,16 @@ class MultimodalTokenizer(MegatronTokenizer):
has_bos=True,
has_system_role=True,
)
elif prompt_format in ("llama3p1", "llama3p2"):
# "<|start_header_id|>assistant<|end_header|>\n\n" is the prefix for assistant messages.
# That occupies 4 tokens and can be masked in the target.
self._prompt_config = PromptConfig(
assistant_prefix_len=4,
pad_token_id=tokenizer.convert_tokens_to_ids("<|finetune_right_pad_id|>"),
custom_chat_template=llama3p1_chat_template,
has_bos=True,
has_system_role=True,
)
elif prompt_format == "nvlm-yi-34b":
self._prompt_config = PromptConfig(
assistant_prefix_len=4,
......@@ -121,7 +136,7 @@ class MultimodalTokenizer(MegatronTokenizer):
has_bos=False,
has_system_role=True,
)
elif prompt_format == "qwen2p0":
elif prompt_format in ("qwen2p0", "qwen2p5"):
# "<|im_start|>assistant\n" is the prefix for assistant messages
self._prompt_config = PromptConfig(
assistant_prefix_len=3,
......
......@@ -15,7 +15,6 @@ from megatron.core.datasets.megatron_tokenizer import MegatronTokenizer
from .bert_tokenization import FullTokenizer as FullBertTokenizer
from .gpt2_tokenization import GPT2Tokenizer
from megatron.training.tokenizer.multimodal_tokenizer import MultimodalTokenizer
from transformers import Qwen2Tokenizer
def build_tokenizer(args, **kwargs):
......@@ -51,11 +50,6 @@ def build_tokenizer(args, **kwargs):
elif args.tokenizer_type == 'Llama2Tokenizer':
assert args.tokenizer_model is not None
tokenizer = _Llama2Tokenizer(args.tokenizer_model)
elif args.tokenizer_type == 'Llama3Tokenizer':
assert args.tokenizer_model is not None
tokenizer = _Llama3Tokenizer(args.tokenizer_model)
elif args.tokenizer_type == 'QwenTokenizer':
tokenizer = _Qwen2Tokenizer(args.vocab_file, args.merge_file)
elif args.tokenizer_type == 'TikTokenizer':
assert args.tokenizer_model is not None
assert args.tiktoken_pattern is not None
......@@ -612,96 +606,6 @@ class _Llama2Tokenizer(_SentencePieceTokenizer):
return None
class _Llama3Tokenizer(MegatronTokenizer):
"""tiktokenTokenizer-Megatron llama3 改写"""
# https://github.com/meta-llama/llama3/blob/main/llama/tokenizer.py
def __init__(self, model_file):
super().__init__(model_file)
from pathlib import Path
import tiktoken
from tiktoken.load import load_tiktoken_bpe
tokenizer_path=model_file
special_tokens = [
"<|begin_of_text|>",
"<|end_of_text|>",
"<|reserved_special_token_0|>",
"<|reserved_special_token_1|>",
"<|reserved_special_token_2|>",
"<|reserved_special_token_3|>",
"<|start_header_id|>",
"<|end_header_id|>",
"<|reserved_special_token_4|>",
"<|eot_id|>", # end of turn
] + [f"<|reserved_special_token_{i}|>" for i in range (5, 256 - 5)]
mergeable_ranks = load_tiktoken_bpe(tokenizer_path)
self.tokenizer = tiktoken.Encoding(tokenizer_path,
pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+",
mergeable_ranks=mergeable_ranks,
special_tokens={token: len (mergeable_ranks) + i for i, token in enumerate (special_tokens)},
)
self.eod_id = self.tokenizer.encode("<|end_of_text|>", allowed_special="all")[0]
@property
def vocab_size(self):
return self.tokenizer.n_vocab
@property
def vocab(self):
return self.tokenizer.encode
@property
def inv_vocab(self):
return self.tokenizer.encode
def tokenize(self, text):
return self.tokenizer.encode(text)
def detokenize(self, token_ids):
return self.tokenizer.encode(token_ids)
@property
def eod(self):
return self.eod_id
class _Qwen2Tokenizer(MegatronTokenizer):
def __init__(self, vocab_file, merge_file,extra_vocab_size=0):
super().__init__(vocab_file, merge_file)
self.tokenizer = Qwen2Tokenizer(vocab_file, merge_file)
self.extra_vocab_size = extra_vocab_size
self.tokenizer.add_special_tokens(special_tokens_dict=dict(pad_token="<|extra_0|>"))
@property
def vocab_size(self):
return len(self.tokenizer.encoder) + self.extra_vocab_size
@property
def vocab(self):
return self.tokenizer.encoder
@property
def inv_vocab(self):
return self.tokenizer.decoder
def tokenize(self, text):
return self.tokenizer.encode(text)
def detokenize(self, token_ids):
return self.tokenizer.decode(token_ids)
@property
def eod(self):
return self.tokenizer.eos_token_id
@property
def eos_token(self):
return self.tokenizer.eos_token
@property
def pad_token_id(self):
return self.tokenizer.pad_token_id
def reload_mergeable_ranks(path: str, max_vocab: Optional[int] = None) -> Dict[bytes, int]:
"""
Reload our tokenizer JSON file and convert it to Tiktoken format.
......@@ -851,7 +755,21 @@ class CustomTikTokenizer(MegatronTokenizer):
return self._model.decode(tokens)
def offsets(self, ids: list[int], text: str) -> list[int]:
return self._model.decode_with_offsets(ids)[1]
try:
return self._model.decode_with_offsets(ids)[1]
except UnicodeDecodeError:
# Tiktoken has an unnecessary check that raises UnicodeDecodeError
# from `text = b"".join(token_bytes).decode("utf-8", errors="strict")`
# which is not needed for our use case. So we re-implement it, without
# the check.
token_bytes = self._model.decode_tokens_bytes(ids)
text_len = 0
offsets = []
for token in token_bytes:
offsets.append(max(0, text_len - (0x80 <= token[0] < 0xC0)))
text_len += sum(1 for c in token if not 0x80 <= c < 0xC0)
return offsets
@property
def vocab_size(self) -> int:
......
......@@ -10,6 +10,9 @@ import logging
import math
import os
import sys
from typing import List
import torch.distributed
from .log_handler import CustomHandler
# Make default logging level INFO, but filter out all log messages not from MCore.
logging.basicConfig(handlers=[CustomHandler()], level=logging.INFO)
......@@ -24,14 +27,15 @@ from megatron.core.utils import (
check_param_hashes_across_dp_replicas,
get_model_config,
StragglerDetector,
is_float8tensor,
)
from megatron.core.fp8_utils import is_float8tensor
from megatron.training.checkpointing import load_checkpoint
from megatron.training.checkpointing import save_checkpoint
from megatron.training.checkpointing import checkpoint_exists
from megatron.legacy.model import Float16Module
from megatron.core.distributed import DistributedDataParallelConfig
from megatron.core.distributed import DistributedDataParallel as DDP
from megatron.core.distributed.custom_fsdp import FullyShardedDataParallel as custom_FSDP
try:
from megatron.core.distributed import TorchFullyShardedDataParallel as torch_FSDP
......@@ -51,6 +55,10 @@ from megatron.core.rerun_state_machine import (
from megatron.training.initialize import initialize_megatron
from megatron.training.initialize import write_args_to_tensorboard
from megatron.training.initialize import set_jit_fusion_options
from megatron.training.utils import (
get_batch_on_this_cp_rank,
get_batch_on_this_tp_rank,
)
from megatron.legacy.data.data_samplers import build_pretraining_data_loader
from megatron.core.optimizer_param_scheduler import OptimizerParamScheduler
from megatron.core.transformer.moe import upcycling_utils
......@@ -69,14 +77,16 @@ from megatron.core.num_microbatches_calculator import (
from .async_utils import maybe_finalize_async_save
from .utils import (
append_to_progress_log,
calc_params_l2_norm,
check_adlr_autoresume_termination,
logical_and_across_model_parallel_group,
reduce_max_stat_across_model_parallel_group,
is_last_rank,
print_rank_0,
print_rank_last,
report_memory,
unwrap_model,
append_to_progress_log,
update_use_dist_ckpt,
)
from .global_vars import (
......@@ -86,7 +96,8 @@ from .global_vars import (
get_timers,
get_tensorboard_writer,
get_wandb_writer,
get_one_logger)
get_one_logger,
)
from . import one_logger_utils
from . import ft_integration
......@@ -124,6 +135,10 @@ def num_floating_point_operations(args, batch_size):
if args.moe_shared_expert_intermediate_size is None
else args.moe_shared_expert_intermediate_size
)
if args.num_experts is None:
ffn_hidden_size = args.ffn_hidden_size
else:
ffn_hidden_size = args.moe_ffn_hidden_size
# The 12x term below comes from the following factors; for more details, see
# "APPENDIX: FLOATING-POINT OPERATIONS" in https://arxiv.org/abs/2104.04473.
......@@ -135,13 +150,6 @@ def num_floating_point_operations(args, batch_size):
# - 2x: A GEMM of a m*n tensor with a n*k tensor requires 2mnk floating-point operations.
expansion_factor = 3 * 2 * 2
# print(f"batch_size: {batch_size}, \
# query_projection_to_hidden_size_ratio: {query_projection_to_hidden_size_ratio}, \
# num_experts_routed_to: {num_experts_routed_to}, \
# gated_linear_multiplier: {gated_linear_multiplier}, \
# shared_expert_ffn_hidden_size: {shared_expert_ffn_hidden_size}, \
# gated_linear_multiplier: {gated_linear_multiplier}, \
# ")
return (
expansion_factor
* batch_size
......@@ -160,7 +168,7 @@ def num_floating_point_operations(args, batch_size):
)
# MLP.
+ (
(args.ffn_hidden_size / args.hidden_size)
(ffn_hidden_size / args.hidden_size)
* num_experts_routed_to
* gated_linear_multiplier
)
......@@ -219,7 +227,7 @@ def get_start_time_from_progress_log():
def preprocess_common_state_dict(common_state_dict):
import copy
# Convert args key of type namespace to dictionary
# Convert args key of type namespace to dictionary
preprocessed_common_state_dict = copy.deepcopy(common_state_dict)
preprocessed_common_state_dict['args'] = vars(preprocessed_common_state_dict['args'])
# Remove rank and local rank from state dict if it exists, since they are expected to be different
......@@ -287,6 +295,12 @@ def pretrain(
if args.log_progress:
append_to_progress_log("Starting job")
# Initialize fault tolerance
# NOTE: ft_integration functions other than `setup` are no-op if the FT is not initialized
if args.enable_ft_package:
ft_integration.setup(args)
ft_integration.maybe_setup_simulated_fault()
# Set pytorch JIT layer fusion options and warmup JIT functions.
set_jit_fusion_options()
......@@ -315,11 +329,29 @@ def pretrain(
# Context used for persisting some state between checkpoint saves.
if args.non_persistent_ckpt_type == 'local':
raise RuntimeError('LocalCheckpointManagers are not yet integrated')
checkpointing_context = {
'local_checkpoint_manager': BasicLocalCheckpointManager(
args.non_persistent_local_ckpt_dir
try:
from nvidia_resiliency_ext.checkpointing.local.ckpt_managers.local_manager import \
LocalCheckpointManager
from nvidia_resiliency_ext.checkpointing.local.replication.group_utils import \
parse_group_sequence, GroupWrapper
from nvidia_resiliency_ext.checkpointing.local.replication.strategies import \
CliqueReplicationStrategy
except ModuleNotFoundError:
raise RuntimeError("The 'nvidia_resiliency_ext' module is required for local "
"checkpointing but was not found. Please ensure it is installed.")
if args.replication:
repl_strategy = CliqueReplicationStrategy.from_replication_params(
args.replication_jump,
args.replication_factor
)
else:
repl_strategy = None
checkpointing_context = {
'local_checkpoint_manager': LocalCheckpointManager(args.non_persistent_local_ckpt_dir,
repl_strategy=repl_strategy
)
}
else:
checkpointing_context = {}
......@@ -364,11 +396,6 @@ def pretrain(
args.do_valid, args.do_test, args.dataloader_type,
args.retro_project_dir, args.retro_cyclic_train_iters)
if args.enable_ft_package and ft_integration.get_rank_monitor_client() is not None:
ft_integration.get_rank_monitor_client().init_workload_monitoring()
ft_timeouts = ft_integration.get_rank_monitor_client().timeouts
print_rank_0(f"Fault tolerance client initialized. Timeouts: {ft_timeouts}")
# Print setup timing.
print_rank_0('done with setup ...')
timers.log(['model-and-optimizer-setup',
......@@ -400,8 +427,7 @@ def pretrain(
save_checkpoint(iteration, model, optimizer, opt_param_scheduler,
num_floating_point_operations_so_far, checkpointing_context,
train_data_iterator=train_data_iterator,
ft_client=ft_integration.get_rank_monitor_client(
ft_integration.StateMachineActions.SAVE_CHECKPOINT), preprocess_common_state_dict_fn=preprocess_common_state_dict)
preprocess_common_state_dict_fn=preprocess_common_state_dict)
one_logger and one_logger.log_metrics({
'app_train_loop_finish_time': one_logger_utils.get_timestamp_in_ms()
......@@ -431,11 +457,16 @@ def pretrain(
wandb_writer = get_wandb_writer()
if wandb_writer:
wandb_writer.finish()
maybe_finalize_async_save(blocking=True)
ft_integration.on_checkpointing_start()
maybe_finalize_async_save(blocking=True, terminate=True)
ft_integration.on_checkpointing_end(is_async_finalization=True)
one_logger and one_logger.log_metrics({
'app_finish_time': one_logger_utils.get_timestamp_in_ms()
})
ft_integration.shutdown()
one_logger_utils.finish()
......@@ -476,47 +507,54 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap
args.model_type = model_type
# Build model.
if mpu.get_pipeline_model_parallel_world_size() > 1 and \
args.virtual_pipeline_model_parallel_size is not None:
assert model_type != ModelType.encoder_and_decoder, \
"Interleaved schedule not supported for model with both encoder and decoder"
model = []
for i in range(args.virtual_pipeline_model_parallel_size):
mpu.set_virtual_pipeline_model_parallel_rank(i)
# Set pre_process and post_process only after virtual rank is set.
def build_model():
if mpu.get_pipeline_model_parallel_world_size() > 1 and \
args.virtual_pipeline_model_parallel_size is not None:
assert model_type != ModelType.encoder_and_decoder, \
"Interleaved schedule not supported for model with both encoder and decoder"
model = []
for i in range(args.virtual_pipeline_model_parallel_size):
mpu.set_virtual_pipeline_model_parallel_rank(i)
# Set pre_process and post_process only after virtual rank is set.
pre_process = mpu.is_pipeline_first_stage()
post_process = mpu.is_pipeline_last_stage()
this_model = model_provider_func(
pre_process=pre_process,
post_process=post_process
)
this_model.model_type = model_type
model.append(this_model)
else:
pre_process = mpu.is_pipeline_first_stage()
post_process = mpu.is_pipeline_last_stage()
this_model = model_provider_func(
pre_process=pre_process,
post_process=post_process
)
this_model.model_type = model_type
model.append(this_model)
add_encoder = True
add_decoder = True
if model_type == ModelType.encoder_and_decoder:
if mpu.get_pipeline_model_parallel_world_size() > 1:
rank = mpu.get_pipeline_model_parallel_rank()
first_decoder_rank = args.encoder_pipeline_model_parallel_size
world_size = mpu.get_pipeline_model_parallel_world_size()
pre_process = rank == 0 or rank == first_decoder_rank
post_process = (rank == (first_decoder_rank - 1)) or (rank == (world_size - 1))
add_encoder = mpu.is_inside_encoder(rank)
add_decoder = mpu.is_inside_decoder(rank)
model = model_provider_func(
pre_process=pre_process,
post_process=post_process,
add_encoder=add_encoder,
add_decoder=add_decoder)
else:
model = model_provider_func(
pre_process=pre_process,
post_process=post_process
)
model.model_type = model_type
return model
if args.init_model_with_meta_device:
with torch.device('meta'):
model = build_model()
else:
pre_process = mpu.is_pipeline_first_stage()
post_process = mpu.is_pipeline_last_stage()
add_encoder = True
add_decoder = True
if model_type == ModelType.encoder_and_decoder:
if mpu.get_pipeline_model_parallel_world_size() > 1:
rank = mpu.get_pipeline_model_parallel_rank()
first_decoder_rank = args.encoder_pipeline_model_parallel_size
world_size = mpu.get_pipeline_model_parallel_world_size()
pre_process = rank == 0 or rank == first_decoder_rank
post_process = (rank == (first_decoder_rank - 1)) or (rank == (world_size - 1))
add_encoder = mpu.is_inside_encoder(rank)
add_decoder = mpu.is_inside_decoder(rank)
model = model_provider_func(
pre_process=pre_process,
post_process=post_process,
add_encoder=add_encoder,
add_decoder=add_decoder)
else:
model = model_provider_func(
pre_process=pre_process,
post_process=post_process
)
model.model_type = model_type
model = build_model()
if not isinstance(model, list):
model = [model]
......@@ -530,17 +568,23 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap
tensor_parallel.set_defaults_if_not_set_tensor_model_parallel_attributes(param)
# Print number of parameters.
num_parameters = sum(
[sum([p.nelement() for p in model_module.parameters()])
for model_module in model]
)
if mpu.get_data_parallel_rank() == 0:
print(' > number of parameters on (tensor, pipeline) '
'model parallel rank ({}, {}): {}'.format(
mpu.get_tensor_model_parallel_rank(),
mpu.get_pipeline_model_parallel_rank(),
sum([sum([p.nelement() for p in model_module.parameters()])
for model_module in model])), flush=True)
num_parameters), flush=True)
# GPU allocation.
for model_module in model:
model_module.cuda(torch.cuda.current_device())
# For FSDP2, we don't allocate GPU memory here. We allocate GPU memory
# in the fully_shard function of FSDP2 instead.
if not (args.use_torch_fsdp2 and args.use_cpu_initialization) and not args.init_model_with_meta_device:
for model_module in model:
model_module.cuda(torch.cuda.current_device())
# Fp16 conversion.
if args.fp16 or args.bf16:
......@@ -562,9 +606,11 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap
fp8_meta.amax_history[0][fp8_meta_index] = 0
if wrap_with_ddp:
if getattr(args, "use_torch_fsdp2", False):
if args.use_torch_fsdp2:
assert HAVE_FSDP2, "Torch FSDP2 requires torch>=2.4.0"
DP = torch_FSDP
elif args.use_custom_fsdp:
DP = custom_FSDP
else:
DP = DDP
......@@ -576,17 +622,42 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap
kwargs[f.name] = getattr(args, f.name)
kwargs['grad_reduce_in_fp32'] = args.accumulate_allreduce_grads_in_fp32
kwargs['check_for_nan_in_grad'] = args.check_for_nan_in_loss_and_grad
kwargs['bucket_size'] = args.ddp_bucket_size
kwargs['check_for_large_grads'] = args.check_for_large_grads
if args.ddp_num_buckets is not None:
assert args.ddp_bucket_size is None, \
"Cannot specify both --ddp-num-buckets and --ddp-bucket-size"
assert args.ddp_num_buckets > 0, \
"--ddp-num-buckets must be greater than 0"
kwargs['bucket_size'] = num_parameters // args.ddp_num_buckets
else:
kwargs['bucket_size'] = args.ddp_bucket_size
kwargs['pad_buckets_for_high_nccl_busbw'] = args.ddp_pad_buckets_for_high_nccl_busbw
kwargs['average_in_collective'] = args.ddp_average_in_collective
if args.use_custom_fsdp and args.use_precision_aware_optimizer:
kwargs["preserve_fp32_weights"] = False
ddp_config = DistributedDataParallelConfig(**kwargs)
overlap_param_gather_with_optimizer_step = getattr(args, 'overlap_param_gather_with_optimizer_step', False)
if not getattr(args, "use_torch_fsdp2", False):
# In the custom FSDP and DDP use path, we need to initialize the bucket size.
# If bucket_size is not provided as an input, use sane default.
# If using very large dp_sizes, make buckets larger to ensure that chunks used in NCCL
# ring-reduce implementations are large enough to remain bandwidth-bound rather than
# latency-bound.
if ddp_config.bucket_size is None:
ddp_config.bucket_size = max(
40000000, 1000000 * mpu.get_data_parallel_world_size(with_context_parallel=True)
)
# Set bucket_size to infinity if overlap_grad_reduce is False.
if not ddp_config.overlap_grad_reduce:
ddp_config.bucket_size = None
model = [DP(config=config,
ddp_config=ddp_config,
module=model_chunk,
# Turn off bucketing for model_chunk 2 onwards, since communication for these
# model chunks is overlapped with compute anyway.
disable_bucketing=(model_chunk_idx > 0) or overlap_param_gather_with_optimizer_step)
disable_bucketing=(model_chunk_idx > 0) or args.overlap_param_gather_with_optimizer_step)
for (model_chunk_idx, model_chunk) in enumerate(model)]
# Broadcast params from data parallel src rank to other data parallel ranks.
......@@ -674,7 +745,8 @@ def setup_model_and_optimizer(model_provider_func,
config = OptimizerConfig(**kwargs)
config.timers = timers
optimizer = get_megatron_optimizer(config, model, no_wd_decay_cond,
scale_lr_cond, lr_mult)
scale_lr_cond, lr_mult,
use_gloo_process_groups=args.enable_gloo_process_groups)
opt_param_scheduler = get_optimizer_param_scheduler(optimizer)
if args.moe_use_upcycling:
......@@ -713,9 +785,8 @@ def setup_model_and_optimizer(model_provider_func,
timers('load-checkpoint', log_level=0).start(barrier=True)
args.iteration, args.num_floating_point_operations_so_far = load_checkpoint(
model, optimizer, opt_param_scheduler,
ft_client=ft_integration.get_rank_monitor_client(), checkpointing_context=checkpointing_context,
skip_load_to_model_and_opt=HAVE_FSDP2 and getattr(args, "use_torch_fsdp2", False))
model, optimizer, opt_param_scheduler, checkpointing_context=checkpointing_context,
skip_load_to_model_and_opt=HAVE_FSDP2 and args.use_torch_fsdp2)
timers('load-checkpoint').stop(barrier=True)
timers.log(['load-checkpoint'])
one_logger and one_logger.log_metrics({
......@@ -752,8 +823,17 @@ def setup_model_and_optimizer(model_provider_func,
return model, optimizer, opt_param_scheduler
def dummy_train_step(data_iterator):
"""Single dummy training step."""
num_microbatches = get_num_microbatches()
for _ in range(num_microbatches):
# Re-use methods used in get_batch() from pretrain_{gpt, mamba}.py.
batch = get_batch_on_this_tp_rank(data_iterator)
batch = get_batch_on_this_cp_rank(batch)
def train_step(forward_step_func, data_iterator,
model, optimizer, opt_param_scheduler, config):
model, optimizer, opt_param_scheduler, config):
"""Single training step."""
args = get_args()
timers = get_timers()
......@@ -785,17 +865,27 @@ def train_step(forward_step_func, data_iterator,
torch.cuda.empty_cache()
# Vision gradients.
if getattr(args, 'vision_pretraining', False) and args.vision_pretraining_type == "dino":
if args.vision_pretraining and args.vision_pretraining_type == "dino":
unwrapped_model = unwrap_model(model[0])
unwrapped_model.cancel_gradients_last_layer(args.curr_iteration)
# Update parameters.
timers('optimizer', log_level=1).start(barrier=args.barrier_with_L1_time)
update_successful, grad_norm, num_zeros_in_grad = optimizer.step()
timers('optimizer').stop()
# when freezing sub-models we may have a mixture of successful and unsucessful ranks,
# so we must gather across mp ranks
update_successful = logical_and_across_model_parallel_group(update_successful)
# grad_norm and num_zeros_in_grad will be None on ranks without trainable params,
# so we must gather across mp ranks
grad_norm = reduce_max_stat_across_model_parallel_group(grad_norm)
if args.log_num_zeros_in_grad:
num_zeros_in_grad = reduce_max_stat_across_model_parallel_group(num_zeros_in_grad)
# Vision momentum.
if getattr(args, 'vision_pretraining', False) and args.vision_pretraining_type == "dino":
if args.vision_pretraining and args.vision_pretraining_type == "dino":
unwrapped_model = unwrap_model(model[0])
unwrapped_model.update_momentum(args.curr_iteration)
......@@ -832,7 +922,6 @@ def train_step(forward_step_func, data_iterator,
numerator += val
denominator += 1
loss_reduced[key] = numerator / denominator
return loss_reduced, skipped_iter, should_checkpoint, should_exit, exit_code, grad_norm, num_zeros_in_grad
return {}, skipped_iter, should_checkpoint, should_exit, exit_code, grad_norm, num_zeros_in_grad
......@@ -913,6 +1002,8 @@ def training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_r
total_iterations = total_loss_dict[advanced_iters_key] + \
total_loss_dict[skipped_iters_key]
# learning rate will be None on ranks without trainable params, so we must gather across mp ranks
learning_rate = reduce_max_stat_across_model_parallel_group(learning_rate)
# Tensorboard values.
# Timer requires all the ranks to call.
if args.log_timers_to_tensorboard and \
......@@ -920,22 +1011,16 @@ def training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_r
timers.write(timers_to_log, writer, iteration,
normalizer=total_iterations)
if writer and (iteration % args.tensorboard_log_interval == 0):
if args.record_memory_history and is_last_rank():
snapshot = torch.cuda.memory._snapshot()
from pickle import dump
with open(args.memory_snapshot_path , 'wb') as f:
dump(snapshot, f)
if wandb_writer:
wandb_writer.log({'samples vs steps': args.consumed_train_samples},
iteration)
writer.add_scalar('learning-rate', learning_rate, iteration)
if args.decoupled_lr is not None:
writer.add_scalar('decoupled-learning-rate', decoupled_learning_rate, iteration)
writer.add_scalar('learning-rate vs samples', learning_rate,
args.consumed_train_samples)
args.consumed_train_samples)
if wandb_writer:
wandb_writer.log({'learning-rate': learning_rate}, iteration)
if args.decoupled_lr is not None:
writer.add_scalar('decoupled-learning-rate', decoupled_learning_rate, iteration)
if args.skipped_train_samples > 0:
writer.add_scalar('skipped-train-samples', args.skipped_train_samples, iteration)
if wandb_writer:
......@@ -993,6 +1078,11 @@ def training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_r
mem_stats["allocated_bytes.all.current"],
iteration,
)
writer.add_scalar(
"mem-max-allocated-bytes",
mem_stats["allocated_bytes.all.peak"],
iteration,
)
writer.add_scalar(
"mem-allocated-count",
mem_stats["allocation.all.current"],
......@@ -1003,6 +1093,12 @@ def training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_r
track_moe_metrics(moe_loss_scale, iteration, writer, wandb_writer, total_loss_dict, args.moe_per_layer_logging)
if iteration % args.log_interval == 0:
if args.record_memory_history and is_last_rank():
snapshot = torch.cuda.memory._snapshot()
from pickle import dump
with open(args.memory_snapshot_path, 'wb') as f:
dump(snapshot, f)
elapsed_time = timers('interval-time').elapsed(barrier=True)
elapsed_time_per_iteration = elapsed_time / total_iterations
......@@ -1035,7 +1131,6 @@ def training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_r
writer.add_scalar('throughput', throughput, iteration)
if wandb_writer:
wandb_writer.log({'throughput': throughput}, iteration)
assert learning_rate is not None
# Decoupled_learning_rate should be not None only on first and last pipeline stage.
log_string += f' learning rate: {learning_rate:.6E} |'
if args.decoupled_lr is not None and (mpu.is_pipeline_first_stage(ignore_virtual=True) or
......@@ -1068,7 +1163,7 @@ def training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_r
total_loss_dict[skipped_iters_key] = 0
total_loss_dict[nan_iters_key] = 0
print_rank_last(log_string)
if report_memory_flag and learning_rate > 0.:
if report_memory_flag:
# Report memory after optimizer state has been initialized.
if torch.distributed.get_rank() == 0:
num_microbatches = get_num_microbatches()
......@@ -1120,10 +1215,10 @@ def enable_forward_pre_hook(model_chunks):
model_chunk.enable_forward_pre_hook()
def disable_forward_pre_hook(model_chunks):
def disable_forward_pre_hook(model_chunks, param_sync=True):
for model_chunk in model_chunks:
assert isinstance(model_chunk, DDP)
model_chunk.disable_forward_pre_hook()
model_chunk.disable_forward_pre_hook(param_sync=param_sync)
def save_checkpoint_and_time(iteration, model, optimizer, opt_param_scheduler,
......@@ -1137,26 +1232,23 @@ def save_checkpoint_and_time(iteration, model, optimizer, opt_param_scheduler,
# Extra barrier is added to make sure all ranks report the max time.
timer_key = 'save-checkpoint-non-persistent' if non_persistent_ckpt else 'save-checkpoint'
timers(timer_key, log_level=0).start(barrier=True)
save_checkpoint_start_time = timers('save-checkpoint').active_time()
# Log E2E metrics before save-checkpoint
one_logger_utils.track_e2e_metrics()
if args.use_distributed_optimizer and args.overlap_param_gather:
if should_disable_forward_pre_hook(args):
disable_forward_pre_hook(model)
save_checkpoint(iteration, model, optimizer, opt_param_scheduler,
num_floating_point_operations_so_far, checkpointing_context,
non_persistent_ckpt=non_persistent_ckpt, train_data_iterator=train_data_iterator,
ft_client=ft_integration.get_rank_monitor_client(
ft_integration.StateMachineActions.SAVE_CHECKPOINT), preprocess_common_state_dict_fn=preprocess_common_state_dict)
if args.use_distributed_optimizer and args.overlap_param_gather:
preprocess_common_state_dict_fn=preprocess_common_state_dict)
if should_disable_forward_pre_hook(args):
enable_forward_pre_hook(model)
timers(timer_key).stop(barrier=True)
timers.log([timer_key])
save_checkpoint_finish_time = timers('save-checkpoint').active_time()
# Log E2E metrics after save-checkpoint
one_logger_utils.track_e2e_metrics()
save_checkpoint_duration = save_checkpoint_finish_time - save_checkpoint_start_time
save_checkpoint_duration = timers(timer_key).elapsed()
one_logger_utils.on_save_checkpoint_end(save_checkpoint_duration, iteration, args.async_save)
if args.log_progress and not non_persistent_ckpt:
......@@ -1172,21 +1264,6 @@ def post_training_step_callbacks(model, optimizer, opt_param_scheduler, iteratio
"""Run all post-training-step functions (e.g., FT heartbeats, GC)."""
args = get_args()
# Send heartbeat to FT package and update timeouts.
if args.enable_ft_package:
ft_client = ft_integration.get_rank_monitor_client(
ft_integration.StateMachineActions.TRAIN_HEARTBEAT)
if ft_client is not None:
ft_client.send_heartbeat()
# TODO: We are always calculating timeouts in the current implementation.
# If we want to rely on manually setting these, then we need to add additional
# arguments to training and pass it here.
if ft_integration.can_update_timeouts():
ft_integration.get_rank_monitor_client(
ft_integration.StateMachineActions.UPDATE_TIMEOUT).calculate_and_set_timeouts()
print_rank_0(f'Updated FT timeouts. New values: \
{ft_integration.get_rank_monitor_client().timeouts}')
# Bring CPU and GPU back in sync if on right iteration.
if args.train_sync_interval and iteration % args.train_sync_interval == 0:
torch.cuda.synchronize()
......@@ -1199,13 +1276,13 @@ def post_training_step_callbacks(model, optimizer, opt_param_scheduler, iteratio
# Check weight hash across DP replicas.
if args.check_weight_hash_across_dp_replicas_interval is not None and \
iteration % args.check_weight_hash_across_dp_replicas_interval == 0:
if args.use_distributed_optimizer and args.overlap_param_gather:
if should_disable_forward_pre_hook(args):
disable_forward_pre_hook(model)
assert check_param_hashes_across_dp_replicas(model, cross_check=True), \
"Parameter hashes not matching across DP replicas"
torch.distributed.barrier()
print_rank_0(f">>> Weight hashes match after {iteration} iterations...")
if args.use_distributed_optimizer and args.overlap_param_gather:
if should_disable_forward_pre_hook(args):
enable_forward_pre_hook(model)
# Autoresume.
......@@ -1223,7 +1300,6 @@ def post_training_step_callbacks(model, optimizer, opt_param_scheduler, iteratio
prof.stop()
else:
torch.cuda.cudart().cudaProfilerStop()
# Manual garbage collection.
if args.manual_gc:
......@@ -1265,14 +1341,12 @@ def checkpoint_and_decide_exit(model, optimizer, opt_param_scheduler, iteration,
elif args.save and args.non_persistent_save_interval and \
iteration % args.non_persistent_save_interval == 0:
timers('interval-time').stop()
save_checkpoint_and_time(iteration, model, optimizer,
opt_param_scheduler,
num_floating_point_operations_so_far,
checkpointing_context,
non_persistent_ckpt=True, train_data_iterator=train_data_iterator)
saved_checkpoint = True
timers('interval-time', log_level=0).start(barrier=True)
# Exit based on duration.
if args.exit_duration_in_mins:
......@@ -1328,6 +1402,11 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
# Iterations.
iteration = args.iteration
# Make sure rerun_state_machine has the right iteration loaded from checkpoint.
rerun_state_machine = get_rerun_state_machine()
if rerun_state_machine.current_iteration != iteration:
print_rank_0(f"Setting rerun_state_machine.current_iteration to {iteration}...")
rerun_state_machine.current_iteration = iteration
# Track E2E metrics at the start of training.
one_logger_utils.on_train_start(iteration=iteration, consumed_train_samples=args.consumed_train_samples,
......@@ -1341,7 +1420,7 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
# Setup some training config params.
config.grad_scale_func = optimizer.scale_loss
config.timers = timers
if isinstance(model[0], DDP) and args.overlap_grad_reduce:
if isinstance(model[0], (custom_FSDP, DDP)) and args.overlap_grad_reduce:
assert config.no_sync_func is None, \
('When overlap_grad_reduce is True, config.no_sync_func must be None; '
'a custom no_sync_func is not supported when overlapping grad-reduce')
......@@ -1361,6 +1440,7 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
timers('interval-time', log_level=0).start(barrier=True)
print_datetime('before the start of training step')
report_memory_flag = True
pre_hook_enabled = False
should_exit = False
exit_code = 0
......@@ -1391,12 +1471,14 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
def get_e2e_base_metrics():
"""Get base metrics values for one-logger to calculate E2E tracking metrics.
"""
num_floating_point_operations_since_current_train_start = \
num_floating_point_operations_so_far - args.num_floating_point_operations_so_far
return {
'iteration': iteration,
'train_duration': timers('interval-time').active_time(),
'eval_duration': eval_duration,
'eval_iterations': eval_iterations,
'total_flops': num_floating_point_operations_since_last_log_event,
'total_flops_since_current_train_start': num_floating_point_operations_since_current_train_start,
'num_floating_point_operations_so_far': num_floating_point_operations_so_far,
'consumed_train_samples': args.consumed_train_samples,
'world_size': args.world_size,
......@@ -1409,44 +1491,47 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
prof = None
if args.profile and torch.distributed.get_rank() in args.profile_ranks and args.use_pytorch_profiler:
def trace_handler(p):
from pathlib import Path
Path(f"{args.profile_dir}").mkdir(parents=True, exist_ok=True)
if args.rank in [0]:
print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1))
p.export_chrome_trace("{path}/trace_rank{rank}_step{step}.json".format(
path=args.profile_dir, rank=torch.distributed.get_rank(), step=p.step_num))
prof = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
schedule=torch.profiler.schedule(
wait=max(args.profile_step_start-1, 0),
warmup=1 if args.profile_step_start > 0 else 0,
active=args.profile_step_end-args.profile_step_start,
repeat=1),
on_trace_ready=trace_handler)
on_trace_ready=torch.profiler.tensorboard_trace_handler(args.tensorboard_dir),
record_shapes=True,
with_stack=True)
prof.start()
elif args.profile and torch.distributed.get_rank() in args.profile_ranks and args.use_hip_profiler:
import ctypes
roctracer = ctypes.cdll.LoadLibrary("/opt/dtk/roctracer/lib/libroctracer64.so")
start_iteration = iteration
# Disable forward pre-hook to start training to ensure that errors in checkpoint loading
# or random initialization don't propagate to all ranks in first all-gather (which is a
# no-op if things work correctly).
if should_disable_forward_pre_hook(args):
disable_forward_pre_hook(model, param_sync=False)
# Also remove param_sync_func temporarily so that sync calls made in
# `forward_backward_func` are no-ops.
param_sync_func = config.param_sync_func
config.param_sync_func = None
pre_hook_enabled = False
# Also, check weight hash across DP replicas to be very pedantic.
if args.check_weight_hash_across_dp_replicas_interval is not None:
assert check_param_hashes_across_dp_replicas(model, cross_check=True), \
"Parameter hashes not matching across DP replicas"
torch.distributed.barrier()
print_rank_0(f">>> Weight hashes match after {iteration} iterations...")
# Run training iterations till done.
while iteration < args.train_iters:
if args.profile and torch.distributed.get_rank() in args.profile_ranks:
if args.use_pytorch_profiler:
prof.step()
elif args.use_hip_profiler:
if iteration == args.profile_step_start: roctracer.roctracer_start()
if iteration == args.profile_step_end: roctracer.roctracer_stop()
elif iteration == args.profile_step_start:
torch.cuda.cudart().cudaProfilerStart()
torch.autograd.profiler.emit_nvtx(record_shapes=True).__enter__()
ft_integration.on_checkpointing_start()
maybe_finalize_async_save(blocking=False)
ft_integration.on_checkpointing_end(is_async_finalization=True)
# Update number of microbatches first without consistency check to decide if a
# checkpoint should be saved. If the number of microbatches is different
......@@ -1456,36 +1541,68 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
if get_num_microbatches() != num_microbatches and iteration != 0:
assert get_num_microbatches() > num_microbatches, \
(f"Number of microbatches should be increasing due to batch size rampup; "
f"instead going from {num_microbatches} to {get_num_microbatches()}")
f"instead going from {num_microbatches} to {get_num_microbatches()}")
if args.save is not None:
save_checkpoint_and_time(iteration, model, optimizer,
opt_param_scheduler,
num_floating_point_operations_so_far,
checkpointing_context, train_data_iterator=train_data_iterator)
opt_param_scheduler,
num_floating_point_operations_so_far,
checkpointing_context, train_data_iterator=train_data_iterator)
num_microbatches = get_num_microbatches()
update_num_microbatches(args.consumed_train_samples, consistency_check=True, verbose=True)
# Completely skip iteration if needed.
if iteration in args.iterations_to_skip:
# Dummy train_step to fast forward train_data_iterator.
dummy_train_step(train_data_iterator)
iteration += 1
batch_size = mpu.get_data_parallel_world_size() * \
args.micro_batch_size * \
get_num_microbatches()
args.consumed_train_samples += batch_size
args.skipped_train_samples += batch_size
continue
# Run training step.
args.curr_iteration = iteration
ft_integration.on_training_step_start()
loss_dict, skipped_iter, should_checkpoint, should_exit, exit_code, grad_norm, num_zeros_in_grad = \
train_step(forward_step_func,
train_data_iterator,
model,
optimizer,
opt_param_scheduler,
config)
train_data_iterator,
model,
optimizer,
opt_param_scheduler,
config)
ft_integration.on_training_step_end()
if should_checkpoint:
save_checkpoint_and_time(iteration, model, optimizer,
opt_param_scheduler,
num_floating_point_operations_so_far,
checkpointing_context, train_data_iterator=train_data_iterator)
opt_param_scheduler,
num_floating_point_operations_so_far,
checkpointing_context, train_data_iterator=train_data_iterator)
if should_exit:
break
# why is skipped_iter ignored?
# Enable forward pre-hooks after first set of forward and backward passes.
# When running in fp16, skip all NaN iterations until steady-state loss scaling value
# is reached.
if iteration == start_iteration:
if skipped_iter:
# Only enable forward pre-hook after a training step has successfully run. Relevant
# for fp16 codepath where first XX iterations are skipped until steady-state loss
# scale value is reached.
start_iteration = iteration + 1
else:
# Enable forward pre-hook after training step has successfully run. All subsequent
# forward passes will use the forward pre-hook / `param_sync_func` in
# `forward_backward_func`.
if should_disable_forward_pre_hook(args):
enable_forward_pre_hook(model)
config.param_sync_func = param_sync_func
pre_hook_enabled = True
iteration += 1
batch_size = mpu.get_data_parallel_world_size() * \
args.micro_batch_size * \
get_num_microbatches()
args.micro_batch_size * \
get_num_microbatches()
args.consumed_train_samples += batch_size
num_skipped_samples_in_batch = (get_current_global_batch_size() -
get_current_running_global_batch_size())
......@@ -1499,8 +1616,12 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
num_floating_point_operations_since_last_log_event += num_floating_point_operations_in_batch
# Logging.
loss_scale = optimizer.get_loss_scale().item()
if not optimizer.is_stub_optimizer:
loss_scale = optimizer.get_loss_scale().item()
else:
loss_scale = 1.0
params_norm = None
if args.log_params_norm:
params_norm = calc_params_l2_norm(model)
learning_rate = None
......@@ -1511,28 +1632,29 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
else:
learning_rate = param_group['lr']
report_memory_flag = training_log(loss_dict, total_loss_dict,
learning_rate,
decoupled_learning_rate,
iteration, loss_scale,
report_memory_flag, skipped_iter,
grad_norm, params_norm, num_zeros_in_grad)
learning_rate,
decoupled_learning_rate,
iteration, loss_scale,
report_memory_flag, skipped_iter,
grad_norm, params_norm, num_zeros_in_grad)
# Evaluation.
if args.eval_interval and iteration % args.eval_interval == 0 and \
args.do_valid:
timers('interval-time').stop()
if args.use_distributed_optimizer and args.overlap_param_gather:
if should_disable_forward_pre_hook(args):
disable_forward_pre_hook(model)
pre_hook_enabled = False
if args.manual_gc and args.manual_gc_eval:
# Collect all objects.
gc.collect()
prefix = f'iteration {iteration}'
timers('eval-time', log_level=0).start(barrier=True)
evaluate_and_print_results(prefix, forward_step_func,
valid_data_iterator, model,
iteration, process_non_loss_data_func,
config, verbose=False, write_to_tensorboard=True,
non_loss_data_func=non_loss_data_func)
valid_data_iterator, model,
iteration, process_non_loss_data_func,
config, verbose=False, write_to_tensorboard=True,
non_loss_data_func=non_loss_data_func)
eval_duration += timers('eval-time').elapsed()
eval_iterations += args.eval_iters
timers('eval-time').stop()
......@@ -1541,23 +1663,20 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
if args.manual_gc and args.manual_gc_eval:
# Collect only the objects created and used in evaluation.
gc.collect(generation=0)
if args.use_distributed_optimizer and args.overlap_param_gather:
if should_disable_forward_pre_hook(args):
enable_forward_pre_hook(model)
pre_hook_enabled = True
timers('interval-time', log_level=0).start(barrier=True)
if args.enable_ft_package and ft_integration.get_rank_monitor_client() is not None:
ft_integration.get_rank_monitor_client(
ft_integration.StateMachineActions.EVAL_HEARTBEAT).send_heartbeat()
# Miscellaneous post-training-step functions (e.g., FT heartbeats, GC).
# Some of these only happen at specific iterations.
post_training_step_callbacks(model, optimizer, opt_param_scheduler, iteration, prof,
num_floating_point_operations_since_last_log_event)
num_floating_point_operations_since_last_log_event)
# Checkpoint and decide whether to exit.
should_exit = checkpoint_and_decide_exit(model, optimizer, opt_param_scheduler, iteration,
num_floating_point_operations_so_far,
checkpointing_context, train_data_iterator)
num_floating_point_operations_so_far,
checkpointing_context, train_data_iterator)
if should_exit:
break
......@@ -1569,19 +1688,23 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
writer.flush()
# Close out pre-hooks if using distributed optimizer and overlapped param gather.
if args.use_distributed_optimizer and args.overlap_param_gather:
if pre_hook_enabled:
disable_forward_pre_hook(model)
ft_integration.on_checkpointing_start()
# This will finalize all unfinalized async request and terminate
# a persistent async worker if persistent ckpt worker is enabled
maybe_finalize_async_save(blocking=True, terminate=True)
ft_integration.on_checkpointing_end(is_async_finalization=True)
if args.enable_ft_package and ft_integration.get_rank_monitor_client() is not None:
ft_integration.get_rank_monitor_client().shutdown_workload_monitoring()
maybe_finalize_async_save(blocking=True)
# If any exit conditions (signal handler, duration, iterations) have been reached, exit.
if should_exit:
wandb_writer = get_wandb_writer()
if wandb_writer:
wandb_writer.finish()
ft_integration.shutdown()
sys.exit(exit_code)
return iteration, num_floating_point_operations_so_far
......@@ -1632,6 +1755,7 @@ def evaluate(forward_step_func,
forward_backward_func = get_forward_backward_func()
# Don't care about timing during evaluation
config.timers = None
ft_integration.on_eval_step_start()
loss_dicts = forward_backward_func(
forward_step_func=forward_step_func,
data_iterator=data_iterator,
......@@ -1641,6 +1765,7 @@ def evaluate(forward_step_func,
micro_batch_size=args.micro_batch_size,
decoder_seq_length=args.decoder_seq_length,
forward_only=True)
ft_integration.on_eval_step_end()
config.timers = get_timers()
# Empty unused memory
......@@ -1701,7 +1826,9 @@ def evaluate(forward_step_func,
timers('evaluate').stop()
timers.log(['evaluate'])
rerun_state_machine.set_mode(rerun_mode)
rerun_state_machine.set_mode(rerun_mode)
return total_loss_dict, collected_non_loss_data, False
......@@ -1869,12 +1996,15 @@ def build_train_valid_test_data_iterators(
def _get_iterator(dataloader_type, dataloader):
"""Return dataset iterator."""
if dataloader_type == "single":
return RerunDataIterator(dataloader)
return RerunDataIterator(iter(dataloader))
elif dataloader_type == "cyclic":
return RerunDataIterator(cyclic_iter(dataloader))
return RerunDataIterator(iter(cyclic_iter(dataloader)))
elif dataloader_type == "external":
# External dataloader is passed through. User is expected to define how to iterate.
return RerunDataIterator(dataloader, make_iterable=False)
if isinstance(dataloader, list):
return [RerunDataIterator(d) for d in dataloader]
else:
return RerunDataIterator(dataloader)
else:
raise RuntimeError("unexpected dataloader type")
......@@ -1894,3 +2024,8 @@ def build_train_valid_test_data_iterators(
test_data_iterator = None
return train_data_iterator, valid_data_iterator, test_data_iterator
def should_disable_forward_pre_hook(args):
"""Block forward pre-hook for certain configurations."""
return not args.use_custom_fsdp and args.use_distributed_optimizer and args.overlap_param_gather
......@@ -33,18 +33,23 @@ from megatron.training import (
get_adlr_autoresume,
)
from megatron.core import DistributedDataParallel as DDP
from megatron.core.distributed.custom_fsdp import FullyShardedDataParallel as custom_FSDP
from megatron.core import mpu
from megatron.core.datasets.utils import get_blend_from_list
from megatron.core.tensor_parallel import param_is_not_tensor_parallel_duplicate
from megatron.core.utils import get_data_parallel_group_if_dtensor, to_local_if_dtensor
from megatron.core.utils import (
get_batch_on_this_cp_rank,
get_data_parallel_group_if_dtensor,
to_local_if_dtensor,
)
from megatron.legacy.model import Float16Module
from megatron.legacy.model.module import param_is_not_shared
try:
from megatron.core.distributed import TorchFullyShardedDataParallel as torch_FSDP
ALL_MODULE_WRAPPER_CLASSNAMES = (DDP, torch_FSDP, Float16Module)
ALL_MODULE_WRAPPER_CLASSNAMES = (DDP, torch_FSDP, custom_FSDP, Float16Module)
except ImportError:
ALL_MODULE_WRAPPER_CLASSNAMES = (DDP, Float16Module)
ALL_MODULE_WRAPPER_CLASSNAMES = (DDP, custom_FSDP, Float16Module)
def unwrap_model(model, module_instances=ALL_MODULE_WRAPPER_CLASSNAMES):
......@@ -62,7 +67,7 @@ def unwrap_model(model, module_instances=ALL_MODULE_WRAPPER_CLASSNAMES):
return unwrapped_model
def calc_params_l2_norm(model):
def calc_params_l2_norm(model, force_create_fp32_copy=False):
"""Calculate l2 norm of parameters """
args = get_args()
if not isinstance(model, list):
......@@ -70,54 +75,110 @@ def calc_params_l2_norm(model):
# Seperate moe and dense params
params_data = []
moe_params_data = []
sharded_params_data = []
data_parallel_group = None
custom_fsdp_all_param_is_shared = False
for model_chunk in model:
for i, param in enumerate(model_chunk.parameters()):
for param in model_chunk.parameters():
data_parallel_group = get_data_parallel_group_if_dtensor(param, data_parallel_group)
is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
if not (param.requires_grad and is_not_tp_duplicate):
if not is_not_tp_duplicate:
continue
assert is_not_tp_duplicate
if hasattr(param, "fully_shard_param_local_shard"):
param = param.fully_shard_param_local_shard
assert [getattr(p, "fully_shard_param_local_shard", None) is not None for p in model_chunk.parameters()]
custom_fsdp_all_param_is_shared = True
if param.numel() == 0:
continue
if not getattr(param, 'allreduce', True):
# TODO: Implement memory optimization for MoE parameters.
assert param_is_not_shared(param)
param = to_local_if_dtensor(param)
moe_params_data.append(param.data.float() if args.bf16 else param.data)
else:
if param_is_not_shared(param):
param = to_local_if_dtensor(param)
params_data.append(param.data.float() if args.bf16 else param.data)
# Calculate dense param norm
if args.bf16:
if not force_create_fp32_copy and hasattr(param, 'main_param'):
if getattr(param, 'main_param_sharded', False):
if param.main_param is not None:
sharded_params_data.append(param.main_param)
else:
params_data.append(param.main_param)
else:
# Fallback to original logic of making a fp32 copy of the
# parameter if `.main_param` attribute is not available.
params_data.append(param.data.float())
else:
params_data.append(param.data)
# Calculate norm.
dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device='cuda')
norm, _ = multi_tensor_applier(
multi_tensor_l2norm,
dummy_overflow_buf,
[params_data],
False # no per-parameter norm
)
norm_2 = norm * norm
if len(params_data) > 0:
norm, _ = multi_tensor_applier(
multi_tensor_l2norm,
dummy_overflow_buf,
[params_data],
False # no per-parameter norm.
)
norm_2 = norm * norm
else:
norm_2 = torch.zeros((1,), dtype=torch.float32, device='cuda')
if data_parallel_group is not None:
torch.distributed.all_reduce(norm_2,
op=torch.distributed.ReduceOp.SUM,
group=data_parallel_group)
# Sum across all model-parallel GPUs(tensor + pipeline).
# Add norm contribution from params with sharded main_params. These norms need to be
# accumulated across the DP group since the main parameters are sharded because
# of distributed optimizer.
if len(sharded_params_data) > 0:
dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device='cuda')
sharded_norm, _ = multi_tensor_applier(
multi_tensor_l2norm,
dummy_overflow_buf,
[sharded_params_data],
False # no per-parameter norm.
)
sharded_norm_2 = sharded_norm * sharded_norm
# Sum over all DP groups.
torch.distributed.all_reduce(
sharded_norm_2,
op=torch.distributed.ReduceOp.SUM,
group=mpu.get_data_parallel_group()
)
norm_2 += sharded_norm_2
if custom_fsdp_all_param_is_shared:
torch.distributed.all_reduce(norm_2,
op=torch.distributed.ReduceOp.SUM,
group=mpu.get_data_parallel_group())
# Sum across all model-parallel GPUs (tensor + pipeline).
torch.distributed.all_reduce(
norm_2,
op=torch.distributed.ReduceOp.SUM,
group=mpu.get_model_parallel_group()
)
# Calculate moe norm
# Add norm contribution from expert layers in MoEs.
if len(moe_params_data) > 0:
moe_norm, _ = multi_tensor_applier(
multi_tensor_l2norm,
dummy_overflow_buf,
[moe_params_data],
False # no per-parameter norm
False # no per-parameter norm.
)
moe_norm_2 = moe_norm * moe_norm
if custom_fsdp_all_param_is_shared:
torch.distributed.all_reduce(moe_norm_2,
op=torch.distributed.ReduceOp.SUM,
group=mpu.get_expert_data_parallel_group())
# Sum across expert tensor, model and pipeline parallel GPUs.
torch.distributed.all_reduce(
moe_norm_2,
......@@ -125,6 +186,7 @@ def calc_params_l2_norm(model):
group=mpu.get_expert_tensor_model_pipeline_parallel_group()
)
norm_2 += moe_norm_2
return norm_2.item() ** 0.5
......@@ -140,6 +202,41 @@ def average_losses_across_data_parallel_group(losses):
return averaged_losses
def reduce_max_stat_across_model_parallel_group(stat: float) -> float:
"""
Ranks without an optimizer will have no grad_norm or num_zeros_in_grad stats.
We need to ensure the logging and writer rank has those values.
This function reduces a stat tensor across the model parallel group.
We use an all_reduce max since the values have already been summed across optimizer ranks where possible
"""
if stat is None:
stat = -1.0
stat = torch.tensor([stat], dtype=torch.float32, device=torch.cuda.current_device())
torch.distributed.all_reduce(
stat, op=torch.distributed.ReduceOp.MAX, group=mpu.get_model_parallel_group()
)
if stat.item() == -1.0:
return None
else:
return stat.item()
def logical_and_across_model_parallel_group(input: bool) -> bool:
"""
This function gathers a bool value across the model parallel group
"""
if input is True:
input = 1
else:
input = 0
input = torch.tensor([input], dtype=torch.int, device=torch.cuda.current_device())
torch.distributed.all_reduce(
input, op=torch.distributed.ReduceOp.MIN, group=mpu.get_model_parallel_group()
)
return bool(input.item())
def report_memory(name):
"""Simple GPU memory report."""
mega_bytes = 1024.0 * 1024.0
......@@ -254,39 +351,6 @@ def get_ltor_masks_and_position_ids(data,
return attention_mask, loss_mask, position_ids
def get_batch_on_this_cp_rank(batch):
""" Slice batch input along sequence dimension into multiple chunks,
which are parallelized across GPUs in a context parallel group.
"""
# With causal masking, each token only attends to its prior tokens. Simply split
# sequence into CP chunks can result in severe load imbalance. That's to say, chunks
# at the end of sequence have bigger workload than others. To address this issue,
# we split sequence into 2*CP ranks. Assuming CP=2, we then get 4 chunks, chunk_0
# and chunk_3 are assigned to GPU0, chunk_1 and chunk_2 are assigned to GPU1, so
# that we can get balanced workload among GPUs in a context parallel group.
args = get_args()
cp_size = args.context_parallel_size
if cp_size > 1:
cp_rank = mpu.get_context_parallel_rank()
for key, val in batch.items():
if val is not None:
seq_dim = 1 if key != 'attention_mask' else 2
val = val.view(
*val.shape[0:seq_dim],
2 * cp_size,
val.shape[seq_dim] // (2 * cp_size),
*val.shape[(seq_dim + 1) :],
)
index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)],
device="cpu", pin_memory=True).cuda(non_blocking=True)
val = val.index_select(seq_dim, index)
val = val.view(*val.shape[0:seq_dim], -1, *val.shape[(seq_dim + 2) :])
batch[key] = val
return batch
def print_rank_0(message):
"""If distributed is initialized, print only on rank 0."""
if torch.distributed.is_initialized():
......@@ -295,6 +359,10 @@ def print_rank_0(message):
else:
print(message, flush=True)
def is_rank0():
"""Returns true if called in the rank0, false otherwise"""
return torch.distributed.is_initialized() and torch.distributed.get_rank() == 0
def is_last_rank():
return torch.distributed.get_rank() == (
torch.distributed.get_world_size() - 1)
......@@ -307,6 +375,9 @@ def print_rank_last(message):
else:
print(message, flush=True)
def get_device_arch_version():
"""Returns GPU arch version (8: Ampere, 9: Hopper, 10: Blackwell, ...)"""
return torch.cuda.get_device_properties(torch.device("cuda:0")).major
def append_to_progress_log(string, barrier=True):
"""Append given string to progress log."""
......@@ -431,11 +502,11 @@ def get_batch_on_this_tp_rank(data_iterator):
_broadcast(loss_mask)
_broadcast(attention_mask)
_broadcast(position_ids)
elif mpu.is_pipeline_first_stage():
labels=None
loss_mask=None
_broadcast(tokens)
_broadcast(attention_mask)
_broadcast(position_ids)
......@@ -443,11 +514,11 @@ def get_batch_on_this_tp_rank(data_iterator):
elif mpu.is_pipeline_last_stage():
tokens=None
position_ids=None
_broadcast(labels)
_broadcast(loss_mask)
_broadcast(attention_mask)
batch = {
'tokens': tokens,
'labels': labels,
......
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from pathlib import Path
from typing import Tuple
from megatron.training.global_vars import get_wandb_writer
from megatron.training.utils import print_rank_last
def _get_wandb_artifact_tracker_filename(save_dir: str) -> Path:
"""Wandb artifact tracker file rescords the latest artifact wandb entity and project"""
return Path(save_dir) / "latest_wandb_artifact_path.txt"
def _get_artifact_name_and_version(save_dir: Path, checkpoint_path: Path) -> Tuple[str, str]:
return save_dir.stem, checkpoint_path.stem
def on_save_checkpoint_success(checkpoint_path: str, tracker_filename: str, save_dir: str, iteration: int) -> None:
"""Function to be called after checkpointing succeeds and checkpoint is persisted for logging it as an artifact in W&B
Args:
checkpoint_path (str): path of the saved checkpoint
tracker_filename (str): path of the tracker filename for the checkpoint iteration
save_dir (str): path of the root save folder for all checkpoints
iteration (int): iteration of the checkpoint
"""
wandb_writer = get_wandb_writer()
if wandb_writer:
metadata = {"iteration": iteration}
artifact_name, artifact_version = _get_artifact_name_and_version(Path(save_dir), Path(checkpoint_path))
artifact = wandb_writer.Artifact(artifact_name, type="model", metadata=metadata)
artifact.add_reference(f"file://{checkpoint_path}", checksum=False)
artifact.add_file(tracker_filename)
wandb_writer.run.log_artifact(artifact, aliases=[artifact_version])
wandb_tracker_filename = _get_wandb_artifact_tracker_filename(save_dir)
wandb_tracker_filename.write_text(f"{wandb_writer.run.entity}/{wandb_writer.run.project}")
def on_load_checkpoint_success(checkpoint_path: str, load_dir: str) -> None:
"""Function to be called after succesful loading of a checkpoint, for aggregation and logging it to W&B
Args:
checkpoint_path (str): path of the loaded checkpoint
load_dir (str): path of the root save folder for all checkpoints
iteration (int): iteration of the checkpoint
"""
wandb_writer = get_wandb_writer()
if wandb_writer:
try:
artifact_name, artifact_version = _get_artifact_name_and_version(Path(load_dir), Path(checkpoint_path))
wandb_tracker_filename = _get_wandb_artifact_tracker_filename(load_dir)
artifact_path = ""
if wandb_tracker_filename.is_file():
artifact_path = wandb_tracker_filename.read_text().strip()
artifact_path = f"{artifact_path}/"
wandb_writer.run.use_artifact(f"{artifact_path}{artifact_name}:{artifact_version}")
except Exception:
print_rank_last(f" failed to find checkpoint {checkpoint_path} in wandb")
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment