ft_integration.py 13 KB
Newer Older
wangxj's avatar
wangxj committed
1
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
xingjinliang's avatar
xingjinliang committed
2
3

"""
wangxj's avatar
wangxj committed
4
Fault Tolerance (FT) package integration for Megatron-LM, using the FT section-based API.
xingjinliang's avatar
xingjinliang committed
5

wangxj's avatar
wangxj committed
6
7
The FT package is included in "nvidia-resiliency-ext" 
(https://github.com/NVIDIA/nvidia-resiliency-ext).
xingjinliang's avatar
xingjinliang committed
8

wangxj's avatar
wangxj committed
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
NOTE: The workload must be run using the `ft_launcher` tool provided by `nvidia-resiliency-ext.`
NOTE: Calls to the public API of this module are no-ops if FT is not initialized 
(`ft_integration.setup` was not called).
NOTE: Default distributed process group should be initialized before calling `ft_integration.setup`

The "setup" FT section is opened during FT initialization and closed before the first training or 
eval iteration. Training and evaluation steps are wrapped in the "step" section, but only after a 
few warmup iterations. This is because the initial iterations may be slower, and we want the "step" 
timeout to be short. These warmup steps, which are not wrapped in the "step" section, will fall into
the out-of-section area. All checkpoint-saving-related operations (including asynchronous 
checkpointing finalization) are wrapped in the "checkpointing" section.

If timeout calculation is enabled (--calc-ft-timeouts), 
FT timeouts are updated after each checkpoint and at the end of the run.
Updated values are based on observed intervals.

`ft_launcher` command example:
```
ft_launcher \
    --rdzv_backend=c10d --rdzv_endpoint=${MASTER_ADDR}:${MASTER_PORT} \
    --nnodes=${NUM_NODES} --nproc-per-node=${NUM_GPUS_PER_NODE} \
    --ft-param-rank_section_timeouts=setup:600,step:180,checkpointing:420 \
    --ft-param-rank_out_of_section_timeout=300 \
    train_script_with_ft.py
```
xingjinliang's avatar
xingjinliang committed
34
35
"""

wangxj's avatar
wangxj committed
36
37
38
39
40
41
42
43
44
45
46
47
import argparse
import json
import os
import random
import signal
import sys
import threading
import time
from typing import Any, Optional

import torch

xingjinliang's avatar
xingjinliang committed
48
from . import global_vars
wangxj's avatar
wangxj committed
49
from .utils import is_rank0, print_rank_0
xingjinliang's avatar
xingjinliang committed
50

wangxj's avatar
wangxj committed
51
52
53
54
55
56
57
58
59
60
61
62
63
_GLOBAL_RANK_MONITOR_CLIENT = None

_ft_state_path = None
_is_persistent_chkpt_loaded = False
_is_async_chkpt_enabled = False
_is_calculating_timeouts = False
_is_setup_section_open = False
_seen_checkpoints_cnt = 0
_seen_tr_iters_cnt = 0
_curr_eval_iter_idx = 0

_NUM_WARMUP_ITERS = 1
_MIN_ITERS_FOR_STEP_TIMEOUT_UPDATE = 16
xingjinliang's avatar
xingjinliang committed
64
65


wangxj's avatar
wangxj committed
66
67
68
69
70
def get_rank_monitor_client() -> Optional[Any]:
    """Returns the underlying fault tolerance client instance

    Returns:
        RankMonitorClient: rank monitor client instance, or None if FT was not initialized
xingjinliang's avatar
xingjinliang committed
71
    """
wangxj's avatar
wangxj committed
72
    return _GLOBAL_RANK_MONITOR_CLIENT
xingjinliang's avatar
xingjinliang committed
73
74


wangxj's avatar
wangxj committed
75
76
def setup(args: argparse.Namespace) -> None:
    """Initialize fault tolerance
xingjinliang's avatar
xingjinliang committed
77

wangxj's avatar
wangxj committed
78
79
    Args:
        args (argparse.Namespace): parsed Megatron-LM command line arguments
xingjinliang's avatar
xingjinliang committed
80

wangxj's avatar
wangxj committed
81
82
83
    Raises:
        ValueError: if invalid config is provided
    """
xingjinliang's avatar
xingjinliang committed
84
    from nvidia_resiliency_ext.fault_tolerance import RankMonitorClient
wangxj's avatar
wangxj committed
85
86
87
88
89
90
91
92
93
94
95

    print_rank_0(f"FT: initializing...")

    checkpoint_dir = args.save
    if not checkpoint_dir:
        raise ValueError("checkpointing save dir must be set to enable fault tolerance")
    if is_rank0() and not os.path.exists(checkpoint_dir):
        # MLM checkpoint dir will be needed for saving FT state.
        # it can happen before the checkpointing, so create it in advance
        os.makedirs(checkpoint_dir, exist_ok=True)

xingjinliang's avatar
xingjinliang committed
96
97
98
99
100
    cli = RankMonitorClient()
    global _GLOBAL_RANK_MONITOR_CLIENT
    global_vars._ensure_var_is_not_initialized(_GLOBAL_RANK_MONITOR_CLIENT, 'rank monitor client')
    _GLOBAL_RANK_MONITOR_CLIENT = cli

wangxj's avatar
wangxj committed
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
    global _ft_state_path
    _ft_state_path = os.path.join(checkpoint_dir, "ft_state.json")

    global _is_async_chkpt_enabled
    _is_async_chkpt_enabled = args.async_save

    global _is_calculating_timeouts
    _is_calculating_timeouts = args.calc_ft_timeouts

    cli.init_workload_monitoring()
    _load_state_if_exists()
    print_rank_0(f"FT: initialized. Timeouts={cli.section_timeouts}")

    cli.start_section("setup")
    global _is_setup_section_open
    _is_setup_section_open = True


def on_training_step_start() -> None:
    """Should be called before each training step"""
    rmon_cli = get_rank_monitor_client()
    if rmon_cli is not None:
        global _is_setup_section_open
        if _is_setup_section_open:
            rmon_cli.end_section("setup")
            _is_setup_section_open = False
        if _seen_tr_iters_cnt >= _NUM_WARMUP_ITERS:
            rmon_cli.start_section("step")
        # reset eval step index. we started training, so evaluation is done
        global _curr_eval_iter_idx
        _curr_eval_iter_idx = 0


def on_training_step_end() -> None:
    """Should be called after each training step"""
    rmon_cli = get_rank_monitor_client()
    if rmon_cli is not None:
        global _seen_tr_iters_cnt
        if _seen_tr_iters_cnt >= _NUM_WARMUP_ITERS:
            rmon_cli.end_section("step")
        _seen_tr_iters_cnt += 1


def on_eval_step_start() -> None:
    """Should be called before each validation step"""
    rmon_cli = get_rank_monitor_client()
    if rmon_cli is not None:
        global _is_setup_section_open
        if _is_setup_section_open:
            # setup section can be open if there were no training iters before evaluation
            rmon_cli.end_section("setup")
            _is_setup_section_open = False
        if _curr_eval_iter_idx >= _NUM_WARMUP_ITERS:
            rmon_cli.start_section("step")


def on_eval_step_end() -> None:
    """Should be called after each validation step"""
    rmon_cli = get_rank_monitor_client()
    if rmon_cli is not None:
        global _curr_eval_iter_idx
        if _curr_eval_iter_idx >= _NUM_WARMUP_ITERS:
            rmon_cli.end_section("step")
        _curr_eval_iter_idx += 1


def on_checkpointing_start() -> None:
    """Should be called before each checkpoint-saving-related operation."""
    rmon_cli = get_rank_monitor_client()
    if rmon_cli is not None:
        rmon_cli.start_section("checkpointing")


def on_checkpointing_end(is_async_finalization: bool) -> None:
    """Should be called after each checkpoint-saving-related operation.

    Args:
        is_async_finalization (bool): true if called after an async checkpointing finalization
    """
    rmon_cli = get_rank_monitor_client()
    if rmon_cli is not None:
        rmon_cli.end_section("checkpointing")
    # async checkpointing finalization is called before each training iter, it can be no-op.
    # let's try to update the timeouts only on the `save_checkpoint`
    if not is_async_finalization:
        global _seen_checkpoints_cnt
        _seen_checkpoints_cnt += 1
        _maybe_update_timeouts()


def on_checkpoint_loaded(is_local_chkpt: bool) -> None:
    """Should be called after a checkpoint was loaded

    Args:
        is_local_chkpt (bool): true if it was a local checkpoint, false if global
    """
    # checkpoint can be loaded during "setup"
    # check if persistent checkpoint was loaded,
    # in-memory checkpoint reading can be very fast,
    # so we could underestimate the "setup" timeout
    global _is_persistent_chkpt_loaded
    _is_persistent_chkpt_loaded = not is_local_chkpt


def shutdown() -> None:
    """Shutdowns fault folerance, updates the FT timeouts if possible"""
    global _GLOBAL_RANK_MONITOR_CLIENT
    rmon_cli = get_rank_monitor_client()
    if rmon_cli is not None:
        print_rank_0("FT: closing...")
        _maybe_update_timeouts(is_closing_ft=True)
        rmon_cli.shutdown_workload_monitoring()
        print_rank_0("FT: closed.")
    _GLOBAL_RANK_MONITOR_CLIENT = None


def _load_state_if_exists():
    rmon_cli = get_rank_monitor_client()
    if os.path.exists(_ft_state_path):
        with open(_ft_state_path, "r") as f:
            ft_state = json.load(f)
        rmon_cli.load_state_dict(ft_state)
        print_rank_0(f"FT: loaded timeouts from {_ft_state_path}. {rmon_cli.section_timeouts}")


def _update_timeouts(selected_sections, calc_out_of_section):
    print_rank_0(
        f"FT: updating timeouts for: {selected_sections} "
        + f"update out-of-section: {calc_out_of_section} ..."
    )
    rmon_cli = get_rank_monitor_client()
    rmon_cli.calculate_and_set_section_timeouts(
        selected_sections=selected_sections, calc_out_of_section=calc_out_of_section
    )
    if is_rank0():
        ft_state = rmon_cli.state_dict()
        with open(_ft_state_path, "w") as f:
            json.dump(ft_state, f)
        print_rank_0(f"FT: updated timeouts saved to {_ft_state_path}. {rmon_cli.section_timeouts}")


def _maybe_update_timeouts(is_closing_ft=False):
    rmon_cli = get_rank_monitor_client()
    if rmon_cli is None:
        return
    if not _is_calculating_timeouts:
        return

    # Decide which section timeouts can be updated
    sections_to_update = []

    if _is_persistent_chkpt_loaded:
        sections_to_update.append("setup")
    else:
        print_rank_0(
            "FT: can't update the setup section timeout until persistent checkpoint is loaded"
        )

    if _seen_tr_iters_cnt >= _MIN_ITERS_FOR_STEP_TIMEOUT_UPDATE:
        sections_to_update.append("step")
    else:
        print_rank_0("FT: need to see more training iterations to update the step section timeout")

    if _seen_checkpoints_cnt > 0:
        if not _is_async_chkpt_enabled:
            sections_to_update.append("checkpointing")
        else:
            # There can be too much checkpointing section time variability
            # across runs with the async checkpointing, e.g. in some runs all checkpointing
            # work can be parallelized (=short checkpointing sections) while in others we can
            # hit a costly finalization.
            print_rank_0(
                "FT: can't update the checkpointing section timeout with async checkpointing"
            )
    else:
        print_rank_0("FT: checkpointing section is not updated until a checkpoint was saved")

    update_out_of_section = False
    if is_closing_ft:
        # with async checkpointing, "checkpointing" section is not updated,
        # but still we want to see some checkpointing to ensure that is was a complete run
        if {'setup', 'step'}.issubset(sections_to_update) and _seen_checkpoints_cnt > 0:
            update_out_of_section = True
        else:
            print_rank_0(
                "FT: the out-of-section timeout won't be updated until all FT sections were seen"
            )

    else:
        print_rank_0("FT: the out-of-section timeout won't be updated as the FT is not closing yet")

    if sections_to_update or update_out_of_section:
        _update_timeouts(
            selected_sections=sections_to_update, calc_out_of_section=update_out_of_section
        )


def maybe_setup_simulated_fault() -> None:
    """Sets a simulated fault, based on `FT_SIM_FAULT_DESC` env variable.
    Simulated fault description format:
    rank_hung|rank_killed;rank_to_fail|"";base_delay
    NOTE: This if for FT testing only
    """

    simulated_fault_desc = os.environ.get('FT_SIM_FAULT_DESC', None)
    if not simulated_fault_desc:
        return
    fault_type: Any  # silence mypy
    rank_to_fail: Any  # silence mypy
    base_delay: Any  # silence mypy
    fault_type, rank_to_fail, base_delay = simulated_fault_desc.split(';')
    fault_type = fault_type.strip()
    rank_to_fail = rank_to_fail.strip()
    rank_to_fail = int(rank_to_fail) if rank_to_fail else None
    base_delay = float(base_delay.strip())

    rng = random.Random()

    print_rank_0(
        f"FT: Initializing simulated fault: {fault_type},"
        + f"rank to fail: {rank_to_fail}, base delay: {base_delay}"
    )

    # rank that simulates a fault can be explicitly specified in the `rank_to_fail` field
    # if not specified, it just picks a random rank
    rank = torch.distributed.get_rank()
    rand_rank = rng.randint(0, torch.distributed.get_world_size() - 1)
    rank_to_fail = rank_to_fail if rank_to_fail is not None else rand_rank
    rank_to_fail = torch.tensor([rank_to_fail], device=torch.cuda.current_device())
    torch.distributed.broadcast(rank_to_fail, 0)
    rank_to_fail = int(rank_to_fail.item())

    if rank != rank_to_fail:
        # this rank is not going to simulate a fault, nothing more to do
        return

    if fault_type == 'random':
        fault_type = rng.choice(['rank_killed', 'rank_hung'])

    if fault_type == 'rank_killed':
        target_pid = os.getpid()
    elif fault_type == 'rank_hung':
        target_pid = os.getpid()
    else:
        raise Exception(f"Unknown fault type {fault_type} expected one of: rank_killed, rank_hung.")

    # add some randomness to the delay
    delay = base_delay + 0.2 * rng.random() * base_delay

    print_rank_0(f"FT: Selected fault={fault_type}; target rank={rank_to_fail}; delay={delay}")

    def __fault_thread():
        time.sleep(delay)
        for of in [sys.stdout, sys.stderr]:
            print(
                f"\n####\nFT: Simulating fault: {fault_type}; rank to fail: {rank_to_fail}\n####\n",
                file=of,
                flush=True,
            )
        if fault_type == 'rank_hung':
            os.kill(target_pid, signal.SIGSTOP)
        else:
            os.kill(target_pid, signal.SIGKILL)
xingjinliang's avatar
xingjinliang committed
364

wangxj's avatar
wangxj committed
365
366
367
    fault_sim_thread = threading.Thread(target=__fault_thread)
    fault_sim_thread.daemon = True
    fault_sim_thread.start()