Commit 2c63b5cd authored by wangxj's avatar wangxj
Browse files

升级0.12版本

parent c271aaae
Pipeline #2451 passed with stage
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
...@@ -18,11 +18,15 @@ from megatron.core.models.retro.utils import ( ...@@ -18,11 +18,15 @@ from megatron.core.models.retro.utils import (
get_config_path as get_retro_config_path, get_config_path as get_retro_config_path,
get_gpt_data_dir as get_retro_data_dir, get_gpt_data_dir as get_retro_data_dir,
) )
from megatron.core.rerun_state_machine import RerunStateMachine
from megatron.core.transformer import TransformerConfig, MLATransformerConfig from megatron.core.transformer import TransformerConfig, MLATransformerConfig
from megatron.core.transformer.enums import AttnBackend from megatron.core.transformer.enums import AttnBackend
from megatron.core.utils import is_torch_min_version from megatron.core.utils import (
is_torch_min_version,
get_torch_version,
)
from megatron.training.activations import squared_relu from megatron.training.activations import squared_relu
from megatron.training.utils import update_use_dist_ckpt from megatron.training.utils import update_use_dist_ckpt, get_device_arch_version
def parse_args(extra_args_provider=None, ignore_unknown_args=False): def parse_args(extra_args_provider=None, ignore_unknown_args=False):
...@@ -187,21 +191,28 @@ def moe_freq_type(x): ...@@ -187,21 +191,28 @@ def moe_freq_type(x):
def validate_args(args, defaults={}): def validate_args(args, defaults={}):
# Temporary # Temporary
assert args.non_persistent_ckpt_type in ['global', None], \ assert args.non_persistent_ckpt_type in ['global', 'local', None], \
'Currently only global checkpoints are supported' 'Currently only global and local checkpoints are supported'
if args.non_persistent_ckpt_type == 'local':
try:
from nvidia_resiliency_ext.checkpointing.local.ckpt_managers.local_manager import \
LocalCheckpointManager
except ModuleNotFoundError as e:
raise RuntimeError('nvidia_resiliency_ext is required for local checkpointing') from e
# Load saved args from Retro (if applicable). # Load saved args from Retro (if applicable).
load_retro_args(args) load_retro_args(args)
# Set args.use_dist_ckpt from args.ckpt_format. # Set args.use_dist_ckpt from args.ckpt_format.
if args.use_legacy_models:
assert args.ckpt_format == "torch", \
"legacy model format only supports the 'torch' checkpoint format."
update_use_dist_ckpt(args) update_use_dist_ckpt(args)
if args.encoder_pipeline_model_parallel_size == 0 and args.num_experts == 0: if args.encoder_pipeline_model_parallel_size == 0 and args.num_experts == 0:
assert args.encoder_tensor_model_parallel_size == args.tensor_model_parallel_size, "If non-MOE encoder shares first decoder pipeline rank it must have the same TP as the decoder." assert args.encoder_tensor_model_parallel_size == args.tensor_model_parallel_size, "If non-MOE encoder shares first decoder pipeline rank it must have the same TP as the decoder."
if args.encoder_tensor_model_parallel_size > 0: if args.encoder_tensor_model_parallel_size > 0:
assert args.encoder_pipeline_model_parallel_size > 0, "encoder_pipeline_model_parallel_size must be defined."
assert args.num_attention_heads % args.encoder_tensor_model_parallel_size == 0 assert args.num_attention_heads % args.encoder_tensor_model_parallel_size == 0
assert args.encoder_tensor_model_parallel_size <= args.tensor_model_parallel_size, "We do not support encoders with more TP than the decoder." assert args.encoder_tensor_model_parallel_size <= args.tensor_model_parallel_size, "We do not support encoders with more TP than the decoder."
...@@ -220,12 +231,8 @@ def validate_args(args, defaults={}): ...@@ -220,12 +231,8 @@ def validate_args(args, defaults={}):
if args.attention_backend == AttnBackend.local: if args.attention_backend == AttnBackend.local:
assert args.spec[0] == 'local' , '--attention-backend local is only supported with --spec local' assert args.spec[0] == 'local' , '--attention-backend local is only supported with --spec local'
# Pipeline model parallel size. # Pipeline model parallel size.
args.transformer_pipeline_model_parallel_size = ( args.transformer_pipeline_model_parallel_size = args.pipeline_model_parallel_size
args.pipeline_model_parallel_size - 1
if args.standalone_embedding_stage else
args.pipeline_model_parallel_size
)
args.data_parallel_size = args.world_size // total_model_size args.data_parallel_size = args.world_size // total_model_size
...@@ -329,13 +336,12 @@ def validate_args(args, defaults={}): ...@@ -329,13 +336,12 @@ def validate_args(args, defaults={}):
print('setting global batch size to {}'.format( print('setting global batch size to {}'.format(
args.global_batch_size), flush=True) args.global_batch_size), flush=True)
assert args.global_batch_size > 0 assert args.global_batch_size > 0
if args.decoder_first_pipeline_num_layers is None and args.decoder_last_pipeline_num_layers is None:
# Divisibility check not applicable for T5 models which specify encoder_num_layers # Uneven virtual pipeline parallelism
# and decoder_num_layers. assert args.num_layers_per_virtual_pipeline_stage is None or args.num_virtual_stages_per_pipeline_rank is None, \
if args.num_layers is not None: '--num-layers-per-virtual-pipeline-stage and --num-virtual-stages-per-pipeline-rank cannot be set at the same time'
assert args.num_layers % args.transformer_pipeline_model_parallel_size == 0, \
'Number of layers should be divisible by the pipeline-model-parallel size' if args.num_layers_per_virtual_pipeline_stage is not None or args.num_virtual_stages_per_pipeline_rank is not None:
if args.num_layers_per_virtual_pipeline_stage is not None:
if args.overlap_p2p_comm: if args.overlap_p2p_comm:
assert args.pipeline_model_parallel_size > 1, \ assert args.pipeline_model_parallel_size > 1, \
'When interleaved schedule is used, pipeline-model-parallel size '\ 'When interleaved schedule is used, pipeline-model-parallel size '\
...@@ -345,15 +351,28 @@ def validate_args(args, defaults={}): ...@@ -345,15 +351,28 @@ def validate_args(args, defaults={}):
'When interleaved schedule is used and p2p communication overlap is disabled, '\ 'When interleaved schedule is used and p2p communication overlap is disabled, '\
'pipeline-model-parallel size should be greater than 2 to avoid having multiple '\ 'pipeline-model-parallel size should be greater than 2 to avoid having multiple '\
'p2p sends and recvs between same 2 ranks per communication batch' 'p2p sends and recvs between same 2 ranks per communication batch'
assert args.num_layers is not None
# Double check divisibility check here since check above is if guarded. if args.num_virtual_stages_per_pipeline_rank is None:
assert args.num_layers % args.transformer_pipeline_model_parallel_size == 0, \ assert args.decoder_first_pipeline_num_layers is None and args.decoder_last_pipeline_num_layers is None, \
'Number of layers should be divisible by the pipeline-model-parallel size' 'please use --num-virtual-stages-per-pipeline-rank to specify virtual pipeline parallel degree when enable uneven pipeline parallelism'
num_layers_per_pipeline_stage = args.num_layers // args.transformer_pipeline_model_parallel_size num_layers = args.num_layers
assert num_layers_per_pipeline_stage % args.num_layers_per_virtual_pipeline_stage == 0, \
'Number of layers per pipeline stage must be divisible by number of layers per virtual pipeline stage' if args.account_for_embedding_in_pipeline_split:
args.virtual_pipeline_model_parallel_size = num_layers_per_pipeline_stage // \ num_layers += 1
args.num_layers_per_virtual_pipeline_stage
if args.account_for_loss_in_pipeline_split:
num_layers += 1
assert num_layers % args.transformer_pipeline_model_parallel_size == 0, \
'number of layers of the model must be divisible pipeline model parallel size'
num_layers_per_pipeline_stage = num_layers // args.transformer_pipeline_model_parallel_size
assert num_layers_per_pipeline_stage % args.num_layers_per_virtual_pipeline_stage == 0, \
'number of layers per pipeline stage must be divisible number of layers per virtual pipeline stage'
args.virtual_pipeline_model_parallel_size = num_layers_per_pipeline_stage // \
args.num_layers_per_virtual_pipeline_stage
else:
args.virtual_pipeline_model_parallel_size = args.num_virtual_stages_per_pipeline_rank
else: else:
args.virtual_pipeline_model_parallel_size = None args.virtual_pipeline_model_parallel_size = None
# Overlap P2P communication is disabled if not using the interleaved schedule. # Overlap P2P communication is disabled if not using the interleaved schedule.
...@@ -364,6 +383,30 @@ def validate_args(args, defaults={}): ...@@ -364,6 +383,30 @@ def validate_args(args, defaults={}):
print('WARNING: Setting args.overlap_p2p_comm and args.align_param_gather to False ' print('WARNING: Setting args.overlap_p2p_comm and args.align_param_gather to False '
'since non-interleaved schedule does not support overlapping p2p communication ' 'since non-interleaved schedule does not support overlapping p2p communication '
'and aligned param AG') 'and aligned param AG')
if args.decoder_first_pipeline_num_layers is None and args.decoder_last_pipeline_num_layers is None:
# Divisibility check not applicable for T5 models which specify encoder_num_layers
# and decoder_num_layers.
if args.num_layers is not None:
num_layers = args.num_layers
if args.account_for_embedding_in_pipeline_split:
num_layers += 1
if args.account_for_loss_in_pipeline_split:
num_layers += 1
assert num_layers % args.transformer_pipeline_model_parallel_size == 0, \
'Number of layers should be divisible by the pipeline-model-parallel size'
if args.rank == 0:
print(f"Number of virtual stages per pipeline stage: {args.virtual_pipeline_model_parallel_size}")
if args.data_parallel_sharding_strategy == "optim_grads_params":
args.overlap_param_gather = True
args.overlap_grad_reduce = True
if args.data_parallel_sharding_strategy == "optim_grads":
args.overlap_grad_reduce = True
if args.overlap_param_gather: if args.overlap_param_gather:
assert args.use_distributed_optimizer, \ assert args.use_distributed_optimizer, \
...@@ -373,8 +416,8 @@ def validate_args(args, defaults={}): ...@@ -373,8 +416,8 @@ def validate_args(args, defaults={}):
assert not args.use_legacy_models, \ assert not args.use_legacy_models, \
'--overlap-param-gather only supported with MCore models' '--overlap-param-gather only supported with MCore models'
if getattr(args, "use_torch_fsdp2", False): if args.use_torch_fsdp2:
assert get_torch_version() >= PkgVersion("2.4"), \ assert is_torch_min_version("2.4.0"), \
'FSDP2 requires PyTorch >= 2.4.0 with FSDP 2 support.' 'FSDP2 requires PyTorch >= 2.4.0 with FSDP 2 support.'
assert args.pipeline_model_parallel_size == 1, \ assert args.pipeline_model_parallel_size == 1, \
'--use-torch-fsdp2 is not supported with pipeline parallelism' '--use-torch-fsdp2 is not supported with pipeline parallelism'
...@@ -401,10 +444,33 @@ def validate_args(args, defaults={}): ...@@ -401,10 +444,33 @@ def validate_args(args, defaults={}):
assert not args.use_dist_ckpt, \ assert not args.use_dist_ckpt, \
'--overlap-param-gather-with-optimizer-step not supported with distributed checkpointing yet' '--overlap-param-gather-with-optimizer-step not supported with distributed checkpointing yet'
dtype_map = {
'fp32': torch.float32, 'bf16': torch.bfloat16, 'fp16': torch.float16, 'fp8': torch.uint8,
}
map_dtype = lambda d: d if isinstance(d, torch.dtype) else dtype_map[d]
args.main_grads_dtype = map_dtype(args.main_grads_dtype)
args.main_params_dtype = map_dtype(args.main_params_dtype)
args.exp_avg_dtype = map_dtype(args.exp_avg_dtype)
args.exp_avg_sq_dtype = map_dtype(args.exp_avg_sq_dtype)
if args.fp8_param_gather: if args.fp8_param_gather:
assert args.use_distributed_optimizer, \ assert args.use_distributed_optimizer, \
'--fp8-param-gather only supported with distributed optimizer' '--fp8-param-gather only supported with distributed optimizer'
if args.use_custom_fsdp:
assert args.use_distributed_optimizer, \
'--use-custom-fsdp only supported with distributed optimizer'
if args.data_parallel_sharding_strategy in ["optim_grads_params", "optim_grads"]:
warnings.warn('Please make sure your TransformerEngine support FSDP + gradient accumulation fusion')
assert args.gradient_accumulation_fusion is False, \
"optim_grads_params optim_grads are not supported with gradient accumulation fusion"
if args.data_parallel_sharding_strategy == "optim_grads_params":
assert args.check_weight_hash_across_dp_replicas_interval is None, \
'check_weight_hash_across_dp_replicas_interval is not supported with optim_grads_params'
# Parameters dtype. # Parameters dtype.
args.params_dtype = torch.float args.params_dtype = torch.float
if args.fp16: if args.fp16:
...@@ -422,7 +488,13 @@ def validate_args(args, defaults={}): ...@@ -422,7 +488,13 @@ def validate_args(args, defaults={}):
args.params_dtype = torch.bfloat16 args.params_dtype = torch.bfloat16
# bfloat16 requires gradient accumulation and all-reduce to # bfloat16 requires gradient accumulation and all-reduce to
# be done in fp32. # be done in fp32.
if not args.accumulate_allreduce_grads_in_fp32: if args.accumulate_allreduce_grads_in_fp32:
assert args.main_grads_dtype == torch.float32, \
"--main-grads-dtype can only be fp32 when --accumulate-allreduce-grads-in-fp32 is set"
if args.grad_reduce_in_bf16:
args.accumulate_allreduce_grads_in_fp32 = False
elif not args.accumulate_allreduce_grads_in_fp32 and args.main_grads_dtype == torch.float32:
args.accumulate_allreduce_grads_in_fp32 = True args.accumulate_allreduce_grads_in_fp32 = True
if args.rank == 0: if args.rank == 0:
print('accumulate and all-reduce gradients in fp32 for ' print('accumulate and all-reduce gradients in fp32 for '
...@@ -525,7 +597,9 @@ def validate_args(args, defaults={}): ...@@ -525,7 +597,9 @@ def validate_args(args, defaults={}):
args.seq_length = args.encoder_seq_length args.seq_length = args.encoder_seq_length
if args.seq_length is not None: if args.seq_length is not None:
assert args.max_position_embeddings >= args.seq_length assert args.max_position_embeddings >= args.seq_length, \
f"max_position_embeddings ({args.max_position_embeddings}) must be greater than " \
f"or equal to seq_length ({args.seq_length})."
if args.decoder_seq_length is not None: if args.decoder_seq_length is not None:
assert args.max_position_embeddings >= args.decoder_seq_length assert args.max_position_embeddings >= args.decoder_seq_length
if args.lr is not None: if args.lr is not None:
...@@ -597,7 +671,7 @@ def validate_args(args, defaults={}): ...@@ -597,7 +671,7 @@ def validate_args(args, defaults={}):
# model parallel memory optimization is enabled # model parallel memory optimization is enabled
if args.sequence_parallel: if args.sequence_parallel:
args.async_tensor_model_parallel_allreduce = False args.async_tensor_model_parallel_allreduce = False
if getattr(args, "use_torch_fsdp2", False): if args.use_torch_fsdp2:
warnings.warn( warnings.warn(
"Using sequence parallelism with FSDP2 together. Try not to using them " "Using sequence parallelism with FSDP2 together. Try not to using them "
"together since they require different CUDA_MAX_CONNECTIONS settings " "together since they require different CUDA_MAX_CONNECTIONS settings "
...@@ -605,13 +679,14 @@ def validate_args(args, defaults={}): ...@@ -605,13 +679,14 @@ def validate_args(args, defaults={}):
"environment variable CUDA_DEVICE_MAX_CONNECTIONS to 1 while FSDP2 " "environment variable CUDA_DEVICE_MAX_CONNECTIONS to 1 while FSDP2 "
"requires not setting CUDA_DEVICE_MAX_CONNECTIONS=1 for better parallelization.") "requires not setting CUDA_DEVICE_MAX_CONNECTIONS=1 for better parallelization.")
if os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1": if os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1" and get_device_arch_version() < 10:
# CUDA_DEVICE_MAX_CONNECTIONS requirement no longer exists since the Blackwell architecture
if args.sequence_parallel: if args.sequence_parallel:
raise RuntimeError( warnings.warn(
"Using sequence parallelism requires setting the environment variable " "Using sequence parallelism requires setting the environment variable "
"CUDA_DEVICE_MAX_CONNECTIONS to 1") "CUDA_DEVICE_MAX_CONNECTIONS to 1")
if args.async_tensor_model_parallel_allreduce: if args.async_tensor_model_parallel_allreduce:
raise RuntimeError( warnings.warn(
"Using async gradient all reduce requires setting the environment " "Using async gradient all reduce requires setting the environment "
"variable CUDA_DEVICE_MAX_CONNECTIONS to 1") "variable CUDA_DEVICE_MAX_CONNECTIONS to 1")
...@@ -642,9 +717,6 @@ def validate_args(args, defaults={}): ...@@ -642,9 +717,6 @@ def validate_args(args, defaults={}):
assert not args.use_legacy_models, \ assert not args.use_legacy_models, \
'--decoupled-lr and --decoupled-min-lr is not supported in legacy models.' '--decoupled-lr and --decoupled-min-lr is not supported in legacy models.'
# FlashAttention
args.use_flash_attn = args.use_flash_attn_cutlass or args.use_flash_attn_triton
# Legacy RoPE arguments # Legacy RoPE arguments
if args.use_rotary_position_embeddings: if args.use_rotary_position_embeddings:
args.position_embedding_type = 'rope' args.position_embedding_type = 'rope'
...@@ -660,15 +732,21 @@ def validate_args(args, defaults={}): ...@@ -660,15 +732,21 @@ def validate_args(args, defaults={}):
if not args.add_position_embedding and args.position_embedding_type != 'rope': if not args.add_position_embedding and args.position_embedding_type != 'rope':
raise RuntimeError('--no-position-embedding is deprecated, use --position-embedding-type') raise RuntimeError('--no-position-embedding is deprecated, use --position-embedding-type')
# Relative position embeddings arguments
if args.position_embedding_type == 'relative':
assert (
args.transformer_impl == "transformer_engine"
), 'Local transformer implementation currently does not support attention bias-based position embeddings.'
# MoE Spec check # MoE Spec check
if args.num_experts == 0: if args.num_experts == 0:
args.num_experts = None args.num_experts = None
if args.num_experts is not None: if args.num_experts is not None:
assert args.spec is None, "Model Spec must be None when using MoEs" assert args.spec is None, "Model Spec must be None when using MoEs"
if args.moe_ffn_hidden_size is None: if args.moe_ffn_hidden_size is None:
args.moe_ffn_hidden_size = args.ffn_hidden_size args.moe_ffn_hidden_size = args.ffn_hidden_size
# Context parallel # Context parallel
if args.context_parallel_size > 1: if args.context_parallel_size > 1:
assert not args.use_legacy_models, "Context parallelism is not supported in legacy models." assert not args.use_legacy_models, "Context parallelism is not supported in legacy models."
...@@ -691,10 +769,6 @@ def validate_args(args, defaults={}): ...@@ -691,10 +769,6 @@ def validate_args(args, defaults={}):
any([args.train_data_path, args.valid_data_path, args.test_data_path]) \ any([args.train_data_path, args.valid_data_path, args.test_data_path]) \
<= 1, "A single data source must be provided in training mode, else None" <= 1, "A single data source must be provided in training mode, else None"
if args.use_tp_pp_dp_mapping:
assert args.context_parallel_size * args.expert_model_parallel_size <= 1, \
"context_parallel and expert_model_parallel can't be used with tp-pp-dp mapping."
# Deterministic mode # Deterministic mode
if args.deterministic_mode: if args.deterministic_mode:
assert not args.use_flash_attn, "Flash attention can not be used in deterministic mode." assert not args.use_flash_attn, "Flash attention can not be used in deterministic mode."
...@@ -710,6 +784,21 @@ def validate_args(args, defaults={}): ...@@ -710,6 +784,21 @@ def validate_args(args, defaults={}):
if args.apply_query_key_layer_scaling: if args.apply_query_key_layer_scaling:
args.attention_softmax_in_fp32 = True args.attention_softmax_in_fp32 = True
if args.result_rejected_tracker_filename is not None:
# Append to passed-in args.iterations_to_skip.
iterations_to_skip_from_file = RerunStateMachine.get_skipped_iterations_from_tracker_file(
args.result_rejected_tracker_filename
)
args.iterations_to_skip.extend(iterations_to_skip_from_file)
# Make sure all functionality that requires Gloo process groups is disabled.
if not args.enable_gloo_process_groups:
if args.use_distributed_optimizer:
# If using distributed optimizer, must use distributed checkpointing.
# Legacy checkpointing uses Gloo process groups to collect full distributed
# optimizer state in the CPU memory of DP rank 0.
assert args.use_dist_ckpt
# Checkpointing # Checkpointing
if args.ckpt_fully_parallel_save_deprecated and args.rank == 0: if args.ckpt_fully_parallel_save_deprecated and args.rank == 0:
print('--ckpt-fully-parallel-save flag is deprecated and has no effect.' print('--ckpt-fully-parallel-save flag is deprecated and has no effect.'
...@@ -745,6 +834,31 @@ def validate_args(args, defaults={}): ...@@ -745,6 +834,31 @@ def validate_args(args, defaults={}):
args.no_load_rng = True args.no_load_rng = True
print('Warning: disabling --no-load-rng for upcycling.') print('Warning: disabling --no-load-rng for upcycling.')
# Optimizer CPU offload check
if args.optimizer_cpu_offload:
assert args.use_precision_aware_optimizer, (
"The optimizer cpu offload must be used in conjunction with `--use-precision-aware-optimizer`, "
"as the hybrid device optimizer reuses the code path of this flag."
)
# MoE loss and include embedding and loss layer check
if args.num_experts is not None:
if args.moe_router_load_balancing_type != "none" or args.moe_z_loss_coeff is not None:
assert not args.account_for_embedding_in_pipeline_split, \
"Cannot support load balancing loss and z loss with --account-for-embedding-in-pipeline-split"
assert not args.account_for_loss_in_pipeline_split, \
"Cannot support load balancing loss and z loss with --account-for-loss-in-pipeline-split"
if args.non_persistent_ckpt_type == "local":
assert args.non_persistent_local_ckpt_dir is not None, "Tried to use local checkpointing without specifying --local-ckpt-dir!"
if args.replication:
assert args.replication_jump is not None, "--replication requires the value of --replication-jump!"
assert args.non_persistent_ckpt_type == "local", f"--replication requires args.non_persistent_ckpt_type == 'local', but got: {args.non_persistent_ckpt_type}"
elif args.replication_jump:
print("Warning: --replication-jump was specified despite not using replication. Ignoring.")
args.replication_jump = None
# Print arguments. # Print arguments.
_print_args("arguments", args) _print_args("arguments", args)
...@@ -791,8 +905,8 @@ def core_transformer_config_from_args(args, config_class=None): ...@@ -791,8 +905,8 @@ def core_transformer_config_from_args(args, config_class=None):
kw_args['batch_p2p_comm'] = not args.overlap_p2p_comm kw_args['batch_p2p_comm'] = not args.overlap_p2p_comm
kw_args['num_moe_experts'] = args.num_experts kw_args['num_moe_experts'] = args.num_experts
kw_args['rotary_interleaved'] = args.rotary_interleaved kw_args['rotary_interleaved'] = args.rotary_interleaved
kw_args['first_pipeline_num_layers']= args.decoder_first_pipeline_num_layers kw_args['num_layers_in_first_pipeline_stage']= args.decoder_first_pipeline_num_layers
kw_args['last_pipeline_num_layers']= args.decoder_last_pipeline_num_layers kw_args['num_layers_in_last_pipeline_stage']= args.decoder_last_pipeline_num_layers
if args.swiglu: if args.swiglu:
kw_args['activation_func'] = F.silu kw_args['activation_func'] = F.silu
kw_args['gated_linear_unit'] = True kw_args['gated_linear_unit'] = True
...@@ -847,6 +961,11 @@ def _add_transformer_engine_args(parser): ...@@ -847,6 +961,11 @@ def _add_transformer_engine_args(parser):
group.add_argument('--fp8-param-gather', action='store_true', group.add_argument('--fp8-param-gather', action='store_true',
help='Keep the compute param in fp8 (do not use any other intermediate ' help='Keep the compute param in fp8 (do not use any other intermediate '
'dtype) and perform the param all-gather in fp8.') 'dtype) and perform the param all-gather in fp8.')
group.add_argument('--te-rng-tracker', action='store_true', default=False,
help='Use the Transformer Engine version of the random number generator. '
'Required for CUDA graphs support.')
group.add_argument('--inference-rng-tracker', action='store_true', default=False,
help='Use a random number generator configured for inference.')
return parser return parser
def _add_inference_args(parser): def _add_inference_args(parser):
...@@ -873,8 +992,15 @@ def _add_inference_args(parser): ...@@ -873,8 +992,15 @@ def _add_inference_args(parser):
'Bert embedder.') 'Bert embedder.')
group.add_argument('--flash-decode', default=False, action="store_true", group.add_argument('--flash-decode', default=False, action="store_true",
help='Whether to use the flash decoding kernel.') help='Whether to use the flash decoding kernel.')
group.add_argument('--enable-cuda-graph', default=False, action="store_true",
help='Use CUDA graph capture and replay.')
group.add_argument("--cuda-graph-warmup-steps", type=int, default=3,
help="Number of CUDA graph warmup steps")
group.add_argument('--inference-max-requests', type=int, default=8,
help='Maximum number of requests for inference.',
dest='inference_max_batch_size')
group.add_argument('--inference-max-seq-length', type=int, default=2560, group.add_argument('--inference-max-seq-length', type=int, default=2560,
help='Maximum sequence length allocated for prefill during inference.', help='Maximum sequence length expected for inference (prefill + decode).',
dest='inference_max_seq_length') dest='inference_max_seq_length')
return parser return parser
...@@ -957,8 +1083,12 @@ def _add_network_size_args(parser): ...@@ -957,8 +1083,12 @@ def _add_network_size_args(parser):
help='Maximum number of position embeddings to use. ' help='Maximum number of position embeddings to use. '
'This is the size of position embedding.') 'This is the size of position embedding.')
group.add_argument('--position-embedding-type', type=str, default='learned_absolute', group.add_argument('--position-embedding-type', type=str, default='learned_absolute',
choices=['learned_absolute', 'rope', 'none'], choices=['learned_absolute', 'rope', 'relative', 'none'],
help='Position embedding type.') help='Position embedding type.')
group.add_argument('--relative-attention-num-buckets', type=int, default=32,
help='Number of buckets for relative position embeddings.')
group.add_argument('--relative-attention-max-distance', type=int, default=128,
help='Maximum distance for relative position embeddings calculation.')
group.add_argument('--use-rotary-position-embeddings', action='store_true', group.add_argument('--use-rotary-position-embeddings', action='store_true',
help='Use rotary positional embeddings or not. ' help='Use rotary positional embeddings or not. '
'Deprecated: use --position-embedding-type') 'Deprecated: use --position-embedding-type')
...@@ -971,7 +1101,9 @@ def _add_network_size_args(parser): ...@@ -971,7 +1101,9 @@ def _add_network_size_args(parser):
group.add_argument('--rotary-seq-len-interpolation-factor', type=int, default=None, group.add_argument('--rotary-seq-len-interpolation-factor', type=int, default=None,
help='Sequence length interpolation factor for rotary embeddings.') help='Sequence length interpolation factor for rotary embeddings.')
group.add_argument('--use-rope-scaling', action='store_true', group.add_argument('--use-rope-scaling', action='store_true',
help='Apply rope scaling as used in llama3.1') help='Apply rope scaling as used in llama3.x')
group.add_argument('--rope-scaling-factor', type=float, default=8.0,
help='Rope scaling factor in llama3.x models')
group.add_argument('--no-position-embedding', group.add_argument('--no-position-embedding',
action='store_false', action='store_false',
help='Disable position embedding. Deprecated: use --position-embedding-type', help='Disable position embedding. Deprecated: use --position-embedding-type',
...@@ -1059,6 +1191,9 @@ def _add_ft_package_args(parser): ...@@ -1059,6 +1191,9 @@ def _add_ft_package_args(parser):
group.add_argument('--enable-ft-package', action='store_true', group.add_argument('--enable-ft-package', action='store_true',
help='If set, Fault Tolerance package is enabled. ' help='If set, Fault Tolerance package is enabled. '
'Note: This feature is for Nvidia internal use only.') 'Note: This feature is for Nvidia internal use only.')
group.add_argument('--calc-ft-timeouts', action='store_true',
help='If set, FT package will try to automatically compute the timeouts. '
'Note: This feature is for Nvidia internal use only.')
return parser return parser
...@@ -1227,6 +1362,9 @@ def _add_training_args(parser): ...@@ -1227,6 +1362,9 @@ def _add_training_args(parser):
group.add_argument('--check-for-spiky-loss', action='store_true', group.add_argument('--check-for-spiky-loss', action='store_true',
help='Check for spiky loss', help='Check for spiky loss',
dest='check_for_spiky_loss') dest='check_for_spiky_loss')
group.add_argument('--check-for-large-grads', action='store_true',
help='Check for unexpectedly large grads',
dest='check_for_large_grads')
group.add_argument('--distribute-saved-activations', group.add_argument('--distribute-saved-activations',
action='store_true', action='store_true',
help='If set, distribute recomputed activations ' help='If set, distribute recomputed activations '
...@@ -1259,17 +1397,19 @@ def _add_training_args(parser): ...@@ -1259,17 +1397,19 @@ def _add_training_args(parser):
help='Global step to start profiling.') help='Global step to start profiling.')
group.add_argument('--profile-step-end', type=int, default=12, group.add_argument('--profile-step-end', type=int, default=12,
help='Global step to stop profiling.') help='Global step to stop profiling.')
group.add_argument('--iterations-to-skip', nargs='+', type=int, default=[],
help='List of iterations to skip, empty by default.')
group.add_argument('--result-rejected-tracker-filename', type=str, default=None,
help='Optional name of file tracking `result_rejected` events.')
group.add_argument('--disable-gloo-process-groups', action='store_false',
dest='enable_gloo_process_groups',
help='Disables creation and usage of Gloo process groups.')
group.add_argument('--use-pytorch-profiler', action='store_true', group.add_argument('--use-pytorch-profiler', action='store_true',
help='Use the built-in pytorch profiler. ' help='Use the built-in pytorch profiler. '
'Useful if you wish to view profiles in tensorboard.', 'Useful if you wish to view profiles in tensorboard.',
dest='use_pytorch_profiler') dest='use_pytorch_profiler')
group.add_argument('--use-hip-profiler', action='store_true',
help='Use HIP PROFILER',
dest='use_hip_profiler')
group.add_argument('--profile-ranks', nargs='+', type=int, default=[0], group.add_argument('--profile-ranks', nargs='+', type=int, default=[0],
help='Global ranks to profile.') help='Global ranks to profile.')
group.add_argument('--profile-dir', type=str, default="./",
help='profile dir to save.')
group.add_argument('--record-memory-history', action="store_true", default=False, group.add_argument('--record-memory-history', action="store_true", default=False,
help='Record memory history in last rank.') help='Record memory history in last rank.')
group.add_argument('--memory-snapshot-path', type=str, default="snapshot.pickle", group.add_argument('--memory-snapshot-path', type=str, default="snapshot.pickle",
...@@ -1363,11 +1503,9 @@ def _add_training_args(parser): ...@@ -1363,11 +1503,9 @@ def _add_training_args(parser):
group.add_argument('--cross-entropy-loss-fusion', action='store_true', group.add_argument('--cross-entropy-loss-fusion', action='store_true',
help='Enabled fusion of cross entropy loss calculation.', help='Enabled fusion of cross entropy loss calculation.',
dest='cross_entropy_loss_fusion') dest='cross_entropy_loss_fusion')
group.add_argument('--use-flash-attn-cutlass', action='store_true', group.add_argument('--use-flash-attn', action='store_true',
help='use FlashAttention implementation of attention. ' help='use FlashAttention implementation of attention. '
'https://arxiv.org/abs/2205.14135') 'https://arxiv.org/abs/2205.14135')
group.add_argument('--use-flash-attn-triton', action='store_true',
help='use FlashAttention implementation of attention using Triton.')
group.add_argument('--disable-bias-linear', action='store_false', group.add_argument('--disable-bias-linear', action='store_false',
help='Disable bias in the linear layers', help='Disable bias in the linear layers',
dest='add_bias_linear') dest='add_bias_linear')
...@@ -1377,6 +1515,18 @@ def _add_training_args(parser): ...@@ -1377,6 +1515,18 @@ def _add_training_args(parser):
group.add_argument('--optimizer', type=str, default='adam', group.add_argument('--optimizer', type=str, default='adam',
choices=['adam', 'sgd'], choices=['adam', 'sgd'],
help='Optimizer function') help='Optimizer function')
group.add_argument('--optimizer-cpu-offload', action='store_true',
help='Offload optimizer state to CPU')
group.add_argument('--optimizer-offload-fraction', type=float, default=1.0,
help='Ratio of optimizer state to offload to CPU')
group.add_argument('--use-torch-optimizer-for-cpu-offload', action='store_true',
help="Use torch.optim.Optimizer instead of Megatron's optimizer in optimizer cpu offload mode.")
group.add_argument('--overlap-cpu-optimizer-d2h-h2d', action='store_true', default=False,
help='Overlap CPU optimizer step, gradients D2H and updated parameters H2D.')
group.add_argument('--no-pin-cpu-grads', action='store_false', dest='pin_cpu_grads',
help='Disable pinning of CPU memory for gradients.')
group.add_argument('--no-pin-cpu-params', action='store_false', dest='pin_cpu_params',
help='Disable pinning of CPU memory for parameters.')
group.add_argument('--dataloader-type', type=str, default=None, group.add_argument('--dataloader-type', type=str, default=None,
choices=['single', 'cyclic', 'external'], choices=['single', 'cyclic', 'external'],
help='Single pass vs multiple pass data loader') help='Single pass vs multiple pass data loader')
...@@ -1425,6 +1575,10 @@ def _add_training_args(parser): ...@@ -1425,6 +1575,10 @@ def _add_training_args(parser):
group.add_argument('--disable-tp-comm-split-rs', action='store_false', group.add_argument('--disable-tp-comm-split-rs', action='store_false',
help='Disables the Reduce-Scatter overlap with fprop GEMM.', help='Disables the Reduce-Scatter overlap with fprop GEMM.',
dest='tp_comm_split_rs') dest='tp_comm_split_rs')
group.add_argument('--pipeline-model-parallel-comm-backend', type=str, default=None,
choices=['nccl', 'ucc'],
help='Select a communicator backend for pipeline parallel communication. '
'If None, the default backend will be used.')
return parser return parser
...@@ -1549,8 +1703,7 @@ def _add_checkpointing_args(parser): ...@@ -1549,8 +1703,7 @@ def _add_checkpointing_args(parser):
choices=['global', 'local', 'in_memory', None], choices=['global', 'local', 'in_memory', None],
help='Type of non-persistent model checkpoints. ' help='Type of non-persistent model checkpoints. '
'"global" - Saved as a standard checkpoint (e.g., on Lustre) with old checkpoints being removed. ' '"global" - Saved as a standard checkpoint (e.g., on Lustre) with old checkpoints being removed. '
'"local" - [TBD] Each rank saves a portion of the checkpoint locally (e.g., on SSD/ramdisk). ' '"local" - Each rank saves a portion of the checkpoint locally (e.g., on SSD/ramdisk). '
'"in_memory" - [TBD] A special kind of local checkpoint that avoids serialization. '
'None - No non-persistent checkpointing (default option).') 'None - No non-persistent checkpointing (default option).')
group.add_argument('--non-persistent-global-ckpt-dir', type=str, default=None, group.add_argument('--non-persistent-global-ckpt-dir', type=str, default=None,
help='Directory containing global non-persistent model checkpoints.') help='Directory containing global non-persistent model checkpoints.')
...@@ -1586,6 +1739,9 @@ def _add_checkpointing_args(parser): ...@@ -1586,6 +1739,9 @@ def _add_checkpointing_args(parser):
group.add_argument('--use-dist-ckpt', action='store_true', group.add_argument('--use-dist-ckpt', action='store_true',
dest='use_dist_ckpt_deprecated', dest='use_dist_ckpt_deprecated',
help='Deprecated: see --ckpt-format.') help='Deprecated: see --ckpt-format.')
group.add_argument('--use-persistent-ckpt-worker', action='store_true',
help='Enables a persitent checkpoint worker for async save')
group.add_argument('--auto-detect-ckpt-format', action='store_true', group.add_argument('--auto-detect-ckpt-format', action='store_true',
help='Determine if the checkpoint format is in legacy or distributed format.' help='Determine if the checkpoint format is in legacy or distributed format.'
' If False, expects distributed checkpoint iff args.ckpt_format != "torch".' ' If False, expects distributed checkpoint iff args.ckpt_format != "torch".'
...@@ -1641,6 +1797,8 @@ def _add_mixed_precision_args(parser): ...@@ -1641,6 +1797,8 @@ def _add_mixed_precision_args(parser):
help='Run model in fp16 mode.') help='Run model in fp16 mode.')
group.add_argument('--bf16', action='store_true', group.add_argument('--bf16', action='store_true',
help='Run model in bfloat16 mode.') help='Run model in bfloat16 mode.')
group.add_argument('--grad-reduce-in-bf16', action='store_true',
help='Reduce gradients in bfloat16.')
group.add_argument('--loss-scale', type=float, default=None, group.add_argument('--loss-scale', type=float, default=None,
help='Static loss scaling, positive power of 2 ' help='Static loss scaling, positive power of 2 '
'values can improve fp16 convergence. If None, dynamic' 'values can improve fp16 convergence. If None, dynamic'
...@@ -1699,6 +1857,8 @@ def _add_distributed_args(parser): ...@@ -1699,6 +1857,8 @@ def _add_distributed_args(parser):
'--tensor-model-parallel-size instead.') '--tensor-model-parallel-size instead.')
group.add_argument('--num-layers-per-virtual-pipeline-stage', type=int, default=None, group.add_argument('--num-layers-per-virtual-pipeline-stage', type=int, default=None,
help='Number of layers per virtual pipeline stage') help='Number of layers per virtual pipeline stage')
group.add_argument('--num-virtual-stages-per-pipeline-rank', type=int, default=None,
help='Number of virtual pipeline stages per pipeline parallelism rank')
group.add_argument('--microbatch-group-size-per-virtual-pipeline-stage', type=int, default=None, group.add_argument('--microbatch-group-size-per-virtual-pipeline-stage', type=int, default=None,
help='Number of contiguous microbatches per virtual pipeline stage', help='Number of contiguous microbatches per virtual pipeline stage',
dest='microbatch_group_size_per_vp_stage') dest='microbatch_group_size_per_vp_stage')
...@@ -1726,8 +1886,15 @@ def _add_distributed_args(parser): ...@@ -1726,8 +1886,15 @@ def _add_distributed_args(parser):
help='If not set, all PP stages will launch gradient reduces simultaneously. ' help='If not set, all PP stages will launch gradient reduces simultaneously. '
'Otherwise, each PP stage will independently launch as needed.', 'Otherwise, each PP stage will independently launch as needed.',
dest='align_grad_reduce') dest='align_grad_reduce')
group.add_argument('--ddp-num-buckets', type=int, default=None,
help='Number of buckets for data-parallel communication')
group.add_argument('--ddp-bucket-size', type=int, default=None, group.add_argument('--ddp-bucket-size', type=int, default=None,
help='Bucket size for data-parallel communication') help='Bucket size for data-parallel communication')
group.add_argument('--ddp-pad-buckets-for-high-nccl-busbw', action='store_true',
default=False, help='If set, make sure the bucket size is divisible by a large power '
'of 2 (2^16) to ensure NCCL collectives have high bus bandwidth at large DP counts, '
'since NCCL message size (which for ring algorithms is bucket_size / dp_size) '
'apparently needs to be divisible by a power of 2 for high busbw.')
group.add_argument('--ddp-average-in-collective', action='store_true', group.add_argument('--ddp-average-in-collective', action='store_true',
default=False, help='If set, average directly in data-parallel communication collective.') default=False, help='If set, average directly in data-parallel communication collective.')
group.add_argument('--overlap-param-gather', action='store_true', group.add_argument('--overlap-param-gather', action='store_true',
...@@ -1745,21 +1912,33 @@ def _add_distributed_args(parser): ...@@ -1745,21 +1912,33 @@ def _add_distributed_args(parser):
default=False, help='If set, use custom-built ring exchange ' default=False, help='If set, use custom-built ring exchange '
'for p2p communications. Note that this option will require ' 'for p2p communications. Note that this option will require '
'a custom built image that support ring-exchange p2p.') 'a custom built image that support ring-exchange p2p.')
group.add_argument('--local-rank', type=int, default=int(os.getenv('LOCAL_RANK', '0')),
help='local rank passed from distributed launcher.')
group.add_argument('--lazy-mpu-init', type=bool, required=False, group.add_argument('--lazy-mpu-init', type=bool, required=False,
help='If set to True, initialize_megatron() ' help='If set to True, initialize_megatron() '
'skips DDP initialization and returns function to ' 'skips DDP initialization and returns function to '
'complete it instead.Also turns on ' 'complete it instead. Also turns on '
'--use-cpu-initialization flag. This is for ' '--use-cpu-initialization flag. This is for '
'external DDP manager.' ) 'external DDP manager.' )
group.add_argument('--standalone-embedding-stage', action='store_true', group.add_argument('--account-for-embedding-in-pipeline-split', action='store_true',
default=False, help='If set, *input* embedding layer ' default=False, help='If set, *input* embedding layer will be treated as a standard transformer'
'is placed on its own pipeline stage, without any ' 'layer in the context of partition and placement for pipeline parallelism.')
'transformer layers. (For T5, this flag currently only ' group.add_argument('--account-for-loss-in-pipeline-split', action='store_true',
'affects the encoder embedding.)') default=False, help='If set, loss layer will be treated as a standard transformer'
'layer in the context of partition and placement for pipeline parallelism.')
group.add_argument('--use-distributed-optimizer', action='store_true', group.add_argument('--use-distributed-optimizer', action='store_true',
help='Use distributed optimizer.') help='Use distributed optimizer.')
group.add_argument('--use-custom-fsdp', action='store_true',
help='Use the Megatron FSDP code path in DDP.')
group.add_argument('--init-model-with-meta-device', action='store_true')
group.add_argument('--data-parallel-sharding-strategy', type=str, default='no_shard',
choices=['no_shard', 'optim', 'optim_grads', 'optim_grads_params'],
help='Sharding strategy of data parallelism.')
group.add_argument('--no-gradient-reduce-div-fusion', action='store_false', dest='gradient_reduce_div_fusion',
help='If not set, fuse the division in gradient reduce.')
group.add_argument('--suggested-communication-unit-size', type=int, default=400_000_000,
help='When batch communication is needed across multiple buckets, '
'this environment variable guides the size of communication unit size.')
group.add_argument('--keep-fp8-transpose-cache-when-using-custom-fsdp', action='store_true',
help='If set, keep the fp8 transpose cache when using custom FSDP.')
group.add_argument('--num-distributed-optimizer-instances', type=int, default=1, group.add_argument('--num-distributed-optimizer-instances', type=int, default=1,
help='Number of Distributed Optimizer copies across Data Parallel domain.') help='Number of Distributed Optimizer copies across Data Parallel domain.')
group.add_argument('--use-torch-fsdp2', action='store_true', group.add_argument('--use-torch-fsdp2', action='store_true',
...@@ -1786,10 +1965,21 @@ def _add_distributed_args(parser): ...@@ -1786,10 +1965,21 @@ def _add_distributed_args(parser):
'setting `min_ctas`, `max_ctas`, and `cga_cluster_size`.') 'setting `min_ctas`, `max_ctas`, and `cga_cluster_size`.')
group.add_argument('--use-tp-pp-dp-mapping', action='store_true', default=False, group.add_argument('--use-tp-pp-dp-mapping', action='store_true', default=False,
help='If set, distributed ranks initialize order is changed ' help='If set, distributed ranks initialize order is changed '
'from tp-dp-pp to tp-pp-dp. Make sure EP and CP aren\'t used ' 'from tp-cp-ep-dp-pp to tp-cp-ep-pp-dp.')
'with this option enabled') group.add_argument('--replication', action='store_true', default=False,
help="If set, replication of local checkpoints is enabled. "
"Needs to be enabled on all ranks.")
group.add_argument('--replication-jump', default=None, type=int,
help="Specifies `J`, the spacing between ranks storing replicas of a given rank's data. "
"Replicas for rank `n` may be on ranks `n+J`, `n+2J`, ..., or `n-J`, `n-2J`, etc. "
"This flag has an effect only if --replication is used. "
"and must be consistent across all ranks.")
group.add_argument('--replication-factor', default=2, type=int,
help="Number of machines storing the replica of a given rank's data.")
group.add_argument('--rank', default=-1, type=int, group.add_argument('--rank', default=-1, type=int,
help='node rank for distributed training') help='node rank for distributed training')
group.add_argument('--local-rank', type=int, default=int(os.getenv('LOCAL_RANK', '0')),
help='local rank passed from distributed launcher.')
group.add_argument('--world-size', type=int, default=8, group.add_argument('--world-size', type=int, default=8,
help='number of nodes for distributed training') help='number of nodes for distributed training')
group.add_argument('--dist-url', group.add_argument('--dist-url',
...@@ -1834,8 +2024,6 @@ def _add_tokenizer_args(parser): ...@@ -1834,8 +2024,6 @@ def _add_tokenizer_args(parser):
'GPTSentencePieceTokenizer', 'GPTSentencePieceTokenizer',
'HuggingFaceTokenizer', 'HuggingFaceTokenizer',
'Llama2Tokenizer', 'Llama2Tokenizer',
'Llama3Tokenizer',
'QwenTokenizer',
'TikTokenizer', 'TikTokenizer',
'MultimodalTokenizer', 'MultimodalTokenizer',
'NullTokenizer'], 'NullTokenizer'],
...@@ -1862,11 +2050,6 @@ def _add_data_args(parser): ...@@ -1862,11 +2050,6 @@ def _add_data_args(parser):
'(3) a list of prefixes e.g. prefix1 prefix2. ' '(3) a list of prefixes e.g. prefix1 prefix2. '
'For (3), weights are inferred from the lengths of the contributing datasets. ' 'For (3), weights are inferred from the lengths of the contributing datasets. '
'This argument is exclusive to the other independent --*-data-path arguments.') 'This argument is exclusive to the other independent --*-data-path arguments.')
group.add_argument('--renormalize-blend-weights', action='store_true',
help='Renormalize the blend weights to account for the mid-level dataset '
'oversampling done to ensure fulfillment of the requested number of '
'samples. Use this option if prompted. Defaults to False for backward '
'comparability in the data sample order.')
group.add_argument('--split', type=str, default=None, group.add_argument('--split', type=str, default=None,
help='Comma-separated list of proportions for training,' help='Comma-separated list of proportions for training,'
' validation, and test split. For example the split ' ' validation, and test split. For example the split '
...@@ -2078,6 +2261,7 @@ def _add_vision_args(parser): ...@@ -2078,6 +2261,7 @@ def _add_vision_args(parser):
def _add_moe_args(parser): def _add_moe_args(parser):
group = parser.add_argument_group(title="moe") group = parser.add_argument_group(title="moe")
# General arguments
group.add_argument('--expert-model-parallel-size', type=int, default=1, group.add_argument('--expert-model-parallel-size', type=int, default=1,
help='Degree of expert model parallelism.') help='Degree of expert model parallelism.')
group.add_argument('--expert-tensor-parallel-size', type=int, default=None, group.add_argument('--expert-tensor-parallel-size', type=int, default=None,
...@@ -2103,16 +2287,39 @@ def _add_moe_args(parser): ...@@ -2103,16 +2287,39 @@ def _add_moe_args(parser):
help='Enable overlapping between shared expert computations and dispatcher communications. ' help='Enable overlapping between shared expert computations and dispatcher communications. '
'Without this, the shared epxerts execute after the routed experts. ' 'Without this, the shared epxerts execute after the routed experts. '
'Only effective when moe-shared-expert-intermediate-size is set.') 'Only effective when moe-shared-expert-intermediate-size is set.')
group.add_argument('--moe-grouped-gemm', action='store_true',
help='When there are multiple experts per rank, launch multiple local GEMM kernels in multiple streams to improve the utilization and performance with GroupedLinear in TransformerEngine.')
# Router arguments
group.add_argument('--moe-router-load-balancing-type', type=str, group.add_argument('--moe-router-load-balancing-type', type=str,
choices=['aux_loss', 'sinkhorn', 'none'], choices=['aux_loss', 'seq_aux_loss', 'sinkhorn', 'none'],
default='aux_loss', default='aux_loss',
help='Determines the load balancing strategy for the router. "aux_loss" corresponds to the load balancing loss used in GShard and SwitchTransformer, "sinkhorn" corresponds to the balancing algorithm used in S-BASE, and "none" implies no load balancing. The default is "aux_loss".') help='Determines the load balancing strategy for the router. "aux_loss" corresponds to the load balancing loss used in GShard and SwitchTransformer; "seq_aux_loss" corresponds to the load balancing loss used in DeepSeekV2, which computes the loss for each individual sample; "sinkhorn" corresponds to the balancing algorithm used in S-BASE, and "none" implies no load balancing. The default is "aux_loss".')
group.add_argument('--moe-router-score-function', type=str,
choices=['softmax', 'sigmoid'],
default='softmax',
help='Score function for MoE TopK routing. Can be "softmax" or "sigmoid".')
group.add_argument('--moe-router-topk', type=int, default=2, group.add_argument('--moe-router-topk', type=int, default=2,
help='Number of experts to route to for each token. The default is 2.') help='Number of experts to route to for each token. The default is 2.')
group.add_argument('--moe-router-pre-softmax', action='store_true', group.add_argument('--moe-router-pre-softmax', action='store_true',
help='Enable pre-softmax routing for MoE, which means softmax is before the top-k selection. By default, softmax is done after top-k.') help='Enable pre-softmax routing for MoE, which means softmax is before the top-k selection. By default, softmax is done after top-k.')
group.add_argument('--moe-grouped-gemm', action='store_true', group.add_argument('--moe-router-num-groups', type=int, default=None,
help='When there are multiple experts per rank, launch multiple local GEMM kernels in multiple streams to improve the utilization and performance with GroupedLinear in TransformerEngine.') help='Number of groups to divide experts into for group-limited routing. When using group-limited routing: 1) Experts are divided into equal-sized groups, 2) For each token, a subset of groups are selected based on routing scores (sum of top-2 expert scores within each group), 3) From these selected groups, moe_router_topk experts are chosen.'
'Two common use cases: 1) Device-limited routing: Set equal to expert parallel size (EP) to limit each token to experts on a subset of devices (See DeepSeek-V2: https://arxiv.org/pdf/2405.04434) 2) Node-limited routing: Set equal to number of nodes in EP group to limit each token to experts on a subset of nodes (See DeepSeek-V3: https://arxiv.org/pdf/2412.19437)')
group.add_argument('--moe-router-group-topk', type=int, default=None,
help='Number of selected groups for group-limited routing.')
group.add_argument('--moe-router-topk-scaling-factor', type=float, default=None,
help='Scaling factor for routing score in top-k selection, only works when --moe-router-pre-softmax enabled. Defaults to None, which means no scaling.')
group.add_argument('--moe-router-enable-expert-bias', action='store_true',
help='TopK routing with dynamic expert bias in the aux-loss-free load balancing strategy. '
'The routing decision is based on the sum of the routing scores and the expert bias. '
'See https://arxiv.org/abs/2408.15664 for details.')
group.add_argument('--moe-router-bias-update-rate', type=float, default=1e-3,
help='Expert bias update rate in the aux-loss-free load balancing strategy. '
'The expert bias is updated based on the number of assigned tokens to each expert in a global batch, '
'where the bias is increased for the experts with less assigned tokens and decreased for the experts with more assigned tokens. '
'The default value 1e-3 is same as that used in DeepSeekV3.')
group.add_argument('--moe-use-legacy-grouped-gemm', action='store_true',
help='Use legacy GroupedMLP rather than TEGroupedMLP. Note: The legacy one will be deprecated soon.')
group.add_argument('--moe-aux-loss-coeff', type=float, default=0.0, group.add_argument('--moe-aux-loss-coeff', type=float, default=0.0,
help='Scaling coefficient for the aux loss: a starting value of 1e-2 is recommended.') help='Scaling coefficient for the aux loss: a starting value of 1e-2 is recommended.')
group.add_argument('--moe-z-loss-coeff', type=float, default=None, group.add_argument('--moe-z-loss-coeff', type=float, default=None,
...@@ -2120,9 +2327,11 @@ def _add_moe_args(parser): ...@@ -2120,9 +2327,11 @@ def _add_moe_args(parser):
group.add_argument('--moe-input-jitter-eps', type=float, default=None, group.add_argument('--moe-input-jitter-eps', type=float, default=None,
help='Add noise to the input tensor by applying jitter with a specified epsilon value.') help='Add noise to the input tensor by applying jitter with a specified epsilon value.')
group.add_argument('--moe-token-dispatcher-type', type=str, group.add_argument('--moe-token-dispatcher-type', type=str,
choices=['allgather', 'alltoall', 'alltoall_seq'], choices=['allgather', 'alltoall', 'flex', 'alltoall_seq'],
default='allgather', default='allgather',
help="The type of token dispatcher to use. The default is 'allgather'. Options are 'allgather', 'alltoall' and 'alltoall_seq'. We recommend using 'alltoall' when applying expert parallelism. For more information, please refer to the documentation in core/moe/README.") help="The type of token dispatcher to use. The default is 'allgather'. Options are 'allgather', 'alltoall' and 'alltoall_seq'. We recommend using 'alltoall' when applying expert parallelism. For more information, please refer to the documentation in core/moe/README.")
group.add_argument('--moe-enable-deepep', action='store_true',
help='[Experimental] Enable DeepSeek/DeepEP for efficient token dispatching and combine in MoE models. Only works with flex token dispatcher by setting --moe-token-dispatcher-type=flex.')
group.add_argument('--moe-per-layer-logging', action='store_true', group.add_argument('--moe-per-layer-logging', action='store_true',
help='Enable per-layer logging for MoE, currently supports auxiliary loss and z loss.') help='Enable per-layer logging for MoE, currently supports auxiliary loss and z loss.')
# Token dropping arguments # Token dropping arguments
...@@ -2139,6 +2348,8 @@ def _add_moe_args(parser): ...@@ -2139,6 +2348,8 @@ def _add_moe_args(parser):
group.add_argument('--moe-use-upcycling', action='store_true', group.add_argument('--moe-use-upcycling', action='store_true',
help='Load a checkpoint of a dense model, convert it into an MoE model, and save the converted model to the path specified by --save. ' help='Load a checkpoint of a dense model, convert it into an MoE model, and save the converted model to the path specified by --save. '
'Upcycling is implemented on the top of distributed checkpointing, so it supports parallel modes different from the dense model.') 'Upcycling is implemented on the top of distributed checkpointing, so it supports parallel modes different from the dense model.')
group.add_argument('--moe-permute-fusion', action='store_true',
help='Fuse token rearrangement ops during token dispatching.')
return parser return parser
...@@ -2156,6 +2367,10 @@ def _add_mla_args(parser): ...@@ -2156,6 +2367,10 @@ def _add_mla_args(parser):
help="Dimension of the head in the V projection.") help="Dimension of the head in the V projection.")
group.add_argument('--rotary-scaling-factor', type=float, default=1.0, group.add_argument('--rotary-scaling-factor', type=float, default=1.0,
help="Rotary scaling factor for the rotary embeddings.") help="Rotary scaling factor for the rotary embeddings.")
group.add_argument('--mscale', type=float, default=1.0,
help="Mscale for YaRN RoPE in multi-latent attention.")
group.add_argument('--mscale-all-dim', type=float, default=1.0,
help="Mscale all dimensions for YaRN RoPE in multi-latent attention.")
return parser return parser
...@@ -2185,4 +2400,18 @@ def _add_experimental_args(parser): ...@@ -2185,4 +2400,18 @@ def _add_experimental_args(parser):
'the overidden pattern') 'the overidden pattern')
group.add_argument('--yaml-cfg', type=str, default=None, group.add_argument('--yaml-cfg', type=str, default=None,
help = 'Config file to add additional arguments') help = 'Config file to add additional arguments')
# Args of precision-aware optimizer
group.add_argument('--use-precision-aware-optimizer', action='store_true',
help='Use the precision-aware optimizer in TransformerEngine, which allows '
'setting the main params and optimizer states to lower precision, such as '
'fp16 and fp8.')
group.add_argument('--main-grads-dtype', default='fp32', choices=['fp32', 'bf16'],
help='Dtype of main grads when enabling precision-aware-optimizer')
group.add_argument('--main-params-dtype', default='fp32', choices=['fp32', 'fp16'],
help='Dtype of main params when enabling precision-aware-optimizer')
group.add_argument('--exp-avg-dtype', default='fp32', choices=['fp32', 'fp16', 'fp8'],
help='Dtype of exp_avg when enabling precision-aware-optimizer')
group.add_argument('--exp-avg-sq-dtype', default='fp32', choices=['fp32', 'fp16', 'fp8'],
help='Dtype of exp_avg_sq when enabling precision-aware-optimizer')
return parser return parser
...@@ -13,11 +13,19 @@ from megatron.training.utils import print_rank_0 ...@@ -13,11 +13,19 @@ from megatron.training.utils import print_rank_0
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Singleton manager of async calls # Singleton manager of async calls
# The default is `TemporalAsyncCaller`
_async_calls_queue = AsyncCallsQueue() _async_calls_queue = AsyncCallsQueue()
def init_persistent_async_worker():
global _async_calls_queue
# Recreate the async_calls_queue for persistent worker
# This duplicate step is for backward compatiblity
_async_calls_queue = AsyncCallsQueue(persistent=True)
def schedule_async_save(async_request: AsyncRequest): def schedule_async_save(async_request: AsyncRequest):
""" Schedule the async save request. """Schedule the async save request.
Args: Args:
async_request (AsyncRequest): the async save request. async_request (AsyncRequest): the async save request.
...@@ -25,19 +33,33 @@ def schedule_async_save(async_request: AsyncRequest): ...@@ -25,19 +33,33 @@ def schedule_async_save(async_request: AsyncRequest):
_async_calls_queue.schedule_async_request(async_request) _async_calls_queue.schedule_async_request(async_request)
def maybe_finalize_async_save(blocking: bool = False): def maybe_finalize_async_save(blocking: bool = False, terminate=False):
""" Finalizes active async save calls. """Finalizes active async save calls.
Args: Args:
blocking (bool, optional): if True, will wait until all active requests blocking (bool, optional): if True, will wait until all active requests
are done. Otherwise, finalizes only the async request that already are done. Otherwise, finalizes only the async request that already
finished. Defaults to False. finished. Defaults to False.
terminate (bool, optional): if True, the asynchronous queue will
be closed as the last action of this function.
""" """
args = get_args() args = get_args()
if not args.async_save: if not args.async_save:
return return
if blocking and _async_calls_queue.get_num_unfinalized_calls() > 0: if blocking and not is_empty_async_queue():
print_rank_0('Unfinalized async checkpoint saves. Finalizing them synchronously now.') print_rank_0('Unfinalized async checkpoint saves. Finalizing them synchronously now.')
_async_calls_queue.maybe_finalize_async_calls(blocking) _async_calls_queue.maybe_finalize_async_calls(blocking, no_dist=False)
if terminate:
_async_calls_queue.close()
def is_empty_async_queue() -> bool:
"""Check if async calls queue is empty. This result is consistent across ranks.
Returns:
bool: True if there is any ongoing async call.
"""
return _async_calls_queue.get_num_unfinalized_calls() == 0
...@@ -20,21 +20,20 @@ import torch ...@@ -20,21 +20,20 @@ import torch
from megatron.core import mpu, tensor_parallel, dist_checkpointing from megatron.core import mpu, tensor_parallel, dist_checkpointing
from megatron.core.dist_checkpointing.mapping import ShardedObject from megatron.core.dist_checkpointing.mapping import ShardedObject
from megatron.core.dist_checkpointing.serialization import get_default_load_sharded_strategy from megatron.core.dist_checkpointing.serialization import get_default_load_sharded_strategy
from megatron.core.dist_checkpointing.state_dict_transformation import (
prepare_state_dict_for_save,
recreate_state_dict_after_load,
)
from megatron.core.dist_checkpointing.strategies.fully_parallel import \ from megatron.core.dist_checkpointing.strategies.fully_parallel import \
FullyParallelSaveStrategyWrapper, FullyParallelLoadStrategyWrapper FullyParallelSaveStrategyWrapper, FullyParallelLoadStrategyWrapper
from megatron.core.num_microbatches_calculator import update_num_microbatches from megatron.core.num_microbatches_calculator import update_num_microbatches
from megatron.core.utils import is_float8tensor from megatron.core.fp8_utils import is_float8tensor
from megatron.core.rerun_state_machine import get_rerun_state_machine from megatron.core.rerun_state_machine import get_rerun_state_machine
from .async_utils import schedule_async_save from .async_utils import schedule_async_save, is_empty_async_queue
from .global_vars import get_args, get_one_logger from .global_vars import get_args, get_one_logger
from .utils import unwrap_model, print_rank_0, append_to_progress_log, is_last_rank from .utils import unwrap_model, print_rank_0, append_to_progress_log, is_last_rank
from ..core.dist_checkpointing.serialization import \ from ..core.dist_checkpointing.serialization import \
get_default_save_sharded_strategy get_default_save_sharded_strategy
from .one_logger_utils import on_save_checkpoint_start, on_save_checkpoint_success from .one_logger_utils import on_save_checkpoint_start, on_save_checkpoint_success
from . import wandb_utils
from . import ft_integration
# [ModelOpt]: Import # [ModelOpt]: Import
try: try:
...@@ -305,7 +304,7 @@ class CheckpointType(Enum): ...@@ -305,7 +304,7 @@ class CheckpointType(Enum):
def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floating_point_operations_so_far, def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floating_point_operations_so_far,
checkpointing_context=None, pipeline_rank=None, expert_rank=None, tensor_rank=None, pipeline_parallel=None, expert_parallel=None, non_persistent_ckpt=False, checkpointing_context=None, pipeline_rank=None, expert_rank=None, tensor_rank=None, pipeline_parallel=None, expert_parallel=None, non_persistent_ckpt=False,
train_data_iterator=None, ft_client=None, preprocess_common_state_dict_fn = None): train_data_iterator=None, preprocess_common_state_dict_fn = None):
"""Save a model, optimizer and optionally dataloader checkpoint. """Save a model, optimizer and optionally dataloader checkpoint.
Checkpointing context is used to persist some checkpointing state Checkpointing context is used to persist some checkpointing state
...@@ -315,8 +314,7 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floati ...@@ -315,8 +314,7 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floati
the checkpoint will be saved with special functionality for removing old checkpoints. the checkpoint will be saved with special functionality for removing old checkpoints.
There are several types of non-persistent checkpoints: There are several types of non-persistent checkpoints:
"global" - Saved as a standard checkpoint (e.g., on Lustre) with old checkpoints being removed. "global" - Saved as a standard checkpoint (e.g., on Lustre) with old checkpoints being removed.
"local" - [TBD] Each rank saves a portion of the checkpoint locally (e.g., on SSD/ramdisk). "local" - Each rank saves a portion of the checkpoint locally (e.g., on SSD/ramdisk).
"in_memory" - [TBD] A special kind of local checkpoint that avoids serialization.
Dataloader checkpoint is only saved if the dataloader supports it. Currently this applies only Dataloader checkpoint is only saved if the dataloader supports it. Currently this applies only
to the Megatron Energon dataloader (multimodal) and not the built-in Megatron dataloader (text-only). to the Megatron Energon dataloader (multimodal) and not the built-in Megatron dataloader (text-only).
...@@ -324,9 +322,15 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floati ...@@ -324,9 +322,15 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floati
start_ckpt = time() start_ckpt = time()
args = get_args() args = get_args()
if args.async_save and not is_empty_async_queue():
print_rank_0('WARNING: Starting a checkpoint save before previous has finished. Consider increasing the checkpoint interval.')
# Prepare E2E metrics at start of save checkpoint # Prepare E2E metrics at start of save checkpoint
productive_metrics = on_save_checkpoint_start(args.async_save) productive_metrics = on_save_checkpoint_start(args.async_save)
# Monitor for the checkpointing timeout (no-op if FT is not enabled)
ft_integration.on_checkpointing_start()
# Only rank zero of the data parallel writes to the disk. # Only rank zero of the data parallel writes to the disk.
model = unwrap_model(model) model = unwrap_model(model)
...@@ -347,7 +351,6 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floati ...@@ -347,7 +351,6 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floati
save_dir, leave_ckpt_num=1, do_async=args.async_save save_dir, leave_ckpt_num=1, do_async=args.async_save
) )
elif args.non_persistent_ckpt_type == 'local': elif args.non_persistent_ckpt_type == 'local':
raise RuntimeError('LocalCheckpointManagers are not yet integrated')
ckpt_type = CheckpointType.LOCAL ckpt_type = CheckpointType.LOCAL
save_dir = checkpointing_context['local_checkpoint_manager'].local_ckpt_dir save_dir = checkpointing_context['local_checkpoint_manager'].local_ckpt_dir
else: else:
...@@ -361,13 +364,19 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floati ...@@ -361,13 +364,19 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floati
# Collect rng state across data parallel ranks. # Collect rng state across data parallel ranks.
rng_state = get_rng_state(ckpt_type != CheckpointType.LEGACY) rng_state = get_rng_state(ckpt_type != CheckpointType.LEGACY)
# Collect rerun state across all ranks
rerun_state_machine = get_rerun_state_machine()
rerun_state = rerun_state_machine.state_dict(
data_iterator=train_data_iterator, use_dist_ckpt=ckpt_type != CheckpointType.LEGACY
)
# Checkpoint name. # Checkpoint name.
return_base_dir = (ckpt_type != CheckpointType.LEGACY) return_base_dir = (ckpt_type != CheckpointType.LEGACY)
checkpoint_name = get_checkpoint_name(save_dir, iteration, release=False, pipeline_parallel=pipeline_parallel, checkpoint_name = get_checkpoint_name(save_dir, iteration, release=False, pipeline_parallel=pipeline_parallel,
tensor_rank=tensor_rank, pipeline_rank=pipeline_rank, expert_parallel=expert_parallel, expert_rank=expert_rank, return_base_dir=return_base_dir) tensor_rank=tensor_rank, pipeline_rank=pipeline_rank, expert_parallel=expert_parallel, expert_rank=expert_rank, return_base_dir=return_base_dir)
# Save dataloader state if the dataloader supports it (currently only Megatron Energon). # Save dataloader state if the dataloader supports it (currently only Megatron Energon).
save_dataloader_state(train_data_iterator, iteration, getattr(args, "dataloader_save", None)) maybe_save_dataloader_state(train_data_iterator, iteration, getattr(args, "dataloader_save", None))
# Save distributed optimizer's custom parameter state. # Save distributed optimizer's custom parameter state.
if ( if (
...@@ -379,7 +388,8 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floati ...@@ -379,7 +388,8 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floati
optim_checkpoint_name = \ optim_checkpoint_name = \
get_distributed_optimizer_checkpoint_name(checkpoint_name) get_distributed_optimizer_checkpoint_name(checkpoint_name)
ensure_directory_exists(optim_checkpoint_name) ensure_directory_exists(optim_checkpoint_name)
optimizer.save_parameter_state(optim_checkpoint_name) if not optimizer.is_stub_optimizer:
optimizer.save_parameter_state(optim_checkpoint_name)
async_save_request = None async_save_request = None
if args.async_save: if args.async_save:
...@@ -409,11 +419,9 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floati ...@@ -409,11 +419,9 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floati
use_dist_ckpt=ckpt_type != CheckpointType.LEGACY, use_dist_ckpt=ckpt_type != CheckpointType.LEGACY,
iteration=iteration, iteration=iteration,
optim_sd_kwargs=optim_sd_kwargs, optim_sd_kwargs=optim_sd_kwargs,
train_data_iterator=train_data_iterator, rerun_state=rerun_state,
) )
if args.enable_ft_package and ft_client is not None:
state_dict["ft_state"] = ft_client.state_dict()
state_dict['num_floating_point_operations_so_far'] = num_floating_point_operations_so_far state_dict['num_floating_point_operations_so_far'] = num_floating_point_operations_so_far
if ckpt_type == CheckpointType.GLOBAL: if ckpt_type == CheckpointType.GLOBAL:
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
...@@ -428,6 +436,14 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floati ...@@ -428,6 +436,14 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floati
save_strategy = get_default_save_sharded_strategy(args.ckpt_format) save_strategy = get_default_save_sharded_strategy(args.ckpt_format)
if args.ckpt_assume_constant_structure and args.ckpt_format == 'torch_dist': if args.ckpt_assume_constant_structure and args.ckpt_format == 'torch_dist':
save_strategy.use_cached_ckpt_structure = args.ckpt_assume_constant_structure save_strategy.use_cached_ckpt_structure = args.ckpt_assume_constant_structure
if checkpointing_context is not None and 'load_strategy' in checkpointing_context:
cached_global_metadata = getattr(checkpointing_context['load_strategy'], 'cached_global_metadata', None)
if cached_global_metadata is not None:
logger.debug("Plugging in the read metadata from the load strategy...")
save_strategy.cached_global_metadata = cached_global_metadata
else:
logger.debug("Failed to plug in the read metadata from the load strategy...")
if args.ckpt_fully_parallel_save: if args.ckpt_fully_parallel_save:
save_strategy = FullyParallelSaveStrategyWrapper(save_strategy, mpu.get_data_parallel_group(with_context_parallel=True), save_strategy = FullyParallelSaveStrategyWrapper(save_strategy, mpu.get_data_parallel_group(with_context_parallel=True),
args.ckpt_assume_constant_structure) args.ckpt_assume_constant_structure)
...@@ -446,26 +462,44 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floati ...@@ -446,26 +462,44 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floati
else: else:
# [ModelOpt]: Inject modelopt_state into state_dict # [ModelOpt]: Inject modelopt_state into state_dict
if has_nvidia_modelopt: if has_nvidia_modelopt:
save_modelopt_state(model, state_dict) if ckpt_type == CheckpointType.LOCAL:
print_rank_0('WARNING: Local checkpointing does not support nvidia_modelopt.')
else:
save_modelopt_state(model, state_dict)
end_ckpt = time()
logger.debug(f"rank: {rank}, takes {end_ckpt - start_ckpt} to prepare state dict for ckpt ")
if ckpt_type == CheckpointType.LOCAL: if ckpt_type == CheckpointType.LOCAL:
state_dict_for_save = prepare_state_dict_for_save( try:
state_dict, algo=args.non_persistent_local_ckpt_algo from megatron.core.dist_checkpointing.tensor_aware_state_dict import MCoreTensorAwareStateDict
except ModuleNotFoundError:
raise RuntimeError("The 'nvidia_resiliency_ext' module is required for local "
"checkpointing but was not found. Please ensure it is installed.")
algo = args.non_persistent_local_ckpt_algo
cached_metadata = None
if args.ckpt_assume_constant_structure and 'local_checkpoint_cache' in checkpointing_context:
cached_metadata = checkpointing_context['local_checkpoint_cache']
state_dict_for_save, cacheable_metadata = MCoreTensorAwareStateDict.from_state_dict(
state_dict, algo=algo, cached_metadata=cached_metadata,
parallelization_group=mpu.get_data_parallel_group(with_context_parallel=True)
) )
async_save_request = checkpointing_context['local_checkpoint_manager'].save( async_save_request = checkpointing_context['local_checkpoint_manager'].save(
state_dict_for_save, iteration, is_async=bool(args.async_save) state_dict_for_save, iteration, is_async=bool(args.async_save)
) )
checkpointing_context['local_checkpoint_cache'] = cacheable_metadata
else: else:
assert ckpt_type == CheckpointType.LEGACY assert ckpt_type == CheckpointType.LEGACY
# Save. # Save.
ensure_directory_exists(checkpoint_name) ensure_directory_exists(checkpoint_name)
torch.save(state_dict, checkpoint_name) torch.save(state_dict, checkpoint_name)
start_misc = time() start_misc = time()
if not args.async_save: if ckpt_type != CheckpointType.LOCAL:
assert async_save_request is None if not args.async_save:
# Wait so everyone is done (necessary) assert async_save_request is None
if torch.distributed.is_initialized(): # Wait so everyone is done (necessary)
torch.distributed.barrier() if torch.distributed.is_initialized():
torch.distributed.barrier()
# And update the latest iteration # And update the latest iteration
if not torch.distributed.is_initialized() \ if not torch.distributed.is_initialized() \
...@@ -507,6 +541,17 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floati ...@@ -507,6 +541,17 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floati
else: else:
onelogger_finalize_fn() onelogger_finalize_fn()
# Additional callback for wandb (last rank)
if not torch.distributed.is_initialized() \
or is_last_rank():
def wandb_finalize_fn():
wandb_utils.on_save_checkpoint_success(checkpoint_name, get_checkpoint_tracker_filename(save_dir), save_dir, iteration)
if args.async_save:
assert async_save_request is not None
async_save_request.add_finalize_fn(wandb_finalize_fn)
else:
wandb_finalize_fn()
if args.async_save: if args.async_save:
schedule_async_save(async_save_request) schedule_async_save(async_save_request)
print_rank_0(' scheduled an async checkpoint save at iteration {:7d} to {}' \ print_rank_0(' scheduled an async checkpoint save at iteration {:7d} to {}' \
...@@ -519,6 +564,7 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floati ...@@ -519,6 +564,7 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floati
end_misc = time() end_misc = time()
logger.debug(f"rank: {rank}, takes {end_misc - start_misc} to finalize ckpt save ") logger.debug(f"rank: {rank}, takes {end_misc - start_misc} to finalize ckpt save ")
ft_integration.on_checkpointing_end(is_async_finalization=False)
def cleanup_old_non_persistent_checkpoint(save_dir, leave_ckpt_num=1, do_async=False): def cleanup_old_non_persistent_checkpoint(save_dir, leave_ckpt_num=1, do_async=False):
if torch.distributed.is_initialized() and torch.distributed.get_rank() != 0: if torch.distributed.is_initialized() and torch.distributed.get_rank() != 0:
...@@ -543,7 +589,7 @@ def cleanup_old_non_persistent_checkpoint(save_dir, leave_ckpt_num=1, do_async=F ...@@ -543,7 +589,7 @@ def cleanup_old_non_persistent_checkpoint(save_dir, leave_ckpt_num=1, do_async=F
remove_iter_ckpts(rm_iter_ckpts) remove_iter_ckpts(rm_iter_ckpts)
def save_dataloader_state(train_iterator, iteration, dataloader_save_path): def maybe_save_dataloader_state(train_iterator, iteration, dataloader_save_path):
"""Saves dataloader state if the dataloader supports it. """Saves dataloader state if the dataloader supports it.
Currently, this is only used by Megatron Energon dataloader (multimodal) to store its state at a Currently, this is only used by Megatron Energon dataloader (multimodal) to store its state at a
...@@ -558,13 +604,13 @@ def save_dataloader_state(train_iterator, iteration, dataloader_save_path): ...@@ -558,13 +604,13 @@ def save_dataloader_state(train_iterator, iteration, dataloader_save_path):
iteration (int): Current iteration. iteration (int): Current iteration.
dataloader_save_path (str): Path where the dataloader state is saved. dataloader_save_path (str): Path where the dataloader state is saved.
""" """
# If no dataloader or saving path is provided, then exit early. # If no dataloader or saving path is provided, exit early, otherwise, raise an error.
if train_iterator is None or dataloader_save_path is None: if train_iterator is None or dataloader_save_path is None or dataloader_save_path == "":
return return
# If dataloader doesn't support saving state, exit early. # If dataloader doesn't support saving state, raise an error.
if not hasattr(train_iterator, "save_state"): if not hasattr(train_iterator.iterable, "save_state"):
return raise RuntimeError(f"Could not find a save_state for the train_iterator of type {type(train_iterator)}")
# Save dataloader state for each data parallel rank only once. # Save dataloader state for each data parallel rank only once.
first_rank = mpu.is_pipeline_first_stage(ignore_virtual=True) and mpu.get_tensor_model_parallel_rank() == 0 first_rank = mpu.is_pipeline_first_stage(ignore_virtual=True) and mpu.get_tensor_model_parallel_rank() == 0
...@@ -573,7 +619,7 @@ def save_dataloader_state(train_iterator, iteration, dataloader_save_path): ...@@ -573,7 +619,7 @@ def save_dataloader_state(train_iterator, iteration, dataloader_save_path):
dp_rank = mpu.get_data_parallel_rank() dp_rank = mpu.get_data_parallel_rank()
print(f"saving dataloader checkpoint at iteration {iteration} to {dataloader_save_path}") print(f"saving dataloader checkpoint at iteration {iteration} to {dataloader_save_path}")
train_dataloader_state_dict = train_iterator.save_state() train_dataloader_state_dict = train_iterator.iterable.save_state()
data_state_save_path = get_checkpoint_name( data_state_save_path = get_checkpoint_name(
dataloader_save_path, iteration, dataloader_save_path, iteration,
basename=f'train_dataloader_dprank{dp_rank:03d}.pt' basename=f'train_dataloader_dprank{dp_rank:03d}.pt'
...@@ -593,7 +639,7 @@ def save_dataloader_state(train_iterator, iteration, dataloader_save_path): ...@@ -593,7 +639,7 @@ def save_dataloader_state(train_iterator, iteration, dataloader_save_path):
def generate_state_dict(args, model, optimizer, opt_param_scheduler, def generate_state_dict(args, model, optimizer, opt_param_scheduler,
rng_state, use_dist_ckpt=False, iteration=None, rng_state, use_dist_ckpt=False, iteration=None,
optim_sd_kwargs=None, train_data_iterator=None): optim_sd_kwargs=None, rerun_state=None):
# Arguments, iteration, and model. # Arguments, iteration, and model.
state_dict = {} state_dict = {}
state_dict['args'] = args state_dict['args'] = args
...@@ -614,7 +660,7 @@ def generate_state_dict(args, model, optimizer, opt_param_scheduler, ...@@ -614,7 +660,7 @@ def generate_state_dict(args, model, optimizer, opt_param_scheduler,
model[i].state_dict_for_save_checkpoint()) model[i].state_dict_for_save_checkpoint())
# Optimizer stuff. # Optimizer stuff.
if not args.no_save_optim: if not args.no_save_optim:
if optimizer is not None: if optimizer is not None and not optimizer.is_stub_optimizer:
state_dict['optimizer'] = (optimizer.sharded_state_dict(state_dict, **(optim_sd_kwargs or {})) state_dict['optimizer'] = (optimizer.sharded_state_dict(state_dict, **(optim_sd_kwargs or {}))
if use_dist_ckpt else if use_dist_ckpt else
optimizer.state_dict()) optimizer.state_dict())
...@@ -623,10 +669,7 @@ def generate_state_dict(args, model, optimizer, opt_param_scheduler, ...@@ -623,10 +669,7 @@ def generate_state_dict(args, model, optimizer, opt_param_scheduler,
opt_param_scheduler.state_dict() opt_param_scheduler.state_dict()
# Rerun state # Rerun state
rerun_state_machine = get_rerun_state_machine() state_dict['rerun_state_machine'] = rerun_state
state_dict['rerun_state_machine'] = rerun_state_machine.get_checkpoint_state(
train_data_iterator
)
# RNG states. # RNG states.
if not args.no_save_rng: if not args.no_save_rng:
...@@ -719,8 +762,7 @@ def _get_non_persistent_iteration(non_persistent_global_dir, args, checkpointing ...@@ -719,8 +762,7 @@ def _get_non_persistent_iteration(non_persistent_global_dir, args, checkpointing
print_rank_0(' will not load any non-persistent checkpoint') print_rank_0(' will not load any non-persistent checkpoint')
return iteration return iteration
elif args.non_persistent_ckpt_type == "local": elif args.non_persistent_ckpt_type == "local":
raise RuntimeError('LocalCheckpointManagers are not yet integrated') return checkpointing_context['local_checkpoint_manager'].find_latest()
return checkpointing_context['local_checkpoint_manager'].get_latest_checkpoint_iteration()
else: else:
assert False, 'Please use local or global non-persistent checkpoints' \ assert False, 'Please use local or global non-persistent checkpoints' \
f'(got: {args.non_persistent_ckpt_type})' f'(got: {args.non_persistent_ckpt_type})'
...@@ -744,17 +786,17 @@ def _load_non_persistent_base_checkpoint( ...@@ -744,17 +786,17 @@ def _load_non_persistent_base_checkpoint(
f'Loading from a non-persistent checkpoint (non-persistent iter {non_persistent_iteration})' f'Loading from a non-persistent checkpoint (non-persistent iter {non_persistent_iteration})'
) )
return _load_global_dist_base_checkpoint( return _load_global_dist_base_checkpoint(
non_persistent_global_dir, args, rank0, sharded_state_dict, non_persistent_iteration, False non_persistent_global_dir, args, rank0, sharded_state_dict, non_persistent_iteration, False,
checkpointing_context=checkpointing_context
) )
elif args.non_persistent_ckpt_type == "local": elif args.non_persistent_ckpt_type == "local":
raise RuntimeError('LocalCheckpointManagers are not yet integrated')
intermediate_state_dict, checkpoint_name = checkpointing_context[ intermediate_state_dict, checkpoint_name = checkpointing_context[
'local_checkpoint_manager' 'local_checkpoint_manager'
].load() ].load()
state_dict = recreate_state_dict_after_load( state_dict = intermediate_state_dict.to_state_dict(
sharded_state_dict, sharded_state_dict,
intermediate_state_dict,
algo=args.non_persistent_local_ckpt_algo, algo=args.non_persistent_local_ckpt_algo,
parallelization_group = mpu.get_data_parallel_group(with_context_parallel=True)
) )
return state_dict, checkpoint_name, False, CheckpointType.LOCAL return state_dict, checkpoint_name, False, CheckpointType.LOCAL
else: else:
...@@ -763,7 +805,7 @@ def _load_non_persistent_base_checkpoint( ...@@ -763,7 +805,7 @@ def _load_non_persistent_base_checkpoint(
def _load_global_dist_base_checkpoint( def _load_global_dist_base_checkpoint(
load_dir, args, rank0, sharded_state_dict, iteration, release load_dir, args, rank0, sharded_state_dict, iteration, release, checkpointing_context=None
): ):
""" Load the base state_dict from the given directory containing the global distributed checkpoint """ """ Load the base state_dict from the given directory containing the global distributed checkpoint """
if rank0: if rank0:
...@@ -787,6 +829,8 @@ def _load_global_dist_base_checkpoint( ...@@ -787,6 +829,8 @@ def _load_global_dist_base_checkpoint(
load_strategy = FullyParallelLoadStrategyWrapper( load_strategy = FullyParallelLoadStrategyWrapper(
load_strategy, mpu.get_data_parallel_group(with_context_parallel=True) load_strategy, mpu.get_data_parallel_group(with_context_parallel=True)
) )
if checkpointing_context is not None:
checkpointing_context["load_strategy"] = load_strategy
state_dict = dist_checkpointing.load(sharded_state_dict, checkpoint_name, load_strategy, strict=args.dist_ckpt_strictness) state_dict = dist_checkpointing.load(sharded_state_dict, checkpoint_name, load_strategy, strict=args.dist_ckpt_strictness)
return state_dict, checkpoint_name, release, CheckpointType.GLOBAL return state_dict, checkpoint_name, release, CheckpointType.GLOBAL
...@@ -860,7 +904,7 @@ def _load_base_checkpoint( ...@@ -860,7 +904,7 @@ def _load_base_checkpoint(
# Handle global distributed checkpoint # Handle global distributed checkpoint
if is_dist_ckpt: if is_dist_ckpt:
return _load_global_dist_base_checkpoint( return _load_global_dist_base_checkpoint(
load_dir, args, rank0, sharded_state_dict, iteration, release load_dir, args, rank0, sharded_state_dict, iteration, release, checkpointing_context=checkpointing_context
) )
# Handle global legacy checkpoint # Handle global legacy checkpoint
if rank0: if rank0:
...@@ -1035,7 +1079,7 @@ def fix_fp8_params_lose_precision_when_loading_dist_ckpt(state_dict): ...@@ -1035,7 +1079,7 @@ def fix_fp8_params_lose_precision_when_loading_dist_ckpt(state_dict):
def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', strict=True, def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', strict=True,
ft_client=None, checkpointing_context=None, skip_load_to_model_and_opt=False): checkpointing_context=None, skip_load_to_model_and_opt=False):
"""Load a model checkpoint and return the iteration. """Load a model checkpoint and return the iteration.
strict (bool): whether to strictly enforce that the keys in strict (bool): whether to strictly enforce that the keys in
:attr:`state_dict` of the checkpoint match the names of :attr:`state_dict` of the checkpoint match the names of
...@@ -1059,7 +1103,8 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri ...@@ -1059,7 +1103,8 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
raise FileNotFoundError("No checkpoint found in load directory or pretrained directory") raise FileNotFoundError("No checkpoint found in load directory or pretrained directory")
args.finetune = True args.finetune = True
model = unwrap_model(model) ddp_model = model
model = unwrap_model(ddp_model)
load_kwargs = {} load_kwargs = {}
is_dist_ckpt = False is_dist_ckpt = False
...@@ -1074,11 +1119,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri ...@@ -1074,11 +1119,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
rank0=True, rank0=True,
checkpointing_context=checkpointing_context, checkpointing_context=checkpointing_context,
) )
if args.enable_ft_package and ft_client is not None and state_dict is not None:
if 'ft_state' in state_dict:
ft_client.load_state_dict(state_dict['ft_state'])
else:
print_rank_0("ft_state is not present in state_dict")
is_dist_ckpt = ( is_dist_ckpt = (
ckpt_type == CheckpointType.LOCAL ckpt_type == CheckpointType.LOCAL
or dist_checkpointing.check_is_distributed_checkpoint(checkpoint_name) or dist_checkpointing.check_is_distributed_checkpoint(checkpoint_name)
...@@ -1136,6 +1177,32 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri ...@@ -1136,6 +1177,32 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
gen_sd_optim = None gen_sd_optim = None
gen_sd_opt_param_scheduler = None gen_sd_opt_param_scheduler = None
# Determine if rerun state will be loaded
if (
ckpt_tp_pp == run_tp_pp
and not release
and not args.finetune
and 'rerun_state_machine' in state_dict
):
rerun_state_machine = get_rerun_state_machine()
gen_sd_rerun_state = rerun_state_machine.state_dict(
data_iterator=None, use_dist_ckpt=True
)
else:
gen_sd_rerun_state = None
if ckpt_tp_pp != run_tp_pp:
print_rank_0("{}: Rerun state will be ignored".format(mismatch_msg))
# [ModelOpt]: IMPORTANT! Restoring modelopt_state (sharded or not) must be performed
# after the model instance has been created and before _load_base_checkpoint is called.
if has_nvidia_modelopt:
if ckpt_type == CheckpointType.LOCAL:
print_rank_0('WARNING: Local checkpointing does not support nvidia_modelopt.')
elif ckpt_type == CheckpointType.GLOBAL:
restore_modelopt_state(model, state_dict)
else:
restore_sharded_modelopt_state(model, checkpoint_name)
# [ModelOpt]: Initial loading from non-resume sharded checkpoint to a Distillation Model # [ModelOpt]: Initial loading from non-resume sharded checkpoint to a Distillation Model
# will result in key mismatch with loss modules potentially containing parameters, since # will result in key mismatch with loss modules potentially containing parameters, since
# it requires generating a state_dict before loading. Here we hide those modules if present. # it requires generating a state_dict before loading. Here we hide those modules if present.
...@@ -1145,9 +1212,9 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri ...@@ -1145,9 +1212,9 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
stack.enter_context(m.hide_loss_modules()) stack.enter_context(m.hide_loss_modules())
load_kwargs['sharded_state_dict'] = generate_state_dict( load_kwargs['sharded_state_dict'] = generate_state_dict(
args, model, gen_sd_optim, gen_sd_opt_param_scheduler, gen_sd_rng_state, args, model, gen_sd_optim, gen_sd_opt_param_scheduler, gen_sd_rng_state,
use_dist_ckpt=True, optim_sd_kwargs=optim_sd_kwargs, train_data_iterator=None use_dist_ckpt=True, optim_sd_kwargs=optim_sd_kwargs, rerun_state=gen_sd_rerun_state
) )
# When "--fp8-param-gather" is disabled, this function doesn't modify anything. # When "--fp8-param-gather" is disabled, this function doesn't modify anything.
fix_fp8_params_lose_precision_when_loading_dist_ckpt(load_kwargs['sharded_state_dict']) fix_fp8_params_lose_precision_when_loading_dist_ckpt(load_kwargs['sharded_state_dict'])
...@@ -1156,12 +1223,6 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri ...@@ -1156,12 +1223,6 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
**load_kwargs **load_kwargs
) )
if args.enable_ft_package and ft_client is not None and state_dict is not None:
if 'ft_state' in state_dict:
ft_client.load_state_dict(state_dict['ft_state'])
else:
print_rank_0("ft_state is not present in state_dict")
# Checkpoint not loaded. # Checkpoint not loaded.
if state_dict is None: if state_dict is None:
# Iteration and num_floating_point_operations_so_far default to 0. # Iteration and num_floating_point_operations_so_far default to 0.
...@@ -1202,24 +1263,15 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri ...@@ -1202,24 +1263,15 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
else: else:
print_rank_0('could not find arguments in the checkpoint ...') print_rank_0('could not find arguments in the checkpoint ...')
# [ModelOpt]: loading modelopt_state (sharded or not)
if has_nvidia_modelopt:
if ckpt_type == CheckpointType.LOCAL:
raise NotImplementedError('Local checkpointing does not support model opt')
if not args.use_dist_ckpt:
restore_modelopt_state(model, state_dict)
else:
restore_sharded_modelopt_state(model, checkpoint_name)
# Model. # Model.
strict = False if args.retro_add_retriever else strict strict = False if args.retro_add_retriever else strict
if not skip_load_to_model_and_opt: if not skip_load_to_model_and_opt:
if len(model) == 1: if len(ddp_model) == 1:
model[0].load_state_dict(state_dict['model'], strict=strict) ddp_model[0].load_state_dict(state_dict['model'], strict=strict)
else: else:
for i in range(len(model)): for i in range(len(ddp_model)):
mpu.set_virtual_pipeline_model_parallel_rank(i) mpu.set_virtual_pipeline_model_parallel_rank(i)
model[i].load_state_dict(state_dict['model%d' % i], strict=strict) ddp_model[i].load_state_dict(state_dict['model%d' % i], strict=strict)
# Fix up query/key/value matrix ordering if needed. # Fix up query/key/value matrix ordering if needed.
checkpoint_version = get_checkpoint_version() checkpoint_version = get_checkpoint_version()
...@@ -1230,7 +1282,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri ...@@ -1230,7 +1282,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
if not release and not args.finetune and not args.no_load_optim: if not release and not args.finetune and not args.no_load_optim:
try: try:
# Load state dict. # Load state dict.
if not skip_load_to_model_and_opt and optimizer is not None: if not skip_load_to_model_and_opt and optimizer is not None and not optimizer.is_stub_optimizer:
optimizer.load_state_dict(state_dict['optimizer']) optimizer.load_state_dict(state_dict['optimizer'])
# Load distributed optimizer's custom parameter state. # Load distributed optimizer's custom parameter state.
...@@ -1268,7 +1320,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri ...@@ -1268,7 +1320,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
# rerun state # rerun state
try: try:
if 'rerun_state_machine' in state_dict: if 'rerun_state_machine' in state_dict:
get_rerun_state_machine().set_checkpoint_state(state_dict['rerun_state_machine']) get_rerun_state_machine().load_state_dict(state_dict['rerun_state_machine'])
except Exception as e: except Exception as e:
print(f"Unable to restore RerunMachine from checkpoint: {e}") print(f"Unable to restore RerunMachine from checkpoint: {e}")
sys.exit() sys.exit()
...@@ -1317,7 +1369,18 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri ...@@ -1317,7 +1369,18 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
f'p {mpu.get_pipeline_model_parallel_rank() + 1}/{mpu.get_pipeline_model_parallel_world_size()} ] ' f'p {mpu.get_pipeline_model_parallel_rank() + 1}/{mpu.get_pipeline_model_parallel_world_size()} ] '
f'at iteration {iteration}') f'at iteration {iteration}')
# Additional callback for wandb (last rank)
if not torch.distributed.is_initialized() \
or is_last_rank():
wandb_utils.on_load_checkpoint_success(checkpoint_name, load_dir)
torch.cuda.empty_cache() torch.cuda.empty_cache()
if iteration > 0:
# Notify FT that a checkpoint was loaded.
is_local_chkpt = (ckpt_type == CheckpointType.LOCAL)
ft_integration.on_checkpoint_loaded(is_local_chkpt=is_local_chkpt)
return iteration, num_floating_point_operations_so_far return iteration, num_floating_point_operations_so_far
......
File mode changed from 100755 to 100644
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
""" """
FT Package Integration Fault Tolerance (FT) package integration for Megatron-LM, using the FT section-based API.
This file is part of the integration process for the FT package, a custom heartbeat-based The FT package is included in "nvidia-resiliency-ext"
system developed by NVIDIA. The FT package monitors the ranks to detect hangs, gracefully (https://github.com/NVIDIA/nvidia-resiliency-ext).
terminates the workload, and respawns it from the last checkpoints. It includes an auto
config feature that automatically sets up timeouts based on the observed time of iterations.
Note: This tool is an internal NVIDIA tool and is not open source. This file does not NOTE: The workload must be run using the `ft_launcher` tool provided by `nvidia-resiliency-ext.`
contain the FT package itself but supports its integration. NOTE: Calls to the public API of this module are no-ops if FT is not initialized
(`ft_integration.setup` was not called).
NOTE: Default distributed process group should be initialized before calling `ft_integration.setup`
The "setup" FT section is opened during FT initialization and closed before the first training or
eval iteration. Training and evaluation steps are wrapped in the "step" section, but only after a
few warmup iterations. This is because the initial iterations may be slower, and we want the "step"
timeout to be short. These warmup steps, which are not wrapped in the "step" section, will fall into
the out-of-section area. All checkpoint-saving-related operations (including asynchronous
checkpointing finalization) are wrapped in the "checkpointing" section.
If timeout calculation is enabled (--calc-ft-timeouts),
FT timeouts are updated after each checkpoint and at the end of the run.
Updated values are based on observed intervals.
`ft_launcher` command example:
```
ft_launcher \
--rdzv_backend=c10d --rdzv_endpoint=${MASTER_ADDR}:${MASTER_PORT} \
--nnodes=${NUM_NODES} --nproc-per-node=${NUM_GPUS_PER_NODE} \
--ft-param-rank_section_timeouts=setup:600,step:180,checkpointing:420 \
--ft-param-rank_out_of_section_timeout=300 \
train_script_with_ft.py
```
""" """
import types import argparse
from enum import Enum, auto import json
import os
import random
import signal
import sys
import threading
import time
from typing import Any, Optional
import torch
from . import global_vars from . import global_vars
from .utils import is_rank0, print_rank_0
class StateMachineActions(Enum): _GLOBAL_RANK_MONITOR_CLIENT = None
NONE = auto()
SAVE_CHECKPOINT = auto() _ft_state_path = None
TRAIN_HEARTBEAT = auto() _is_persistent_chkpt_loaded = False
EVAL_HEARTBEAT = auto() _is_async_chkpt_enabled = False
UPDATE_TIMEOUT = auto() _is_calculating_timeouts = False
_is_setup_section_open = False
_seen_checkpoints_cnt = 0
_seen_tr_iters_cnt = 0
_curr_eval_iter_idx = 0
_NUM_WARMUP_ITERS = 1
_MIN_ITERS_FOR_STEP_TIMEOUT_UPDATE = 16
class _TrainingStateMachine:
"""
This class encapsulates logic for determining when:
- FT timeouts can be updated (`.can_update_timeouts` property)
`on_ ...` methods update the state and should be called from the corresponding places. def get_rank_monitor_client() -> Optional[Any]:
"""Returns the underlying fault tolerance client instance
Returns:
RankMonitorClient: rank monitor client instance, or None if FT was not initialized
""" """
return _GLOBAL_RANK_MONITOR_CLIENT
MIN_ITERS_FOR_TIMEOUT_UPDATE = 2
def __init__(self):
self.num_tr_iters_total = 0
self.num_tr_iter_at_last_save = None
self.seen_checkpointing = False
self.timeouts_updated = False
def on_save_checkpoint(self):
self.num_tr_iter_at_last_save = self.num_tr_iters_total
def on_train_heartbeat(self):
self.num_tr_iters_total += 1
if not self.seen_checkpointing and self.num_tr_iter_at_last_save is not None:
# detect mid-epoch checkpointing that makes hearbeat interval longer
iters_pre_save = self.num_tr_iter_at_last_save
iters_post_save = self.num_tr_iters_total - self.num_tr_iter_at_last_save
self.seen_checkpointing = iters_pre_save > 0 and iters_post_save > 0
def on_eval_heartbeat(self):
pass
def on_timeouts_updated(self):
self.timeouts_updated = True
@property
def can_update_timeouts(self) -> bool:
"""
Returns True if new timeouts can be computed.
`.on_timeouts_updated()` resets this property back to False.
"""
if self.timeouts_updated:
# timeouts are updated at most once per training run
return False
if self.num_tr_iters_total < self.MIN_ITERS_FOR_TIMEOUT_UPDATE:
# need a few training iters
return False
# check if there was checkoint saving
# this makes heartbeat iterval longer than usual.
return self.seen_checkpointing
def perform_action(self, action: StateMachineActions):
if action == StateMachineActions.TRAIN_HEARTBEAT:
self.on_train_heartbeat()
elif action == StateMachineActions.SAVE_CHECKPOINT:
self.on_save_checkpoint()
elif action == StateMachineActions.EVAL_HEARTBEAT:
self.on_eval_heartbeat()
elif action == StateMachineActions.UPDATE_TIMEOUT:
self.on_timeouts_updated()
assert not self.can_update_timeouts
# No action for StateMachineActions.NONE
def setup(args: argparse.Namespace) -> None:
"""Initialize fault tolerance
_GLOBAL_RANK_MONITOR_CLIENT = None Args:
_GLOBAL_STATE_MACHINE = _TrainingStateMachine() args (argparse.Namespace): parsed Megatron-LM command line arguments
def _set_rank_monitor_client(): Raises:
ValueError: if invalid config is provided
"""
from nvidia_resiliency_ext.fault_tolerance import RankMonitorClient from nvidia_resiliency_ext.fault_tolerance import RankMonitorClient
print_rank_0(f"FT: initializing...")
checkpoint_dir = args.save
if not checkpoint_dir:
raise ValueError("checkpointing save dir must be set to enable fault tolerance")
if is_rank0() and not os.path.exists(checkpoint_dir):
# MLM checkpoint dir will be needed for saving FT state.
# it can happen before the checkpointing, so create it in advance
os.makedirs(checkpoint_dir, exist_ok=True)
cli = RankMonitorClient() cli = RankMonitorClient()
global _GLOBAL_RANK_MONITOR_CLIENT global _GLOBAL_RANK_MONITOR_CLIENT
global_vars._ensure_var_is_not_initialized(_GLOBAL_RANK_MONITOR_CLIENT, 'rank monitor client') global_vars._ensure_var_is_not_initialized(_GLOBAL_RANK_MONITOR_CLIENT, 'rank monitor client')
_GLOBAL_RANK_MONITOR_CLIENT = cli _GLOBAL_RANK_MONITOR_CLIENT = cli
def get_rank_monitor_client(action=StateMachineActions.NONE): global _ft_state_path
global _GLOBAL_RANK_MONITOR_CLIENT, _GLOBAL_STATE_MACHINE _ft_state_path = os.path.join(checkpoint_dir, "ft_state.json")
if _GLOBAL_RANK_MONITOR_CLIENT is None:
try: global _is_async_chkpt_enabled
_set_rank_monitor_client() _is_async_chkpt_enabled = args.async_save
except ImportError:
_GLOBAL_RANK_MONITOR_CLIENT = None global _is_calculating_timeouts
_GLOBAL_STATE_MACHINE.perform_action(action) _is_calculating_timeouts = args.calc_ft_timeouts
return _GLOBAL_RANK_MONITOR_CLIENT
cli.init_workload_monitoring()
_load_state_if_exists()
print_rank_0(f"FT: initialized. Timeouts={cli.section_timeouts}")
cli.start_section("setup")
global _is_setup_section_open
_is_setup_section_open = True
def on_training_step_start() -> None:
"""Should be called before each training step"""
rmon_cli = get_rank_monitor_client()
if rmon_cli is not None:
global _is_setup_section_open
if _is_setup_section_open:
rmon_cli.end_section("setup")
_is_setup_section_open = False
if _seen_tr_iters_cnt >= _NUM_WARMUP_ITERS:
rmon_cli.start_section("step")
# reset eval step index. we started training, so evaluation is done
global _curr_eval_iter_idx
_curr_eval_iter_idx = 0
def on_training_step_end() -> None:
"""Should be called after each training step"""
rmon_cli = get_rank_monitor_client()
if rmon_cli is not None:
global _seen_tr_iters_cnt
if _seen_tr_iters_cnt >= _NUM_WARMUP_ITERS:
rmon_cli.end_section("step")
_seen_tr_iters_cnt += 1
def on_eval_step_start() -> None:
"""Should be called before each validation step"""
rmon_cli = get_rank_monitor_client()
if rmon_cli is not None:
global _is_setup_section_open
if _is_setup_section_open:
# setup section can be open if there were no training iters before evaluation
rmon_cli.end_section("setup")
_is_setup_section_open = False
if _curr_eval_iter_idx >= _NUM_WARMUP_ITERS:
rmon_cli.start_section("step")
def on_eval_step_end() -> None:
"""Should be called after each validation step"""
rmon_cli = get_rank_monitor_client()
if rmon_cli is not None:
global _curr_eval_iter_idx
if _curr_eval_iter_idx >= _NUM_WARMUP_ITERS:
rmon_cli.end_section("step")
_curr_eval_iter_idx += 1
def on_checkpointing_start() -> None:
"""Should be called before each checkpoint-saving-related operation."""
rmon_cli = get_rank_monitor_client()
if rmon_cli is not None:
rmon_cli.start_section("checkpointing")
def on_checkpointing_end(is_async_finalization: bool) -> None:
"""Should be called after each checkpoint-saving-related operation.
Args:
is_async_finalization (bool): true if called after an async checkpointing finalization
"""
rmon_cli = get_rank_monitor_client()
if rmon_cli is not None:
rmon_cli.end_section("checkpointing")
# async checkpointing finalization is called before each training iter, it can be no-op.
# let's try to update the timeouts only on the `save_checkpoint`
if not is_async_finalization:
global _seen_checkpoints_cnt
_seen_checkpoints_cnt += 1
_maybe_update_timeouts()
def on_checkpoint_loaded(is_local_chkpt: bool) -> None:
"""Should be called after a checkpoint was loaded
Args:
is_local_chkpt (bool): true if it was a local checkpoint, false if global
"""
# checkpoint can be loaded during "setup"
# check if persistent checkpoint was loaded,
# in-memory checkpoint reading can be very fast,
# so we could underestimate the "setup" timeout
global _is_persistent_chkpt_loaded
_is_persistent_chkpt_loaded = not is_local_chkpt
def shutdown() -> None:
"""Shutdowns fault folerance, updates the FT timeouts if possible"""
global _GLOBAL_RANK_MONITOR_CLIENT
rmon_cli = get_rank_monitor_client()
if rmon_cli is not None:
print_rank_0("FT: closing...")
_maybe_update_timeouts(is_closing_ft=True)
rmon_cli.shutdown_workload_monitoring()
print_rank_0("FT: closed.")
_GLOBAL_RANK_MONITOR_CLIENT = None
def _load_state_if_exists():
rmon_cli = get_rank_monitor_client()
if os.path.exists(_ft_state_path):
with open(_ft_state_path, "r") as f:
ft_state = json.load(f)
rmon_cli.load_state_dict(ft_state)
print_rank_0(f"FT: loaded timeouts from {_ft_state_path}. {rmon_cli.section_timeouts}")
def _update_timeouts(selected_sections, calc_out_of_section):
print_rank_0(
f"FT: updating timeouts for: {selected_sections} "
+ f"update out-of-section: {calc_out_of_section} ..."
)
rmon_cli = get_rank_monitor_client()
rmon_cli.calculate_and_set_section_timeouts(
selected_sections=selected_sections, calc_out_of_section=calc_out_of_section
)
if is_rank0():
ft_state = rmon_cli.state_dict()
with open(_ft_state_path, "w") as f:
json.dump(ft_state, f)
print_rank_0(f"FT: updated timeouts saved to {_ft_state_path}. {rmon_cli.section_timeouts}")
def _maybe_update_timeouts(is_closing_ft=False):
rmon_cli = get_rank_monitor_client()
if rmon_cli is None:
return
if not _is_calculating_timeouts:
return
# Decide which section timeouts can be updated
sections_to_update = []
if _is_persistent_chkpt_loaded:
sections_to_update.append("setup")
else:
print_rank_0(
"FT: can't update the setup section timeout until persistent checkpoint is loaded"
)
if _seen_tr_iters_cnt >= _MIN_ITERS_FOR_STEP_TIMEOUT_UPDATE:
sections_to_update.append("step")
else:
print_rank_0("FT: need to see more training iterations to update the step section timeout")
if _seen_checkpoints_cnt > 0:
if not _is_async_chkpt_enabled:
sections_to_update.append("checkpointing")
else:
# There can be too much checkpointing section time variability
# across runs with the async checkpointing, e.g. in some runs all checkpointing
# work can be parallelized (=short checkpointing sections) while in others we can
# hit a costly finalization.
print_rank_0(
"FT: can't update the checkpointing section timeout with async checkpointing"
)
else:
print_rank_0("FT: checkpointing section is not updated until a checkpoint was saved")
update_out_of_section = False
if is_closing_ft:
# with async checkpointing, "checkpointing" section is not updated,
# but still we want to see some checkpointing to ensure that is was a complete run
if {'setup', 'step'}.issubset(sections_to_update) and _seen_checkpoints_cnt > 0:
update_out_of_section = True
else:
print_rank_0(
"FT: the out-of-section timeout won't be updated until all FT sections were seen"
)
else:
print_rank_0("FT: the out-of-section timeout won't be updated as the FT is not closing yet")
if sections_to_update or update_out_of_section:
_update_timeouts(
selected_sections=sections_to_update, calc_out_of_section=update_out_of_section
)
def maybe_setup_simulated_fault() -> None:
"""Sets a simulated fault, based on `FT_SIM_FAULT_DESC` env variable.
Simulated fault description format:
rank_hung|rank_killed;rank_to_fail|"";base_delay
NOTE: This if for FT testing only
"""
simulated_fault_desc = os.environ.get('FT_SIM_FAULT_DESC', None)
if not simulated_fault_desc:
return
fault_type: Any # silence mypy
rank_to_fail: Any # silence mypy
base_delay: Any # silence mypy
fault_type, rank_to_fail, base_delay = simulated_fault_desc.split(';')
fault_type = fault_type.strip()
rank_to_fail = rank_to_fail.strip()
rank_to_fail = int(rank_to_fail) if rank_to_fail else None
base_delay = float(base_delay.strip())
rng = random.Random()
print_rank_0(
f"FT: Initializing simulated fault: {fault_type},"
+ f"rank to fail: {rank_to_fail}, base delay: {base_delay}"
)
# rank that simulates a fault can be explicitly specified in the `rank_to_fail` field
# if not specified, it just picks a random rank
rank = torch.distributed.get_rank()
rand_rank = rng.randint(0, torch.distributed.get_world_size() - 1)
rank_to_fail = rank_to_fail if rank_to_fail is not None else rand_rank
rank_to_fail = torch.tensor([rank_to_fail], device=torch.cuda.current_device())
torch.distributed.broadcast(rank_to_fail, 0)
rank_to_fail = int(rank_to_fail.item())
if rank != rank_to_fail:
# this rank is not going to simulate a fault, nothing more to do
return
if fault_type == 'random':
fault_type = rng.choice(['rank_killed', 'rank_hung'])
if fault_type == 'rank_killed':
target_pid = os.getpid()
elif fault_type == 'rank_hung':
target_pid = os.getpid()
else:
raise Exception(f"Unknown fault type {fault_type} expected one of: rank_killed, rank_hung.")
# add some randomness to the delay
delay = base_delay + 0.2 * rng.random() * base_delay
print_rank_0(f"FT: Selected fault={fault_type}; target rank={rank_to_fail}; delay={delay}")
def __fault_thread():
time.sleep(delay)
for of in [sys.stdout, sys.stderr]:
print(
f"\n####\nFT: Simulating fault: {fault_type}; rank to fail: {rank_to_fail}\n####\n",
file=of,
flush=True,
)
if fault_type == 'rank_hung':
os.kill(target_pid, signal.SIGSTOP)
else:
os.kill(target_pid, signal.SIGKILL)
def can_update_timeouts(): fault_sim_thread = threading.Thread(target=__fault_thread)
global _GLOBAL_STATE_MACHINE fault_sim_thread.daemon = True
return _GLOBAL_STATE_MACHINE.can_update_timeouts fault_sim_thread.start()
File mode changed from 100755 to 100644
...@@ -2,29 +2,34 @@ ...@@ -2,29 +2,34 @@
"""Megatron initialization.""" """Megatron initialization."""
import logging import logging
import random
import os import os
import random
import time import time
import warnings import warnings
from datetime import timedelta
import numpy as np import numpy as np
import torch import torch
from datetime import timedelta
from megatron.legacy import fused_kernels
from megatron.training import get_adlr_autoresume
from megatron.training import get_args
from megatron.training import get_tensorboard_writer
from megatron.core import mpu, tensor_parallel from megatron.core import mpu, tensor_parallel
from megatron.core.rerun_state_machine import initialize_rerun_state_machine, RerunErrorInjector, RerunDiagnostic, RerunMode
from megatron.training.arguments import parse_args, validate_args
from megatron.training.yaml_arguments import validate_yaml
from megatron.training.checkpointing import load_args_from_checkpoint
from megatron.training.global_vars import set_global_variables
from megatron.core.fusions.fused_bias_dropout import bias_dropout_add_fused_train from megatron.core.fusions.fused_bias_dropout import bias_dropout_add_fused_train
from megatron.core.fusions.fused_bias_gelu import bias_gelu from megatron.core.fusions.fused_bias_gelu import bias_gelu
from megatron.core.fusions.fused_bias_swiglu import bias_swiglu from megatron.core.fusions.fused_bias_swiglu import bias_swiglu
from megatron.core.parallel_state import create_group
from megatron.core.rerun_state_machine import (
RerunDiagnostic,
RerunErrorInjector,
RerunMode,
initialize_rerun_state_machine,
)
from megatron.core.utils import get_te_version, is_te_min_version, is_torch_min_version from megatron.core.utils import get_te_version, is_te_min_version, is_torch_min_version
from megatron.legacy import fused_kernels
from megatron.training import get_adlr_autoresume, get_args, get_tensorboard_writer
from megatron.training.arguments import parse_args, validate_args
from megatron.training.async_utils import init_persistent_async_worker
from megatron.training.checkpointing import load_args_from_checkpoint
from megatron.training.global_vars import set_global_variables
from megatron.training.yaml_arguments import validate_yaml
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -36,7 +41,7 @@ def initialize_megatron( ...@@ -36,7 +41,7 @@ def initialize_megatron(
allow_no_cuda=False, allow_no_cuda=False,
skip_mpu_initialization=False, skip_mpu_initialization=False,
get_embedding_ranks=None, get_embedding_ranks=None,
get_position_embedding_ranks=None get_position_embedding_ranks=None,
): ):
"""Set global variables, initialize distributed, and """Set global variables, initialize distributed, and
set autoresume and random seeds. set autoresume and random seeds.
...@@ -61,14 +66,21 @@ def initialize_megatron( ...@@ -61,14 +66,21 @@ def initialize_megatron(
if args.use_checkpoint_args or args_defaults.get("use_checkpoint_args", False): if args.use_checkpoint_args or args_defaults.get("use_checkpoint_args", False):
assert args.load is not None, "--use-checkpoint-args requires --load argument" assert args.load is not None, "--use-checkpoint-args requires --load argument"
assert args.non_persistent_ckpt_type != "local", (
"--use-checkpoint-args is not supported with --non_persistent_ckpt_type=local. "
"Two-stage checkpoint loading is not implemented, and all arguments must be defined "
"before initializing LocalCheckpointManager."
)
load_args_from_checkpoint(args) load_args_from_checkpoint(args)
if args.async_save and args.use_persistent_ckpt_worker:
init_persistent_async_worker()
if args.yaml_cfg is not None: if args.yaml_cfg is not None:
args = validate_yaml(args, args_defaults) args = validate_yaml(args, args_defaults)
else: else:
validate_args(args, args_defaults) validate_args(args, args_defaults)
# set global args, build tokenizer, and set adlr-autoresume, # set global args, build tokenizer, and set adlr-autoresume,
# tensorboard-writer, and timers. # tensorboard-writer, and timers.
set_global_variables(args) set_global_variables(args)
...@@ -78,10 +90,8 @@ def initialize_megatron( ...@@ -78,10 +90,8 @@ def initialize_megatron(
# init rerun state # init rerun state
def state_save_func(): def state_save_func():
return { return {'rng_tracker_states': tensor_parallel.get_cuda_rng_tracker().get_states()}
'rng_tracker_states': tensor_parallel.get_cuda_rng_tracker().get_states()
}
def state_restore_func(state_dict): def state_restore_func(state_dict):
if state_dict['rng_tracker_states']: if state_dict['rng_tracker_states']:
tensor_parallel.get_cuda_rng_tracker().set_states(state_dict['rng_tracker_states']) tensor_parallel.get_cuda_rng_tracker().set_states(state_dict['rng_tracker_states'])
...@@ -95,6 +105,7 @@ def initialize_megatron( ...@@ -95,6 +105,7 @@ def initialize_megatron(
error_injection_rate=args.error_injection_rate, error_injection_rate=args.error_injection_rate,
error_injection_type=RerunDiagnostic(args.error_injection_type), error_injection_type=RerunDiagnostic(args.error_injection_type),
), ),
result_rejected_tracker_filename=args.result_rejected_tracker_filename,
) )
# torch.distributed initialization # torch.distributed initialization
...@@ -106,7 +117,12 @@ def initialize_megatron( ...@@ -106,7 +117,12 @@ def initialize_megatron(
# Random seeds for reproducibility. # Random seeds for reproducibility.
if args.rank == 0: if args.rank == 0:
print("> setting random seeds to {} ...".format(args.seed)) print("> setting random seeds to {} ...".format(args.seed))
_set_random_seed(args.seed, args.data_parallel_random_init) _set_random_seed(
args.seed,
args.data_parallel_random_init,
args.te_rng_tracker,
args.inference_rng_tracker,
)
if skip_mpu_initialization: if skip_mpu_initialization:
return None return None
...@@ -133,8 +149,8 @@ def initialize_megatron( ...@@ -133,8 +149,8 @@ def initialize_megatron(
_compile_dependencies() _compile_dependencies()
if args.tp_comm_overlap: if args.tp_comm_overlap:
#TODO: Should this be activated with just decoder-tp-comm-overlap too? # TODO: Should this be activated with just decoder-tp-comm-overlap too?
_initialize_tp_communicators() _initialize_tp_communicators()
# No continuation function # No continuation function
return None return None
...@@ -172,17 +188,10 @@ def _compile_dependencies(): ...@@ -172,17 +188,10 @@ def _compile_dependencies():
# Constraints on sequence length and attn_batch_size to enable warp based # Constraints on sequence length and attn_batch_size to enable warp based
# optimization and upper triangular optimization (for causal mask) # optimization and upper triangular optimization (for causal mask)
custom_kernel_constraint = ( custom_kernel_constraint = (
seq_len > 16 seq_len > 16 and seq_len <= 16384 and seq_len % 4 == 0 and attn_batch_size % 4 == 0
and seq_len <= 16384
and seq_len % 4 == 0
and attn_batch_size % 4 == 0
) )
# Print a warning. # Print a warning.
if not ( if not ((args.fp16 or args.bf16) and custom_kernel_constraint and args.masked_softmax_fusion):
(args.fp16 or args.bf16)
and custom_kernel_constraint
and args.masked_softmax_fusion
):
if args.rank == 0: if args.rank == 0:
print( print(
"WARNING: constraints for invoking optimized" "WARNING: constraints for invoking optimized"
...@@ -192,14 +201,14 @@ def _compile_dependencies(): ...@@ -192,14 +201,14 @@ def _compile_dependencies():
) )
# Always build on rank zero first. # Always build on rank zero first.
if torch.distributed.get_rank() == 0: # if torch.distributed.get_rank() == 0:
start_time = time.time() # start_time = time.time()
print("> compiling and loading fused kernels ...", flush=True) # print("> compiling and loading fused kernels ...", flush=True)
#fused_kernels.load(args) # fused_kernels.load(args)
torch.distributed.barrier() # torch.distributed.barrier()
else: # else:
torch.distributed.barrier() # torch.distributed.barrier()
#fused_kernels.load(args) # fused_kernels.load(args)
# Simple barrier to make sure all ranks have passed the # Simple barrier to make sure all ranks have passed the
# compilation phase successfully before moving on to the # compilation phase successfully before moving on to the
# rest of the program. We think this might ensure that # rest of the program. We think this might ensure that
...@@ -212,48 +221,65 @@ def _compile_dependencies(): ...@@ -212,48 +221,65 @@ def _compile_dependencies():
flush=True, flush=True,
) )
def _initialize_tp_communicators(): def _initialize_tp_communicators():
""" initializing the communicators with user buffers for high-performance tensor-model-parallel """initializing the communicators with user buffers for high-performance tensor-model-parallel
communication overlap """ communication overlap"""
try: try:
import yaml import transformer_engine
import yaml
import transformer_engine from transformer_engine.pytorch import module as te_module
from transformer_engine.pytorch import module as te_module
except ImportError: except ImportError:
raise RuntimeError("Tensor Parallel Communication/GEMM Overlap optimization needs 'yaml' and " raise RuntimeError(
"'transformer_engine' packages") "Tensor Parallel Communication/GEMM Overlap optimization needs 'yaml' and "
"'transformer_engine' packages"
)
args = get_args() args = get_args()
if args.tp_comm_overlap_cfg is not None: if args.tp_comm_overlap_cfg is not None:
with open(args.tp_comm_overlap_cfg,"r") as stream: with open(args.tp_comm_overlap_cfg, "r") as stream:
ub_cfgs = yaml.safe_load(stream) ub_cfgs = yaml.safe_load(stream)
else: else:
ub_cfgs = {} ub_cfgs = {}
if getattr(args, 'decoder_tp_comm_overlap', False): if getattr(args, 'decoder_tp_comm_overlap', False):
input_shape = [(args.decoder_seq_length * args.micro_batch_size) // args.context_parallel_size , args.hidden_size] input_shape = [
(args.decoder_seq_length * args.micro_batch_size) // args.context_parallel_size,
args.hidden_size,
]
else: else:
input_shape = [(args.seq_length * args.micro_batch_size) // args.context_parallel_size , args.hidden_size] input_shape = [
(args.seq_length * args.micro_batch_size) // args.context_parallel_size,
args.hidden_size,
]
if is_te_min_version("1.9.0"): if is_te_min_version("1.9.0"):
# The process group with the target bootstrap backend is created in Transformer Engine. # The process group with the target bootstrap backend is created in Transformer Engine.
te_module.base.initialize_ub(shape = input_shape, tp_size = args.tensor_model_parallel_size, te_module.base.initialize_ub(
use_fp8 = (args.fp8 is not None) , ub_cfgs = ub_cfgs, shape=input_shape,
bootstrap_backend = args.tp_comm_bootstrap_backend) tp_size=args.tensor_model_parallel_size,
use_fp8=(args.fp8 is not None),
ub_cfgs=ub_cfgs,
bootstrap_backend=args.tp_comm_bootstrap_backend,
)
else: else:
if args.tp_comm_bootstrap_backend != 'mpi': if args.tp_comm_bootstrap_backend != 'mpi':
warnings.warn( warnings.warn(
f"Transformer Engine v{get_te_version()} supports only MPI bootstrap backend." f"Transformer Engine v{get_te_version()} supports only MPI bootstrap backend."
) )
# Create a MPI process group to help with TP communication overlap bootstrap. # Create a MPI process group to help with TP communication overlap bootstrap.
torch.distributed.new_group(backend='mpi') create_group(backend='mpi', group_desc='TP_BOOTSTRAP_GROUP_MPI')
te_module.base.initialize_ub(shape = input_shape, tp_size = args.tensor_model_parallel_size, te_module.base.initialize_ub(
use_fp8 = (args.fp8 is not None) , ub_cfgs = ub_cfgs) shape=input_shape,
tp_size=args.tensor_model_parallel_size,
use_fp8=(args.fp8 is not None),
ub_cfgs=ub_cfgs,
)
def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks): def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks):
"""Initialize torch.distributed and core model parallel.""" """Initialize torch.distributed and core model parallel."""
...@@ -264,14 +290,14 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks): ...@@ -264,14 +290,14 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks):
if args.rank == 0: if args.rank == 0:
print( print(
"torch distributed is already initialized, " "torch distributed is already initialized, " "skipping initialization ...",
"skipping initialization ...",
flush=True, flush=True,
) )
args.rank = torch.distributed.get_rank() args.rank = torch.distributed.get_rank()
args.world_size = torch.distributed.get_world_size() args.world_size = torch.distributed.get_world_size()
else: else:
if args.rank == 0: if args.rank == 0:
print("> initializing torch distributed ...", flush=True) print("> initializing torch distributed ...", flush=True)
# Manually set the device ids. # Manually set the device ids.
...@@ -283,7 +309,7 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks): ...@@ -283,7 +309,7 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks):
# Call the init process # Call the init process
init_process_group_kwargs = { init_process_group_kwargs = {
'backend' : args.distributed_backend, 'backend': args.distributed_backend,
'world_size': args.world_size, 'world_size': args.world_size,
'rank': args.rank, 'rank': args.rank,
'init_method': args.dist_url, 'init_method': args.dist_url,
...@@ -303,6 +329,7 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks): ...@@ -303,6 +329,7 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks):
args.pipeline_model_parallel_size, args.pipeline_model_parallel_size,
args.virtual_pipeline_model_parallel_size, args.virtual_pipeline_model_parallel_size,
args.pipeline_model_parallel_split_rank, args.pipeline_model_parallel_split_rank,
pipeline_model_parallel_comm_backend=args.pipeline_model_parallel_comm_backend,
context_parallel_size=args.context_parallel_size, context_parallel_size=args.context_parallel_size,
hierarchical_context_parallel_sizes=args.hierarchical_context_parallel_sizes, hierarchical_context_parallel_sizes=args.hierarchical_context_parallel_sizes,
expert_model_parallel_size=args.expert_model_parallel_size, expert_model_parallel_size=args.expert_model_parallel_size,
...@@ -310,11 +337,12 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks): ...@@ -310,11 +337,12 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks):
expert_tensor_parallel_size=args.expert_tensor_parallel_size, expert_tensor_parallel_size=args.expert_tensor_parallel_size,
distributed_timeout_minutes=args.distributed_timeout_minutes, distributed_timeout_minutes=args.distributed_timeout_minutes,
nccl_communicator_config_path=args.nccl_communicator_config_path, nccl_communicator_config_path=args.nccl_communicator_config_path,
order='tp-cp-ep-dp-pp' if not args.use_tp_pp_dp_mapping else 'tp-pp-dp', order='tp-cp-ep-dp-pp' if not args.use_tp_pp_dp_mapping else 'tp-cp-ep-pp-dp',
encoder_tensor_model_parallel_size=args.encoder_tensor_model_parallel_size, encoder_tensor_model_parallel_size=args.encoder_tensor_model_parallel_size,
encoder_pipeline_model_parallel_size=args.encoder_pipeline_model_parallel_size, encoder_pipeline_model_parallel_size=args.encoder_pipeline_model_parallel_size,
get_embedding_ranks=get_embedding_ranks, get_embedding_ranks=get_embedding_ranks,
get_position_embedding_ranks=get_position_embedding_ranks, get_position_embedding_ranks=get_position_embedding_ranks,
create_gloo_process_groups=args.enable_gloo_process_groups,
) )
if args.rank == 0: if args.rank == 0:
print( print(
...@@ -336,7 +364,9 @@ def _init_autoresume(): ...@@ -336,7 +364,9 @@ def _init_autoresume():
torch.distributed.barrier() torch.distributed.barrier()
def _set_random_seed(seed_, data_parallel_random_init=False): def _set_random_seed(
seed_, data_parallel_random_init=False, te_rng_tracker=False, inference_rng_tracker=False
):
"""Set random seed for reproducability.""" """Set random seed for reproducability."""
if seed_ is not None and seed_ > 0: if seed_ is not None and seed_ > 0:
# Ensure that different pipeline MP stages get different seeds. # Ensure that different pipeline MP stages get different seeds.
...@@ -348,7 +378,9 @@ def _set_random_seed(seed_, data_parallel_random_init=False): ...@@ -348,7 +378,9 @@ def _set_random_seed(seed_, data_parallel_random_init=False):
np.random.seed(seed) np.random.seed(seed)
torch.manual_seed(seed) torch.manual_seed(seed)
if torch.cuda.device_count() > 0: if torch.cuda.device_count() > 0:
tensor_parallel.model_parallel_cuda_manual_seed(seed) tensor_parallel.model_parallel_cuda_manual_seed(
seed, te_rng_tracker, inference_rng_tracker
)
else: else:
raise ValueError("Seed ({}) should be a positive integer.".format(seed_)) raise ValueError("Seed ({}) should be a positive integer.".format(seed_))
...@@ -374,7 +406,7 @@ def set_jit_fusion_options(): ...@@ -374,7 +406,7 @@ def set_jit_fusion_options():
torch._C._jit_override_can_fuse_on_cpu(False) torch._C._jit_override_can_fuse_on_cpu(False)
torch._C._jit_override_can_fuse_on_gpu(False) torch._C._jit_override_can_fuse_on_gpu(False)
torch._C._jit_set_texpr_fuser_enabled(False) torch._C._jit_set_texpr_fuser_enabled(False)
torch._C._jit_set_nvfuser_enabled(False)#(True) torch._C._jit_set_nvfuser_enabled(True)
torch._C._debug_set_autodiff_subgraph_inlining(False) torch._C._debug_set_autodiff_subgraph_inlining(False)
else: else:
# legacy pytorch fuser # legacy pytorch fuser
...@@ -398,9 +430,7 @@ def _warmup_jit_function(): ...@@ -398,9 +430,7 @@ def _warmup_jit_function():
# Warmup fused bias+gelu # Warmup fused bias+gelu
bias = torch.rand( bias = torch.rand(
args.ffn_hidden_size // args.tensor_model_parallel_size, args.ffn_hidden_size // args.tensor_model_parallel_size, dtype=dtype, device="cuda"
dtype=dtype,
device="cuda",
) )
input = torch.rand( input = torch.rand(
( (
...@@ -437,15 +467,11 @@ def _warmup_jit_function(): ...@@ -437,15 +467,11 @@ def _warmup_jit_function():
dtype=dtype, dtype=dtype,
device="cuda", device="cuda",
) )
bias = torch.rand((args.hidden_size), dtype=dtype, device="cuda").expand_as( bias = torch.rand((args.hidden_size), dtype=dtype, device="cuda").expand_as(residual)
residual
)
dropout_rate = 0.1 dropout_rate = 0.1
# Warmup JIT fusions with the input grad_enable state of both forward # Warmup JIT fusions with the input grad_enable state of both forward
# prop and recomputation # prop and recomputation
for input_grad, bias_grad, residual_grad in zip( for input_grad, bias_grad, residual_grad in zip([False, True], [True, True], [True, True]):
[False, True], [True, True], [True, True]
):
input.requires_grad = input_grad input.requires_grad = input_grad
bias.requires_grad = bias_grad bias.requires_grad = bias_grad
residual.requires_grad = residual_grad residual.requires_grad = residual_grad
...@@ -456,7 +482,7 @@ def _warmup_jit_function(): ...@@ -456,7 +482,7 @@ def _warmup_jit_function():
def setup_logging() -> None: def setup_logging() -> None:
""" Sets the default logging level based on cmdline args and env vars. """Sets the default logging level based on cmdline args and env vars.
Precedence: Precedence:
1. Command line argument `--logging-level` 1. Command line argument `--logging-level`
......
File mode changed from 100755 to 100644
...@@ -3,6 +3,8 @@ import time, os ...@@ -3,6 +3,8 @@ import time, os
from .global_vars import get_one_logger, get_args from .global_vars import get_one_logger, get_args
_one_logger_utils_version = "1.0.0-mlm"
def get_timestamp_in_ms(): def get_timestamp_in_ms():
"""Helper function to get timestamp in ms """Helper function to get timestamp in ms
...@@ -86,7 +88,7 @@ def _produce_e2e_metrics(log_throughput=False, throughput=None): ...@@ -86,7 +88,7 @@ def _produce_e2e_metrics(log_throughput=False, throughput=None):
# Unpack and assign local vars # Unpack and assign local vars
base_metrics = one_logger.store_get('get_e2e_base_metrics')() base_metrics = one_logger.store_get('get_e2e_base_metrics')()
(iteration, train_duration, eval_duration, eval_iterations, (iteration, train_duration, eval_duration, eval_iterations,
total_flops, num_floating_point_operations_so_far, total_flops_since_current_train_start, num_floating_point_operations_so_far,
consumed_train_samples, world_size, seq_length) = base_metrics.values() consumed_train_samples, world_size, seq_length) = base_metrics.values()
iteration_start = one_logger.store_get('iteration_start') iteration_start = one_logger.store_get('iteration_start')
...@@ -125,7 +127,7 @@ def _produce_e2e_metrics(log_throughput=False, throughput=None): ...@@ -125,7 +127,7 @@ def _produce_e2e_metrics(log_throughput=False, throughput=None):
if log_throughput: if log_throughput:
if train_duration: if train_duration:
train_throughput_per_gpu = total_flops / (train_duration * 10**12 * world_size) train_throughput_per_gpu = total_flops_since_current_train_start / (train_duration * 10**12 * world_size)
else: else:
train_throughput_per_gpu = 0.0 train_throughput_per_gpu = 0.0
...@@ -136,7 +138,7 @@ def _produce_e2e_metrics(log_throughput=False, throughput=None): ...@@ -136,7 +138,7 @@ def _produce_e2e_metrics(log_throughput=False, throughput=None):
throughput_metrics = { throughput_metrics = {
'train_tflop_end': float(num_floating_point_operations_so_far) / (10**12), 'train_tflop_end': float(num_floating_point_operations_so_far) / (10**12),
'train_tflop': float(total_flops) / (10**12), 'train_tflop': float(total_flops_since_current_train_start) / (10**12),
'train_throughput_per_gpu': train_throughput_per_gpu, 'train_throughput_per_gpu': train_throughput_per_gpu,
'train_throughput_per_gpu_max': train_throughput_per_gpu_max, 'train_throughput_per_gpu_max': train_throughput_per_gpu_max,
} }
...@@ -234,7 +236,7 @@ def on_save_checkpoint_start(async_save): ...@@ -234,7 +236,7 @@ def on_save_checkpoint_start(async_save):
# Unpack and assign local vars # Unpack and assign local vars
base_metrics = one_logger.store_get('get_e2e_base_metrics')() base_metrics = one_logger.store_get('get_e2e_base_metrics')()
(iteration, train_duration, eval_duration, eval_iterations, (iteration, train_duration, eval_duration, eval_iterations,
total_flops, num_floating_point_operations_so_far, total_flops_since_current_train_start, num_floating_point_operations_so_far,
consumed_train_samples, world_size, seq_length) = base_metrics.values() consumed_train_samples, world_size, seq_length) = base_metrics.values()
save_checkpoint_count = one_logger.store_get('save_checkpoint_count') + 1 save_checkpoint_count = one_logger.store_get('save_checkpoint_count') + 1
...@@ -289,6 +291,7 @@ def on_pretrain_start(): ...@@ -289,6 +291,7 @@ def on_pretrain_start():
'app_run_type': 'training', 'app_run_type': 'training',
'summary_data_schema_version': '1.0.0', 'summary_data_schema_version': '1.0.0',
'app_metrics_feature_tags': 'full', 'app_metrics_feature_tags': 'full',
'one_logger_utils_version': _one_logger_utils_version,
}) })
def track_config_flags(train_iters, skip_train, do_train, do_valid, do_test, def track_config_flags(train_iters, skip_train, do_train, do_valid, do_test,
......
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
...@@ -39,6 +39,11 @@ nvlm_yi_34b_template = "{{- bos_token }}{% for message in messages %}{{'<|im_sta ...@@ -39,6 +39,11 @@ nvlm_yi_34b_template = "{{- bos_token }}{% for message in messages %}{{'<|im_sta
qwen2p0_custom_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" qwen2p0_custom_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
# Note: this is the same template as https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct/blob/main/tokenizer_config.json#L2053
# but we removed the forced system message.
llama3p1_chat_template = """{{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- set date_string = \"26 Jul 2024\" %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = none %}\n{%- endif %}\n\n{%- if system_message is not none %}{#- System message + builtin tools #}\n{{- \"<|start_header_id|>system<|end_header_id|>\\n\\n\" }}\n{%- if builtin_tools is defined or tools is not none %}\n {{- \"Environment: ipython\\n\" }}\n{%- endif %}\n{%- if builtin_tools is defined %}\n {{- \"Tools: \" + builtin_tools | reject('equalto', 'code_interpreter') | join(\", \") + \"\\n\\n\"}}\n{%- endif %}{%- if tools is not none and not tools_in_user_message %}\n {{- \"You have access to the following functions. To call a function, please respond with JSON for a function call.\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n{%- endif %}\n{{- system_message }}\n{{- \"<|eot_id|>\" }}\n\n{%-endif %}{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception(\"Cannot put tools in the first user message when there's no first user message!\") }}\n{%- endif %}\n {{- '<|start_header_id|>user<|end_header_id|>\\n\\n' -}}\n {{- \"Given the following functions, please respond with a JSON for a function call \" }}\n {{- \"with its proper arguments that best answers the given prompt.\\n\\n\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n {{- first_user_message + \"<|eot_id|>\"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' }}\n {%- elif 'tool_calls' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception(\"This model only supports single tool-calls at once!\") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {%- if builtin_tools is defined and tool_call.name in builtin_tools %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- \"<|python_tag|>\" + tool_call.name + \".call(\" }}\n {%- for arg_name, arg_val in tool_call.arguments | items %}\n {{- arg_name + '=\"' + arg_val + '\"' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \")\" }}\n {%- else %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- '{\"name\": \"' + tool_call.name + '\", ' }}\n {{- '\"parameters\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- \"}\" }}\n {%- endif %}\n {%- if builtin_tools is defined %}\n {#- This means we're in ipython mode #}\n {{- \"<|eom_id|>\" }}\n {%- else %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n {%- elif message.role == \"tool\" or message.role == \"ipython\" %}\n {{- \"<|start_header_id|>ipython<|end_header_id|>\\n\\n\" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}\n{%- endif %}\n"""
@dataclass @dataclass
class PromptConfig: class PromptConfig:
...@@ -104,6 +109,16 @@ class MultimodalTokenizer(MegatronTokenizer): ...@@ -104,6 +109,16 @@ class MultimodalTokenizer(MegatronTokenizer):
has_bos=True, has_bos=True,
has_system_role=True, has_system_role=True,
) )
elif prompt_format in ("llama3p1", "llama3p2"):
# "<|start_header_id|>assistant<|end_header|>\n\n" is the prefix for assistant messages.
# That occupies 4 tokens and can be masked in the target.
self._prompt_config = PromptConfig(
assistant_prefix_len=4,
pad_token_id=tokenizer.convert_tokens_to_ids("<|finetune_right_pad_id|>"),
custom_chat_template=llama3p1_chat_template,
has_bos=True,
has_system_role=True,
)
elif prompt_format == "nvlm-yi-34b": elif prompt_format == "nvlm-yi-34b":
self._prompt_config = PromptConfig( self._prompt_config = PromptConfig(
assistant_prefix_len=4, assistant_prefix_len=4,
...@@ -121,7 +136,7 @@ class MultimodalTokenizer(MegatronTokenizer): ...@@ -121,7 +136,7 @@ class MultimodalTokenizer(MegatronTokenizer):
has_bos=False, has_bos=False,
has_system_role=True, has_system_role=True,
) )
elif prompt_format == "qwen2p0": elif prompt_format in ("qwen2p0", "qwen2p5"):
# "<|im_start|>assistant\n" is the prefix for assistant messages # "<|im_start|>assistant\n" is the prefix for assistant messages
self._prompt_config = PromptConfig( self._prompt_config = PromptConfig(
assistant_prefix_len=3, assistant_prefix_len=3,
......
...@@ -15,7 +15,6 @@ from megatron.core.datasets.megatron_tokenizer import MegatronTokenizer ...@@ -15,7 +15,6 @@ from megatron.core.datasets.megatron_tokenizer import MegatronTokenizer
from .bert_tokenization import FullTokenizer as FullBertTokenizer from .bert_tokenization import FullTokenizer as FullBertTokenizer
from .gpt2_tokenization import GPT2Tokenizer from .gpt2_tokenization import GPT2Tokenizer
from megatron.training.tokenizer.multimodal_tokenizer import MultimodalTokenizer from megatron.training.tokenizer.multimodal_tokenizer import MultimodalTokenizer
from transformers import Qwen2Tokenizer
def build_tokenizer(args, **kwargs): def build_tokenizer(args, **kwargs):
...@@ -51,11 +50,6 @@ def build_tokenizer(args, **kwargs): ...@@ -51,11 +50,6 @@ def build_tokenizer(args, **kwargs):
elif args.tokenizer_type == 'Llama2Tokenizer': elif args.tokenizer_type == 'Llama2Tokenizer':
assert args.tokenizer_model is not None assert args.tokenizer_model is not None
tokenizer = _Llama2Tokenizer(args.tokenizer_model) tokenizer = _Llama2Tokenizer(args.tokenizer_model)
elif args.tokenizer_type == 'Llama3Tokenizer':
assert args.tokenizer_model is not None
tokenizer = _Llama3Tokenizer(args.tokenizer_model)
elif args.tokenizer_type == 'QwenTokenizer':
tokenizer = _Qwen2Tokenizer(args.vocab_file, args.merge_file)
elif args.tokenizer_type == 'TikTokenizer': elif args.tokenizer_type == 'TikTokenizer':
assert args.tokenizer_model is not None assert args.tokenizer_model is not None
assert args.tiktoken_pattern is not None assert args.tiktoken_pattern is not None
...@@ -612,96 +606,6 @@ class _Llama2Tokenizer(_SentencePieceTokenizer): ...@@ -612,96 +606,6 @@ class _Llama2Tokenizer(_SentencePieceTokenizer):
return None return None
class _Llama3Tokenizer(MegatronTokenizer):
"""tiktokenTokenizer-Megatron llama3 改写"""
# https://github.com/meta-llama/llama3/blob/main/llama/tokenizer.py
def __init__(self, model_file):
super().__init__(model_file)
from pathlib import Path
import tiktoken
from tiktoken.load import load_tiktoken_bpe
tokenizer_path=model_file
special_tokens = [
"<|begin_of_text|>",
"<|end_of_text|>",
"<|reserved_special_token_0|>",
"<|reserved_special_token_1|>",
"<|reserved_special_token_2|>",
"<|reserved_special_token_3|>",
"<|start_header_id|>",
"<|end_header_id|>",
"<|reserved_special_token_4|>",
"<|eot_id|>", # end of turn
] + [f"<|reserved_special_token_{i}|>" for i in range (5, 256 - 5)]
mergeable_ranks = load_tiktoken_bpe(tokenizer_path)
self.tokenizer = tiktoken.Encoding(tokenizer_path,
pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+",
mergeable_ranks=mergeable_ranks,
special_tokens={token: len (mergeable_ranks) + i for i, token in enumerate (special_tokens)},
)
self.eod_id = self.tokenizer.encode("<|end_of_text|>", allowed_special="all")[0]
@property
def vocab_size(self):
return self.tokenizer.n_vocab
@property
def vocab(self):
return self.tokenizer.encode
@property
def inv_vocab(self):
return self.tokenizer.encode
def tokenize(self, text):
return self.tokenizer.encode(text)
def detokenize(self, token_ids):
return self.tokenizer.encode(token_ids)
@property
def eod(self):
return self.eod_id
class _Qwen2Tokenizer(MegatronTokenizer):
def __init__(self, vocab_file, merge_file,extra_vocab_size=0):
super().__init__(vocab_file, merge_file)
self.tokenizer = Qwen2Tokenizer(vocab_file, merge_file)
self.extra_vocab_size = extra_vocab_size
self.tokenizer.add_special_tokens(special_tokens_dict=dict(pad_token="<|extra_0|>"))
@property
def vocab_size(self):
return len(self.tokenizer.encoder) + self.extra_vocab_size
@property
def vocab(self):
return self.tokenizer.encoder
@property
def inv_vocab(self):
return self.tokenizer.decoder
def tokenize(self, text):
return self.tokenizer.encode(text)
def detokenize(self, token_ids):
return self.tokenizer.decode(token_ids)
@property
def eod(self):
return self.tokenizer.eos_token_id
@property
def eos_token(self):
return self.tokenizer.eos_token
@property
def pad_token_id(self):
return self.tokenizer.pad_token_id
def reload_mergeable_ranks(path: str, max_vocab: Optional[int] = None) -> Dict[bytes, int]: def reload_mergeable_ranks(path: str, max_vocab: Optional[int] = None) -> Dict[bytes, int]:
""" """
Reload our tokenizer JSON file and convert it to Tiktoken format. Reload our tokenizer JSON file and convert it to Tiktoken format.
...@@ -851,7 +755,21 @@ class CustomTikTokenizer(MegatronTokenizer): ...@@ -851,7 +755,21 @@ class CustomTikTokenizer(MegatronTokenizer):
return self._model.decode(tokens) return self._model.decode(tokens)
def offsets(self, ids: list[int], text: str) -> list[int]: def offsets(self, ids: list[int], text: str) -> list[int]:
return self._model.decode_with_offsets(ids)[1] try:
return self._model.decode_with_offsets(ids)[1]
except UnicodeDecodeError:
# Tiktoken has an unnecessary check that raises UnicodeDecodeError
# from `text = b"".join(token_bytes).decode("utf-8", errors="strict")`
# which is not needed for our use case. So we re-implement it, without
# the check.
token_bytes = self._model.decode_tokens_bytes(ids)
text_len = 0
offsets = []
for token in token_bytes:
offsets.append(max(0, text_len - (0x80 <= token[0] < 0xC0)))
text_len += sum(1 for c in token if not 0x80 <= c < 0xC0)
return offsets
@property @property
def vocab_size(self) -> int: def vocab_size(self) -> int:
......
...@@ -10,6 +10,9 @@ import logging ...@@ -10,6 +10,9 @@ import logging
import math import math
import os import os
import sys import sys
from typing import List
import torch.distributed
from .log_handler import CustomHandler from .log_handler import CustomHandler
# Make default logging level INFO, but filter out all log messages not from MCore. # Make default logging level INFO, but filter out all log messages not from MCore.
logging.basicConfig(handlers=[CustomHandler()], level=logging.INFO) logging.basicConfig(handlers=[CustomHandler()], level=logging.INFO)
...@@ -24,14 +27,15 @@ from megatron.core.utils import ( ...@@ -24,14 +27,15 @@ from megatron.core.utils import (
check_param_hashes_across_dp_replicas, check_param_hashes_across_dp_replicas,
get_model_config, get_model_config,
StragglerDetector, StragglerDetector,
is_float8tensor,
) )
from megatron.core.fp8_utils import is_float8tensor
from megatron.training.checkpointing import load_checkpoint from megatron.training.checkpointing import load_checkpoint
from megatron.training.checkpointing import save_checkpoint from megatron.training.checkpointing import save_checkpoint
from megatron.training.checkpointing import checkpoint_exists from megatron.training.checkpointing import checkpoint_exists
from megatron.legacy.model import Float16Module from megatron.legacy.model import Float16Module
from megatron.core.distributed import DistributedDataParallelConfig from megatron.core.distributed import DistributedDataParallelConfig
from megatron.core.distributed import DistributedDataParallel as DDP from megatron.core.distributed import DistributedDataParallel as DDP
from megatron.core.distributed.custom_fsdp import FullyShardedDataParallel as custom_FSDP
try: try:
from megatron.core.distributed import TorchFullyShardedDataParallel as torch_FSDP from megatron.core.distributed import TorchFullyShardedDataParallel as torch_FSDP
...@@ -51,6 +55,10 @@ from megatron.core.rerun_state_machine import ( ...@@ -51,6 +55,10 @@ from megatron.core.rerun_state_machine import (
from megatron.training.initialize import initialize_megatron from megatron.training.initialize import initialize_megatron
from megatron.training.initialize import write_args_to_tensorboard from megatron.training.initialize import write_args_to_tensorboard
from megatron.training.initialize import set_jit_fusion_options from megatron.training.initialize import set_jit_fusion_options
from megatron.training.utils import (
get_batch_on_this_cp_rank,
get_batch_on_this_tp_rank,
)
from megatron.legacy.data.data_samplers import build_pretraining_data_loader from megatron.legacy.data.data_samplers import build_pretraining_data_loader
from megatron.core.optimizer_param_scheduler import OptimizerParamScheduler from megatron.core.optimizer_param_scheduler import OptimizerParamScheduler
from megatron.core.transformer.moe import upcycling_utils from megatron.core.transformer.moe import upcycling_utils
...@@ -69,14 +77,16 @@ from megatron.core.num_microbatches_calculator import ( ...@@ -69,14 +77,16 @@ from megatron.core.num_microbatches_calculator import (
from .async_utils import maybe_finalize_async_save from .async_utils import maybe_finalize_async_save
from .utils import ( from .utils import (
append_to_progress_log,
calc_params_l2_norm, calc_params_l2_norm,
check_adlr_autoresume_termination, check_adlr_autoresume_termination,
logical_and_across_model_parallel_group,
reduce_max_stat_across_model_parallel_group,
is_last_rank, is_last_rank,
print_rank_0, print_rank_0,
print_rank_last, print_rank_last,
report_memory, report_memory,
unwrap_model, unwrap_model,
append_to_progress_log,
update_use_dist_ckpt, update_use_dist_ckpt,
) )
from .global_vars import ( from .global_vars import (
...@@ -86,7 +96,8 @@ from .global_vars import ( ...@@ -86,7 +96,8 @@ from .global_vars import (
get_timers, get_timers,
get_tensorboard_writer, get_tensorboard_writer,
get_wandb_writer, get_wandb_writer,
get_one_logger) get_one_logger,
)
from . import one_logger_utils from . import one_logger_utils
from . import ft_integration from . import ft_integration
...@@ -124,6 +135,10 @@ def num_floating_point_operations(args, batch_size): ...@@ -124,6 +135,10 @@ def num_floating_point_operations(args, batch_size):
if args.moe_shared_expert_intermediate_size is None if args.moe_shared_expert_intermediate_size is None
else args.moe_shared_expert_intermediate_size else args.moe_shared_expert_intermediate_size
) )
if args.num_experts is None:
ffn_hidden_size = args.ffn_hidden_size
else:
ffn_hidden_size = args.moe_ffn_hidden_size
# The 12x term below comes from the following factors; for more details, see # The 12x term below comes from the following factors; for more details, see
# "APPENDIX: FLOATING-POINT OPERATIONS" in https://arxiv.org/abs/2104.04473. # "APPENDIX: FLOATING-POINT OPERATIONS" in https://arxiv.org/abs/2104.04473.
...@@ -135,13 +150,6 @@ def num_floating_point_operations(args, batch_size): ...@@ -135,13 +150,6 @@ def num_floating_point_operations(args, batch_size):
# - 2x: A GEMM of a m*n tensor with a n*k tensor requires 2mnk floating-point operations. # - 2x: A GEMM of a m*n tensor with a n*k tensor requires 2mnk floating-point operations.
expansion_factor = 3 * 2 * 2 expansion_factor = 3 * 2 * 2
# print(f"batch_size: {batch_size}, \
# query_projection_to_hidden_size_ratio: {query_projection_to_hidden_size_ratio}, \
# num_experts_routed_to: {num_experts_routed_to}, \
# gated_linear_multiplier: {gated_linear_multiplier}, \
# shared_expert_ffn_hidden_size: {shared_expert_ffn_hidden_size}, \
# gated_linear_multiplier: {gated_linear_multiplier}, \
# ")
return ( return (
expansion_factor expansion_factor
* batch_size * batch_size
...@@ -160,7 +168,7 @@ def num_floating_point_operations(args, batch_size): ...@@ -160,7 +168,7 @@ def num_floating_point_operations(args, batch_size):
) )
# MLP. # MLP.
+ ( + (
(args.ffn_hidden_size / args.hidden_size) (ffn_hidden_size / args.hidden_size)
* num_experts_routed_to * num_experts_routed_to
* gated_linear_multiplier * gated_linear_multiplier
) )
...@@ -219,7 +227,7 @@ def get_start_time_from_progress_log(): ...@@ -219,7 +227,7 @@ def get_start_time_from_progress_log():
def preprocess_common_state_dict(common_state_dict): def preprocess_common_state_dict(common_state_dict):
import copy import copy
# Convert args key of type namespace to dictionary # Convert args key of type namespace to dictionary
preprocessed_common_state_dict = copy.deepcopy(common_state_dict) preprocessed_common_state_dict = copy.deepcopy(common_state_dict)
preprocessed_common_state_dict['args'] = vars(preprocessed_common_state_dict['args']) preprocessed_common_state_dict['args'] = vars(preprocessed_common_state_dict['args'])
# Remove rank and local rank from state dict if it exists, since they are expected to be different # Remove rank and local rank from state dict if it exists, since they are expected to be different
...@@ -287,6 +295,12 @@ def pretrain( ...@@ -287,6 +295,12 @@ def pretrain(
if args.log_progress: if args.log_progress:
append_to_progress_log("Starting job") append_to_progress_log("Starting job")
# Initialize fault tolerance
# NOTE: ft_integration functions other than `setup` are no-op if the FT is not initialized
if args.enable_ft_package:
ft_integration.setup(args)
ft_integration.maybe_setup_simulated_fault()
# Set pytorch JIT layer fusion options and warmup JIT functions. # Set pytorch JIT layer fusion options and warmup JIT functions.
set_jit_fusion_options() set_jit_fusion_options()
...@@ -315,11 +329,29 @@ def pretrain( ...@@ -315,11 +329,29 @@ def pretrain(
# Context used for persisting some state between checkpoint saves. # Context used for persisting some state between checkpoint saves.
if args.non_persistent_ckpt_type == 'local': if args.non_persistent_ckpt_type == 'local':
raise RuntimeError('LocalCheckpointManagers are not yet integrated') try:
checkpointing_context = { from nvidia_resiliency_ext.checkpointing.local.ckpt_managers.local_manager import \
'local_checkpoint_manager': BasicLocalCheckpointManager( LocalCheckpointManager
args.non_persistent_local_ckpt_dir from nvidia_resiliency_ext.checkpointing.local.replication.group_utils import \
parse_group_sequence, GroupWrapper
from nvidia_resiliency_ext.checkpointing.local.replication.strategies import \
CliqueReplicationStrategy
except ModuleNotFoundError:
raise RuntimeError("The 'nvidia_resiliency_ext' module is required for local "
"checkpointing but was not found. Please ensure it is installed.")
if args.replication:
repl_strategy = CliqueReplicationStrategy.from_replication_params(
args.replication_jump,
args.replication_factor
) )
else:
repl_strategy = None
checkpointing_context = {
'local_checkpoint_manager': LocalCheckpointManager(args.non_persistent_local_ckpt_dir,
repl_strategy=repl_strategy
)
} }
else: else:
checkpointing_context = {} checkpointing_context = {}
...@@ -364,11 +396,6 @@ def pretrain( ...@@ -364,11 +396,6 @@ def pretrain(
args.do_valid, args.do_test, args.dataloader_type, args.do_valid, args.do_test, args.dataloader_type,
args.retro_project_dir, args.retro_cyclic_train_iters) args.retro_project_dir, args.retro_cyclic_train_iters)
if args.enable_ft_package and ft_integration.get_rank_monitor_client() is not None:
ft_integration.get_rank_monitor_client().init_workload_monitoring()
ft_timeouts = ft_integration.get_rank_monitor_client().timeouts
print_rank_0(f"Fault tolerance client initialized. Timeouts: {ft_timeouts}")
# Print setup timing. # Print setup timing.
print_rank_0('done with setup ...') print_rank_0('done with setup ...')
timers.log(['model-and-optimizer-setup', timers.log(['model-and-optimizer-setup',
...@@ -400,8 +427,7 @@ def pretrain( ...@@ -400,8 +427,7 @@ def pretrain(
save_checkpoint(iteration, model, optimizer, opt_param_scheduler, save_checkpoint(iteration, model, optimizer, opt_param_scheduler,
num_floating_point_operations_so_far, checkpointing_context, num_floating_point_operations_so_far, checkpointing_context,
train_data_iterator=train_data_iterator, train_data_iterator=train_data_iterator,
ft_client=ft_integration.get_rank_monitor_client( preprocess_common_state_dict_fn=preprocess_common_state_dict)
ft_integration.StateMachineActions.SAVE_CHECKPOINT), preprocess_common_state_dict_fn=preprocess_common_state_dict)
one_logger and one_logger.log_metrics({ one_logger and one_logger.log_metrics({
'app_train_loop_finish_time': one_logger_utils.get_timestamp_in_ms() 'app_train_loop_finish_time': one_logger_utils.get_timestamp_in_ms()
...@@ -431,11 +457,16 @@ def pretrain( ...@@ -431,11 +457,16 @@ def pretrain(
wandb_writer = get_wandb_writer() wandb_writer = get_wandb_writer()
if wandb_writer: if wandb_writer:
wandb_writer.finish() wandb_writer.finish()
maybe_finalize_async_save(blocking=True)
ft_integration.on_checkpointing_start()
maybe_finalize_async_save(blocking=True, terminate=True)
ft_integration.on_checkpointing_end(is_async_finalization=True)
one_logger and one_logger.log_metrics({ one_logger and one_logger.log_metrics({
'app_finish_time': one_logger_utils.get_timestamp_in_ms() 'app_finish_time': one_logger_utils.get_timestamp_in_ms()
}) })
ft_integration.shutdown()
one_logger_utils.finish() one_logger_utils.finish()
...@@ -476,47 +507,54 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap ...@@ -476,47 +507,54 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap
args.model_type = model_type args.model_type = model_type
# Build model. # Build model.
if mpu.get_pipeline_model_parallel_world_size() > 1 and \ def build_model():
args.virtual_pipeline_model_parallel_size is not None: if mpu.get_pipeline_model_parallel_world_size() > 1 and \
assert model_type != ModelType.encoder_and_decoder, \ args.virtual_pipeline_model_parallel_size is not None:
"Interleaved schedule not supported for model with both encoder and decoder" assert model_type != ModelType.encoder_and_decoder, \
model = [] "Interleaved schedule not supported for model with both encoder and decoder"
for i in range(args.virtual_pipeline_model_parallel_size): model = []
mpu.set_virtual_pipeline_model_parallel_rank(i) for i in range(args.virtual_pipeline_model_parallel_size):
# Set pre_process and post_process only after virtual rank is set. mpu.set_virtual_pipeline_model_parallel_rank(i)
# Set pre_process and post_process only after virtual rank is set.
pre_process = mpu.is_pipeline_first_stage()
post_process = mpu.is_pipeline_last_stage()
this_model = model_provider_func(
pre_process=pre_process,
post_process=post_process
)
this_model.model_type = model_type
model.append(this_model)
else:
pre_process = mpu.is_pipeline_first_stage() pre_process = mpu.is_pipeline_first_stage()
post_process = mpu.is_pipeline_last_stage() post_process = mpu.is_pipeline_last_stage()
this_model = model_provider_func( add_encoder = True
pre_process=pre_process, add_decoder = True
post_process=post_process if model_type == ModelType.encoder_and_decoder:
) if mpu.get_pipeline_model_parallel_world_size() > 1:
this_model.model_type = model_type rank = mpu.get_pipeline_model_parallel_rank()
model.append(this_model) first_decoder_rank = args.encoder_pipeline_model_parallel_size
world_size = mpu.get_pipeline_model_parallel_world_size()
pre_process = rank == 0 or rank == first_decoder_rank
post_process = (rank == (first_decoder_rank - 1)) or (rank == (world_size - 1))
add_encoder = mpu.is_inside_encoder(rank)
add_decoder = mpu.is_inside_decoder(rank)
model = model_provider_func(
pre_process=pre_process,
post_process=post_process,
add_encoder=add_encoder,
add_decoder=add_decoder)
else:
model = model_provider_func(
pre_process=pre_process,
post_process=post_process
)
model.model_type = model_type
return model
if args.init_model_with_meta_device:
with torch.device('meta'):
model = build_model()
else: else:
pre_process = mpu.is_pipeline_first_stage() model = build_model()
post_process = mpu.is_pipeline_last_stage()
add_encoder = True
add_decoder = True
if model_type == ModelType.encoder_and_decoder:
if mpu.get_pipeline_model_parallel_world_size() > 1:
rank = mpu.get_pipeline_model_parallel_rank()
first_decoder_rank = args.encoder_pipeline_model_parallel_size
world_size = mpu.get_pipeline_model_parallel_world_size()
pre_process = rank == 0 or rank == first_decoder_rank
post_process = (rank == (first_decoder_rank - 1)) or (rank == (world_size - 1))
add_encoder = mpu.is_inside_encoder(rank)
add_decoder = mpu.is_inside_decoder(rank)
model = model_provider_func(
pre_process=pre_process,
post_process=post_process,
add_encoder=add_encoder,
add_decoder=add_decoder)
else:
model = model_provider_func(
pre_process=pre_process,
post_process=post_process
)
model.model_type = model_type
if not isinstance(model, list): if not isinstance(model, list):
model = [model] model = [model]
...@@ -530,17 +568,23 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap ...@@ -530,17 +568,23 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap
tensor_parallel.set_defaults_if_not_set_tensor_model_parallel_attributes(param) tensor_parallel.set_defaults_if_not_set_tensor_model_parallel_attributes(param)
# Print number of parameters. # Print number of parameters.
num_parameters = sum(
[sum([p.nelement() for p in model_module.parameters()])
for model_module in model]
)
if mpu.get_data_parallel_rank() == 0: if mpu.get_data_parallel_rank() == 0:
print(' > number of parameters on (tensor, pipeline) ' print(' > number of parameters on (tensor, pipeline) '
'model parallel rank ({}, {}): {}'.format( 'model parallel rank ({}, {}): {}'.format(
mpu.get_tensor_model_parallel_rank(), mpu.get_tensor_model_parallel_rank(),
mpu.get_pipeline_model_parallel_rank(), mpu.get_pipeline_model_parallel_rank(),
sum([sum([p.nelement() for p in model_module.parameters()]) num_parameters), flush=True)
for model_module in model])), flush=True)
# GPU allocation. # GPU allocation.
for model_module in model: # For FSDP2, we don't allocate GPU memory here. We allocate GPU memory
model_module.cuda(torch.cuda.current_device()) # in the fully_shard function of FSDP2 instead.
if not (args.use_torch_fsdp2 and args.use_cpu_initialization) and not args.init_model_with_meta_device:
for model_module in model:
model_module.cuda(torch.cuda.current_device())
# Fp16 conversion. # Fp16 conversion.
if args.fp16 or args.bf16: if args.fp16 or args.bf16:
...@@ -562,9 +606,11 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap ...@@ -562,9 +606,11 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap
fp8_meta.amax_history[0][fp8_meta_index] = 0 fp8_meta.amax_history[0][fp8_meta_index] = 0
if wrap_with_ddp: if wrap_with_ddp:
if getattr(args, "use_torch_fsdp2", False): if args.use_torch_fsdp2:
assert HAVE_FSDP2, "Torch FSDP2 requires torch>=2.4.0" assert HAVE_FSDP2, "Torch FSDP2 requires torch>=2.4.0"
DP = torch_FSDP DP = torch_FSDP
elif args.use_custom_fsdp:
DP = custom_FSDP
else: else:
DP = DDP DP = DDP
...@@ -576,17 +622,42 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap ...@@ -576,17 +622,42 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap
kwargs[f.name] = getattr(args, f.name) kwargs[f.name] = getattr(args, f.name)
kwargs['grad_reduce_in_fp32'] = args.accumulate_allreduce_grads_in_fp32 kwargs['grad_reduce_in_fp32'] = args.accumulate_allreduce_grads_in_fp32
kwargs['check_for_nan_in_grad'] = args.check_for_nan_in_loss_and_grad kwargs['check_for_nan_in_grad'] = args.check_for_nan_in_loss_and_grad
kwargs['bucket_size'] = args.ddp_bucket_size kwargs['check_for_large_grads'] = args.check_for_large_grads
if args.ddp_num_buckets is not None:
assert args.ddp_bucket_size is None, \
"Cannot specify both --ddp-num-buckets and --ddp-bucket-size"
assert args.ddp_num_buckets > 0, \
"--ddp-num-buckets must be greater than 0"
kwargs['bucket_size'] = num_parameters // args.ddp_num_buckets
else:
kwargs['bucket_size'] = args.ddp_bucket_size
kwargs['pad_buckets_for_high_nccl_busbw'] = args.ddp_pad_buckets_for_high_nccl_busbw
kwargs['average_in_collective'] = args.ddp_average_in_collective kwargs['average_in_collective'] = args.ddp_average_in_collective
if args.use_custom_fsdp and args.use_precision_aware_optimizer:
kwargs["preserve_fp32_weights"] = False
ddp_config = DistributedDataParallelConfig(**kwargs) ddp_config = DistributedDataParallelConfig(**kwargs)
overlap_param_gather_with_optimizer_step = getattr(args, 'overlap_param_gather_with_optimizer_step', False) if not getattr(args, "use_torch_fsdp2", False):
# In the custom FSDP and DDP use path, we need to initialize the bucket size.
# If bucket_size is not provided as an input, use sane default.
# If using very large dp_sizes, make buckets larger to ensure that chunks used in NCCL
# ring-reduce implementations are large enough to remain bandwidth-bound rather than
# latency-bound.
if ddp_config.bucket_size is None:
ddp_config.bucket_size = max(
40000000, 1000000 * mpu.get_data_parallel_world_size(with_context_parallel=True)
)
# Set bucket_size to infinity if overlap_grad_reduce is False.
if not ddp_config.overlap_grad_reduce:
ddp_config.bucket_size = None
model = [DP(config=config, model = [DP(config=config,
ddp_config=ddp_config, ddp_config=ddp_config,
module=model_chunk, module=model_chunk,
# Turn off bucketing for model_chunk 2 onwards, since communication for these # Turn off bucketing for model_chunk 2 onwards, since communication for these
# model chunks is overlapped with compute anyway. # model chunks is overlapped with compute anyway.
disable_bucketing=(model_chunk_idx > 0) or overlap_param_gather_with_optimizer_step) disable_bucketing=(model_chunk_idx > 0) or args.overlap_param_gather_with_optimizer_step)
for (model_chunk_idx, model_chunk) in enumerate(model)] for (model_chunk_idx, model_chunk) in enumerate(model)]
# Broadcast params from data parallel src rank to other data parallel ranks. # Broadcast params from data parallel src rank to other data parallel ranks.
...@@ -674,7 +745,8 @@ def setup_model_and_optimizer(model_provider_func, ...@@ -674,7 +745,8 @@ def setup_model_and_optimizer(model_provider_func,
config = OptimizerConfig(**kwargs) config = OptimizerConfig(**kwargs)
config.timers = timers config.timers = timers
optimizer = get_megatron_optimizer(config, model, no_wd_decay_cond, optimizer = get_megatron_optimizer(config, model, no_wd_decay_cond,
scale_lr_cond, lr_mult) scale_lr_cond, lr_mult,
use_gloo_process_groups=args.enable_gloo_process_groups)
opt_param_scheduler = get_optimizer_param_scheduler(optimizer) opt_param_scheduler = get_optimizer_param_scheduler(optimizer)
if args.moe_use_upcycling: if args.moe_use_upcycling:
...@@ -713,9 +785,8 @@ def setup_model_and_optimizer(model_provider_func, ...@@ -713,9 +785,8 @@ def setup_model_and_optimizer(model_provider_func,
timers('load-checkpoint', log_level=0).start(barrier=True) timers('load-checkpoint', log_level=0).start(barrier=True)
args.iteration, args.num_floating_point_operations_so_far = load_checkpoint( args.iteration, args.num_floating_point_operations_so_far = load_checkpoint(
model, optimizer, opt_param_scheduler, model, optimizer, opt_param_scheduler, checkpointing_context=checkpointing_context,
ft_client=ft_integration.get_rank_monitor_client(), checkpointing_context=checkpointing_context, skip_load_to_model_and_opt=HAVE_FSDP2 and args.use_torch_fsdp2)
skip_load_to_model_and_opt=HAVE_FSDP2 and getattr(args, "use_torch_fsdp2", False))
timers('load-checkpoint').stop(barrier=True) timers('load-checkpoint').stop(barrier=True)
timers.log(['load-checkpoint']) timers.log(['load-checkpoint'])
one_logger and one_logger.log_metrics({ one_logger and one_logger.log_metrics({
...@@ -752,8 +823,17 @@ def setup_model_and_optimizer(model_provider_func, ...@@ -752,8 +823,17 @@ def setup_model_and_optimizer(model_provider_func,
return model, optimizer, opt_param_scheduler return model, optimizer, opt_param_scheduler
def dummy_train_step(data_iterator):
"""Single dummy training step."""
num_microbatches = get_num_microbatches()
for _ in range(num_microbatches):
# Re-use methods used in get_batch() from pretrain_{gpt, mamba}.py.
batch = get_batch_on_this_tp_rank(data_iterator)
batch = get_batch_on_this_cp_rank(batch)
def train_step(forward_step_func, data_iterator, def train_step(forward_step_func, data_iterator,
model, optimizer, opt_param_scheduler, config): model, optimizer, opt_param_scheduler, config):
"""Single training step.""" """Single training step."""
args = get_args() args = get_args()
timers = get_timers() timers = get_timers()
...@@ -785,17 +865,27 @@ def train_step(forward_step_func, data_iterator, ...@@ -785,17 +865,27 @@ def train_step(forward_step_func, data_iterator,
torch.cuda.empty_cache() torch.cuda.empty_cache()
# Vision gradients. # Vision gradients.
if getattr(args, 'vision_pretraining', False) and args.vision_pretraining_type == "dino": if args.vision_pretraining and args.vision_pretraining_type == "dino":
unwrapped_model = unwrap_model(model[0]) unwrapped_model = unwrap_model(model[0])
unwrapped_model.cancel_gradients_last_layer(args.curr_iteration) unwrapped_model.cancel_gradients_last_layer(args.curr_iteration)
# Update parameters. # Update parameters.
timers('optimizer', log_level=1).start(barrier=args.barrier_with_L1_time) timers('optimizer', log_level=1).start(barrier=args.barrier_with_L1_time)
update_successful, grad_norm, num_zeros_in_grad = optimizer.step() update_successful, grad_norm, num_zeros_in_grad = optimizer.step()
timers('optimizer').stop() timers('optimizer').stop()
# when freezing sub-models we may have a mixture of successful and unsucessful ranks,
# so we must gather across mp ranks
update_successful = logical_and_across_model_parallel_group(update_successful)
# grad_norm and num_zeros_in_grad will be None on ranks without trainable params,
# so we must gather across mp ranks
grad_norm = reduce_max_stat_across_model_parallel_group(grad_norm)
if args.log_num_zeros_in_grad:
num_zeros_in_grad = reduce_max_stat_across_model_parallel_group(num_zeros_in_grad)
# Vision momentum. # Vision momentum.
if getattr(args, 'vision_pretraining', False) and args.vision_pretraining_type == "dino": if args.vision_pretraining and args.vision_pretraining_type == "dino":
unwrapped_model = unwrap_model(model[0]) unwrapped_model = unwrap_model(model[0])
unwrapped_model.update_momentum(args.curr_iteration) unwrapped_model.update_momentum(args.curr_iteration)
...@@ -832,7 +922,6 @@ def train_step(forward_step_func, data_iterator, ...@@ -832,7 +922,6 @@ def train_step(forward_step_func, data_iterator,
numerator += val numerator += val
denominator += 1 denominator += 1
loss_reduced[key] = numerator / denominator loss_reduced[key] = numerator / denominator
return loss_reduced, skipped_iter, should_checkpoint, should_exit, exit_code, grad_norm, num_zeros_in_grad return loss_reduced, skipped_iter, should_checkpoint, should_exit, exit_code, grad_norm, num_zeros_in_grad
return {}, skipped_iter, should_checkpoint, should_exit, exit_code, grad_norm, num_zeros_in_grad return {}, skipped_iter, should_checkpoint, should_exit, exit_code, grad_norm, num_zeros_in_grad
...@@ -913,6 +1002,8 @@ def training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_r ...@@ -913,6 +1002,8 @@ def training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_r
total_iterations = total_loss_dict[advanced_iters_key] + \ total_iterations = total_loss_dict[advanced_iters_key] + \
total_loss_dict[skipped_iters_key] total_loss_dict[skipped_iters_key]
# learning rate will be None on ranks without trainable params, so we must gather across mp ranks
learning_rate = reduce_max_stat_across_model_parallel_group(learning_rate)
# Tensorboard values. # Tensorboard values.
# Timer requires all the ranks to call. # Timer requires all the ranks to call.
if args.log_timers_to_tensorboard and \ if args.log_timers_to_tensorboard and \
...@@ -920,22 +1011,16 @@ def training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_r ...@@ -920,22 +1011,16 @@ def training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_r
timers.write(timers_to_log, writer, iteration, timers.write(timers_to_log, writer, iteration,
normalizer=total_iterations) normalizer=total_iterations)
if writer and (iteration % args.tensorboard_log_interval == 0): if writer and (iteration % args.tensorboard_log_interval == 0):
if args.record_memory_history and is_last_rank():
snapshot = torch.cuda.memory._snapshot()
from pickle import dump
with open(args.memory_snapshot_path , 'wb') as f:
dump(snapshot, f)
if wandb_writer: if wandb_writer:
wandb_writer.log({'samples vs steps': args.consumed_train_samples}, wandb_writer.log({'samples vs steps': args.consumed_train_samples},
iteration) iteration)
writer.add_scalar('learning-rate', learning_rate, iteration) writer.add_scalar('learning-rate', learning_rate, iteration)
if args.decoupled_lr is not None:
writer.add_scalar('decoupled-learning-rate', decoupled_learning_rate, iteration)
writer.add_scalar('learning-rate vs samples', learning_rate, writer.add_scalar('learning-rate vs samples', learning_rate,
args.consumed_train_samples) args.consumed_train_samples)
if wandb_writer: if wandb_writer:
wandb_writer.log({'learning-rate': learning_rate}, iteration) wandb_writer.log({'learning-rate': learning_rate}, iteration)
if args.decoupled_lr is not None:
writer.add_scalar('decoupled-learning-rate', decoupled_learning_rate, iteration)
if args.skipped_train_samples > 0: if args.skipped_train_samples > 0:
writer.add_scalar('skipped-train-samples', args.skipped_train_samples, iteration) writer.add_scalar('skipped-train-samples', args.skipped_train_samples, iteration)
if wandb_writer: if wandb_writer:
...@@ -993,6 +1078,11 @@ def training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_r ...@@ -993,6 +1078,11 @@ def training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_r
mem_stats["allocated_bytes.all.current"], mem_stats["allocated_bytes.all.current"],
iteration, iteration,
) )
writer.add_scalar(
"mem-max-allocated-bytes",
mem_stats["allocated_bytes.all.peak"],
iteration,
)
writer.add_scalar( writer.add_scalar(
"mem-allocated-count", "mem-allocated-count",
mem_stats["allocation.all.current"], mem_stats["allocation.all.current"],
...@@ -1003,6 +1093,12 @@ def training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_r ...@@ -1003,6 +1093,12 @@ def training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_r
track_moe_metrics(moe_loss_scale, iteration, writer, wandb_writer, total_loss_dict, args.moe_per_layer_logging) track_moe_metrics(moe_loss_scale, iteration, writer, wandb_writer, total_loss_dict, args.moe_per_layer_logging)
if iteration % args.log_interval == 0: if iteration % args.log_interval == 0:
if args.record_memory_history and is_last_rank():
snapshot = torch.cuda.memory._snapshot()
from pickle import dump
with open(args.memory_snapshot_path, 'wb') as f:
dump(snapshot, f)
elapsed_time = timers('interval-time').elapsed(barrier=True) elapsed_time = timers('interval-time').elapsed(barrier=True)
elapsed_time_per_iteration = elapsed_time / total_iterations elapsed_time_per_iteration = elapsed_time / total_iterations
...@@ -1035,7 +1131,6 @@ def training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_r ...@@ -1035,7 +1131,6 @@ def training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_r
writer.add_scalar('throughput', throughput, iteration) writer.add_scalar('throughput', throughput, iteration)
if wandb_writer: if wandb_writer:
wandb_writer.log({'throughput': throughput}, iteration) wandb_writer.log({'throughput': throughput}, iteration)
assert learning_rate is not None
# Decoupled_learning_rate should be not None only on first and last pipeline stage. # Decoupled_learning_rate should be not None only on first and last pipeline stage.
log_string += f' learning rate: {learning_rate:.6E} |' log_string += f' learning rate: {learning_rate:.6E} |'
if args.decoupled_lr is not None and (mpu.is_pipeline_first_stage(ignore_virtual=True) or if args.decoupled_lr is not None and (mpu.is_pipeline_first_stage(ignore_virtual=True) or
...@@ -1068,7 +1163,7 @@ def training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_r ...@@ -1068,7 +1163,7 @@ def training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_r
total_loss_dict[skipped_iters_key] = 0 total_loss_dict[skipped_iters_key] = 0
total_loss_dict[nan_iters_key] = 0 total_loss_dict[nan_iters_key] = 0
print_rank_last(log_string) print_rank_last(log_string)
if report_memory_flag and learning_rate > 0.: if report_memory_flag:
# Report memory after optimizer state has been initialized. # Report memory after optimizer state has been initialized.
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
num_microbatches = get_num_microbatches() num_microbatches = get_num_microbatches()
...@@ -1120,10 +1215,10 @@ def enable_forward_pre_hook(model_chunks): ...@@ -1120,10 +1215,10 @@ def enable_forward_pre_hook(model_chunks):
model_chunk.enable_forward_pre_hook() model_chunk.enable_forward_pre_hook()
def disable_forward_pre_hook(model_chunks): def disable_forward_pre_hook(model_chunks, param_sync=True):
for model_chunk in model_chunks: for model_chunk in model_chunks:
assert isinstance(model_chunk, DDP) assert isinstance(model_chunk, DDP)
model_chunk.disable_forward_pre_hook() model_chunk.disable_forward_pre_hook(param_sync=param_sync)
def save_checkpoint_and_time(iteration, model, optimizer, opt_param_scheduler, def save_checkpoint_and_time(iteration, model, optimizer, opt_param_scheduler,
...@@ -1137,26 +1232,23 @@ def save_checkpoint_and_time(iteration, model, optimizer, opt_param_scheduler, ...@@ -1137,26 +1232,23 @@ def save_checkpoint_and_time(iteration, model, optimizer, opt_param_scheduler,
# Extra barrier is added to make sure all ranks report the max time. # Extra barrier is added to make sure all ranks report the max time.
timer_key = 'save-checkpoint-non-persistent' if non_persistent_ckpt else 'save-checkpoint' timer_key = 'save-checkpoint-non-persistent' if non_persistent_ckpt else 'save-checkpoint'
timers(timer_key, log_level=0).start(barrier=True) timers(timer_key, log_level=0).start(barrier=True)
save_checkpoint_start_time = timers('save-checkpoint').active_time()
# Log E2E metrics before save-checkpoint # Log E2E metrics before save-checkpoint
one_logger_utils.track_e2e_metrics() one_logger_utils.track_e2e_metrics()
if args.use_distributed_optimizer and args.overlap_param_gather: if should_disable_forward_pre_hook(args):
disable_forward_pre_hook(model) disable_forward_pre_hook(model)
save_checkpoint(iteration, model, optimizer, opt_param_scheduler, save_checkpoint(iteration, model, optimizer, opt_param_scheduler,
num_floating_point_operations_so_far, checkpointing_context, num_floating_point_operations_so_far, checkpointing_context,
non_persistent_ckpt=non_persistent_ckpt, train_data_iterator=train_data_iterator, non_persistent_ckpt=non_persistent_ckpt, train_data_iterator=train_data_iterator,
ft_client=ft_integration.get_rank_monitor_client( preprocess_common_state_dict_fn=preprocess_common_state_dict)
ft_integration.StateMachineActions.SAVE_CHECKPOINT), preprocess_common_state_dict_fn=preprocess_common_state_dict) if should_disable_forward_pre_hook(args):
if args.use_distributed_optimizer and args.overlap_param_gather:
enable_forward_pre_hook(model) enable_forward_pre_hook(model)
timers(timer_key).stop(barrier=True) timers(timer_key).stop(barrier=True)
timers.log([timer_key]) timers.log([timer_key])
save_checkpoint_finish_time = timers('save-checkpoint').active_time()
# Log E2E metrics after save-checkpoint # Log E2E metrics after save-checkpoint
one_logger_utils.track_e2e_metrics() one_logger_utils.track_e2e_metrics()
save_checkpoint_duration = save_checkpoint_finish_time - save_checkpoint_start_time save_checkpoint_duration = timers(timer_key).elapsed()
one_logger_utils.on_save_checkpoint_end(save_checkpoint_duration, iteration, args.async_save) one_logger_utils.on_save_checkpoint_end(save_checkpoint_duration, iteration, args.async_save)
if args.log_progress and not non_persistent_ckpt: if args.log_progress and not non_persistent_ckpt:
...@@ -1172,21 +1264,6 @@ def post_training_step_callbacks(model, optimizer, opt_param_scheduler, iteratio ...@@ -1172,21 +1264,6 @@ def post_training_step_callbacks(model, optimizer, opt_param_scheduler, iteratio
"""Run all post-training-step functions (e.g., FT heartbeats, GC).""" """Run all post-training-step functions (e.g., FT heartbeats, GC)."""
args = get_args() args = get_args()
# Send heartbeat to FT package and update timeouts.
if args.enable_ft_package:
ft_client = ft_integration.get_rank_monitor_client(
ft_integration.StateMachineActions.TRAIN_HEARTBEAT)
if ft_client is not None:
ft_client.send_heartbeat()
# TODO: We are always calculating timeouts in the current implementation.
# If we want to rely on manually setting these, then we need to add additional
# arguments to training and pass it here.
if ft_integration.can_update_timeouts():
ft_integration.get_rank_monitor_client(
ft_integration.StateMachineActions.UPDATE_TIMEOUT).calculate_and_set_timeouts()
print_rank_0(f'Updated FT timeouts. New values: \
{ft_integration.get_rank_monitor_client().timeouts}')
# Bring CPU and GPU back in sync if on right iteration. # Bring CPU and GPU back in sync if on right iteration.
if args.train_sync_interval and iteration % args.train_sync_interval == 0: if args.train_sync_interval and iteration % args.train_sync_interval == 0:
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -1199,13 +1276,13 @@ def post_training_step_callbacks(model, optimizer, opt_param_scheduler, iteratio ...@@ -1199,13 +1276,13 @@ def post_training_step_callbacks(model, optimizer, opt_param_scheduler, iteratio
# Check weight hash across DP replicas. # Check weight hash across DP replicas.
if args.check_weight_hash_across_dp_replicas_interval is not None and \ if args.check_weight_hash_across_dp_replicas_interval is not None and \
iteration % args.check_weight_hash_across_dp_replicas_interval == 0: iteration % args.check_weight_hash_across_dp_replicas_interval == 0:
if args.use_distributed_optimizer and args.overlap_param_gather: if should_disable_forward_pre_hook(args):
disable_forward_pre_hook(model) disable_forward_pre_hook(model)
assert check_param_hashes_across_dp_replicas(model, cross_check=True), \ assert check_param_hashes_across_dp_replicas(model, cross_check=True), \
"Parameter hashes not matching across DP replicas" "Parameter hashes not matching across DP replicas"
torch.distributed.barrier() torch.distributed.barrier()
print_rank_0(f">>> Weight hashes match after {iteration} iterations...") print_rank_0(f">>> Weight hashes match after {iteration} iterations...")
if args.use_distributed_optimizer and args.overlap_param_gather: if should_disable_forward_pre_hook(args):
enable_forward_pre_hook(model) enable_forward_pre_hook(model)
# Autoresume. # Autoresume.
...@@ -1223,7 +1300,6 @@ def post_training_step_callbacks(model, optimizer, opt_param_scheduler, iteratio ...@@ -1223,7 +1300,6 @@ def post_training_step_callbacks(model, optimizer, opt_param_scheduler, iteratio
prof.stop() prof.stop()
else: else:
torch.cuda.cudart().cudaProfilerStop() torch.cuda.cudart().cudaProfilerStop()
# Manual garbage collection. # Manual garbage collection.
if args.manual_gc: if args.manual_gc:
...@@ -1265,14 +1341,12 @@ def checkpoint_and_decide_exit(model, optimizer, opt_param_scheduler, iteration, ...@@ -1265,14 +1341,12 @@ def checkpoint_and_decide_exit(model, optimizer, opt_param_scheduler, iteration,
elif args.save and args.non_persistent_save_interval and \ elif args.save and args.non_persistent_save_interval and \
iteration % args.non_persistent_save_interval == 0: iteration % args.non_persistent_save_interval == 0:
timers('interval-time').stop()
save_checkpoint_and_time(iteration, model, optimizer, save_checkpoint_and_time(iteration, model, optimizer,
opt_param_scheduler, opt_param_scheduler,
num_floating_point_operations_so_far, num_floating_point_operations_so_far,
checkpointing_context, checkpointing_context,
non_persistent_ckpt=True, train_data_iterator=train_data_iterator) non_persistent_ckpt=True, train_data_iterator=train_data_iterator)
saved_checkpoint = True saved_checkpoint = True
timers('interval-time', log_level=0).start(barrier=True)
# Exit based on duration. # Exit based on duration.
if args.exit_duration_in_mins: if args.exit_duration_in_mins:
...@@ -1328,6 +1402,11 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler, ...@@ -1328,6 +1402,11 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
# Iterations. # Iterations.
iteration = args.iteration iteration = args.iteration
# Make sure rerun_state_machine has the right iteration loaded from checkpoint.
rerun_state_machine = get_rerun_state_machine()
if rerun_state_machine.current_iteration != iteration:
print_rank_0(f"Setting rerun_state_machine.current_iteration to {iteration}...")
rerun_state_machine.current_iteration = iteration
# Track E2E metrics at the start of training. # Track E2E metrics at the start of training.
one_logger_utils.on_train_start(iteration=iteration, consumed_train_samples=args.consumed_train_samples, one_logger_utils.on_train_start(iteration=iteration, consumed_train_samples=args.consumed_train_samples,
...@@ -1341,7 +1420,7 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler, ...@@ -1341,7 +1420,7 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
# Setup some training config params. # Setup some training config params.
config.grad_scale_func = optimizer.scale_loss config.grad_scale_func = optimizer.scale_loss
config.timers = timers config.timers = timers
if isinstance(model[0], DDP) and args.overlap_grad_reduce: if isinstance(model[0], (custom_FSDP, DDP)) and args.overlap_grad_reduce:
assert config.no_sync_func is None, \ assert config.no_sync_func is None, \
('When overlap_grad_reduce is True, config.no_sync_func must be None; ' ('When overlap_grad_reduce is True, config.no_sync_func must be None; '
'a custom no_sync_func is not supported when overlapping grad-reduce') 'a custom no_sync_func is not supported when overlapping grad-reduce')
...@@ -1361,6 +1440,7 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler, ...@@ -1361,6 +1440,7 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
timers('interval-time', log_level=0).start(barrier=True) timers('interval-time', log_level=0).start(barrier=True)
print_datetime('before the start of training step') print_datetime('before the start of training step')
report_memory_flag = True report_memory_flag = True
pre_hook_enabled = False
should_exit = False should_exit = False
exit_code = 0 exit_code = 0
...@@ -1391,12 +1471,14 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler, ...@@ -1391,12 +1471,14 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
def get_e2e_base_metrics(): def get_e2e_base_metrics():
"""Get base metrics values for one-logger to calculate E2E tracking metrics. """Get base metrics values for one-logger to calculate E2E tracking metrics.
""" """
num_floating_point_operations_since_current_train_start = \
num_floating_point_operations_so_far - args.num_floating_point_operations_so_far
return { return {
'iteration': iteration, 'iteration': iteration,
'train_duration': timers('interval-time').active_time(), 'train_duration': timers('interval-time').active_time(),
'eval_duration': eval_duration, 'eval_duration': eval_duration,
'eval_iterations': eval_iterations, 'eval_iterations': eval_iterations,
'total_flops': num_floating_point_operations_since_last_log_event, 'total_flops_since_current_train_start': num_floating_point_operations_since_current_train_start,
'num_floating_point_operations_so_far': num_floating_point_operations_so_far, 'num_floating_point_operations_so_far': num_floating_point_operations_so_far,
'consumed_train_samples': args.consumed_train_samples, 'consumed_train_samples': args.consumed_train_samples,
'world_size': args.world_size, 'world_size': args.world_size,
...@@ -1409,44 +1491,47 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler, ...@@ -1409,44 +1491,47 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
prof = None prof = None
if args.profile and torch.distributed.get_rank() in args.profile_ranks and args.use_pytorch_profiler: if args.profile and torch.distributed.get_rank() in args.profile_ranks and args.use_pytorch_profiler:
def trace_handler(p):
from pathlib import Path
Path(f"{args.profile_dir}").mkdir(parents=True, exist_ok=True)
if args.rank in [0]:
print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1))
p.export_chrome_trace("{path}/trace_rank{rank}_step{step}.json".format(
path=args.profile_dir, rank=torch.distributed.get_rank(), step=p.step_num))
prof = torch.profiler.profile( prof = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
schedule=torch.profiler.schedule( schedule=torch.profiler.schedule(
wait=max(args.profile_step_start-1, 0), wait=max(args.profile_step_start-1, 0),
warmup=1 if args.profile_step_start > 0 else 0, warmup=1 if args.profile_step_start > 0 else 0,
active=args.profile_step_end-args.profile_step_start, active=args.profile_step_end-args.profile_step_start,
repeat=1), repeat=1),
on_trace_ready=trace_handler) on_trace_ready=torch.profiler.tensorboard_trace_handler(args.tensorboard_dir),
record_shapes=True,
with_stack=True)
prof.start() prof.start()
elif args.profile and torch.distributed.get_rank() in args.profile_ranks and args.use_hip_profiler:
import ctypes start_iteration = iteration
roctracer = ctypes.cdll.LoadLibrary("/opt/dtk/roctracer/lib/libroctracer64.so") # Disable forward pre-hook to start training to ensure that errors in checkpoint loading
# or random initialization don't propagate to all ranks in first all-gather (which is a
# no-op if things work correctly).
if should_disable_forward_pre_hook(args):
disable_forward_pre_hook(model, param_sync=False)
# Also remove param_sync_func temporarily so that sync calls made in
# `forward_backward_func` are no-ops.
param_sync_func = config.param_sync_func
config.param_sync_func = None
pre_hook_enabled = False
# Also, check weight hash across DP replicas to be very pedantic.
if args.check_weight_hash_across_dp_replicas_interval is not None:
assert check_param_hashes_across_dp_replicas(model, cross_check=True), \
"Parameter hashes not matching across DP replicas"
torch.distributed.barrier()
print_rank_0(f">>> Weight hashes match after {iteration} iterations...")
# Run training iterations till done. # Run training iterations till done.
while iteration < args.train_iters: while iteration < args.train_iters:
if args.profile and torch.distributed.get_rank() in args.profile_ranks: if args.profile and torch.distributed.get_rank() in args.profile_ranks:
if args.use_pytorch_profiler: if args.use_pytorch_profiler:
prof.step() prof.step()
elif args.use_hip_profiler:
if iteration == args.profile_step_start: roctracer.roctracer_start()
if iteration == args.profile_step_end: roctracer.roctracer_stop()
elif iteration == args.profile_step_start: elif iteration == args.profile_step_start:
torch.cuda.cudart().cudaProfilerStart() torch.cuda.cudart().cudaProfilerStart()
torch.autograd.profiler.emit_nvtx(record_shapes=True).__enter__() torch.autograd.profiler.emit_nvtx(record_shapes=True).__enter__()
ft_integration.on_checkpointing_start()
maybe_finalize_async_save(blocking=False) maybe_finalize_async_save(blocking=False)
ft_integration.on_checkpointing_end(is_async_finalization=True)
# Update number of microbatches first without consistency check to decide if a # Update number of microbatches first without consistency check to decide if a
# checkpoint should be saved. If the number of microbatches is different # checkpoint should be saved. If the number of microbatches is different
...@@ -1456,36 +1541,68 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler, ...@@ -1456,36 +1541,68 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
if get_num_microbatches() != num_microbatches and iteration != 0: if get_num_microbatches() != num_microbatches and iteration != 0:
assert get_num_microbatches() > num_microbatches, \ assert get_num_microbatches() > num_microbatches, \
(f"Number of microbatches should be increasing due to batch size rampup; " (f"Number of microbatches should be increasing due to batch size rampup; "
f"instead going from {num_microbatches} to {get_num_microbatches()}") f"instead going from {num_microbatches} to {get_num_microbatches()}")
if args.save is not None: if args.save is not None:
save_checkpoint_and_time(iteration, model, optimizer, save_checkpoint_and_time(iteration, model, optimizer,
opt_param_scheduler, opt_param_scheduler,
num_floating_point_operations_so_far, num_floating_point_operations_so_far,
checkpointing_context, train_data_iterator=train_data_iterator) checkpointing_context, train_data_iterator=train_data_iterator)
num_microbatches = get_num_microbatches() num_microbatches = get_num_microbatches()
update_num_microbatches(args.consumed_train_samples, consistency_check=True, verbose=True) update_num_microbatches(args.consumed_train_samples, consistency_check=True, verbose=True)
# Completely skip iteration if needed.
if iteration in args.iterations_to_skip:
# Dummy train_step to fast forward train_data_iterator.
dummy_train_step(train_data_iterator)
iteration += 1
batch_size = mpu.get_data_parallel_world_size() * \
args.micro_batch_size * \
get_num_microbatches()
args.consumed_train_samples += batch_size
args.skipped_train_samples += batch_size
continue
# Run training step. # Run training step.
args.curr_iteration = iteration args.curr_iteration = iteration
ft_integration.on_training_step_start()
loss_dict, skipped_iter, should_checkpoint, should_exit, exit_code, grad_norm, num_zeros_in_grad = \ loss_dict, skipped_iter, should_checkpoint, should_exit, exit_code, grad_norm, num_zeros_in_grad = \
train_step(forward_step_func, train_step(forward_step_func,
train_data_iterator, train_data_iterator,
model, model,
optimizer, optimizer,
opt_param_scheduler, opt_param_scheduler,
config) config)
ft_integration.on_training_step_end()
if should_checkpoint: if should_checkpoint:
save_checkpoint_and_time(iteration, model, optimizer, save_checkpoint_and_time(iteration, model, optimizer,
opt_param_scheduler, opt_param_scheduler,
num_floating_point_operations_so_far, num_floating_point_operations_so_far,
checkpointing_context, train_data_iterator=train_data_iterator) checkpointing_context, train_data_iterator=train_data_iterator)
if should_exit: if should_exit:
break break
# why is skipped_iter ignored?
# Enable forward pre-hooks after first set of forward and backward passes.
# When running in fp16, skip all NaN iterations until steady-state loss scaling value
# is reached.
if iteration == start_iteration:
if skipped_iter:
# Only enable forward pre-hook after a training step has successfully run. Relevant
# for fp16 codepath where first XX iterations are skipped until steady-state loss
# scale value is reached.
start_iteration = iteration + 1
else:
# Enable forward pre-hook after training step has successfully run. All subsequent
# forward passes will use the forward pre-hook / `param_sync_func` in
# `forward_backward_func`.
if should_disable_forward_pre_hook(args):
enable_forward_pre_hook(model)
config.param_sync_func = param_sync_func
pre_hook_enabled = True
iteration += 1 iteration += 1
batch_size = mpu.get_data_parallel_world_size() * \ batch_size = mpu.get_data_parallel_world_size() * \
args.micro_batch_size * \ args.micro_batch_size * \
get_num_microbatches() get_num_microbatches()
args.consumed_train_samples += batch_size args.consumed_train_samples += batch_size
num_skipped_samples_in_batch = (get_current_global_batch_size() - num_skipped_samples_in_batch = (get_current_global_batch_size() -
get_current_running_global_batch_size()) get_current_running_global_batch_size())
...@@ -1499,8 +1616,12 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler, ...@@ -1499,8 +1616,12 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
num_floating_point_operations_since_last_log_event += num_floating_point_operations_in_batch num_floating_point_operations_since_last_log_event += num_floating_point_operations_in_batch
# Logging. # Logging.
loss_scale = optimizer.get_loss_scale().item() if not optimizer.is_stub_optimizer:
loss_scale = optimizer.get_loss_scale().item()
else:
loss_scale = 1.0
params_norm = None params_norm = None
if args.log_params_norm: if args.log_params_norm:
params_norm = calc_params_l2_norm(model) params_norm = calc_params_l2_norm(model)
learning_rate = None learning_rate = None
...@@ -1511,28 +1632,29 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler, ...@@ -1511,28 +1632,29 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
else: else:
learning_rate = param_group['lr'] learning_rate = param_group['lr']
report_memory_flag = training_log(loss_dict, total_loss_dict, report_memory_flag = training_log(loss_dict, total_loss_dict,
learning_rate, learning_rate,
decoupled_learning_rate, decoupled_learning_rate,
iteration, loss_scale, iteration, loss_scale,
report_memory_flag, skipped_iter, report_memory_flag, skipped_iter,
grad_norm, params_norm, num_zeros_in_grad) grad_norm, params_norm, num_zeros_in_grad)
# Evaluation. # Evaluation.
if args.eval_interval and iteration % args.eval_interval == 0 and \ if args.eval_interval and iteration % args.eval_interval == 0 and \
args.do_valid: args.do_valid:
timers('interval-time').stop() timers('interval-time').stop()
if args.use_distributed_optimizer and args.overlap_param_gather: if should_disable_forward_pre_hook(args):
disable_forward_pre_hook(model) disable_forward_pre_hook(model)
pre_hook_enabled = False
if args.manual_gc and args.manual_gc_eval: if args.manual_gc and args.manual_gc_eval:
# Collect all objects. # Collect all objects.
gc.collect() gc.collect()
prefix = f'iteration {iteration}' prefix = f'iteration {iteration}'
timers('eval-time', log_level=0).start(barrier=True) timers('eval-time', log_level=0).start(barrier=True)
evaluate_and_print_results(prefix, forward_step_func, evaluate_and_print_results(prefix, forward_step_func,
valid_data_iterator, model, valid_data_iterator, model,
iteration, process_non_loss_data_func, iteration, process_non_loss_data_func,
config, verbose=False, write_to_tensorboard=True, config, verbose=False, write_to_tensorboard=True,
non_loss_data_func=non_loss_data_func) non_loss_data_func=non_loss_data_func)
eval_duration += timers('eval-time').elapsed() eval_duration += timers('eval-time').elapsed()
eval_iterations += args.eval_iters eval_iterations += args.eval_iters
timers('eval-time').stop() timers('eval-time').stop()
...@@ -1541,23 +1663,20 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler, ...@@ -1541,23 +1663,20 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
if args.manual_gc and args.manual_gc_eval: if args.manual_gc and args.manual_gc_eval:
# Collect only the objects created and used in evaluation. # Collect only the objects created and used in evaluation.
gc.collect(generation=0) gc.collect(generation=0)
if args.use_distributed_optimizer and args.overlap_param_gather: if should_disable_forward_pre_hook(args):
enable_forward_pre_hook(model) enable_forward_pre_hook(model)
pre_hook_enabled = True
timers('interval-time', log_level=0).start(barrier=True) timers('interval-time', log_level=0).start(barrier=True)
if args.enable_ft_package and ft_integration.get_rank_monitor_client() is not None:
ft_integration.get_rank_monitor_client(
ft_integration.StateMachineActions.EVAL_HEARTBEAT).send_heartbeat()
# Miscellaneous post-training-step functions (e.g., FT heartbeats, GC). # Miscellaneous post-training-step functions (e.g., FT heartbeats, GC).
# Some of these only happen at specific iterations. # Some of these only happen at specific iterations.
post_training_step_callbacks(model, optimizer, opt_param_scheduler, iteration, prof, post_training_step_callbacks(model, optimizer, opt_param_scheduler, iteration, prof,
num_floating_point_operations_since_last_log_event) num_floating_point_operations_since_last_log_event)
# Checkpoint and decide whether to exit. # Checkpoint and decide whether to exit.
should_exit = checkpoint_and_decide_exit(model, optimizer, opt_param_scheduler, iteration, should_exit = checkpoint_and_decide_exit(model, optimizer, opt_param_scheduler, iteration,
num_floating_point_operations_so_far, num_floating_point_operations_so_far,
checkpointing_context, train_data_iterator) checkpointing_context, train_data_iterator)
if should_exit: if should_exit:
break break
...@@ -1569,19 +1688,23 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler, ...@@ -1569,19 +1688,23 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
writer.flush() writer.flush()
# Close out pre-hooks if using distributed optimizer and overlapped param gather. # Close out pre-hooks if using distributed optimizer and overlapped param gather.
if args.use_distributed_optimizer and args.overlap_param_gather: if pre_hook_enabled:
disable_forward_pre_hook(model) disable_forward_pre_hook(model)
ft_integration.on_checkpointing_start()
# This will finalize all unfinalized async request and terminate
# a persistent async worker if persistent ckpt worker is enabled
maybe_finalize_async_save(blocking=True, terminate=True)
ft_integration.on_checkpointing_end(is_async_finalization=True)
if args.enable_ft_package and ft_integration.get_rank_monitor_client() is not None: if args.enable_ft_package and ft_integration.get_rank_monitor_client() is not None:
ft_integration.get_rank_monitor_client().shutdown_workload_monitoring() ft_integration.get_rank_monitor_client().shutdown_workload_monitoring()
maybe_finalize_async_save(blocking=True)
# If any exit conditions (signal handler, duration, iterations) have been reached, exit. # If any exit conditions (signal handler, duration, iterations) have been reached, exit.
if should_exit: if should_exit:
wandb_writer = get_wandb_writer() wandb_writer = get_wandb_writer()
if wandb_writer: if wandb_writer:
wandb_writer.finish() wandb_writer.finish()
ft_integration.shutdown()
sys.exit(exit_code) sys.exit(exit_code)
return iteration, num_floating_point_operations_so_far return iteration, num_floating_point_operations_so_far
...@@ -1632,6 +1755,7 @@ def evaluate(forward_step_func, ...@@ -1632,6 +1755,7 @@ def evaluate(forward_step_func,
forward_backward_func = get_forward_backward_func() forward_backward_func = get_forward_backward_func()
# Don't care about timing during evaluation # Don't care about timing during evaluation
config.timers = None config.timers = None
ft_integration.on_eval_step_start()
loss_dicts = forward_backward_func( loss_dicts = forward_backward_func(
forward_step_func=forward_step_func, forward_step_func=forward_step_func,
data_iterator=data_iterator, data_iterator=data_iterator,
...@@ -1641,6 +1765,7 @@ def evaluate(forward_step_func, ...@@ -1641,6 +1765,7 @@ def evaluate(forward_step_func,
micro_batch_size=args.micro_batch_size, micro_batch_size=args.micro_batch_size,
decoder_seq_length=args.decoder_seq_length, decoder_seq_length=args.decoder_seq_length,
forward_only=True) forward_only=True)
ft_integration.on_eval_step_end()
config.timers = get_timers() config.timers = get_timers()
# Empty unused memory # Empty unused memory
...@@ -1701,7 +1826,9 @@ def evaluate(forward_step_func, ...@@ -1701,7 +1826,9 @@ def evaluate(forward_step_func,
timers('evaluate').stop() timers('evaluate').stop()
timers.log(['evaluate']) timers.log(['evaluate'])
rerun_state_machine.set_mode(rerun_mode)
rerun_state_machine.set_mode(rerun_mode) rerun_state_machine.set_mode(rerun_mode)
return total_loss_dict, collected_non_loss_data, False return total_loss_dict, collected_non_loss_data, False
...@@ -1869,12 +1996,15 @@ def build_train_valid_test_data_iterators( ...@@ -1869,12 +1996,15 @@ def build_train_valid_test_data_iterators(
def _get_iterator(dataloader_type, dataloader): def _get_iterator(dataloader_type, dataloader):
"""Return dataset iterator.""" """Return dataset iterator."""
if dataloader_type == "single": if dataloader_type == "single":
return RerunDataIterator(dataloader) return RerunDataIterator(iter(dataloader))
elif dataloader_type == "cyclic": elif dataloader_type == "cyclic":
return RerunDataIterator(cyclic_iter(dataloader)) return RerunDataIterator(iter(cyclic_iter(dataloader)))
elif dataloader_type == "external": elif dataloader_type == "external":
# External dataloader is passed through. User is expected to define how to iterate. # External dataloader is passed through. User is expected to define how to iterate.
return RerunDataIterator(dataloader, make_iterable=False) if isinstance(dataloader, list):
return [RerunDataIterator(d) for d in dataloader]
else:
return RerunDataIterator(dataloader)
else: else:
raise RuntimeError("unexpected dataloader type") raise RuntimeError("unexpected dataloader type")
...@@ -1894,3 +2024,8 @@ def build_train_valid_test_data_iterators( ...@@ -1894,3 +2024,8 @@ def build_train_valid_test_data_iterators(
test_data_iterator = None test_data_iterator = None
return train_data_iterator, valid_data_iterator, test_data_iterator return train_data_iterator, valid_data_iterator, test_data_iterator
def should_disable_forward_pre_hook(args):
"""Block forward pre-hook for certain configurations."""
return not args.use_custom_fsdp and args.use_distributed_optimizer and args.overlap_param_gather
...@@ -33,18 +33,23 @@ from megatron.training import ( ...@@ -33,18 +33,23 @@ from megatron.training import (
get_adlr_autoresume, get_adlr_autoresume,
) )
from megatron.core import DistributedDataParallel as DDP from megatron.core import DistributedDataParallel as DDP
from megatron.core.distributed.custom_fsdp import FullyShardedDataParallel as custom_FSDP
from megatron.core import mpu from megatron.core import mpu
from megatron.core.datasets.utils import get_blend_from_list from megatron.core.datasets.utils import get_blend_from_list
from megatron.core.tensor_parallel import param_is_not_tensor_parallel_duplicate from megatron.core.tensor_parallel import param_is_not_tensor_parallel_duplicate
from megatron.core.utils import get_data_parallel_group_if_dtensor, to_local_if_dtensor from megatron.core.utils import (
get_batch_on_this_cp_rank,
get_data_parallel_group_if_dtensor,
to_local_if_dtensor,
)
from megatron.legacy.model import Float16Module from megatron.legacy.model import Float16Module
from megatron.legacy.model.module import param_is_not_shared from megatron.legacy.model.module import param_is_not_shared
try: try:
from megatron.core.distributed import TorchFullyShardedDataParallel as torch_FSDP from megatron.core.distributed import TorchFullyShardedDataParallel as torch_FSDP
ALL_MODULE_WRAPPER_CLASSNAMES = (DDP, torch_FSDP, Float16Module) ALL_MODULE_WRAPPER_CLASSNAMES = (DDP, torch_FSDP, custom_FSDP, Float16Module)
except ImportError: except ImportError:
ALL_MODULE_WRAPPER_CLASSNAMES = (DDP, Float16Module) ALL_MODULE_WRAPPER_CLASSNAMES = (DDP, custom_FSDP, Float16Module)
def unwrap_model(model, module_instances=ALL_MODULE_WRAPPER_CLASSNAMES): def unwrap_model(model, module_instances=ALL_MODULE_WRAPPER_CLASSNAMES):
...@@ -62,7 +67,7 @@ def unwrap_model(model, module_instances=ALL_MODULE_WRAPPER_CLASSNAMES): ...@@ -62,7 +67,7 @@ def unwrap_model(model, module_instances=ALL_MODULE_WRAPPER_CLASSNAMES):
return unwrapped_model return unwrapped_model
def calc_params_l2_norm(model): def calc_params_l2_norm(model, force_create_fp32_copy=False):
"""Calculate l2 norm of parameters """ """Calculate l2 norm of parameters """
args = get_args() args = get_args()
if not isinstance(model, list): if not isinstance(model, list):
...@@ -70,54 +75,110 @@ def calc_params_l2_norm(model): ...@@ -70,54 +75,110 @@ def calc_params_l2_norm(model):
# Seperate moe and dense params # Seperate moe and dense params
params_data = [] params_data = []
moe_params_data = [] moe_params_data = []
sharded_params_data = []
data_parallel_group = None data_parallel_group = None
custom_fsdp_all_param_is_shared = False
for model_chunk in model: for model_chunk in model:
for i, param in enumerate(model_chunk.parameters()): for param in model_chunk.parameters():
data_parallel_group = get_data_parallel_group_if_dtensor(param, data_parallel_group) data_parallel_group = get_data_parallel_group_if_dtensor(param, data_parallel_group)
is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param) is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
if not (param.requires_grad and is_not_tp_duplicate): if not is_not_tp_duplicate:
continue continue
assert is_not_tp_duplicate assert is_not_tp_duplicate
if hasattr(param, "fully_shard_param_local_shard"):
param = param.fully_shard_param_local_shard
assert [getattr(p, "fully_shard_param_local_shard", None) is not None for p in model_chunk.parameters()]
custom_fsdp_all_param_is_shared = True
if param.numel() == 0:
continue
if not getattr(param, 'allreduce', True): if not getattr(param, 'allreduce', True):
# TODO: Implement memory optimization for MoE parameters.
assert param_is_not_shared(param) assert param_is_not_shared(param)
param = to_local_if_dtensor(param) param = to_local_if_dtensor(param)
moe_params_data.append(param.data.float() if args.bf16 else param.data) moe_params_data.append(param.data.float() if args.bf16 else param.data)
else: else:
if param_is_not_shared(param): if param_is_not_shared(param):
param = to_local_if_dtensor(param) param = to_local_if_dtensor(param)
params_data.append(param.data.float() if args.bf16 else param.data) if args.bf16:
if not force_create_fp32_copy and hasattr(param, 'main_param'):
# Calculate dense param norm if getattr(param, 'main_param_sharded', False):
if param.main_param is not None:
sharded_params_data.append(param.main_param)
else:
params_data.append(param.main_param)
else:
# Fallback to original logic of making a fp32 copy of the
# parameter if `.main_param` attribute is not available.
params_data.append(param.data.float())
else:
params_data.append(param.data)
# Calculate norm.
dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device='cuda') dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device='cuda')
norm, _ = multi_tensor_applier( if len(params_data) > 0:
multi_tensor_l2norm, norm, _ = multi_tensor_applier(
dummy_overflow_buf, multi_tensor_l2norm,
[params_data], dummy_overflow_buf,
False # no per-parameter norm [params_data],
) False # no per-parameter norm.
norm_2 = norm * norm )
norm_2 = norm * norm
else:
norm_2 = torch.zeros((1,), dtype=torch.float32, device='cuda')
if data_parallel_group is not None: if data_parallel_group is not None:
torch.distributed.all_reduce(norm_2, torch.distributed.all_reduce(norm_2,
op=torch.distributed.ReduceOp.SUM, op=torch.distributed.ReduceOp.SUM,
group=data_parallel_group) group=data_parallel_group)
# Sum across all model-parallel GPUs(tensor + pipeline). # Add norm contribution from params with sharded main_params. These norms need to be
# accumulated across the DP group since the main parameters are sharded because
# of distributed optimizer.
if len(sharded_params_data) > 0:
dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device='cuda')
sharded_norm, _ = multi_tensor_applier(
multi_tensor_l2norm,
dummy_overflow_buf,
[sharded_params_data],
False # no per-parameter norm.
)
sharded_norm_2 = sharded_norm * sharded_norm
# Sum over all DP groups.
torch.distributed.all_reduce(
sharded_norm_2,
op=torch.distributed.ReduceOp.SUM,
group=mpu.get_data_parallel_group()
)
norm_2 += sharded_norm_2
if custom_fsdp_all_param_is_shared:
torch.distributed.all_reduce(norm_2,
op=torch.distributed.ReduceOp.SUM,
group=mpu.get_data_parallel_group())
# Sum across all model-parallel GPUs (tensor + pipeline).
torch.distributed.all_reduce( torch.distributed.all_reduce(
norm_2, norm_2,
op=torch.distributed.ReduceOp.SUM, op=torch.distributed.ReduceOp.SUM,
group=mpu.get_model_parallel_group() group=mpu.get_model_parallel_group()
) )
# Calculate moe norm
# Add norm contribution from expert layers in MoEs.
if len(moe_params_data) > 0: if len(moe_params_data) > 0:
moe_norm, _ = multi_tensor_applier( moe_norm, _ = multi_tensor_applier(
multi_tensor_l2norm, multi_tensor_l2norm,
dummy_overflow_buf, dummy_overflow_buf,
[moe_params_data], [moe_params_data],
False # no per-parameter norm False # no per-parameter norm.
) )
moe_norm_2 = moe_norm * moe_norm moe_norm_2 = moe_norm * moe_norm
if custom_fsdp_all_param_is_shared:
torch.distributed.all_reduce(moe_norm_2,
op=torch.distributed.ReduceOp.SUM,
group=mpu.get_expert_data_parallel_group())
# Sum across expert tensor, model and pipeline parallel GPUs. # Sum across expert tensor, model and pipeline parallel GPUs.
torch.distributed.all_reduce( torch.distributed.all_reduce(
moe_norm_2, moe_norm_2,
...@@ -125,6 +186,7 @@ def calc_params_l2_norm(model): ...@@ -125,6 +186,7 @@ def calc_params_l2_norm(model):
group=mpu.get_expert_tensor_model_pipeline_parallel_group() group=mpu.get_expert_tensor_model_pipeline_parallel_group()
) )
norm_2 += moe_norm_2 norm_2 += moe_norm_2
return norm_2.item() ** 0.5 return norm_2.item() ** 0.5
...@@ -140,6 +202,41 @@ def average_losses_across_data_parallel_group(losses): ...@@ -140,6 +202,41 @@ def average_losses_across_data_parallel_group(losses):
return averaged_losses return averaged_losses
def reduce_max_stat_across_model_parallel_group(stat: float) -> float:
"""
Ranks without an optimizer will have no grad_norm or num_zeros_in_grad stats.
We need to ensure the logging and writer rank has those values.
This function reduces a stat tensor across the model parallel group.
We use an all_reduce max since the values have already been summed across optimizer ranks where possible
"""
if stat is None:
stat = -1.0
stat = torch.tensor([stat], dtype=torch.float32, device=torch.cuda.current_device())
torch.distributed.all_reduce(
stat, op=torch.distributed.ReduceOp.MAX, group=mpu.get_model_parallel_group()
)
if stat.item() == -1.0:
return None
else:
return stat.item()
def logical_and_across_model_parallel_group(input: bool) -> bool:
"""
This function gathers a bool value across the model parallel group
"""
if input is True:
input = 1
else:
input = 0
input = torch.tensor([input], dtype=torch.int, device=torch.cuda.current_device())
torch.distributed.all_reduce(
input, op=torch.distributed.ReduceOp.MIN, group=mpu.get_model_parallel_group()
)
return bool(input.item())
def report_memory(name): def report_memory(name):
"""Simple GPU memory report.""" """Simple GPU memory report."""
mega_bytes = 1024.0 * 1024.0 mega_bytes = 1024.0 * 1024.0
...@@ -254,39 +351,6 @@ def get_ltor_masks_and_position_ids(data, ...@@ -254,39 +351,6 @@ def get_ltor_masks_and_position_ids(data,
return attention_mask, loss_mask, position_ids return attention_mask, loss_mask, position_ids
def get_batch_on_this_cp_rank(batch):
""" Slice batch input along sequence dimension into multiple chunks,
which are parallelized across GPUs in a context parallel group.
"""
# With causal masking, each token only attends to its prior tokens. Simply split
# sequence into CP chunks can result in severe load imbalance. That's to say, chunks
# at the end of sequence have bigger workload than others. To address this issue,
# we split sequence into 2*CP ranks. Assuming CP=2, we then get 4 chunks, chunk_0
# and chunk_3 are assigned to GPU0, chunk_1 and chunk_2 are assigned to GPU1, so
# that we can get balanced workload among GPUs in a context parallel group.
args = get_args()
cp_size = args.context_parallel_size
if cp_size > 1:
cp_rank = mpu.get_context_parallel_rank()
for key, val in batch.items():
if val is not None:
seq_dim = 1 if key != 'attention_mask' else 2
val = val.view(
*val.shape[0:seq_dim],
2 * cp_size,
val.shape[seq_dim] // (2 * cp_size),
*val.shape[(seq_dim + 1) :],
)
index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)],
device="cpu", pin_memory=True).cuda(non_blocking=True)
val = val.index_select(seq_dim, index)
val = val.view(*val.shape[0:seq_dim], -1, *val.shape[(seq_dim + 2) :])
batch[key] = val
return batch
def print_rank_0(message): def print_rank_0(message):
"""If distributed is initialized, print only on rank 0.""" """If distributed is initialized, print only on rank 0."""
if torch.distributed.is_initialized(): if torch.distributed.is_initialized():
...@@ -295,6 +359,10 @@ def print_rank_0(message): ...@@ -295,6 +359,10 @@ def print_rank_0(message):
else: else:
print(message, flush=True) print(message, flush=True)
def is_rank0():
"""Returns true if called in the rank0, false otherwise"""
return torch.distributed.is_initialized() and torch.distributed.get_rank() == 0
def is_last_rank(): def is_last_rank():
return torch.distributed.get_rank() == ( return torch.distributed.get_rank() == (
torch.distributed.get_world_size() - 1) torch.distributed.get_world_size() - 1)
...@@ -307,6 +375,9 @@ def print_rank_last(message): ...@@ -307,6 +375,9 @@ def print_rank_last(message):
else: else:
print(message, flush=True) print(message, flush=True)
def get_device_arch_version():
"""Returns GPU arch version (8: Ampere, 9: Hopper, 10: Blackwell, ...)"""
return torch.cuda.get_device_properties(torch.device("cuda:0")).major
def append_to_progress_log(string, barrier=True): def append_to_progress_log(string, barrier=True):
"""Append given string to progress log.""" """Append given string to progress log."""
...@@ -431,11 +502,11 @@ def get_batch_on_this_tp_rank(data_iterator): ...@@ -431,11 +502,11 @@ def get_batch_on_this_tp_rank(data_iterator):
_broadcast(loss_mask) _broadcast(loss_mask)
_broadcast(attention_mask) _broadcast(attention_mask)
_broadcast(position_ids) _broadcast(position_ids)
elif mpu.is_pipeline_first_stage(): elif mpu.is_pipeline_first_stage():
labels=None labels=None
loss_mask=None loss_mask=None
_broadcast(tokens) _broadcast(tokens)
_broadcast(attention_mask) _broadcast(attention_mask)
_broadcast(position_ids) _broadcast(position_ids)
...@@ -443,11 +514,11 @@ def get_batch_on_this_tp_rank(data_iterator): ...@@ -443,11 +514,11 @@ def get_batch_on_this_tp_rank(data_iterator):
elif mpu.is_pipeline_last_stage(): elif mpu.is_pipeline_last_stage():
tokens=None tokens=None
position_ids=None position_ids=None
_broadcast(labels) _broadcast(labels)
_broadcast(loss_mask) _broadcast(loss_mask)
_broadcast(attention_mask) _broadcast(attention_mask)
batch = { batch = {
'tokens': tokens, 'tokens': tokens,
'labels': labels, 'labels': labels,
......
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from pathlib import Path
from typing import Tuple
from megatron.training.global_vars import get_wandb_writer
from megatron.training.utils import print_rank_last
def _get_wandb_artifact_tracker_filename(save_dir: str) -> Path:
"""Wandb artifact tracker file rescords the latest artifact wandb entity and project"""
return Path(save_dir) / "latest_wandb_artifact_path.txt"
def _get_artifact_name_and_version(save_dir: Path, checkpoint_path: Path) -> Tuple[str, str]:
return save_dir.stem, checkpoint_path.stem
def on_save_checkpoint_success(checkpoint_path: str, tracker_filename: str, save_dir: str, iteration: int) -> None:
"""Function to be called after checkpointing succeeds and checkpoint is persisted for logging it as an artifact in W&B
Args:
checkpoint_path (str): path of the saved checkpoint
tracker_filename (str): path of the tracker filename for the checkpoint iteration
save_dir (str): path of the root save folder for all checkpoints
iteration (int): iteration of the checkpoint
"""
wandb_writer = get_wandb_writer()
if wandb_writer:
metadata = {"iteration": iteration}
artifact_name, artifact_version = _get_artifact_name_and_version(Path(save_dir), Path(checkpoint_path))
artifact = wandb_writer.Artifact(artifact_name, type="model", metadata=metadata)
artifact.add_reference(f"file://{checkpoint_path}", checksum=False)
artifact.add_file(tracker_filename)
wandb_writer.run.log_artifact(artifact, aliases=[artifact_version])
wandb_tracker_filename = _get_wandb_artifact_tracker_filename(save_dir)
wandb_tracker_filename.write_text(f"{wandb_writer.run.entity}/{wandb_writer.run.project}")
def on_load_checkpoint_success(checkpoint_path: str, load_dir: str) -> None:
"""Function to be called after succesful loading of a checkpoint, for aggregation and logging it to W&B
Args:
checkpoint_path (str): path of the loaded checkpoint
load_dir (str): path of the root save folder for all checkpoints
iteration (int): iteration of the checkpoint
"""
wandb_writer = get_wandb_writer()
if wandb_writer:
try:
artifact_name, artifact_version = _get_artifact_name_and_version(Path(load_dir), Path(checkpoint_path))
wandb_tracker_filename = _get_wandb_artifact_tracker_filename(load_dir)
artifact_path = ""
if wandb_tracker_filename.is_file():
artifact_path = wandb_tracker_filename.read_text().strip()
artifact_path = f"{artifact_path}/"
wandb_writer.run.use_artifact(f"{artifact_path}{artifact_name}:{artifact_version}")
except Exception:
print_rank_last(f" failed to find checkpoint {checkpoint_path} in wandb")
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment