graph.py 27.7 KB
Newer Older
1
2
3
4
5
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""Functions for CUDA Graphs support in FP8"""
6
7
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union

8
9
10
11
12
import torch
from torch.utils._pytree import tree_flatten as _tree_flatten
from torch.utils._pytree import tree_unflatten as _tree_unflatten
from torch._C import _graph_pool_handle

13
from transformer_engine.common.recipe import DelayedScaling
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from .fp8 import (
    fp8_autocast,
    FP8GlobalStateManager,
    get_default_fp8_recipe,
)
from .distributed import get_all_rng_states, graph_safe_rng_available
from .module.base import TransformerEngineBaseModule


__all__ = ["make_graphed_callables"]


_IS_GRAPH_CAPTURING = False

28
29
30
_T = TypeVar("_T")
SingleOrTuple = Union[_T, Tuple[_T, ...]]

31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56

def set_capture_start() -> None:
    """Record beginning of `make_graphed_callables`."""
    global _IS_GRAPH_CAPTURING
    _IS_GRAPH_CAPTURING = True


def set_capture_end() -> None:
    """Record end of `make_graphed_callables`."""
    global _IS_GRAPH_CAPTURING
    _IS_GRAPH_CAPTURING = False


def is_graph_capturing() -> None:
    """Return whether within `make_graphed_callables`."""
    return _IS_GRAPH_CAPTURING


def graph_pool_handle():
    """
    Returns an opaque token representing the id of a graph memory pool.
    """
    return _graph_pool_handle()


def _make_graphed_callables(
57
58
59
60
61
62
63
    callables: SingleOrTuple[Callable],
    sample_args: SingleOrTuple[Tuple[torch.Tensor, ...]],
    num_warmup_iters: int = 3,
    allow_unused_input: bool = False,
    fp8_weight_caching: bool = False,
    sample_kwargs: Optional[SingleOrTuple[Dict[str, Any]]] = None,
    _order: Optional[List[int]] = None,
64
    pool: Optional[Tuple[int, ...]] = None,
65
) -> SingleOrTuple[Callable]:
66
67
68
69
70
71
72
73
74
75
    """
    Helper method for `make_graphed_callables`
    """

    if torch.is_autocast_enabled() and torch.is_autocast_cache_enabled():
        raise RuntimeError(
            "make_graphed_callables does not support the autocast "
            "caching. Please set `cache_enabled=False`."
        )

76
77
78
79
80
81
    # Default is to pass no kwargs to callables
    if sample_kwargs is None:
        if isinstance(callables, tuple):
            sample_kwargs = tuple({} for _ in range(len(sample_args)))
        else:
            sample_kwargs = {}
82

83
84
    # Canonicalize args as tuples
    just_one_callable = False
85
86
87
88
    if not isinstance(callables, tuple):
        just_one_callable = True
        callables = (callables,)
        sample_args = (sample_args,)
89
        sample_kwargs = (sample_kwargs,)
90

91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
    # Check sizes of args
    if _order is None:
        assert len(sample_args) == len(callables)
        assert len(sample_kwargs) == len(callables)
    else:
        # Custom logic for interleaved pipeline parallelism
        # Note: This is tightly coupled with the Megatron-core
        # implementation of interleaved pipeline parallelism at
        # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/pipeline_parallel/schedules.py.
        # Note: The model is assumed to consist of layers
        # (corresponding to callables) that are grouped into
        # equally-sized model chunks. _order is a list of chunk
        # indices (1-indexed) that indicates the order in which the
        # layers are evaluated. Positive values indicate forward
        # passes and negative values indicate backward passes. Each
        # entry in sample_args corresponds to one of the forward
        # passes.
108
109
110
        num_model_chunks = max(_order)
        num_microbatches = len(_order) // num_model_chunks // 2
        assert num_model_chunks * num_microbatches * 2 == len(_order)
111
112
113
        assert len(sample_args) * 2 >= len(_order) and (
            len(sample_args) * 2 % len(_order) == 0
        ), f"{len(sample_args)} >= {len(_order)} and {len(sample_args)} % {len(_order)} == 0"
114
        num_layers = len(sample_args) // num_model_chunks // num_microbatches
115
116
        assert len(callables) == num_model_chunks * num_layers, (
            f"Callables should have ({num_model_chunks * num_layers}) "
117
118
            + f"entries when order input is provided but got {len(callables)}."
        )
119
120
        assert len(sample_args) == num_model_chunks * num_microbatches * num_layers, (
            f"Expected {num_model_chunks * num_microbatches}"
121
122
            + f"args tuple, but got {len(sample_args)}."
        )
123
        assert len(sample_kwargs) == len(sample_args)
124
125

    if fp8_weight_caching:
126
        # Initialize flag that controls FP8 weight updates
127
128
        FP8GlobalStateManager.set_skip_fp8_weight_update_tensor(False)

129
    # Check callables
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
    for c in callables:
        if isinstance(c, torch.nn.Module):
            assert (
                len(c._backward_hooks) == 0
                and len(c._forward_hooks) == 0
                and len(c._forward_pre_hooks) == 0
            ), (
                "Modules must not have hooks registered at the time they are passed. "
                + "However, registering hooks on modules after passing them "
                + "through make_graphed_callables is allowed."
            )
            assert all(b.requires_grad is False for b in c.buffers()), (
                "In any :class:`~torch.nn.Module` passed to "
                + ":func:`~make_graphed_callables`, only parameters may be trainable. "
                + "All buffers must have ``requires_grad=False``."
            )
146
147
148
149
150

    # Flatten callable arguments
    per_callable_kwargs_keys = [list(kwargs.keys()) for kwargs in sample_kwargs]
    flatten_sample_args = []
    for args, kwargs, kwargs_keys in zip(sample_args, sample_kwargs, per_callable_kwargs_keys):
151
        flatten_arg, _ = _tree_flatten(args)
152
153
        flatten_kwarg, _ = _tree_flatten([kwargs[key] for key in kwargs_keys])
        flatten_sample_args.append(tuple(flatten_arg + flatten_kwarg))
154
155
156
157
158
159
160
        assert all(isinstance(arg, torch.Tensor) for arg in flatten_arg), (
            "In the beta API, sample_args "
            + "for each callable must contain only Tensors. Other types are not allowed."
        )

    # If a callable is an nn.Module, its graph's full input surface is the args the user explicitly
    # passes to forward (ie, its sample_args) AND the module's parameter attributes.
161
162
163
164
    # Note: These per_callable_* variables are not actually
    # per-callable, but per-forward-pass (see description of _order).
    # The names are kept for consistency with
    # torch.cuda.make_graphed_callables.
165
166
167
    per_callable_len_user_args = [len(args) for args in flatten_sample_args]
    if _order is None:
        per_callable_module_params = [
168
            tuple(c.parameters()) if isinstance(c, torch.nn.Module) else () for c in callables
169
170
        ]
        per_callable_static_input_surfaces = [
171
            flatten_sample_args[i] + per_callable_module_params[i] for i in range(len(callables))
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
        ]
    else:
        per_callable_module_params = []
        for c in callables:
            for i in range(num_microbatches):
                per_callable_module_params.append(
                    tuple(c.parameters()) if isinstance(c, torch.nn.Module) else ()
                )
        assert len(per_callable_module_params) == len(flatten_sample_args)
        per_callable_static_input_surfaces = [
            flatten_sample_args[i] + per_callable_module_params[i]
            for i in range(len(flatten_sample_args))
        ]

    fwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(flatten_sample_args))]
    bwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(flatten_sample_args))]
    graph_callables = [None for _ in range(len(flatten_sample_args))]
189

190
191
192
193
194
195
196
    # For cases with multiple active RNG states, e.g. TP.
    if graph_safe_rng_available():
        for _, state in get_all_rng_states().items():
            for fwd_graph, bwd_graph in zip(fwd_graphs, bwd_graphs):
                fwd_graph.register_generator_state(state)
                bwd_graph.register_generator_state(state)

197
    mempool = graph_pool_handle() if pool is None else pool
198
199
200
201
202
203

    # Warmup
    # Hopefully prevents cudnn benchmarking and other lazy-initialization cuda work
    # from ending up in any captures.
    torch.cuda.synchronize()
    with torch.cuda.stream(torch.cuda.Stream()):
204
205
206
207
        for func_idx, func in enumerate(callables):
            args = sample_args[func_idx]
            kwargs = sample_kwargs[func_idx]
            static_input_surface = per_callable_static_input_surfaces[func_idx]
208
            for _ in range(num_warmup_iters):
209
                outputs, _ = _tree_flatten(func(*args, **kwargs))
210
211
212
                grad_inputs = torch.autograd.grad(
                    outputs=tuple(o for o in outputs if o.requires_grad),
                    inputs=tuple(i for i in static_input_surface if i.requires_grad),
213
                    grad_outputs=tuple(torch.empty_like(o) for o in outputs if o.requires_grad),
214
215
216
                    only_inputs=True,
                    allow_unused=allow_unused_input,
                )
217
                del outputs, grad_inputs
218
219
220
221
222
223
    torch.cuda.synchronize()

    # All captures here share a mempool. To avoid replays corrupting each other's memory,
    # the safest approach is to capture all passes in the same order they'll run:
    # fwd 1, fwd 2, ... fwd N, then bwd N, bwd N-1, ... bwd 1.

224
    if _order is not None:  # pylint: disable=too-many-nested-blocks
225
226
227
228
229
230
231
232
233
        per_callable_static_outputs = [None] * len(flatten_sample_args)
        per_callable_output_unflatten_spec = [None] * len(flatten_sample_args)
        per_callable_static_grad_outputs = [None] * len(flatten_sample_args)
        per_callable_static_grad_inputs = [None] * len(flatten_sample_args)
        fwd_idx = [0] * num_model_chunks
        bwd_idx = [0] * num_model_chunks
        for c_id in _order:
            if c_id > 0:
                # Capture forward graph for model chunk c_id, microbatch fwd_idx[c_id-1]
234
                m_chunk = c_id - 1
235
                for l_no in range(num_layers):
236
237
238
239
                    func = callables[m_chunk * num_layers + l_no]
                    per_callable_fwd_idx = (m_chunk * num_microbatches * num_layers) + (
                        fwd_idx[m_chunk] * num_layers + l_no
                    )
240
                    args = sample_args[per_callable_fwd_idx]
241
                    kwargs = sample_kwargs[per_callable_fwd_idx]
242
243
                    fwd_graph = fwd_graphs[per_callable_fwd_idx]
                    with torch.cuda.graph(fwd_graph, pool=mempool):
244
                        outputs = func(*args, **kwargs)
245
246
247
248
249
250
251
                    flatten_outputs, spec = _tree_flatten(outputs)
                    per_callable_static_outputs[per_callable_fwd_idx] = tuple(flatten_outputs)
                    per_callable_output_unflatten_spec[per_callable_fwd_idx] = spec
                    graph_callables[per_callable_fwd_idx] = func
                fwd_idx[m_chunk] += 1
            else:
                # Capture backward graph for model chunk c_id, microbatch bwd_idx[-c_id-1]
252
                m_chunk = -c_id - 1
253
                for l_no in list(reversed(range(num_layers))):
254
255
256
                    per_callable_bwd_idx = (m_chunk * num_microbatches * num_layers) + (
                        bwd_idx[m_chunk] * num_layers + l_no
                    )
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
                    static_input_surface = per_callable_static_input_surfaces[per_callable_bwd_idx]
                    static_outputs = per_callable_static_outputs[per_callable_bwd_idx]
                    bwd_graph = bwd_graphs[per_callable_bwd_idx]
                    # For now, assumes all static_outputs require grad
                    static_grad_outputs = tuple(
                        torch.empty_like(o) if o.requires_grad else None for o in static_outputs
                    )
                    with torch.cuda.graph(bwd_graph, pool=mempool):
                        grad_inputs = torch.autograd.grad(
                            outputs=tuple(o for o in static_outputs if o.requires_grad),
                            inputs=tuple(i for i in static_input_surface if i.requires_grad),
                            grad_outputs=tuple(o for o in static_grad_outputs if o is not None),
                            only_inputs=True,
                            allow_unused=allow_unused_input,
                        )
                    # Constructs a tuple suitable for returning from Graphed.backward:
                    # Pads out the actually-needed grads with Nones in gradient slots for inputs
                    # that don't require grad. I couldn't think of a one-liner for this pattern.
                    static_grad_inputs = []
                    grad_idx = 0
                    for arg in static_input_surface:
                        if arg.requires_grad:
                            static_grad_inputs.append(grad_inputs[grad_idx])
                            grad_idx += 1
                        else:
                            static_grad_inputs.append(None)  # type: ignore[arg-type]
                    static_grad_inputs = tuple(static_grad_inputs)  # type: ignore[assignment]

                    per_callable_static_grad_outputs[per_callable_bwd_idx] = static_grad_outputs
                    per_callable_static_grad_inputs[per_callable_bwd_idx] = static_grad_inputs
                bwd_idx[m_chunk] += 1
    else:
        # Capture forward graphs
        per_callable_static_outputs = []
        per_callable_output_unflatten_spec = []
        graph_id = 0
293
        for func, args, kwargs, fwd_graph in zip(callables, sample_args, sample_kwargs, fwd_graphs):
294
            with torch.cuda.graph(fwd_graph, pool=mempool):
295
                outputs = func(*args, **kwargs)
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
            graph_callables[graph_id] = func
            graph_id += 1

            flatten_outputs, spec = _tree_flatten(outputs)
            per_callable_static_outputs.append(tuple(flatten_outputs))
            per_callable_output_unflatten_spec.append(spec)

        # Capture backward graphs in reverse order
        per_callable_static_grad_outputs = []
        per_callable_static_grad_inputs = []
        for static_input_surface, static_outputs, bwd_graph in zip(
            reversed(per_callable_static_input_surfaces),
            reversed(per_callable_static_outputs),
            reversed(bwd_graphs),
        ):
            # For now, assumes all static_outputs require grad
            static_grad_outputs = tuple(
                torch.empty_like(o) if o.requires_grad else None for o in static_outputs
            )
            with torch.cuda.graph(bwd_graph, pool=mempool):
                grad_inputs = torch.autograd.grad(
                    outputs=tuple(o for o in static_outputs if o.requires_grad),
                    inputs=tuple(i for i in static_input_surface if i.requires_grad),
                    grad_outputs=tuple(o for o in static_grad_outputs if o is not None),
                    only_inputs=True,
                    allow_unused=allow_unused_input,
                )
            # Constructs a tuple suitable for returning from Graphed.backward:
            # Pads out the actually-needed grads with Nones in gradient slots for inputs that
            # don't require grad. I couldn't think of a slick one-liner for this pattern.
            static_grad_inputs = []
            grad_idx = 0
            for arg in static_input_surface:
                if arg.requires_grad:
                    static_grad_inputs.append(grad_inputs[grad_idx])
                    grad_idx += 1
                else:
                    static_grad_inputs.append(None)  # type: ignore[arg-type]
            static_grad_inputs = tuple(static_grad_inputs)  # type: ignore[assignment]

            per_callable_static_grad_outputs.append(static_grad_outputs)
            per_callable_static_grad_inputs.append(static_grad_inputs)

        # Reverses the most recent two lists
        per_callable_static_grad_outputs = list(reversed(per_callable_static_grad_outputs))
        per_callable_static_grad_inputs = list(reversed(per_callable_static_grad_inputs))
    # Now for every per_callable list, per_callable_*[i] holds the stuff for the ith callable.

    def make_graphed_autograd_function(
        fwd_graph,
        bwd_graph,
        module_params,
348
        kwargs_keys,
349
350
351
352
353
354
355
356
357
        len_user_args,
        output_unflatten_spec,
        static_input_surface,
        static_outputs,
        static_grad_outputs,
        static_grad_inputs,
    ):
        class Graphed(torch.autograd.Function):
            """Autograd function for graph replay."""
358

359
360
            @staticmethod
            def forward(ctx, skip_fp8_weight_update, *inputs):
361
                # pylint: disable=missing-function-docstring
362
363

                # Set flag for whether to update FP8 weight updates
364
365
366
367
                ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module()
                if ctx.is_first_module and skip_fp8_weight_update is not None:
                    FP8GlobalStateManager.set_skip_fp8_weight_update_tensor(skip_fp8_weight_update)

368
                # Copy values from new tensors into static tensors
369
370
371
                for i in range(len_user_args):
                    if static_input_surface[i].data_ptr() != inputs[i].data_ptr():
                        static_input_surface[i].copy_(inputs[i])
372
373

                # Replay forward graph
374
375
376
377
378
379
380
                fwd_graph.replay()
                assert isinstance(static_outputs, tuple)
                return tuple(o.detach() for o in static_outputs)

            @staticmethod
            @torch.autograd.function.once_differentiable
            def backward(ctx, *grads):
381
                # pylint: disable=missing-function-docstring
382
383

                # Replay backward graph
384
385
386
387
388
389
390
391
392
                assert len(grads) == len(static_grad_outputs)
                for g, grad in zip(static_grad_outputs, grads):
                    if g is not None:
                        # don't copy if autograd gods have been kind and the
                        # incoming grad is already in the right place
                        if g.data_ptr() != grad.data_ptr():
                            g.copy_(grad)
                bwd_graph.replay()

393
                # Update FP8 scale factors if needed
394
395
396
397
398
399
400
401
402
403
                if ctx.is_first_module:
                    FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)

                # Input args that didn't require grad expect a None gradient.
                assert isinstance(static_grad_inputs, tuple)
                return (None,) + tuple(
                    b.detach() if b is not None else b for b in static_grad_inputs
                )

        def functionalized(*user_args, **user_kwargs):
404
405

            # Decide whether to update FP8 weights
406
407
            skip_fp8_weight_update = None
            if fp8_weight_caching:
408
409
                assert "is_first_microbatch" in user_kwargs and isinstance(
                    user_kwargs["is_first_microbatch"], bool
410
411
412
413
                ), "`is_first_microbatch` boolean kwarg must be provided for FP8 weight caching."

                skip_fp8_weight_update = not user_kwargs["is_first_microbatch"]

414
415
416
417
418
419
420
421
422
423
424
425
            # Check that required kwargs are provided
            for key in kwargs_keys:
                if key not in user_kwargs:
                    raise TypeError(
                        f"Graphed callable was initialized with kwarg {key} ,"
                        "but it was not provided in graph replay"
                    )

            # Runs the autograd function with inputs == all inputs to
            # the graph that might require grad (explicit user args +
            # module parameters)
            # Assumes module params didn't change since capture.
426
            flatten_user_args, _ = _tree_flatten(user_args)
427
428
429
            flatten_user_kwargs, _ = _tree_flatten([user_kwargs[key] for key in kwargs_keys])
            func_args = tuple(flatten_user_args) + tuple(flatten_user_kwargs) + module_params
            out = Graphed.apply(skip_fp8_weight_update, *func_args)
430
431
432
433
434
435
436
437
438
439
440
            return _tree_unflatten(out, output_unflatten_spec)

        return functionalized

    # Put together the final graphed callables
    ret = []
    for i in range(len(sample_args)):
        graphed = make_graphed_autograd_function(
            fwd_graphs[i],
            bwd_graphs[i],
            per_callable_module_params[i],
441
            per_callable_kwargs_keys[i],
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
            per_callable_len_user_args[i],
            per_callable_output_unflatten_spec[i],
            per_callable_static_input_surfaces[i],
            per_callable_static_outputs[i],
            per_callable_static_grad_outputs[i],
            per_callable_static_grad_inputs[i],
        )

        func = graph_callables[i]
        if isinstance(func, torch.nn.Module):

            def make_graphed_forward(func, graph_training_state, graphed, orig_fwd):
                def new_fwd(*user_args, **user_kwargs):
                    # If the module's training-or-eval state matches what we graphed,
                    # run the graph, otherwise run the original forward method
                    if func.training == graph_training_state:
                        # Set the FP8 group from global amax reduction.
                        for m in func.modules():
460
461
462
463
                            if (
                                isinstance(m, TransformerEngineBaseModule)
                                and FP8GlobalStateManager.is_fp8_enabled()
                            ):
464
465
466
                                m.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group()
                                m.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe()
                                FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(
467
468
                                    m.fp8_meta, fp8_weights=m._get_fp8_params()
                                )
469
470
                        return graphed(*user_args, **user_kwargs)
                    return orig_fwd(*user_args, **user_kwargs)
471

472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
                return new_fwd

            forward = make_graphed_forward(func, func.training, graphed, func.forward)
            if _order is None:
                func.forward = forward
                ret.append(func)
            else:
                ret.append(forward)
        else:
            ret.append(graphed)

    if just_one_callable:
        return ret[0]

    return tuple(ret)


def save_fp8_tensors(modules, amax_history_len):
    """
    Returns the FP8 tensors for all modules
    with adjusted amax history sizes.
    """
    saved_fp8_meta_tensors = []
    for module in modules:
        for m in module.modules():
            if isinstance(m, TransformerEngineBaseModule):
                if m.primary_weights_in_fp8:
                    m.adjust_amax_history_length(amax_history_len)
                saved_fp8_meta_tensors.append(m.get_fp8_meta_tensors())
    return saved_fp8_meta_tensors


def restore_fp8_tensors(modules, fp8_tensors):
    """Restore FP8 tensors."""
    for module in modules:
        for m in module.modules():
            if isinstance(m, TransformerEngineBaseModule):
                m.reset_fp8_meta_tensors(fp8_tensors.pop(0))
    assert len(fp8_tensors) == 0, "TE internal error."


def make_graphed_callables(
514
515
516
517
518
519
520
521
522
523
    modules: SingleOrTuple[Callable],
    sample_args: SingleOrTuple[Tuple[torch.Tensor, ...]],
    num_warmup_iters: int = 3,
    allow_unused_input: bool = False,
    sample_kwargs: Optional[SingleOrTuple[Dict[str, Any]]] = None,
    fp8_enabled: bool = False,
    fp8_calibrating: bool = False,
    fp8_recipe: Optional[DelayedScaling] = None,
    fp8_weight_caching: bool = False,
    _order: Optional[List[int]] = None,
524
    pool: Optional[Tuple[int, ...]] = None,
525
) -> Union[Callable, Tuple[Callable, ...]]:
526
    """
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
    Make CUDA graph version of Transformer Engine modules

    A variation of PyTorch's `make_graphed_callables` utility function
    with support for Transformer Engine modules and FP8. Please see
    the
    `original PyTorch implementation <https://pytorch.org/docs/stable/generated/torch.cuda.make_graphed_callables.html>`_
    for more documentation.

    Graphing parameters
    -------------------
    modules: (tuple of) callable
             Callable or callables to graph.
    sample_args: (tuple of) tuple of torch.Tensor
                 Positional arguments to callable(s).
    num_warmup_iters: int, default = 3
                      Number of warmup iterations.
    allow_unused_input: bool, default = `False`
                        Whether to handle case where callable inputs
                        and outputs are disconnected in compute graph.
    sample_kwargs: (tuple of) dict, optional
                   Keyword arguments to callable(s)
548
549
550
    pool: (tuple of) int, default = `None`, optional
          An instance returned from function `torch.cuda.graph_pool_handle` that hints
          this graph may share memory with the indicated pool.
551
552
553

    FP8-related parameters
    ----------------------
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
    fp8_enabled: bool, default = `True`
                 whether or not to enable fp8
    fp8_calibrating: bool, default = `False`
                     calibration mode allows collecting statistics such as amax and scale
                     data of fp8 tensors even when executing without fp8 enabled. This is
                     useful for saving an inference ready fp8 checkpoint while training
                     using a higher precision.
    fp8_recipe: recipe.DelayedScaling, default = `None`
                recipe used for FP8 training.
    fp8_weight_caching: bool, default = `False`
                        Whether or not to cache FP8 weights across microbatches. if set to `True`,
                        the `is_first_microbatch` boolean argument must be passed into the forward
                        method for TransformerEngine modules. When storing primary weights in FP8
                        using TE's `fp8_model_init` API and using an FP8 aware optimizer, this arg
                        must be set to `False` if calculating weight transposes' outside TE, e.g.,
                        in the optimizer step.
570

571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
    """
    set_capture_start()

    fp8_recipe = get_default_fp8_recipe() if fp8_recipe is None else fp8_recipe

    # Handle single module.
    just_one_callable = False
    if not isinstance(modules, tuple):
        just_one_callable = True
        modules = (modules,)

    # Store FP8 tensors to reset later.
    saved_fp8_tensors = save_fp8_tensors(modules, fp8_recipe.amax_history_len)

    # FP8 wrapper.
    def wrap_autocast(block):
        old_forward = block.forward
588

589
        def forward_func(*args, **kwargs):
590
591
592
            with fp8_autocast(
                enabled=fp8_enabled, calibrating=fp8_calibrating, fp8_recipe=fp8_recipe, _graph=True
            ):
593
594
                outputs = old_forward(*args, **kwargs)
            return outputs
595

596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
        block.forward = forward_func

    forward_funcs = []
    for module in modules:
        assert isinstance(module, torch.nn.Module), f"Graphing for {type(module)} is not supported."
        wrap_autocast(module)
        forward_funcs.append(module)

    if just_one_callable:
        forward_funcs = forward_funcs[0]
    else:
        forward_funcs = tuple(forward_funcs)

    # Save RNG state.
    if graph_safe_rng_available():
611
612
613
614
        generators = [
            torch.cuda.default_generators[torch.cuda.current_device()],
            *get_all_rng_states().values(),
        ]
615
616
617
618
619
        original_rng_states = [state.get_state() for state in generators]
    else:
        original_rng_states = torch.cuda.get_rng_state()

    graphed_callables = _make_graphed_callables(
620
621
622
        forward_funcs,
        sample_args,
        num_warmup_iters=num_warmup_iters,
623
        allow_unused_input=allow_unused_input,
624
        fp8_weight_caching=fp8_weight_caching,
625
        sample_kwargs=sample_kwargs,
626
        _order=_order,
627
        pool=pool,
628
    )
629
630
631
632
633
634
635
636
637
638
639
640
641

    # Ensures warmup does not affect numerics for ops such as dropout.
    if graph_safe_rng_available():
        for gen, state in zip(generators, original_rng_states):
            gen.set_state(state)
    else:
        torch.cuda.set_rng_state(original_rng_states)

    # Restore FP8 state.
    restore_fp8_tensors(modules, saved_fp8_tensors)

    set_capture_end()
    return graphed_callables