# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. """ Fault Tolerance (FT) package integration for Megatron-LM, using the FT section-based API. The FT package is included in "nvidia-resiliency-ext" (https://github.com/NVIDIA/nvidia-resiliency-ext). NOTE: The workload must be run using the `ft_launcher` tool provided by `nvidia-resiliency-ext.` NOTE: Calls to the public API of this module are no-ops if FT is not initialized (`ft_integration.setup` was not called). NOTE: Default distributed process group should be initialized before calling `ft_integration.setup` The "setup" FT section is opened during FT initialization and closed before the first training or eval iteration. Training and evaluation steps are wrapped in the "step" section, but only after a few warmup iterations. This is because the initial iterations may be slower, and we want the "step" timeout to be short. These warmup steps, which are not wrapped in the "step" section, will fall into the out-of-section area. All checkpoint-saving-related operations (including asynchronous checkpointing finalization) are wrapped in the "checkpointing" section. If timeout calculation is enabled (--calc-ft-timeouts), FT timeouts are updated after each checkpoint and at the end of the run. Updated values are based on observed intervals. `ft_launcher` command example: ``` ft_launcher \ --rdzv_backend=c10d --rdzv_endpoint=${MASTER_ADDR}:${MASTER_PORT} \ --nnodes=${NUM_NODES} --nproc-per-node=${NUM_GPUS_PER_NODE} \ --ft-param-rank_section_timeouts=setup:600,step:180,checkpointing:420 \ --ft-param-rank_out_of_section_timeout=300 \ train_script_with_ft.py ``` """ import argparse import json import os import random import signal import sys import threading import time from typing import Any, Optional import torch from . import global_vars from .utils import is_rank0, print_rank_0 _GLOBAL_RANK_MONITOR_CLIENT = None _ft_state_path = None _is_persistent_chkpt_loaded = False _is_async_chkpt_enabled = False _is_calculating_timeouts = False _is_setup_section_open = False _seen_checkpoints_cnt = 0 _seen_tr_iters_cnt = 0 _curr_eval_iter_idx = 0 _NUM_WARMUP_ITERS = 1 _MIN_ITERS_FOR_STEP_TIMEOUT_UPDATE = 16 def get_rank_monitor_client() -> Optional[Any]: """Returns the underlying fault tolerance client instance Returns: RankMonitorClient: rank monitor client instance, or None if FT was not initialized """ return _GLOBAL_RANK_MONITOR_CLIENT def setup(args: argparse.Namespace) -> None: """Initialize fault tolerance Args: args (argparse.Namespace): parsed Megatron-LM command line arguments Raises: ValueError: if invalid config is provided """ from nvidia_resiliency_ext.fault_tolerance import RankMonitorClient print_rank_0(f"FT: initializing...") checkpoint_dir = args.save if not checkpoint_dir: raise ValueError("checkpointing save dir must be set to enable fault tolerance") if is_rank0() and not os.path.exists(checkpoint_dir): # MLM checkpoint dir will be needed for saving FT state. # it can happen before the checkpointing, so create it in advance os.makedirs(checkpoint_dir, exist_ok=True) cli = RankMonitorClient() global _GLOBAL_RANK_MONITOR_CLIENT global_vars._ensure_var_is_not_initialized(_GLOBAL_RANK_MONITOR_CLIENT, 'rank monitor client') _GLOBAL_RANK_MONITOR_CLIENT = cli global _ft_state_path _ft_state_path = os.path.join(checkpoint_dir, "ft_state.json") global _is_async_chkpt_enabled _is_async_chkpt_enabled = args.async_save global _is_calculating_timeouts _is_calculating_timeouts = args.calc_ft_timeouts cli.init_workload_monitoring() _load_state_if_exists() print_rank_0(f"FT: initialized. Timeouts={cli.section_timeouts}") cli.start_section("setup") global _is_setup_section_open _is_setup_section_open = True def on_training_step_start() -> None: """Should be called before each training step""" rmon_cli = get_rank_monitor_client() if rmon_cli is not None: global _is_setup_section_open if _is_setup_section_open: rmon_cli.end_section("setup") _is_setup_section_open = False if _seen_tr_iters_cnt >= _NUM_WARMUP_ITERS: rmon_cli.start_section("step") # reset eval step index. we started training, so evaluation is done global _curr_eval_iter_idx _curr_eval_iter_idx = 0 def on_training_step_end() -> None: """Should be called after each training step""" rmon_cli = get_rank_monitor_client() if rmon_cli is not None: global _seen_tr_iters_cnt if _seen_tr_iters_cnt >= _NUM_WARMUP_ITERS: rmon_cli.end_section("step") _seen_tr_iters_cnt += 1 def on_eval_step_start() -> None: """Should be called before each validation step""" rmon_cli = get_rank_monitor_client() if rmon_cli is not None: global _is_setup_section_open if _is_setup_section_open: # setup section can be open if there were no training iters before evaluation rmon_cli.end_section("setup") _is_setup_section_open = False if _curr_eval_iter_idx >= _NUM_WARMUP_ITERS: rmon_cli.start_section("step") def on_eval_step_end() -> None: """Should be called after each validation step""" rmon_cli = get_rank_monitor_client() if rmon_cli is not None: global _curr_eval_iter_idx if _curr_eval_iter_idx >= _NUM_WARMUP_ITERS: rmon_cli.end_section("step") _curr_eval_iter_idx += 1 def on_checkpointing_start() -> None: """Should be called before each checkpoint-saving-related operation.""" rmon_cli = get_rank_monitor_client() if rmon_cli is not None: rmon_cli.start_section("checkpointing") def on_checkpointing_end(is_async_finalization: bool) -> None: """Should be called after each checkpoint-saving-related operation. Args: is_async_finalization (bool): true if called after an async checkpointing finalization """ rmon_cli = get_rank_monitor_client() if rmon_cli is not None: rmon_cli.end_section("checkpointing") # async checkpointing finalization is called before each training iter, it can be no-op. # let's try to update the timeouts only on the `save_checkpoint` if not is_async_finalization: global _seen_checkpoints_cnt _seen_checkpoints_cnt += 1 _maybe_update_timeouts() def on_checkpoint_loaded(is_local_chkpt: bool) -> None: """Should be called after a checkpoint was loaded Args: is_local_chkpt (bool): true if it was a local checkpoint, false if global """ # checkpoint can be loaded during "setup" # check if persistent checkpoint was loaded, # in-memory checkpoint reading can be very fast, # so we could underestimate the "setup" timeout global _is_persistent_chkpt_loaded _is_persistent_chkpt_loaded = not is_local_chkpt def shutdown() -> None: """Shutdowns fault folerance, updates the FT timeouts if possible""" global _GLOBAL_RANK_MONITOR_CLIENT rmon_cli = get_rank_monitor_client() if rmon_cli is not None: print_rank_0("FT: closing...") _maybe_update_timeouts(is_closing_ft=True) rmon_cli.shutdown_workload_monitoring() print_rank_0("FT: closed.") _GLOBAL_RANK_MONITOR_CLIENT = None def _load_state_if_exists(): rmon_cli = get_rank_monitor_client() if os.path.exists(_ft_state_path): with open(_ft_state_path, "r") as f: ft_state = json.load(f) rmon_cli.load_state_dict(ft_state) print_rank_0(f"FT: loaded timeouts from {_ft_state_path}. {rmon_cli.section_timeouts}") def _update_timeouts(selected_sections, calc_out_of_section): print_rank_0( f"FT: updating timeouts for: {selected_sections} " + f"update out-of-section: {calc_out_of_section} ..." ) rmon_cli = get_rank_monitor_client() rmon_cli.calculate_and_set_section_timeouts( selected_sections=selected_sections, calc_out_of_section=calc_out_of_section ) if is_rank0(): ft_state = rmon_cli.state_dict() with open(_ft_state_path, "w") as f: json.dump(ft_state, f) print_rank_0(f"FT: updated timeouts saved to {_ft_state_path}. {rmon_cli.section_timeouts}") def _maybe_update_timeouts(is_closing_ft=False): rmon_cli = get_rank_monitor_client() if rmon_cli is None: return if not _is_calculating_timeouts: return # Decide which section timeouts can be updated sections_to_update = [] if _is_persistent_chkpt_loaded: sections_to_update.append("setup") else: print_rank_0( "FT: can't update the setup section timeout until persistent checkpoint is loaded" ) if _seen_tr_iters_cnt >= _MIN_ITERS_FOR_STEP_TIMEOUT_UPDATE: sections_to_update.append("step") else: print_rank_0("FT: need to see more training iterations to update the step section timeout") if _seen_checkpoints_cnt > 0: if not _is_async_chkpt_enabled: sections_to_update.append("checkpointing") else: # There can be too much checkpointing section time variability # across runs with the async checkpointing, e.g. in some runs all checkpointing # work can be parallelized (=short checkpointing sections) while in others we can # hit a costly finalization. print_rank_0( "FT: can't update the checkpointing section timeout with async checkpointing" ) else: print_rank_0("FT: checkpointing section is not updated until a checkpoint was saved") update_out_of_section = False if is_closing_ft: # with async checkpointing, "checkpointing" section is not updated, # but still we want to see some checkpointing to ensure that is was a complete run if {'setup', 'step'}.issubset(sections_to_update) and _seen_checkpoints_cnt > 0: update_out_of_section = True else: print_rank_0( "FT: the out-of-section timeout won't be updated until all FT sections were seen" ) else: print_rank_0("FT: the out-of-section timeout won't be updated as the FT is not closing yet") if sections_to_update or update_out_of_section: _update_timeouts( selected_sections=sections_to_update, calc_out_of_section=update_out_of_section ) def maybe_setup_simulated_fault() -> None: """Sets a simulated fault, based on `FT_SIM_FAULT_DESC` env variable. Simulated fault description format: rank_hung|rank_killed;rank_to_fail|"";base_delay NOTE: This if for FT testing only """ simulated_fault_desc = os.environ.get('FT_SIM_FAULT_DESC', None) if not simulated_fault_desc: return fault_type: Any # silence mypy rank_to_fail: Any # silence mypy base_delay: Any # silence mypy fault_type, rank_to_fail, base_delay = simulated_fault_desc.split(';') fault_type = fault_type.strip() rank_to_fail = rank_to_fail.strip() rank_to_fail = int(rank_to_fail) if rank_to_fail else None base_delay = float(base_delay.strip()) rng = random.Random() print_rank_0( f"FT: Initializing simulated fault: {fault_type}," + f"rank to fail: {rank_to_fail}, base delay: {base_delay}" ) # rank that simulates a fault can be explicitly specified in the `rank_to_fail` field # if not specified, it just picks a random rank rank = torch.distributed.get_rank() rand_rank = rng.randint(0, torch.distributed.get_world_size() - 1) rank_to_fail = rank_to_fail if rank_to_fail is not None else rand_rank rank_to_fail = torch.tensor([rank_to_fail], device=torch.cuda.current_device()) torch.distributed.broadcast(rank_to_fail, 0) rank_to_fail = int(rank_to_fail.item()) if rank != rank_to_fail: # this rank is not going to simulate a fault, nothing more to do return if fault_type == 'random': fault_type = rng.choice(['rank_killed', 'rank_hung']) if fault_type == 'rank_killed': target_pid = os.getpid() elif fault_type == 'rank_hung': target_pid = os.getpid() else: raise Exception(f"Unknown fault type {fault_type} expected one of: rank_killed, rank_hung.") # add some randomness to the delay delay = base_delay + 0.2 * rng.random() * base_delay print_rank_0(f"FT: Selected fault={fault_type}; target rank={rank_to_fail}; delay={delay}") def __fault_thread(): time.sleep(delay) for of in [sys.stdout, sys.stderr]: print( f"\n####\nFT: Simulating fault: {fault_type}; rank to fail: {rank_to_fail}\n####\n", file=of, flush=True, ) if fault_type == 'rank_hung': os.kill(target_pid, signal.SIGSTOP) else: os.kill(target_pid, signal.SIGKILL) fault_sim_thread = threading.Thread(target=__fault_thread) fault_sim_thread.daemon = True fault_sim_thread.start()