graph.py 33 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.

"""Functions for CUDA Graphs support in FP8"""
6
from collections.abc import Iterable
7
8
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union

9
10
11
12
13
import torch
from torch.utils._pytree import tree_flatten as _tree_flatten
from torch.utils._pytree import tree_unflatten as _tree_unflatten
from torch._C import _graph_pool_handle

14
from transformer_engine.common.recipe import DelayedScaling, Recipe
15
from transformer_engine.pytorch.constants import dist_group_type
16
17
18
19
20
21
22
from .fp8 import (
    fp8_autocast,
    FP8GlobalStateManager,
    get_default_fp8_recipe,
)
from .distributed import get_all_rng_states, graph_safe_rng_available
from .module.base import TransformerEngineBaseModule
23
from .ops.op import BasicOperation
24
25
26
27
28
29

__all__ = ["make_graphed_callables"]


_IS_GRAPH_CAPTURING = False

30
31
32
_T = TypeVar("_T")
SingleOrTuple = Union[_T, Tuple[_T, ...]]

33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58

def set_capture_start() -> None:
    """Record beginning of `make_graphed_callables`."""
    global _IS_GRAPH_CAPTURING
    _IS_GRAPH_CAPTURING = True


def set_capture_end() -> None:
    """Record end of `make_graphed_callables`."""
    global _IS_GRAPH_CAPTURING
    _IS_GRAPH_CAPTURING = False


def is_graph_capturing() -> None:
    """Return whether within `make_graphed_callables`."""
    return _IS_GRAPH_CAPTURING


def graph_pool_handle():
    """
    Returns an opaque token representing the id of a graph memory pool.
    """
    return _graph_pool_handle()


def _make_graphed_callables(
59
60
61
62
63
64
65
    callables: SingleOrTuple[Callable],
    sample_args: SingleOrTuple[Tuple[torch.Tensor, ...]],
    num_warmup_iters: int = 3,
    allow_unused_input: bool = False,
    fp8_weight_caching: bool = False,
    sample_kwargs: Optional[SingleOrTuple[Dict[str, Any]]] = None,
    _order: Optional[List[int]] = None,
66
    pool: Optional[Tuple[int, ...]] = None,
67
    retain_graph_in_backward: bool = False,
68
) -> SingleOrTuple[Callable]:
69
70
71
72
73
74
75
76
77
78
    """
    Helper method for `make_graphed_callables`
    """

    if torch.is_autocast_enabled() and torch.is_autocast_cache_enabled():
        raise RuntimeError(
            "make_graphed_callables does not support the autocast "
            "caching. Please set `cache_enabled=False`."
        )

79
80
81
82
83
84
    # Default is to pass no kwargs to callables
    if sample_kwargs is None:
        if isinstance(callables, tuple):
            sample_kwargs = tuple({} for _ in range(len(sample_args)))
        else:
            sample_kwargs = {}
85

86
87
    # Canonicalize args as tuples
    just_one_callable = False
88
89
90
91
    if not isinstance(callables, tuple):
        just_one_callable = True
        callables = (callables,)
        sample_args = (sample_args,)
92
        sample_kwargs = (sample_kwargs,)
93

94
95
96
97
98
99
100
101
    # Check training/inference
    is_training = all(c.training for c in callables)
    if not is_training and any(c.training for c in callables):
        assert False, (
            "make_graphed_callables only supports when modules are all in training or all in"
            " inference mode."
        )

102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
    # Check sizes of args
    if _order is None:
        assert len(sample_args) == len(callables)
        assert len(sample_kwargs) == len(callables)
    else:
        # Custom logic for interleaved pipeline parallelism
        # Note: This is tightly coupled with the Megatron-core
        # implementation of interleaved pipeline parallelism at
        # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/pipeline_parallel/schedules.py.
        # Note: The model is assumed to consist of layers
        # (corresponding to callables) that are grouped into
        # equally-sized model chunks. _order is a list of chunk
        # indices (1-indexed) that indicates the order in which the
        # layers are evaluated. Positive values indicate forward
        # passes and negative values indicate backward passes. Each
        # entry in sample_args corresponds to one of the forward
        # passes.
119
120
121
        num_model_chunks = max(_order)
        num_microbatches = len(_order) // num_model_chunks // 2
        assert num_model_chunks * num_microbatches * 2 == len(_order)
122
123
124
        assert len(sample_args) * 2 >= len(_order) and (
            len(sample_args) * 2 % len(_order) == 0
        ), f"{len(sample_args)} >= {len(_order)} and {len(sample_args)} % {len(_order)} == 0"
125
        num_layers = len(sample_args) // num_model_chunks // num_microbatches
126
127
        assert len(callables) == num_model_chunks * num_layers, (
            f"Callables should have ({num_model_chunks * num_layers}) "
128
129
            + f"entries when order input is provided but got {len(callables)}."
        )
130
131
        assert len(sample_args) == num_model_chunks * num_microbatches * num_layers, (
            f"Expected {num_model_chunks * num_microbatches}"
132
133
            + f"args tuple, but got {len(sample_args)}."
        )
134
        assert len(sample_kwargs) == len(sample_args)
135
136

    if fp8_weight_caching:
137
        # Initialize flag that controls FP8 weight updates
138
139
        FP8GlobalStateManager.set_skip_fp8_weight_update_tensor(False)

140
    # Check callables
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
    for c in callables:
        if isinstance(c, torch.nn.Module):
            assert (
                len(c._backward_hooks) == 0
                and len(c._forward_hooks) == 0
                and len(c._forward_pre_hooks) == 0
            ), (
                "Modules must not have hooks registered at the time they are passed. "
                + "However, registering hooks on modules after passing them "
                + "through make_graphed_callables is allowed."
            )
            assert all(b.requires_grad is False for b in c.buffers()), (
                "In any :class:`~torch.nn.Module` passed to "
                + ":func:`~make_graphed_callables`, only parameters may be trainable. "
                + "All buffers must have ``requires_grad=False``."
            )
157
158
159
160
161

    # Flatten callable arguments
    per_callable_kwargs_keys = [list(kwargs.keys()) for kwargs in sample_kwargs]
    flatten_sample_args = []
    for args, kwargs, kwargs_keys in zip(sample_args, sample_kwargs, per_callable_kwargs_keys):
162
        flatten_arg, _ = _tree_flatten(args)
163
164
        flatten_kwarg, _ = _tree_flatten([kwargs[key] for key in kwargs_keys])
        flatten_sample_args.append(tuple(flatten_arg + flatten_kwarg))
165
166
167
168
169
170
171
        assert all(isinstance(arg, torch.Tensor) for arg in flatten_arg), (
            "In the beta API, sample_args "
            + "for each callable must contain only Tensors. Other types are not allowed."
        )

    # If a callable is an nn.Module, its graph's full input surface is the args the user explicitly
    # passes to forward (ie, its sample_args) AND the module's parameter attributes.
172
173
174
175
    # Note: These per_callable_* variables are not actually
    # per-callable, but per-forward-pass (see description of _order).
    # The names are kept for consistency with
    # torch.cuda.make_graphed_callables.
176
177
178
    per_callable_len_user_args = [len(args) for args in flatten_sample_args]
    if _order is None:
        per_callable_module_params = [
179
            tuple(c.parameters()) if isinstance(c, torch.nn.Module) else () for c in callables
180
181
        ]
        per_callable_static_input_surfaces = [
182
            flatten_sample_args[i] + per_callable_module_params[i] for i in range(len(callables))
183
184
185
        ]
    else:
        per_callable_module_params = []
186
187
188
189
190
191
192
193
        for m_chunk in range(num_model_chunks):
            for _ in range(num_microbatches):
                for l_no in range(num_layers):
                    per_callable_module_params.append(
                        tuple(callables[m_chunk * num_layers + l_no].parameters())
                        if isinstance(callables[m_chunk * num_layers + l_no], torch.nn.Module)
                        else ()
                    )
194
195
196
197
198
199
200
201
202
        assert len(per_callable_module_params) == len(flatten_sample_args)
        per_callable_static_input_surfaces = [
            flatten_sample_args[i] + per_callable_module_params[i]
            for i in range(len(flatten_sample_args))
        ]

    fwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(flatten_sample_args))]
    bwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(flatten_sample_args))]
    graph_callables = [None for _ in range(len(flatten_sample_args))]
203

204
205
206
207
208
209
210
    # For cases with multiple active RNG states, e.g. TP.
    if graph_safe_rng_available():
        for _, state in get_all_rng_states().items():
            for fwd_graph, bwd_graph in zip(fwd_graphs, bwd_graphs):
                fwd_graph.register_generator_state(state)
                bwd_graph.register_generator_state(state)

211
    mempool = graph_pool_handle() if pool is None else pool
212
213
214
215
216

    # Warmup
    # Hopefully prevents cudnn benchmarking and other lazy-initialization cuda work
    # from ending up in any captures.
    torch.cuda.synchronize()
217
218
219
220
221

    # Get warmup func and func_idx.
    warmup_func_idx = []
    warmup_func = []
    if _order is None:
222
        for func_idx, func in enumerate(callables):
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
            warmup_func_idx.append(func_idx)
            warmup_func.append(func)
    else:
        fwd_idx = [0] * num_model_chunks
        for c_id in _order:
            if c_id > 0:
                m_chunk = c_id - 1
                for l_no in range(num_layers):
                    func = callables[m_chunk * num_layers + l_no]
                    func_idx = (m_chunk * num_microbatches * num_layers) + (
                        fwd_idx[m_chunk] * num_layers + l_no
                    )
                    warmup_func_idx.append(func_idx)
                    warmup_func.append(func)
                fwd_idx[m_chunk] += 1
    assert len(warmup_func) == len(
        sample_args
    ), f"Warmup runs {len(warmup_func)} don't match args {len(sample_args)}."
    assert len(warmup_func_idx) == len(
        set(warmup_func_idx)
    ), f"Warmup runs {len(warmup_func)} but only {len(set(warmup_func_idx))} are unique."

    # Filter the TE modules that cudagraph can access.
    visited_te_modules = set()

    def hook_fn(module, inputs, outputs):  # pylint: disable=unused-argument
        if isinstance(module, TransformerEngineBaseModule):
            visited_te_modules.add(module)

    # Run warmup and do the above filtering.
    with torch.cuda.stream(torch.cuda.Stream()):
        for func_idx, func in zip(warmup_func_idx, warmup_func):
255
256
257
            args = sample_args[func_idx]
            kwargs = sample_kwargs[func_idx]
            static_input_surface = per_callable_static_input_surfaces[func_idx]
258
            for _ in range(num_warmup_iters):
259
260
261
262
                hooks = []
                for module in func.modules():
                    hook = module.register_forward_hook(hook_fn)
                    hooks.append(hook)
263
                outputs, _ = _tree_flatten(func(*args, **kwargs))
264
265
                for hook in hooks:
                    hook.remove()
266
267
268
269
270
271
272
273
274
275
                if is_training:
                    grad_inputs = torch.autograd.grad(
                        outputs=tuple(o for o in outputs if o.requires_grad),
                        inputs=tuple(i for i in static_input_surface if i.requires_grad),
                        grad_outputs=tuple(torch.empty_like(o) for o in outputs if o.requires_grad),
                        only_inputs=True,
                        allow_unused=allow_unused_input,
                    )
                else:
                    grad_inputs = None
276
                del outputs, grad_inputs
277
278
279
280
281
            # The following code is added specifically for MCore's special requirements,
            # aimed at preventing warmup from altering the control flow.
            for module in func.modules():
                if hasattr(module, "is_first_microbatch"):
                    module.is_first_microbatch = True
282
283
284
285
286
287
    torch.cuda.synchronize()

    # All captures here share a mempool. To avoid replays corrupting each other's memory,
    # the safest approach is to capture all passes in the same order they'll run:
    # fwd 1, fwd 2, ... fwd N, then bwd N, bwd N-1, ... bwd 1.

288
    if _order is not None:  # pylint: disable=too-many-nested-blocks
289
290
291
292
293
294
295
296
297
        per_callable_static_outputs = [None] * len(flatten_sample_args)
        per_callable_output_unflatten_spec = [None] * len(flatten_sample_args)
        per_callable_static_grad_outputs = [None] * len(flatten_sample_args)
        per_callable_static_grad_inputs = [None] * len(flatten_sample_args)
        fwd_idx = [0] * num_model_chunks
        bwd_idx = [0] * num_model_chunks
        for c_id in _order:
            if c_id > 0:
                # Capture forward graph for model chunk c_id, microbatch fwd_idx[c_id-1]
298
                m_chunk = c_id - 1
299
                for l_no in range(num_layers):
300
301
302
303
                    func = callables[m_chunk * num_layers + l_no]
                    per_callable_fwd_idx = (m_chunk * num_microbatches * num_layers) + (
                        fwd_idx[m_chunk] * num_layers + l_no
                    )
304
                    args = sample_args[per_callable_fwd_idx]
305
                    kwargs = sample_kwargs[per_callable_fwd_idx]
306
307
                    fwd_graph = fwd_graphs[per_callable_fwd_idx]
                    with torch.cuda.graph(fwd_graph, pool=mempool):
308
                        outputs = func(*args, **kwargs)
309
310
311
312
313
314
315
                    flatten_outputs, spec = _tree_flatten(outputs)
                    per_callable_static_outputs[per_callable_fwd_idx] = tuple(flatten_outputs)
                    per_callable_output_unflatten_spec[per_callable_fwd_idx] = spec
                    graph_callables[per_callable_fwd_idx] = func
                fwd_idx[m_chunk] += 1
            else:
                # Capture backward graph for model chunk c_id, microbatch bwd_idx[-c_id-1]
316
                m_chunk = -c_id - 1
317
                for l_no in list(reversed(range(num_layers))):
318
319
320
                    per_callable_bwd_idx = (m_chunk * num_microbatches * num_layers) + (
                        bwd_idx[m_chunk] * num_layers + l_no
                    )
321
322
323
324
325
326
327
                    static_input_surface = per_callable_static_input_surfaces[per_callable_bwd_idx]
                    static_outputs = per_callable_static_outputs[per_callable_bwd_idx]
                    bwd_graph = bwd_graphs[per_callable_bwd_idx]
                    # For now, assumes all static_outputs require grad
                    static_grad_outputs = tuple(
                        torch.empty_like(o) if o.requires_grad else None for o in static_outputs
                    )
328
329
330
331
332
333
334
335
336
337
                    if is_training:
                        with torch.cuda.graph(bwd_graph, pool=mempool):
                            grad_inputs = torch.autograd.grad(
                                outputs=tuple(o for o in static_outputs if o.requires_grad),
                                inputs=tuple(i for i in static_input_surface if i.requires_grad),
                                grad_outputs=tuple(o for o in static_grad_outputs if o is not None),
                                only_inputs=True,
                                allow_unused=allow_unused_input,
                                retain_graph=retain_graph_in_backward,
                            )
338
339
340
341
342
343
                    # Constructs a tuple suitable for returning from Graphed.backward:
                    # Pads out the actually-needed grads with Nones in gradient slots for inputs
                    # that don't require grad. I couldn't think of a one-liner for this pattern.
                    static_grad_inputs = []
                    grad_idx = 0
                    for arg in static_input_surface:
344
                        if is_training and isinstance(arg, torch.Tensor) and arg.requires_grad:
345
346
347
348
349
350
351
352
353
354
355
356
357
358
                            static_grad_inputs.append(grad_inputs[grad_idx])
                            grad_idx += 1
                        else:
                            static_grad_inputs.append(None)  # type: ignore[arg-type]
                    static_grad_inputs = tuple(static_grad_inputs)  # type: ignore[assignment]

                    per_callable_static_grad_outputs[per_callable_bwd_idx] = static_grad_outputs
                    per_callable_static_grad_inputs[per_callable_bwd_idx] = static_grad_inputs
                bwd_idx[m_chunk] += 1
    else:
        # Capture forward graphs
        per_callable_static_outputs = []
        per_callable_output_unflatten_spec = []
        graph_id = 0
359
        for func, args, kwargs, fwd_graph in zip(callables, sample_args, sample_kwargs, fwd_graphs):
360
            with torch.cuda.graph(fwd_graph, pool=mempool):
361
                outputs = func(*args, **kwargs)
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
            graph_callables[graph_id] = func
            graph_id += 1

            flatten_outputs, spec = _tree_flatten(outputs)
            per_callable_static_outputs.append(tuple(flatten_outputs))
            per_callable_output_unflatten_spec.append(spec)

        # Capture backward graphs in reverse order
        per_callable_static_grad_outputs = []
        per_callable_static_grad_inputs = []
        for static_input_surface, static_outputs, bwd_graph in zip(
            reversed(per_callable_static_input_surfaces),
            reversed(per_callable_static_outputs),
            reversed(bwd_graphs),
        ):
            # For now, assumes all static_outputs require grad
            static_grad_outputs = tuple(
                torch.empty_like(o) if o.requires_grad else None for o in static_outputs
            )
381
382
383
384
385
386
387
388
389
390
            if is_training:
                with torch.cuda.graph(bwd_graph, pool=mempool):
                    grad_inputs = torch.autograd.grad(
                        outputs=tuple(o for o in static_outputs if o.requires_grad),
                        inputs=tuple(i for i in static_input_surface if i.requires_grad),
                        grad_outputs=tuple(o for o in static_grad_outputs if o is not None),
                        only_inputs=True,
                        allow_unused=allow_unused_input,
                        retain_graph=retain_graph_in_backward,
                    )
391
392
393
394
395
396
            # Constructs a tuple suitable for returning from Graphed.backward:
            # Pads out the actually-needed grads with Nones in gradient slots for inputs that
            # don't require grad. I couldn't think of a slick one-liner for this pattern.
            static_grad_inputs = []
            grad_idx = 0
            for arg in static_input_surface:
397
                if is_training and isinstance(arg, torch.Tensor) and arg.requires_grad:
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
                    static_grad_inputs.append(grad_inputs[grad_idx])
                    grad_idx += 1
                else:
                    static_grad_inputs.append(None)  # type: ignore[arg-type]
            static_grad_inputs = tuple(static_grad_inputs)  # type: ignore[assignment]

            per_callable_static_grad_outputs.append(static_grad_outputs)
            per_callable_static_grad_inputs.append(static_grad_inputs)

        # Reverses the most recent two lists
        per_callable_static_grad_outputs = list(reversed(per_callable_static_grad_outputs))
        per_callable_static_grad_inputs = list(reversed(per_callable_static_grad_inputs))
    # Now for every per_callable list, per_callable_*[i] holds the stuff for the ith callable.

    def make_graphed_autograd_function(
        fwd_graph,
        bwd_graph,
        module_params,
416
        kwargs_keys,
417
418
419
420
421
422
423
424
425
        len_user_args,
        output_unflatten_spec,
        static_input_surface,
        static_outputs,
        static_grad_outputs,
        static_grad_inputs,
    ):
        class Graphed(torch.autograd.Function):
            """Autograd function for graph replay."""
426

427
428
            @staticmethod
            def forward(ctx, skip_fp8_weight_update, *inputs):
429
                # pylint: disable=missing-function-docstring
430
431

                # Set flag for whether to update FP8 weight updates
432
433
434
435
                ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module()
                if ctx.is_first_module and skip_fp8_weight_update is not None:
                    FP8GlobalStateManager.set_skip_fp8_weight_update_tensor(skip_fp8_weight_update)

436
                # Copy values from new tensors into static tensors
437
                for i in range(len_user_args):
438
439
440
441
                    if (
                        isinstance(static_input_surface[i], torch.Tensor)
                        and static_input_surface[i].data_ptr() != inputs[i].data_ptr()
                    ):
442
                        static_input_surface[i].copy_(inputs[i])
443
444

                # Replay forward graph
445
446
447
448
449
450
451
                fwd_graph.replay()
                assert isinstance(static_outputs, tuple)
                return tuple(o.detach() for o in static_outputs)

            @staticmethod
            @torch.autograd.function.once_differentiable
            def backward(ctx, *grads):
452
                # pylint: disable=missing-function-docstring
453
454

                # Replay backward graph
455
456
457
458
459
460
461
462
463
                assert len(grads) == len(static_grad_outputs)
                for g, grad in zip(static_grad_outputs, grads):
                    if g is not None:
                        # don't copy if autograd gods have been kind and the
                        # incoming grad is already in the right place
                        if g.data_ptr() != grad.data_ptr():
                            g.copy_(grad)
                bwd_graph.replay()

464
                # Update FP8 scale factors if needed
465
466
467
468
469
470
471
472
473
474
                if ctx.is_first_module:
                    FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)

                # Input args that didn't require grad expect a None gradient.
                assert isinstance(static_grad_inputs, tuple)
                return (None,) + tuple(
                    b.detach() if b is not None else b for b in static_grad_inputs
                )

        def functionalized(*user_args, **user_kwargs):
475
476

            # Decide whether to update FP8 weights
477
478
            skip_fp8_weight_update = None
            if fp8_weight_caching:
479
480
                assert "is_first_microbatch" in user_kwargs and isinstance(
                    user_kwargs["is_first_microbatch"], bool
481
482
483
484
                ), "`is_first_microbatch` boolean kwarg must be provided for FP8 weight caching."

                skip_fp8_weight_update = not user_kwargs["is_first_microbatch"]

485
486
487
488
489
490
491
492
493
494
495
496
            # Check that required kwargs are provided
            for key in kwargs_keys:
                if key not in user_kwargs:
                    raise TypeError(
                        f"Graphed callable was initialized with kwarg {key} ,"
                        "but it was not provided in graph replay"
                    )

            # Runs the autograd function with inputs == all inputs to
            # the graph that might require grad (explicit user args +
            # module parameters)
            # Assumes module params didn't change since capture.
497
            flatten_user_args, _ = _tree_flatten(user_args)
498
499
500
            flatten_user_kwargs, _ = _tree_flatten([user_kwargs[key] for key in kwargs_keys])
            func_args = tuple(flatten_user_args) + tuple(flatten_user_kwargs) + module_params
            out = Graphed.apply(skip_fp8_weight_update, *func_args)
501
502
503
504
505
506
507
508
509
510
511
            return _tree_unflatten(out, output_unflatten_spec)

        return functionalized

    # Put together the final graphed callables
    ret = []
    for i in range(len(sample_args)):
        graphed = make_graphed_autograd_function(
            fwd_graphs[i],
            bwd_graphs[i],
            per_callable_module_params[i],
512
            per_callable_kwargs_keys[i],
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
            per_callable_len_user_args[i],
            per_callable_output_unflatten_spec[i],
            per_callable_static_input_surfaces[i],
            per_callable_static_outputs[i],
            per_callable_static_grad_outputs[i],
            per_callable_static_grad_inputs[i],
        )

        func = graph_callables[i]
        if isinstance(func, torch.nn.Module):

            def make_graphed_forward(func, graph_training_state, graphed, orig_fwd):
                def new_fwd(*user_args, **user_kwargs):
                    # If the module's training-or-eval state matches what we graphed,
                    # run the graph, otherwise run the original forward method
                    if func.training == graph_training_state:
                        # Set the FP8 group from global amax reduction.
                        for m in func.modules():
531
532
533
534
                            if (
                                isinstance(m, TransformerEngineBaseModule)
                                and FP8GlobalStateManager.is_fp8_enabled()
                            ):
535
536
537
538
539
540
541
542
543
544
545
546
547
                                if m not in visited_te_modules:
                                    # Only Set the FP8 meta for the modules included by forward
                                    continue
                                fp8_recipe = FP8GlobalStateManager.get_fp8_recipe()
                                from transformer_engine.pytorch.attention import DotProductAttention

                                if (
                                    isinstance(m, DotProductAttention)
                                    and not fp8_recipe.fp8_mha
                                    and not fp8_recipe.fp8_dpa
                                ):
                                    # Don't need to update FP8 meta for non-FP8 DPA
                                    continue
548
549
550
                                m.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group()
                                m.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe()
                                FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(
551
                                    m.fp8_meta,
552
                                )
553
554
                        return graphed(*user_args, **user_kwargs)
                    return orig_fwd(*user_args, **user_kwargs)
555

556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
                return new_fwd

            forward = make_graphed_forward(func, func.training, graphed, func.forward)
            if _order is None:
                func.forward = forward
                ret.append(func)
            else:
                ret.append(forward)
        else:
            ret.append(graphed)

    if just_one_callable:
        return ret[0]

    return tuple(ret)


573
574
def save_fp8_tensors(
    modules: Iterable[torch.nn.Module],
575
576
    fp8_recipe: Recipe,
) -> Optional[List[Any]]:
577
578
579
580
    """
    Returns the FP8 tensors for all modules
    with adjusted amax history sizes.
    """
581
582
583
584

    if not isinstance(fp8_recipe, DelayedScaling):
        return None

585
    fp8_tensors = []
586
587
    for module in modules:
        for m in module.modules():
588
            module_tensors = None
589
590
            if isinstance(m, TransformerEngineBaseModule):
                if m.primary_weights_in_fp8:
591
592
593
594
595
596
597
598
599
600
601
                    m.adjust_amax_history_length(fp8_recipe.amax_history_len)
                module_tensors = m.get_fp8_meta_tensors()
            elif isinstance(m, BasicOperation):
                m.pre_forward(fp8_enabled=True, fp8_recipe=fp8_recipe)
                module_tensors = m._save_fp8_metas()
            fp8_tensors.append(module_tensors)
    return fp8_tensors


def restore_fp8_tensors(
    modules: Iterable[torch.nn.Module],
602
    fp8_tensors: Optional[List[Any]],
603
) -> None:
604
    """Restore FP8 tensors."""
605
606
607
608

    if fp8_tensors is None:
        return

609
610
    for module in modules:
        for m in module.modules():
611
            module_tensors = fp8_tensors.pop(0)
612
            if isinstance(m, TransformerEngineBaseModule):
613
614
615
616
617
618
619
620
                m.reset_fp8_meta_tensors(module_tensors)
            elif isinstance(m, BasicOperation):
                m._load_fp8_metas(module_tensors)
    if len(fp8_tensors) != 0:
        raise RuntimeError(
            f"Got FP8 state for {len(fp8_tensors)} more modules than expected. "
            "There is probably a discrepancy with `save_fp8_tensors`."
        )
621
622
623


def make_graphed_callables(
624
625
626
627
628
629
630
631
    modules: SingleOrTuple[Callable],
    sample_args: SingleOrTuple[Tuple[torch.Tensor, ...]],
    num_warmup_iters: int = 3,
    allow_unused_input: bool = False,
    sample_kwargs: Optional[SingleOrTuple[Dict[str, Any]]] = None,
    fp8_enabled: bool = False,
    fp8_calibrating: bool = False,
    fp8_recipe: Optional[DelayedScaling] = None,
632
    fp8_group: Optional[dist_group_type] = None,
633
634
    fp8_weight_caching: bool = False,
    _order: Optional[List[int]] = None,
635
    pool: Optional[Tuple[int, ...]] = None,
636
    retain_graph_in_backward: bool = False,
637
) -> Union[Callable, Tuple[Callable, ...]]:
638
    """
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
    Make CUDA graph version of Transformer Engine modules

    A variation of PyTorch's `make_graphed_callables` utility function
    with support for Transformer Engine modules and FP8. Please see
    the
    `original PyTorch implementation <https://pytorch.org/docs/stable/generated/torch.cuda.make_graphed_callables.html>`_
    for more documentation.

    Graphing parameters
    -------------------
    modules: (tuple of) callable
             Callable or callables to graph.
    sample_args: (tuple of) tuple of torch.Tensor
                 Positional arguments to callable(s).
    num_warmup_iters: int, default = 3
                      Number of warmup iterations.
    allow_unused_input: bool, default = `False`
                        Whether to handle case where callable inputs
                        and outputs are disconnected in compute graph.
    sample_kwargs: (tuple of) dict, optional
                   Keyword arguments to callable(s)
660
661
662
    pool: (tuple of) int, default = `None`, optional
          An instance returned from function `torch.cuda.graph_pool_handle` that hints
          this graph may share memory with the indicated pool.
663
664
    retain_graph_in_backward: bool, default = `False`
                              Whether to set retain_graph=True in backward graph capture.
665
666
667

    FP8-related parameters
    ----------------------
668
669
670
671
672
673
674
675
676
    fp8_enabled: bool, default = `True`
                 whether or not to enable fp8
    fp8_calibrating: bool, default = `False`
                     calibration mode allows collecting statistics such as amax and scale
                     data of fp8 tensors even when executing without fp8 enabled. This is
                     useful for saving an inference ready fp8 checkpoint while training
                     using a higher precision.
    fp8_recipe: recipe.DelayedScaling, default = `None`
                recipe used for FP8 training.
677
678
679
    fp8_group: torch._C._distributed_c10d.ProcessGroup, default = `None`
               distributed group over which amaxes for the fp8 tensors
               are reduced at the end of each training step.
680
681
682
683
684
685
686
    fp8_weight_caching: bool, default = `False`
                        Whether or not to cache FP8 weights across microbatches. if set to `True`,
                        the `is_first_microbatch` boolean argument must be passed into the forward
                        method for TransformerEngine modules. When storing primary weights in FP8
                        using TE's `fp8_model_init` API and using an FP8 aware optimizer, this arg
                        must be set to `False` if calculating weight transposes' outside TE, e.g.,
                        in the optimizer step.
687

688
689
690
691
692
693
694
695
696
697
698
699
    """
    set_capture_start()

    fp8_recipe = get_default_fp8_recipe() if fp8_recipe is None else fp8_recipe

    # Handle single module.
    just_one_callable = False
    if not isinstance(modules, tuple):
        just_one_callable = True
        modules = (modules,)

    # Store FP8 tensors to reset later.
700
    saved_fp8_tensors = save_fp8_tensors(modules, fp8_recipe=fp8_recipe)
701
702
703
704

    # FP8 wrapper.
    def wrap_autocast(block):
        old_forward = block.forward
705

706
        def forward_func(*args, **kwargs):
707
            with fp8_autocast(
708
709
710
711
712
                enabled=fp8_enabled,
                calibrating=fp8_calibrating,
                fp8_recipe=fp8_recipe,
                fp8_group=fp8_group,
                _graph=True,
713
            ):
714
715
                outputs = old_forward(*args, **kwargs)
            return outputs
716

717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
        block.forward = forward_func

    forward_funcs = []
    for module in modules:
        assert isinstance(module, torch.nn.Module), f"Graphing for {type(module)} is not supported."
        wrap_autocast(module)
        forward_funcs.append(module)

    if just_one_callable:
        forward_funcs = forward_funcs[0]
    else:
        forward_funcs = tuple(forward_funcs)

    # Save RNG state.
    if graph_safe_rng_available():
732
733
734
735
        generators = [
            torch.cuda.default_generators[torch.cuda.current_device()],
            *get_all_rng_states().values(),
        ]
736
737
738
739
740
        original_rng_states = [state.get_state() for state in generators]
    else:
        original_rng_states = torch.cuda.get_rng_state()

    graphed_callables = _make_graphed_callables(
741
742
743
        forward_funcs,
        sample_args,
        num_warmup_iters=num_warmup_iters,
744
        allow_unused_input=allow_unused_input,
745
        fp8_weight_caching=fp8_weight_caching,
746
        sample_kwargs=sample_kwargs,
747
        _order=_order,
748
        pool=pool,
749
        retain_graph_in_backward=retain_graph_in_backward,
750
    )
751
752
753
754
755
756
757
758
759
760
761
762
763

    # Ensures warmup does not affect numerics for ops such as dropout.
    if graph_safe_rng_available():
        for gen, state in zip(generators, original_rng_states):
            gen.set_state(state)
    else:
        torch.cuda.set_rng_state(original_rng_states)

    # Restore FP8 state.
    restore_fp8_tensors(modules, saved_fp8_tensors)

    set_capture_end()
    return graphed_callables