base.py 67.7 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.

"""Base modules and utilities for TransformerEngine PyTorch API"""
6
import io
7
import math
8
9
10
11
import os
import pickle
import warnings
from abc import ABC, abstractmethod
12
from typing import Any, Dict, Generator, List, Optional, Set, Tuple, Union
13
from contextlib import contextmanager
14
import logging
15
from types import MethodType
16
17
18
19

import torch
import torch.nn.functional as F

20
import transformer_engine_torch as tex
21
22
from transformer_engine.common.recipe import Recipe

23
from ._common import _ParameterInitMeta, noop_cat
24
from ..fp8 import (
25
26
    MXFP8BlockScalingRecipeState,
    DelayedScalingRecipeState,
27
    Float8CurrentScalingRecipeState,
28
    Float8BlockScalingRecipeState,
29
    FP8GlobalStateManager,
30
    RecipeState,
31
32
33
34
35
)
from ..distributed import (
    gather_along_first_dim,
    is_fp8_activation_recompute_enabled,
    in_fp8_activation_recompute_phase,
36
    _fsdp_gather_tensors,
37
38
)
from ..constants import dist_group_type
39
40
41
from ..tensor.quantized_tensor import QuantizedTensor, QuantizedTensorBase, Quantizer
from ..tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer
42
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
43
44
from ..tensor._internal.float8_tensor_base import Float8TensorBase
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
45
from ..utils import is_non_tn_fp8_gemm_supported, torch_get_autocast_gpu_dtype
46
from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
47
from ...common.recipe import DelayedScaling, Recipe
48
49
from ...debug.pytorch.debug_state import TEDebugState
from ...debug.pytorch.debug_quantization import DebugQuantizer, DebugQuantizedTensor
50
from ...debug.pytorch.utils import next_iter_when_debug_should_be_run, any_feature_enabled
yuguo's avatar
yuguo committed
51
from torch.utils.cpp_extension import IS_HIP_EXTENSION
52

53
54
__all__ = ["initialize_ub", "destroy_ub"]

55
56
57
_2X_ACC_FPROP = False
_2X_ACC_DGRAD = True
_2X_ACC_WGRAD = True
58
_multi_stream_cublas_workspace = []
59
_dummy_wgrads = {}
60
_multi_stream_cublas_batchgemm_workspace = []
61
62
_cublas_workspace = None
_ub_communicators = None
63
64
ub_stream_nums = int(os.getenv("NVTE_UB_STREAM_NUMS", "2"))
_NUM_MAX_UB_STREAMS = ub_stream_nums if IS_HIP_EXTENSION else 3
65
_MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY = None, None
66
layers_atomic_ring_exchange = []
67
68
69
70


def get_cublas_workspace_size_bytes() -> None:
    """Return 32 MiB if using hopper, 4 MiB for all other architectures."""
yuguo's avatar
yuguo committed
71
72
    # Add env for control the padding for blaslt
    if IS_HIP_EXTENSION:
73
        return 134_217_728
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
    if torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 9:
        return 33_554_432
    return 4_194_304


def get_workspace() -> torch.Tensor:
    """Returns workspace for cublas."""
    global _cublas_workspace
    if _cublas_workspace is None:
        _cublas_workspace = torch.empty(
            get_cublas_workspace_size_bytes(), dtype=torch.uint8, device="cuda"
        )
    return _cublas_workspace


89
90
91
92
def get_multi_stream_cublas_workspace() -> List[torch.Tensor]:
    """Returns workspace for multi-stream cublas."""
    global _multi_stream_cublas_workspace
    if not _multi_stream_cublas_workspace:
93
        for _ in range(tex.get_num_cublas_streams()):
94
95
96
97
98
99
            _multi_stream_cublas_workspace.append(
                torch.empty(get_cublas_workspace_size_bytes(), dtype=torch.uint8, device="cuda")
            )
    return _multi_stream_cublas_workspace


yuguo's avatar
yuguo committed
100
101
102
103
104
105
def get_multi_stream_cublas_batchgemm_workspace() -> List[torch.Tensor]:
    """Returns workspace for multi-stream cublas."""
    global _multi_stream_cublas_batchgemm_workspace
    if not _multi_stream_cublas_batchgemm_workspace:
        for _ in range(tex._num_cublas_batchgemm_streams):
            _multi_stream_cublas_batchgemm_workspace.append(
yuguo's avatar
yuguo committed
106
                torch.empty(128, dtype=torch.uint8, device="cuda")
yuguo's avatar
yuguo committed
107
108
109
            )
    return _multi_stream_cublas_batchgemm_workspace

110

yuguo's avatar
yuguo committed
111
112
113
114
if bool(int(os.getenv("NVTE_DISABLE_FC2_DGRAD_OVERLAP", "0"))):
    remove_ag_gemm_dgrad = ["fc2_dgrad"]
else:
    remove_ag_gemm_dgrad = []
115

116

117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
def get_dummy_wgrad(shape: list, dtype: torch.dtype, zero=False) -> torch.Tensor:
    """Returns a dummy tensor of given shape."""
    assert len(shape) == 2
    global _dummy_wgrads
    if (shape[0], shape[1], dtype) not in _dummy_wgrads:
        _dummy_wgrads[(shape[0], shape[1], dtype)] = torch.empty(
            shape,
            dtype=dtype,
            device="cuda",
            requires_grad=False,
        )
    if zero:
        _dummy_wgrads[(shape[0], shape[1], dtype)].fill_(0)
    return _dummy_wgrads[(shape[0], shape[1], dtype)].detach()

yuguo's avatar
yuguo committed
132
ub_comm_cu_nums = int(os.getenv("NVTE_UB_COMM_CU_NUMS", "8"))
133
134
def initialize_ub(
    shape: list,
135
    tp_size: int,
136
    use_fp8: bool = False,
137
    dtype: torch.dtype = torch.bfloat16,
138
    ub_cfgs: Optional[dict] = None,
139
    bootstrap_backend: Union[str, torch.distributed.Backend] = None,
140
) -> None:
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
    r"""
    Initialize the Userbuffers communicator for overlapping tensor-parallel communications with
    GEMM compute in te.Linear, te.LayerNormLinear and te.LayerNormMLP modules.

    Parameters
    ----------
    shape : list
            shape of the communication buffer, typically set to be the same as the global shape of
            the input tensor to a te.TransformerLayer forward pass, with the sequence and batch
            dimensions collapsed together -- i.e.: `(sequence_length * batch_size, hidden_size)`
    tp_size : int
              number of GPUs in the tensor-parallel process group
    use_fp8 : bool = False
              allocate the communication buffer for FP8 GEMM inputs/outputs
    dtype : torch.dtype = torch.bfloat16
            non-FP8 data type of the communication buffer when `use_fp8 = False`
    ub_cfgs: dict = None
             Configuration dictionary with the structure
             ```
             {
                <gemm_name> : {
                    "method": <"ring_exchange" or "pipeline">,
                    "is_reduce_scatter": bool,
                    "num_sm": int,
                    "cga_size": int,
                    "set_sm_margin": bool,
                    "num_splits": int,
                    "aggregate": bool,
                    "atomic_gemm": bool,
                    "use_ce": bool,
                    "fp8_buf": bool,
                }
             }
             ```
             for `te.TransformerLayer` GEMM layers in `["qkv_fprop", "qkv_dgrad", "qkv_wgrad",
             "proj_fprop", "proj_dgrad", "proj_wgrad", "fc1_fprop", "fc1_dgrad", "fc2_dgrad",
177
             "fc2_fprop", "fc2_wgrad"]`.
178
179
180
181
182
183
184
185
186
187
    bootstrap_backend : str = None
                        `torch.distributed` communication backend for the all-gather, broadcast and
                        barrier collectives during Userbuffers initialization. Not all backends are
                        valid for every cluster configuration and distributed launch method even if
                        they are available in PyTorch. When left unset, the initialization prefers
                        to use the MPI backend, falling back first on Gloo and then NCCL if MPI is
                        not available. Setting `NVTE_UB_WITH_MPI=1` when building TE overrides this
                        option and always initializes Userbuffers with direct MPI calls in C++,
                        which also requires `MPI_HOME=/path/to/mpi/root` to be set at compile time.
    """
188
    if not tex.device_supports_multicast():
yuguo's avatar
yuguo committed
189
        assert bool(int(os.getenv("UB_SKIPMC", "1"))), (
190
191
192
193
            "CUDA device, driver and/or toolkit version does not support comm+GEMM overlap with "
            + "CUDA Multicast. Launch app with UB_SKIPMC=1 to try CUDA IPC instead."
        )

194
195
196
    global _ub_communicators
    assert _ub_communicators is None, "UB communicators are already initialized."
    _ub_communicators = {}
197
198

    if tex.ubuf_built_with_mpi():
199
200
        # We're bootstrapping with direct calls to MPI in Userbuffers code so we need to force
        # an MPI_Init() here by creating a new MPI process group...
201
        assert torch.distributed.is_mpi_available()
202
203
        _ = torch.distributed.new_group(backend="mpi")
        helper = tex.CommOverlapHelper()
204
    else:
205
206
        # Bootstrapping with torch.distributed API, so check backend and construct
        # intra/inter-node process groups...
207
208
209
210
211
        assert (
            torch.distributed.is_initialized()
        ), "torch.distributed must be initialized before Userbuffers"
        if bootstrap_backend is None:
            bootstrap_backend = "nccl"
212
            if torch.distributed.is_mpi_available():
213
                bootstrap_backend = "mpi"
214
215
            elif torch.distributed.is_gloo_available():
                bootstrap_backend = "gloo"
216
        else:
217
218
219
220
221
222
223
224
225
            assert bootstrap_backend in [
                "gloo",
                "mpi",
                "nccl",
            ], "Invalid torch.distributed backend for bootstrapping Userbuffers!"
            assert torch.distributed.is_backend_available(bootstrap_backend), (
                f"PyTorch must be compiled with '{bootstrap_backend}' support in order to "
                f"bootstrap Userbuffers with '{bootstrap_backend}' collectives."
            )
226
227
228
229
230

        world_group = torch.distributed.new_group(backend=bootstrap_backend)
        world_rank = torch.distributed.get_rank(world_group)
        world_size = torch.distributed.get_world_size(world_group)

231
232
        num_domains = world_size // tp_size
        mydomain_idx = world_rank // tp_size
233
        if num_domains > 1:
234
235
236
237
            ranks_per_domain_list = [
                [i * tp_size + t for t in range(tp_size)] for i in range(num_domains)
            ]
            tp_domain_group, _ = torch.distributed.new_subgroups_by_enumeration(
238
239
                ranks_per_domain_list, backend=bootstrap_backend
            )
240
241
            local_rank = torch.distributed.get_rank(tp_domain_group)
            tp_domain_ranks = torch.distributed.get_process_group_ranks(tp_domain_group)
242

243
            helper = tex.CommOverlapHelper(world_group, tp_domain_group)
244
        else:
245
246
            # TP model on single NVLink domain, no replication, no data-parallelism
            mydomain_idx = 0
247
            local_rank = world_rank
248
            tp_domain_ranks = list(range(world_size))
249
250

            helper = tex.CommOverlapHelper(world_group)
251

252
        if world_rank == 0:
253
            print(f"!!! [UB] Number of TP domains: {num_domains}\n", end="", flush=True)
254
255
        if local_rank == 0:
            print(
256
                f"!!! [UB] Global ranks on TP domain {mydomain_idx}: {tp_domain_ranks}\n",
257
258
259
260
                end="",
                flush=True,
            )

261
    # Allocate cuBLAS workspace with expanded size for chunking in overlapping GEMM calls
262
    global _cublas_workspace
263
264
265
266
267
268
269
    if _cublas_workspace is None:
        _cublas_workspace = get_workspace().repeat(_NUM_MAX_UB_STREAMS)
    elif _cublas_workspace.numel() != get_cublas_workspace_size_bytes() * _NUM_MAX_UB_STREAMS:
        # This ensures we don't do `.repeat()` on an already expanded workspace
        _cublas_workspace = torch.empty(
            get_cublas_workspace_size_bytes(), dtype=torch.uint8, device="cuda"
        ).repeat(_NUM_MAX_UB_STREAMS)
270
271

    # Default buffer precision: AllGather buffers use fp8 when using fp8 recipe
272
    layers_all_gather_overlap = [
273
274
275
        "qkv_fprop",
        "qkv_dgrad",
        "proj_dgrad",
276
        "proj_wgrad",
277
278
279
        "fc1_fprop",
        "fc1_dgrad",
        "fc2_dgrad",
280
        "fc2_wgrad",
281
    ]
282
    layers_reduce_scatter_overlap = ["proj_fprop", "fc2_fprop", "qkv_wgrad", "fc1_wgrad"]
Jaemin Choi's avatar
Jaemin Choi committed
283
    dgrad_reduce_scatter_overlap = ["qkv_dgrad", "fc1_dgrad"]
284
    # Default overlap methods for layers
yuguo's avatar
yuguo committed
285
    if bool(int(os.getenv("NVTE_PROJ_NO_PIPELINE_OVERLAP", "0"))) and bool(int(os.getenv("NVTE_FC2_NO_PIPELINE_OVERLAP", "0"))):
yuguo's avatar
yuguo committed
286
287
288
289
        methods = {
            "ring_exchange": ["qkv_fprop", "fc1_fprop", "proj_dgrad", "fc2_dgrad", "proj_fprop", "fc2_fprop"],
            "pipeline": [],
            "bulk": ["qkv_dgrad", "qkv_wgrad", "fc1_dgrad", "fc1_wgrad"],
290
            "external": ["proj_wgrad", "fc2_wgrad"],
yuguo's avatar
yuguo committed
291
        }
yuguo's avatar
yuguo committed
292
293
294
295
296
    elif bool(int(os.getenv("NVTE_PROJ_NO_PIPELINE_OVERLAP", "0"))):
        methods = {
            "ring_exchange": ["qkv_fprop", "fc1_fprop", "proj_dgrad", "fc2_dgrad", "proj_fprop"],
            "pipeline": ["fc2_fprop"],
            "bulk": ["qkv_dgrad", "qkv_wgrad", "fc1_dgrad", "fc1_wgrad"],
297
            "external": ["proj_wgrad", "fc2_wgrad"],
yuguo's avatar
yuguo committed
298
299
300
301
302
303
        }
    elif bool(int(os.getenv("NVTE_FC2_NO_PIPELINE_OVERLAP", "0"))):
        methods = {
            "ring_exchange": ["qkv_fprop", "fc1_fprop", "proj_dgrad", "fc2_dgrad", "fc2_fprop"],
            "pipeline": ["proj_fprop"],
            "bulk": ["qkv_dgrad", "qkv_wgrad", "fc1_dgrad", "fc1_wgrad"],
304
            "external": ["proj_wgrad", "fc2_wgrad"],
yuguo's avatar
yuguo committed
305
        }
yuguo's avatar
yuguo committed
306
307
308
309
310
    else:
        methods = {
            "ring_exchange": ["qkv_fprop", "fc1_fprop", "proj_dgrad", "fc2_dgrad"],
            "pipeline": ["proj_fprop", "fc2_fprop"],
            "bulk": ["qkv_dgrad", "qkv_wgrad", "fc1_dgrad", "fc1_wgrad"],
311
            "external": ["proj_wgrad", "fc2_wgrad"],
yuguo's avatar
yuguo committed
312
        }
313

314
    # AG-RS overlap pairs of layers forming a tensor-parallel block
315
316
    ag_rs_pairs = {"qkv_fprop": "proj_fprop", "fc1_fprop": "fc2_fprop"}
    rs_ag_pairs = {v: k for k, v in ag_rs_pairs.items()}
317
    external_gemm_to_overlap = {"proj_wgrad": "proj_dgrad", "fc2_wgrad": "fc2_dgrad"}
318
319
320
    global layers_atomic_ring_exchange
    layers_atomic_ring_exchange = []

321
322
323
324
325
326
    def get_method(name):
        for method, names in methods.items():
            if name in names:
                return method
        raise KeyError(f"Given layer name {name} does not exist.")

327
    def get_default_config(name):
328
        global _MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY
329
330
        method = get_method(name)
        is_reduce_scatter = name in layers_reduce_scatter_overlap
331
332
        if _MIN_STREAM_PRIORITY is None or _MAX_STREAM_PRIORITY is None:
            _MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY = tex.get_stream_priority_range()
333
334
335
        default_cfg = {
            "method": method,
            "is_reduce_scatter": is_reduce_scatter,
yuguo's avatar
yuguo committed
336
            "num_sm": 1 if method == "ring_exchange" else ub_comm_cu_nums,
337
            "cga_size": 1 if method == "ring_exchange" else 2,
338
339
            "set_sm_margin": not method == "ring_exchange",
            "num_splits": tp_size if method == "ring_exchange" else 4,
yuguo's avatar
yuguo committed
340
            "aggregate": bool(int(os.getenv("NVTE_TP_OVERLAP_AGGREGATE", "0"))),
341
342
343
            "atomic_gemm": False,
            "use_ce": True,
            "fp8_buf": name in layers_all_gather_overlap,
344
345
346
            "comm_priority": _MAX_STREAM_PRIORITY,
            "gemm_priority": _MIN_STREAM_PRIORITY,
            "pipeline_rs_overlap_first_gemm": False,
347
348
349
        }
        return default_cfg

350
351
352
    def add_ub(
        name: str,
        method: str,
353
        is_reduce_scatter: bool,
354
355
        num_sm: int = 16,
        cga_size: int = 2,
356
        set_sm_margin: bool = False,
357
        num_splits: int = 0,
358
359
        aggregate: bool = False,
        atomic_gemm: bool = False,
360
        use_ce: bool = True,
361
        fp8_buf: bool = False,
362
363
364
        comm_priority: int = 0,
        gemm_priority: int = 0,
        pipeline_rs_overlap_first_gemm: bool = False,
365
    ) -> None:
366
367
368
369
370
        if atomic_gemm:
            warnings.warn(
                "Atomic GEMM uses a beta API from cublas and is not tested for all use cases."
            )
            assert use_fp8, "Atomic GEMM overlap supported only for FP8 GEMM."
371
            if method in ("bulk", "external"):
372
                warnings.warn(
373
                    f"At {name}, atoimic GEMM not is supported for a bulk overlap."
374
375
376
                    "Defaulting to `atomic_gemm=False`."
                )
                atomic_gemm = 0
377
        if not is_reduce_scatter and method == "pipeline":
378
            raise ValueError(
379
                f"At {name}, `pipeline` overlap method is not supported for AllGather."
380
            )
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
        # Check if both AG and RS overlaps use `atomic GEMM`` + `p2p ring-exchange`.
        # Using atomic GEMM + p2p ring-exchange in only one of the pair breaks functionality.
        global layers_atomic_ring_exchange
        if atomic_gemm and method == "ring_exchange" and name in ag_rs_pairs:
            layers_atomic_ring_exchange += [name, ag_rs_pairs[name]]
        if name in rs_ag_pairs:
            assert_message = (
                f"At {name}, atomic AG-GEMM overlap with `ring_exchange` shuffles GEMM chunk "
                "outputs, and  RS-GEMM overlap un-suffle them. When one of the GEMM-AG and "
                "GEMM-RS overlaps forming a TP block (e.g., qkv_fprop and proj_fprop) uses "
                "`atomic gemm` and `ring_exhcnage`, its pair must use the same overlap config "
                "for functionality."
            )
            if name in layers_atomic_ring_exchange:
                assert atomic_gemm and method == "ring_exchange", assert_message
            else:
                if atomic_gemm and method == "ring_exchange":
                    assert rs_ag_pairs[name] in layers_atomic_ring_exchange, assert_message

400
401
402
403
404
405
406
407
408
409
        if name in external_gemm_to_overlap:
            assert method == "external", (
                f"At {name}, `external` overlap method is specified, but the selected method is"
                f" {method}"
            )
            assert external_gemm_to_overlap[name] in methods["ring_exchange"], (
                f"At {name}, `external` overlap method is specified, but the external gemm"
                f" {external_gemm_to_overlap[name]} is not using `ring_exchange` overlap method"
            )

410
        buffer_dtype = torch.uint8 if (use_fp8 and fp8_buf) else dtype
411
        if method == "ring_exchange":
412
413
414
415
            ub_obj = tex.CommOverlapP2P(
                shape,  # Communication buffer shape
                buffer_dtype,  # Communication buffer data type
                helper,  # Helper for torch.distributed callbacks during bootstrapping
416
                tp_size,  # Tensor-parallel group size (may be different than local_size)
417
418
419
420
421
422
423
424
                tex.CommOverlapType.RS if is_reduce_scatter else tex.CommOverlapType.AG,
                num_max_streams=_NUM_MAX_UB_STREAMS,
                comm_cga_size=cga_size,
                num_comm_sm=num_sm,
                set_sm_margin=set_sm_margin,
                atomic_gemm=atomic_gemm,
                use_ce=use_ce,
                aggregate=aggregate,
425
426
                gemm_priority=gemm_priority,
                comm_priority=comm_priority,
427
            )
428
        else:
429
430
431
432
            ub_obj = tex.CommOverlap(
                shape,  # Communication buffer shape
                buffer_dtype,  # Communication buffer data type
                helper,  # Helper for torch.distributed callbacks during bootstrapping
433
                tp_size,  # Tensor-parallel group size (may be different than local_size)
434
                num_splits=num_splits,
yuguo's avatar
yuguo committed
435
                num_max_streams=_NUM_MAX_UB_STREAMS,
436
437
438
439
                comm_cga_size=cga_size,
                num_comm_sm=num_sm,
                set_sm_margin=set_sm_margin,
                atomic_gemm=atomic_gemm,
440
441
442
                gemm_priority=gemm_priority,
                comm_priority=comm_priority,
                rs_overlap_first_gemm=pipeline_rs_overlap_first_gemm,
443
            )
444
445
        _ub_communicators[name] = ub_obj

Jaemin Choi's avatar
Jaemin Choi committed
446
447
    if ub_cfgs is not None:
        for name in dgrad_reduce_scatter_overlap:
448
449
            if name in ub_cfgs and "method" in ub_cfgs[name] and ub_cfgs[name]["method"] != "bulk":
                wgrad_name = name.replace("dgrad", "wgrad")
Jaemin Choi's avatar
Jaemin Choi committed
450
451
                assert wgrad_name not in ub_cfgs
                layers_reduce_scatter_overlap.remove(wgrad_name)
452
                layers_all_gather_overlap.remove(name)
Jaemin Choi's avatar
Jaemin Choi committed
453
                layers_reduce_scatter_overlap.append(name)
454
455
456
                methods["bulk"].remove(name)
                new_method = ub_cfgs[name]["method"]
                methods[new_method].append(name)
Jaemin Choi's avatar
Jaemin Choi committed
457

458
459
460
    for name in (
        methods["ring_exchange"] + methods["pipeline"] + methods["bulk"] + methods["external"]
    ):
yuguo's avatar
yuguo committed
461
462
        if name in remove_ag_gemm_dgrad:
            continue
463
        ub_cfg = get_default_config(name)
464
        if ub_cfgs is not None and name in ub_cfgs:
465
            fp8_buf = (name in layers_all_gather_overlap) or (
466
                ub_cfgs[name].get("fp8_buf", False) and name in methods["pipeline"]
467
            )
468
469
470
            ub_cfg.update(ub_cfgs[name])
            ub_cfg["fp8_buf"] = fp8_buf
        add_ub(name, **ub_cfg)
471
472
473
474
475


def get_ub(name: str):
    """Get userbuffer communicator corresponding to give key."""
    assert _ub_communicators is not None, "UB manager is not initialized."
yuguo's avatar
yuguo committed
476
477
478
    # assert name in _ub_communicators, f"UB for {name} is not registered."
    if name in remove_ag_gemm_dgrad:
        return None
479
480
    return _ub_communicators[name]

481

482
483
484
485
486
487
488
def destroy_ub():
    """Destroy all allocated userbuffer communicators."""
    global _ub_communicators
    _ub_communicators = None
    global layers_atomic_ring_exchange
    layers_atomic_ring_exchange = []

489

490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
def fill_userbuffers_buffer_for_all_gather(
    comm,
    local_tensor: torch.Tensor,
    quantizer: Optional[Quantizer],
    process_group,
) -> tuple[torch.Tensor | QuantizedTensorBase, torch.Tensor | QuantizedTensorBase]:
    """Fill local shard of Userbuffers buffer with data for all-gather

    Returns the full tensor and the local shard, both using the
    Userbuffers buffer as their underlying data. These tensors should
    be used carefully (e.g. only immediately before and after a
    Userbuffers operation) since the underlying data may be
    overwritten by other Userbuffers operations.

    May perform blocking communication if needed for the gathered
    tensor's metadata, e.g. scaling factors.

    """

    # Tensor dimensions
    local_shape = local_tensor.size()
    if not local_shape:
        raise ValueError(f"Invalid local tensor (shape={tuple(local_shape)})")
    process_group_size = torch.distributed.get_world_size(process_group)
    global_shape = list(local_shape)
    global_shape[0] *= process_group_size

    # Unquantized data
    if quantizer is None:
        if isinstance(local_tensor, QuantizedTensorBase):
            local_tensor = local_tensor.dequantize()
        if comm.is_fp8_ubuf():
            raise RuntimeError(
                "Attempting to all-gather unquantized tensor, "
                "but Userbuffers is initialized with FP8 buffers"
            )
        comm.copy_into_buffer(local_tensor, local_chunk=True)
        global_tensor = comm.get_buffer(shape=global_shape)
        return global_tensor, local_tensor

    # FP8 data
    if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)):
        if not isinstance(local_tensor, Float8TensorBase):
            if isinstance(local_tensor, QuantizedTensorBase):
                local_tensor.dequantize()
            quantizer.set_usage(rowwise=True, columnwise=False)
            local_tensor = quantizer(local_tensor)
        if not comm.is_fp8_ubuf():
            raise RuntimeError(
                "Attempting to all-gather FP8 tensor, "
                "but Userbuffers is not initialized with FP8 buffers"
            )
        comm.copy_into_buffer(local_tensor._data, local_chunk=True)
        global_tensor_data = comm.get_buffer(shape=global_shape)
        global_tensor = Float8TensorBase(
            data=global_tensor_data,
            fp8_scale_inv=local_tensor._scale_inv,
            fp8_dtype=local_tensor._fp8_dtype,
            quantizer=quantizer,
        )
        return global_tensor, local_tensor

    # MXFP8 data
    if isinstance(quantizer, MXFP8Quantizer):

        # Cast to MXFP8 if needed
        if not isinstance(local_tensor, MXFP8TensorBase):
            if isinstance(local_tensor, QuantizedTensorBase):
                local_tensor.dequantize()
            local_tensor = quantizer(local_tensor)
        if not comm.is_fp8_ubuf():
            raise RuntimeError(
                "Attempting to all-gather MXFP8 tensor, "
                "but Userbuffers is not initialized with FP8 buffers"
            )

        # Check which MXFP8 buffer to communicate
        if quantizer.rowwise_usage == quantizer.columnwise_usage:
            raise ValueError(
                "Userbuffers can only communicate one MXFP8 buffer at a time, "
                f"but quantizer has rowwise_usage={quantizer.rowwise_usage}, "
                f"columnwise_usage={quantizer.columnwise_usage}"
            )
        with_rowwise_data = quantizer.rowwise_usage

        # Copy MXFP8 data to local chunk of Userbuffers buffer
        local_data = (
            local_tensor._rowwise_data if with_rowwise_data else local_tensor._columnwise_data
        )
        comm.copy_into_buffer(local_data, local_chunk=True)

        # Gather scaling-inverses
        if math.prod(local_shape[:-1]) % 128 != 0:
            raise ValueError(
                "Userbuffers requires MXFP8 tensor dims that are divisible by 128, "
                f"but got MXFP8 tensor with shape={tuple(local_shape)}"
            )
        local_scale_inv = (
            local_tensor._rowwise_scale_inv
            if with_rowwise_data
            else local_tensor._columnwise_scale_inv
        )
        local_scale_inv_size = list(local_scale_inv.size())
        global_scale_inv = torch.empty(
            [process_group_size * local_scale_inv_size[0]] + local_scale_inv_size[1:],
            dtype=local_scale_inv.dtype,
            device=local_scale_inv.device,
        )
        torch.distributed.all_gather_into_tensor(
            global_scale_inv,
            local_scale_inv,
            group=process_group,
        )

        # Construct MXFP8 tensor with Userbuffers buffer
        rowwise_data, rowwise_scale_inv = None, None
        columnwise_data, columnwise_scale_inv = None, None
        global_data = comm.get_buffer(shape=global_shape)
        if with_rowwise_data:
            rowwise_data, rowwise_scale_inv = global_data, global_scale_inv
        else:
            columnwise_data, columnwise_scale_inv = global_data, global_scale_inv
        global_tensor = MXFP8TensorBase(
            rowwise_data=rowwise_data,
            rowwise_scale_inv=rowwise_scale_inv,
            columnwise_data=columnwise_data,
            columnwise_scale_inv=columnwise_scale_inv,
            fp8_dtype=local_tensor._fp8_dtype,
            quantizer=quantizer,
        )
        return global_tensor, local_tensor

    # Unsupported data format
    raise ValueError(f"Unsupported quantizer for Userbuffers ({quantizer})")


626
627
628
629
630
631
class TransformerEngineBaseModule(torch.nn.Module, ABC):
    """Base TE module."""

    def __init__(self) -> None:
        super().__init__()
        assert torch.cuda.is_available(), "TransformerEngine needs CUDA."
632
        self.name = None
633
        self.next_iter_when_debug_should_be_run = 0
634
635
636
637
        self.fp8_initialized = False
        self.fp8 = False
        self.fp8_calibration = False
        self.fp8_meta = {}
638
        self.fp8_meta["fp8_checkpoint"] = False
639
640
        self.fp8_meta["fp8_group"] = None
        self.fp8_meta_tensors_initialized = False
641
        self.quantizers = {"scaling_fwd": {}, "scaling_bwd": {}}
642
643
644
        self.tp_group = None
        self.tp_size = 1
        self.sequence_parallel = False
645
646
        self.param_init_meta = {}
        self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters()
647
        self.preserve_high_precision_init_val = FP8GlobalStateManager.with_high_precision_init_val()
648
649
        self.fsdp_wrapped = False
        self.fsdp_group = None
650
        self._fp8_workspaces: Dict[str, QuantizedTensor] = {}
651
        self.activation_dtype: Optional[torch.dtype] = None
652
        self.wgrad_accumulation_and_reduce_hooks = []
653

654
655
656
        if not TEDebugState.debug_enabled:
            TEDebugState.initialize()

657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
    # Names of attributes that can be set quickly (see __setattr__
    # method)
    _fast_setattr_names: Set[str] = {
        "activation_dtype",
        "fp8",
        "fp8_initialized",
        "fp8_calibration",
        "fp8_parameters",
    }

    def __setattr__(self, name: str, value: Any) -> None:
        if name in TransformerEngineBaseModule._fast_setattr_names:
            # torch.nn.Module has a custom __setattr__ that handles
            # modules, parameters, and buffers. This is unnecessary
            # overhead when setting plain attrs.
            self.__dict__[name] = value
        else:
            # Default case
            super().__setattr__(name, value)
676

677
    def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> None:
678
679
680
681
        """
        Delayed scaling only.

        Increase or decrease size of amax history based on given `length`.
682
683
684
685
686
687
688
689
690
691

        .. warning::
            This changes the underlying amax memory location.
        """
        if fwd is None:
            fp8_meta_tensor_keys = ("scaling_fwd", "scaling_bwd")
        else:
            fp8_meta_tensor_keys = ("scaling_fwd" if fwd else "scaling_bwd",)

        for meta_key in fp8_meta_tensor_keys:
692
693
694
            if meta_key not in self.fp8_meta:
                # Handles non-parameter FP8 modules, e.g. DPA.
                continue
695
696
697
698
699
            curr_len = self.fp8_meta[meta_key].amax_history.shape[0]
            if length == curr_len:
                continue
            if length < curr_len:
                self.fp8_meta[meta_key].amax_history = (
700
701
                    self.fp8_meta[meta_key].amax_history[:length].clone()
                )
702
703
704
705
706
707
            elif length > curr_len:
                extra_rows = length - curr_len
                self.fp8_meta[meta_key].amax_history = F.pad(
                    self.fp8_meta[meta_key].amax_history, pad=(0, 0, 0, extra_rows)
                )

708
709
            # Update quantizers with new amax pointers.
            self.quantizers[meta_key] = self.fp8_meta[meta_key].make_quantizers()
710
711
            # Make sure weight tensors has correct quantizers
            self._update_weight_quantizers()
712

713
714
            # Update the global buffers with new amax and history pointers.
            if FP8GlobalStateManager.get_buffer_info() in self.fp8_meta:
715
716
717
                fwd_pos, fwd_key, bwd_pos, bwd_key = self.fp8_meta[
                    FP8GlobalStateManager.get_buffer_info()
                ]
718
719
720
721
722
                for pos, buffer_key in zip((fwd_pos, bwd_pos), (fwd_key, bwd_key)):
                    if buffer_key in FP8GlobalStateManager.global_amax_buffer:
                        assert (
                            buffer_key in FP8GlobalStateManager.global_amax_history_buffer
                        ), "TE internal error during amax history change."
723
724
725
                        FP8GlobalStateManager.global_amax_buffer[buffer_key][pos] = self.fp8_meta[
                            meta_key
                        ].amax_history[0]
726
                        FP8GlobalStateManager.global_amax_history_buffer[buffer_key][pos] = (
727
728
                            self.fp8_meta[meta_key].amax_history
                        )
729

730
    def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None:
731
732
733
        """Init scales and amaxes for fwd | bwd."""
        fp8_meta_tensor_key = "scaling_fwd" if fwd else "scaling_bwd"

734
        # Return early if recipe state matches recipe
735
        if self.fp8_meta_tensors_initialized:
736
737
738
739
740
741
            recipe_state = self.fp8_meta[fp8_meta_tensor_key]
            if recipe.delayed() and isinstance(recipe_state, DelayedScalingRecipeState):
                self.adjust_amax_history_length(recipe.amax_history_len, fwd=fwd)
                return
            if recipe.mxfp8() and isinstance(recipe_state, MXFP8BlockScalingRecipeState):
                return
742
743
744
745
            if recipe.float8_current_scaling() and isinstance(
                recipe_state, Float8CurrentScalingRecipeState
            ):
                return
746
747
748
749
            if recipe.float8_block_scaling() and isinstance(
                recipe_state, Float8BlockScalingRecipeState
            ):
                return
750
751
752

        # Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and
        # 2 (grad_output and grad_input) for bwd
753
        num_fp8_tensors = self.fp8_meta["num_gemms"] * 3 if fwd else self.fp8_meta["num_gemms"] * 2
754

755
756
757
758
759
        # Initialize recipe state and quantizers
        recipe_state = RecipeState.create(
            recipe,
            mode=("forward" if fwd else "backward"),
            num_quantizers=num_fp8_tensors,
760
761
        )

762
763
764
        self.fp8_meta[fp8_meta_tensor_key] = recipe_state
        self.quantizers[fp8_meta_tensor_key] = recipe_state.make_quantizers()

765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
    def _update_weight_quantizers(self) -> None:
        """Update the quantizers for the weight tensors."""
        weight_tensors = self._get_weight_tensors()
        weight_quantizers = self._get_weight_quantizers()
        assert len(weight_tensors) == len(weight_quantizers), (
            f"Number of weight tensors ({len(weight_tensors)}) and quantizers "
            f"({len(weight_quantizers)}) must match"
        )
        for weight, quantizer in zip(weight_tensors, weight_quantizers):
            if quantizer is not None and isinstance(weight, QuantizedTensorBase):
                weight.update_quantizer(quantizer)

    def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]:
        """Get the weight tensors of the module."""
        raise NotImplementedError(
            f"{self.__class__.__name__} class does not implement _get_weight_tensors function"
        )

    def _get_weight_quantizers(self) -> List[Quantizer]:
        """Get the weight quantizers of the module."""
        raise NotImplementedError(
            f"{self.__class__.__name__} class does not implement _get_weight_quantizers function"
        )

789
    def init_fp8_meta_tensors(self, recipe: Recipe) -> None:
790
        """Init scales and amaxes."""
791
792
793
        self.set_meta_tensor(True, recipe)
        self.set_meta_tensor(False, recipe)

794
795
        self.fp8_meta_tensors_initialized = True

796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
    def get_fp8_meta_tensors(self) -> None:
        """Get scales and amaxes."""
        fwd_key, bwd_key = "scaling_fwd", "scaling_bwd"
        if fwd_key not in self.fp8_meta or bwd_key not in self.fp8_meta:
            return None

        fp8_meta_tensors = {fwd_key: [], bwd_key: []}
        with torch.no_grad():
            for key in (fwd_key, bwd_key):
                fp8_meta_tensors[key].append(self.fp8_meta[key].scale.clone())
                fp8_meta_tensors[key].append(self.fp8_meta[key].amax_history.clone())
        return fp8_meta_tensors

    def reset_fp8_meta_tensors(self, fp8_meta_tensors=None) -> None:
        """Reset scales and amaxes."""
811

812
813
814
815
816
        def reset(key):
            if key in self.fp8_meta:
                if fp8_meta_tensors is None:
                    self.fp8_meta[key].scale.copy_(torch.ones_like(self.fp8_meta[key].scale))
                    self.fp8_meta[key].amax_history.copy_(
817
818
                        torch.zeros_like(self.fp8_meta[key].amax_history)
                    )
819
820
821
                else:
                    assert key in fp8_meta_tensors, "Cannot reset fp8 tensors."
                    self.fp8_meta[key].scale.copy_(fp8_meta_tensors[key][0])
822
                    self.fp8_meta[key].amax_history.copy_(fp8_meta_tensors[key][1])
823

824
825
826
827
        with torch.no_grad():
            reset("scaling_fwd")
            reset("scaling_bwd")

828
    def get_extra_state(self) -> torch.Tensor:
829
        """Save before checkpointing."""
830

831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
        # This implementation is working around a few issues:
        #
        # (1) PyTorch's "extra state" infrastructure might be able to
        #     support any picklable type, but they make no guarantees.
        #     We have experienced problems (e.g. in ONNX export) with
        #     non-tensor extra state.
        # (2) PyTorch's checkpointing infrastructure does not remap
        #     devices for "extra state" like it does for "state dict".
        #     Thus, we want to avoid putting extra state on the GPU
        #     since it may be loaded on the wrong device.
        # (3) The extra state consists of many small tensors. If we
        #     want to copy them all to CPU, then we need to avoid the
        #     overhead of many GPU-CPU memory transfers.
        #
        # See: https://github.com/NVIDIA/TransformerEngine/pull/351
        # See: https://github.com/NVIDIA/TransformerEngine/pull/363

        def to_cpu(src: torch.Tensor) -> torch.Tensor:
            """Helper function to make CPU copy of tensor

            Memory transfer is asynchronous w.r.t. host, so GPU should
            be synchronized before using result.

            """
            dst = torch.empty_like(src, device="cpu")
            dst.copy_(src, non_blocking=True)
            return dst

        # Store FP8 state if needed
        state = None
861
        fp8_checkpoint = self.fp8_meta["fp8_checkpoint"] or self.fp8 or self.fp8_calibration
862
        if not fp8_checkpoint:
863
            return torch.empty(0, dtype=torch.uint8)
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881

        # Copy tensors to CPU and store
        state = {}
        state["recipe"] = self.fp8_meta["recipe"]
        if state["recipe"].delayed():
            state["scale_fwd"] = to_cpu(self.fp8_meta["scaling_fwd"].scale)
            state["amax_history_fwd"] = to_cpu(self.fp8_meta["scaling_fwd"].amax_history)
            state["scale_bwd"] = to_cpu(self.fp8_meta["scaling_bwd"].scale)
            state["amax_history_bwd"] = to_cpu(self.fp8_meta["scaling_bwd"].amax_history)

        # Store other pickelable values
        extra = {}
        for k, v in self.fp8_meta.items():
            if k != "buffer_index_and_autocast_key" and isinstance(
                v, (bool, int, float, str, tuple, list)
            ):
                extra[k] = v
        state["extra_fp8_variables"] = extra
882

883
884
885
886
        # Serialize state into byte tensor
        torch.cuda.synchronize()
        state_serialized = bytearray(pickle.dumps(state))
        state_serialized = torch.frombuffer(state_serialized, dtype=torch.uint8)
887
        return state_serialized
888

889
    def set_extra_state(self, state: torch.Tensor) -> None:
890
        """Load previous state."""
891
892

        # Maintain backwards compatibility with older checkpoints.
893
894
895
        if state is None:
            return

896
        # Load state
897
        if isinstance(state, torch.Tensor):
898
899
900
            # No FP8 is indicated by an empty tensor we don't need to unpickle.
            if state.numel() == 0:
                return
901
            # Default format: byte tensor with pickled data
902
            state = pickle.loads(state.detach().cpu().numpy().tobytes())
903
        elif isinstance(state, io.BytesIO):
904
            # Deprecated format with io.BytesIO
905
            state.seek(0)
906
            state = torch.load(state, map_location="cuda")
907
908
        else:
            raise RuntimeError("Unsupported checkpoint format.")
909
910
911

        if state is None:
            return
912

913
914
915
916
917
918
919
920
        # TE 1.x checkpoint compatibility: add DelayedScaling recipe if missing
        if "recipe" not in state:
            # TE 1.x only supported delayed scaling, which was the default recipe
            state["recipe"] = DelayedScaling()
            # TE 1.x also saved scale_inv, which is not needed with Recipe object
            state.pop("scale_inv_fwd", None)
            state.pop("scale_inv_bwd", None)

921
        # Load extra items
922
        self.fp8_meta.update(state["extra_fp8_variables"])
923
        self.fp8_meta["recipe"] = state["recipe"]
924
925
926
        if "global_fp8_buffer_pos_fwd_recompute" in self.fp8_meta:
            del self.fp8_meta["global_fp8_buffer_pos_fwd_recompute"]

927
        # Initialize before loading
928
        self.init_fp8_meta_tensors(self.fp8_meta["recipe"])
929
930
931
932
933
934
935
936
937
938
939

        def copy_tensor(src: torch.Tensor, dst: torch.Tensor) -> None:
            """Helper function to copy tensor from CPU

            Memory transfer is asynchronous w.r.t. host, so GPU should
            be synchronized before using result.

            """
            dst.copy_(src, non_blocking=True)

        # Load tensors
940
941
942
943
944
        if self.fp8_meta["recipe"].delayed():
            copy_tensor(state["scale_fwd"], self.fp8_meta["scaling_fwd"].scale)
            copy_tensor(state["amax_history_fwd"], self.fp8_meta["scaling_fwd"].amax_history)
            copy_tensor(state["scale_bwd"], self.fp8_meta["scaling_bwd"].scale)
            copy_tensor(state["amax_history_bwd"], self.fp8_meta["scaling_bwd"].amax_history)
945
        torch.cuda.synchronize()
946
947
948
949
950

    def set_activation_dtype(self, inp: torch.Tensor) -> None:
        """Get activation data type for AMP."""
        # Native AMP (`torch.autocast`) gets highest priority
        if torch.is_autocast_enabled():
951
            self.activation_dtype = torch_get_autocast_gpu_dtype()
952
953
954
            return

        # All checks after this have already been performed once, thus skip
955
        if self.activation_dtype == inp.dtype:
956
957
            return

958
959
960
961
962
963
964
965
        dtype = inp.dtype
        for name, param in self.named_parameters():
            if param is not None:
                assert dtype == param.dtype, (
                    "Data types for parameters must match when outside of autocasted region. "
                    f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}"
                )
        self.activation_dtype = dtype
966
967

    def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
968
969
970
971
972
973
974
975
976
        """
        Set the tensor parallel group for the given
        module before executing the forward pass.

        Parameters
        ----------
        tp_group : ProcessGroup, default = `None`
                  tensor parallel process group.
        """
977
978
979
        self.tp_group = tp_group
        self.tp_group_initialized = True

980
981
982
    def _get_fp8_params(self) -> Union[List[torch.Tensor], None]:
        """returns the FP8 weights."""
        fp8_params = []
983
        for param in self.parameters(recurse=False):
984
            if isinstance(param, QuantizedTensor) and param.requires_grad:
985
986
987
988
989
                fp8_params.append(param)
        if len(fp8_params) == 0:
            return None
        return fp8_params

990
991
    # This routine is shared across FP8 and FP8_calibration paths so should not actually
    # assume FP8 execution.
992
    def init_fp8_metadata(self, num_gemms: int = 1) -> None:
993
        """Initialize fp8 related metadata and tensors during fprop."""
994
995
        _original_recipe = self.fp8_meta.get("recipe", None)

996
        self.fp8_parameters = FP8GlobalStateManager.with_fp8_parameters()
997
998
        self.fp8 = FP8GlobalStateManager.is_fp8_enabled()
        self.fp8_calibration = FP8GlobalStateManager.is_fp8_calibration()
999
        fp8_enabled = self.fp8 or self.fp8_calibration
1000
        self.fp8_meta["fp8_checkpoint"] = self.fp8 or self.fp8_calibration
1001

1002
        if self.fp8_parameters or fp8_enabled:
1003
1004
1005
1006
            if (
                self.fp8_initialized
                and FP8GlobalStateManager.get_fp8_recipe() == self.fp8_meta["recipe"]
            ):
1007
                # FP8 init has already been run and recipe is the same, don't do anything.
1008
                return
1009
            self.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe()
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
        else:
            # If fp8 isn't enabled, turn off and return.
            self.fp8_initialized = False
            return

        if self.fp8_parameters and not self.fp8_initialized:
            self.fp8_meta["num_gemms"] = num_gemms
            self.init_fp8_meta_tensors(self.fp8_meta["recipe"])

        if fp8_enabled:
            # Set FP8 and other FP8 metadata
1021
            self.fp8_meta["num_gemms"] = num_gemms
1022
            self.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group()
1023
1024
1025
1026
1027
1028

            # Set FP8_MAX per tensor according to recipe
            self.fp8_meta["fp8_max_fwd"] = self.fp8_meta["recipe"].fp8_format.value.max_fwd
            self.fp8_meta["fp8_max_bwd"] = self.fp8_meta["recipe"].fp8_format.value.max_bwd

            # Allocate scales and amaxes
1029
            self.init_fp8_meta_tensors(self.fp8_meta["recipe"])
1030
            self.fp8_initialized = True
1031
1032

            self.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe()
1033

1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
        _current_recipe = self.fp8_meta["recipe"]
        if _original_recipe is not None and not (
            issubclass(_current_recipe.__class__, _original_recipe.__class__)
            or issubclass(_original_recipe.__class__, _current_recipe.__class__)
        ):
            warnings.warn(
                f"Recipe type changed from {_original_recipe.__class__.__name__} "
                f"to {_current_recipe.__class__.__name__}. "
                "This may affect model behavior."
            )
            # Clear cached workspaces as they were created with the old recipe/quantizer type
            self._fp8_workspaces.clear()

1047
1048
1049
1050
1051
    @contextmanager
    def prepare_forward(
        self,
        inp: torch.Tensor,
        num_gemms: int = 1,
1052
        allow_non_contiguous: bool = False,
Jan Bielak's avatar
Jan Bielak committed
1053
    ) -> Generator[torch.Tensor, None, None]:
1054
1055
1056
1057
1058
1059
        """Checks and prep for FWD.
        The context manager is needed because there isn't a way for a module to know
        if it's the last FP8 module in the forward autocast. It is useful
        to setup the forward aggregated amax reduction for every module
        just in case. The autocast exit will pick up the most recent one.
        """
1060
        self.forwarded_at_least_once = True
1061
1062
        # Activation recomputation is used and this is the second forward phase.
        if self.fp8 and in_fp8_activation_recompute_phase():
1063
            FP8GlobalStateManager.get_old_fp8_meta_tensors_for_recompute(self.fp8_meta)
1064
1065
1066
1067
1068
1069
1070
        else:
            assert inp.is_cuda, "TransformerEngine needs CUDA."

            if self.tp_size > 1:
                assert self.tp_group_initialized, "TP group not initialized."

            self.set_activation_dtype(inp)
1071
            self.init_fp8_metadata(num_gemms=num_gemms)
1072
            self._check_weight_tensor_recipe_correspondence()
1073

1074
            if self.fp8 and self.sequence_parallel and self.fp8_meta["recipe"].delayed():
1075
1076
1077
1078
                assert self.fp8_meta["recipe"].reduce_amax, (
                    "Amax reduction across tensor parallel group is "
                    "necessary when using sequence parallelism with FP8."
                )
1079

1080
            if self.fp8 and not FP8GlobalStateManager.fp8_graph_capturing():
1081
                FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(self.fp8_meta)
1082
1083

            # Activation recomputation is used and this is the first forward phase.
1084
            if self.fp8 and self.training and is_fp8_activation_recompute_enabled():
1085
                FP8GlobalStateManager.copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta)
1086
1087

        with torch.cuda.nvtx.range(self.__class__.__name__ + " forward"):
1088
1089
1090
            if not allow_non_contiguous and not inp.is_contiguous():
                inp = inp.contiguous()
            yield inp
1091
1092

        if self.fp8 and in_fp8_activation_recompute_phase():
1093
            FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta)
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113

    def set_nccl_overlap_warning_if_tp(self) -> None:
        """When using TP, the NCCL communication needs to be scheduled
        before the GEMM for there to be a guaranteed overlap. From the
        host side in TE, the comm calls are always launched first, but
        to ensure that the GEMM isn't scheduled first, the environment
        variable `CUDA_DEVICE_MAX_CONNECTIONS` needs to be set to 1 to
        force a single channel.
        """
        if self.tp_size == 1:
            return
        num_cuda_work_queues = int(os.getenv("CUDA_DEVICE_MAX_CONNECTIONS", "0"))
        if num_cuda_work_queues != 1:
            warnings.warn(
                "To guarantee overlapping TP and SP collectives with the backward"
                "GEMMs, set environment variable CUDA_DEVICE_MAX_CONNECTIONS = 1"
            )

    @staticmethod
    def grad_output_preprocess(
1114
1115
1116
1117
        ctx,
        grad_output: torch.Tensor,
        row_parallel_mode: bool,
        quantizer: Optional[Quantizer],
1118
1119
1120
    ) -> Tuple[Union[torch.Tensor, None], ...]:
        """Utility function for backward.
        Returns tuple in order (all optional/None based on training precion/recipe):
1121
1122
            R1: gathered `grad_output`.
            R2: bias gradient on R1.
1123
1124

        """
1125
1126
        grad_output = grad_output.reshape((-1, grad_output.shape[-1]))
        grad_output = grad_output.contiguous()
1127
1128
        gather_grad_output = row_parallel_mode and ctx.sequence_parallel

1129
        # Non-FP8 case: bgrad is fused with wgrad for this case.
1130
        if not ctx.fp8 and not ctx.debug:
1131
            if gather_grad_output:
1132
                if not ctx.ub_overlap_ag or ctx.ub_obj_gradout is None:  # Perform NCCL all-gather
1133
                    grad_output, _ = gather_along_first_dim(grad_output, ctx.tp_group)
1134
1135
1136
1137
1138
1139
1140
                else:  # Initialize Userbuffers all-gather
                    grad_output, _ = fill_userbuffers_buffer_for_all_gather(
                        ctx.ub_obj_gradout,
                        grad_output,
                        None,
                        ctx.tp_group,
                    )
1141
1142
1143
            return grad_output, None

        # FP8 with all-gather: unfused bgrad, fused cast + transpose
1144
        # Also supports debug quantization, which is handled inside gather_along_first_dim.
1145
1146
        if gather_grad_output:
            grad_bias = None
1147
            if ctx.use_bias:
1148
                grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0)
yuguo's avatar
yuguo committed
1149
            if ctx.ub_overlap_ag and ctx.ub_obj_gradout is not None:
1150
1151
                # Quantize the gradient if needed
                if not isinstance(
1152
1153
1154
1155
1156
1157
1158
                    grad_output,
                    (
                        QuantizedTensor,
                        Float8TensorBase,
                        MXFP8TensorBase,
                        Float8BlockwiseQTensorBase,
                    ),
1159
1160
1161
1162
                ):
                    grad_output = quantizer(grad_output)

                # Copy into communication buffer, and replace original gradient with it
1163
1164
1165
1166
1167
1168
                grad_output, _ = fill_userbuffers_buffer_for_all_gather(
                    ctx.ub_obj_gradout,
                    grad_output,
                    quantizer,
                    ctx.tp_group,
                )
1169
            else:
1170
1171
1172
1173
                grad_output, _ = gather_along_first_dim(
                    grad_output,
                    ctx.tp_group,
                    quantizer=quantizer,
1174
                )
1175
            return grad_output, grad_bias
1176

1177
1178
1179
1180
1181
1182
1183
        # Debug without all-gather: unfused cast and bgrad
        # bgrad only if wgrad is in FP8, otherwise it is fused with wgrad and we return None
        if ctx.debug:
            grad_output_ = quantizer(grad_output)
            if (
                isinstance(
                    grad_output_.get_tensor(True),
1184
1185
1186
1187
1188
1189
                    (
                        QuantizedTensor,
                        Float8TensorBase,
                        MXFP8TensorBase,
                        Float8BlockwiseQTensorBase,
                    ),
1190
1191
1192
1193
1194
1195
1196
1197
1198
                )
                and ctx.use_bias
            ):
                grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0)
            else:
                grad_bias = None
            grad_output = grad_output_
            return grad_output, grad_bias

1199
1200
        # FP8 without all-gather: fused bgrad + cast + transpose
        grad_bias = None
1201
        if ctx.use_bias:
1202
1203
1204
1205
            if isinstance(
                grad_output,
                (QuantizedTensor, Float8TensorBase, MXFP8TensorBase, Float8BlockwiseQTensorBase),
            ):
1206
                grad_bias = grad_output.dequantize().view(-1, grad_output.shape[-1]).sum(dim=0)
1207
            else:
yuguo's avatar
yuguo committed
1208
                if isinstance(quantizer, Float8BlockQuantizer) or (isinstance(quantizer, Float8CurrentScalingQuantizer) and IS_HIP_EXTENSION):
1209
1210
1211
1212
1213
1214
1215
1216
                    # unfuse bgrad for now until cast_transpose + dgrad calculation is ready for Float8BlockQuantizer.
                    grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0)
                else:
                    grad_bias, grad_output = tex.bgrad_quantize(grad_output, quantizer)
        if not isinstance(
            grad_output,
            (QuantizedTensor, Float8TensorBase, MXFP8TensorBase, Float8BlockwiseQTensorBase),
        ):
1217
1218
            grad_output = quantizer(grad_output)
        return grad_output, grad_bias
1219

1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
    def register_parameter(self, name, param, **kwargs):
        """
        Thin wrapper around PyTorch parameter registration to stash additional parameter
        metedata used in deferred initialization.
        """
        super().register_parameter(name, param)
        self.param_init_meta[name] = _ParameterInitMeta(**kwargs)

    def reset_parameters(self, defer_init: Optional[bool] = False) -> None:
        """
        Reset all module parameters to initial values. Unless deferred initialization
        is specified, all parameters on a 'meta' device are also materialized on a real cuda
        device before the values are reset to initial.
        """
        if defer_init:
            return

        for name, param in self.named_parameters(recurse=False):
            # Ensure parameter is on a real device
1239
1240
            if param.device == torch.device("meta"):
                param = torch.empty_like(param, device="cuda")
1241
1242
1243
1244
1245
1246
1247

            # Initialize the parameter values on device
            init_fn = self.param_init_meta[name].init_fn
            get_rng_state_tracker = self.param_init_meta[name].get_rng_state_tracker
            if get_rng_state_tracker is None:
                init_fn(param)
            else:
1248
1249
1250
1251
1252
1253
                if hasattr(self, "rng_tracker_name") and self.rng_tracker_name:
                    with get_rng_state_tracker().fork(self.rng_tracker_name):
                        init_fn(param)
                else:
                    with get_rng_state_tracker().fork():
                        init_fn(param)
1254

1255
            # Wrap parameters in QuantizedTensor if needed
1256
            fp8_meta_index = self.param_init_meta[name].fp8_meta_index
1257
            high_precision_init_val = None
1258
            if self.primary_weights_in_fp8 and fp8_meta_index is not None:
1259
1260

                # Keep high-precision values on CPU if needed
1261
1262
1263
                if self.preserve_high_precision_init_val:
                    high_precision_init_val = param.detach().cpu()

1264
                # Configure quantizer
1265
                quantizer = self.quantizers["scaling_fwd"][fp8_meta_index]
1266
1267
1268
                if quantizer is None:
                    raise RuntimeError("Weight quantizer has not been initialized")
                quantizer.set_usage(rowwise=True, columnwise=torch.is_grad_enabled())
1269
                quantizer.internal = False
1270
1271

                # Quantize parameter
1272
                param = quantizer(param)
1273
1274
1275
1276
1277

            # Redo parameter wrap in case we broke it above
            # NOTE: Currently this can only be broken when primary weights are in Fp8 but
            #       re-applying the nn.Parameter() wrap is a no-op when the input is already
            #       a parameter so we always re-apply it just for extra safety.
1278
            param = torch.nn.Parameter(param)
1279
1280

            # Keep high-precision values on CPU if needed
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
            if high_precision_init_val is not None:

                # - Master weights are initialized from model weights, if we use fp8 primary
                #   weights to initialize master weights, the numerical values of master weights
                #   are not consistent with the numerical values when we initialize them from
                #   bf16/fp16 weights.
                # - So we add a `_high_precision_init_val` attribute to each model weight to store
                #   the original bf16/fp16 weight on cpu before casting it to fp8. And users can
                #   use `get_high_precision_init_val` to get this cpu tensor.
                # - This cpu tensor is not needed once the master weight is initialized, so users
                #   should call `clear_high_precision_init_val` to remove it after master weight
                #   is initialized.

                def get(self):
                    if hasattr(self, "_high_precision_init_val"):
                        return self._high_precision_init_val
                    return None

                def clear(self):
                    if hasattr(self, "_high_precision_init_val"):
                        del self._high_precision_init_val

                param._high_precision_init_val = high_precision_init_val
                param.get_high_precision_init_val = MethodType(get, param)
                param.clear_high_precision_init_val = MethodType(clear, param)

            setattr(self, name, param)
1308

1309
1310
1311
    @abstractmethod
    def forward(self):
        """Needs override."""
1312

1313
    def get_weight_workspace(
1314
        self,
1315
1316
        *,
        tensor: Optional[torch.Tensor] = None,
1317
        quantizer: Optional[Quantizer] = None,
1318
1319
1320
        cache_name: Optional[str] = None,
        update_workspace: bool = True,
        skip_update_flag: Optional[torch.Tensor] = None,
1321
        fsdp_group: Optional[dist_group_type] = None,
1322
        workspace_dtype: Optional[torch.dtype] = None,
1323
    ) -> QuantizedTensor:
1324
        """Get workspace buffer for weights and maybe update its values
1325
1326
1327
1328
1329
1330
1331
1332

        The workspace buffer may be cached for future function calls.

        Parameters
        ----------
        tensor : torch.Tensor, optional
            Values to copy into workspace. Required if the workspace
            is being constructed or updated.
1333
1334
1335
        quantizer: Quantizer, optional
            Quantizer used to cast the weights. Required if the
            workspace is being constructed or updated.
1336
1337
1338
1339
1340
1341
1342
        cache_name: str, optional
            Key for caching.
        update_workspace: bool, default = `True`
            Update workspace with values from `tensor`.
        skip_update_flag: torch.Tensor, optional
            GPU flag to skip updating the workspace. Take precedence
            over `update_workspace` if provided.
1343
1344
        fsdp_group: bool, default = None
            FSDP process group that the weights are distributed over.
1345
1346
1347
        workspace_dtype: torch.dtype, default = None
            If weight workspace contains high-precision tensor - for example
            for debug quantization, this is dtype of the tensor.
1348
1349
        """

1350
1351
1352
        # Handle case where weights are already quantized
        # Note: Make sure weights have required usages, but do not
        # destroy unnecessary usages since they may be used later.
1353
        if isinstance(tensor, QuantizedTensor):
1354
1355
1356
1357
1358
1359
            update_rowwise_usage = True if quantizer.rowwise_usage else None
            update_columnwise_usage = True if quantizer.columnwise_usage else None
            tensor.update_usage(
                rowwise_usage=update_rowwise_usage,
                columnwise_usage=update_columnwise_usage,
            )
1360
1361
            return tensor

1362
        # Try getting workspace from cache
1363
1364
1365
        out = None
        if cache_name is not None:
            out = self._fp8_workspaces.get(cache_name, None)
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377

        # Reset cache if workspace is invalid
        if out is not None and quantizer is not None:
            reset_cache = False
            if isinstance(out, Float8TensorBase):
                if (
                    not is_non_tn_fp8_gemm_supported()
                    and quantizer.columnwise_usage
                    and out._transpose is None
                ):
                    reset_cache = True
            elif isinstance(out, MXFP8TensorBase):
1378
                if quantizer.rowwise_usage and out._rowwise_data is None:
1379
                    reset_cache = True
1380
                elif quantizer.columnwise_usage and out._columnwise_data is None:
1381
1382
1383
1384
                    reset_cache = True
            if isinstance(out, DebugQuantizedTensor) != isinstance(quantizer, DebugQuantizer):
                reset_cache = True
            if reset_cache:
1385
                out = None
1386
                del self._fp8_workspaces[cache_name]
1387

1388
1389
1390
1391
1392
        # Gather cached Fp8 workspace if it's distributed
        # NOTE: FSDP sharding is supported only for Fp8 buffers and will not work
        #       for models initialized with Fp8 primary weights.
        if (
            out is not None
1393
            and tensor is not None
1394
            and fsdp_group is not None
1395
            and out.data.shape != tensor.data.shape
1396
1397
1398
1399
        ):
            _fsdp_gather_tensors(fsdp_group, [tensor.data.shape], out)

        # Construct workspace if needed
1400
        if out is None:
1401
            if tensor is None or quantizer is None:
1402
                raise ValueError(
1403
                    "tensor and quantizer kwargs must be provided to construct FP8 workspace"
1404
                )
1405
1406
1407
1408
1409
1410
1411

            if cache_name is not None:
                # Ensure the tensor in the cache is an instance of torch.Tensor,
                # as it persists beyond a single forward pass.
                # Setting internal=True would cause the data to be removed in prepare_for_saving(...).
                quantizer_internal = quantizer.internal
                quantizer.internal = False
1412
            out = quantizer.quantize(tensor, dtype=workspace_dtype)
1413
1414
            if cache_name is not None:
                quantizer.internal = quantizer_internal
1415
1416

            # Update cache
1417
1418
            if cache_name is not None:
                self._fp8_workspaces[cache_name] = out
1419
            return out
1420
1421
1422
1423
1424
1425

        # Update workspace if needed
        if skip_update_flag is not None:
            update_workspace = True
        if update_workspace:
            if tensor is None:
1426
                raise ValueError("tensor kwarg must be provided to update FP8 workspace")
1427
            if hasattr(out, "quantize_"):
1428
                out.quantize_(tensor, noop_flag=skip_update_flag)
1429
1430
            else:
                tex.quantize(tensor, quantizer, out, skip_update_flag)
1431
        return out
1432

1433
1434
1435
    def _load_from_state_dict(
        self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
    ):
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
        """
        This function loads tensors and extra state including fp8 metadata.
        This metadata is essential for copying fp8 tensors, as the copy_ function
        uses the scale_inv parameter from fp8_meta to set the correct scaling factor
        for the new tensor.
        Hence, this extra state must be loaded before the tensor copying process,
        not after, as is typically done in _load_from_state_dict.
        Tensors are copied into fp8 tensors only when self.primary_weights_in_fp8=True,
        otherwise, this behavior is not required.
        """
        if self.primary_weights_in_fp8:
            extra_state_key = prefix + torch.nn.modules.module._EXTRA_STATE_KEY_SUFFIX
            if extra_state_key in state_dict:
                self.set_extra_state(state_dict[extra_state_key])
1450
1451
1452
        super()._load_from_state_dict(
            state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
        )
1453

1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
    def register_wgrad_accumulation_and_reduce_hooks(self, wgrad_accumulation_and_reduce_hook):
        """
        This method is used to manually control the weight gradient accumulation and reduce.
        This method should be called before the backward() method.
        Set the skip_wgrad_accumulation_and_reduce to True to skip the weight gradient accumulation
        and reduce in backward();
        And register the wgrad_accumulation_and_reduce_func to be called in backward_dw() method.
        """
        self.wgrad_accumulation_and_reduce_hooks.append(wgrad_accumulation_and_reduce_hook)

1464
1465
1466
1467
1468
1469
1470
1471
    def backward_dw(self):
        """
        Execute the delayed weight gradient computation.
        This method is called after the main backward pass to compute weight gradients.
        """
        if self.wgrad_store is None or not self.wgrad_store.delay_wgrad_compute():
            return
        with torch.cuda.nvtx.range(f"_{self.__class__.__name__}_wgrad"):
1472
            (wgrad, bgrad), _ = self.wgrad_store.pop()
1473
            if not self.fuse_wgrad_accumulation:
1474
                weight_tensor = noop_cat(self._get_weight_tensors())
1475
1476
1477
1478
1479
                if weight_tensor.grad is None:
                    weight_tensor.grad = wgrad.to(weight_tensor.dtype)
            if self.use_bias:
                bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names])
                if bias_tensor.grad is None:
1480
                    bias_tensor.grad = bgrad.to(bias_tensor.dtype)
1481
1482
1483
1484
            del wgrad
            del bgrad
            for wgrad_accumulation_and_reduce_hook in self.wgrad_accumulation_and_reduce_hooks:
                wgrad_accumulation_and_reduce_hook()
1485

1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
    def is_debug_iter(self) -> bool:
        """
        This function checks if the debug should be enabled for this layer.
        """
        debug = TEDebugState.debug_enabled
        if not debug:
            return False
        self._validate_name()

        # If layer is run first time in new iteration,
        # we need to check if the debug should be enabled for this layer -
        # maybe in previous iterations debug features returned information
        # that no feature will be active for this layer for multiple next iterations.
        started_new_iteration = TEDebugState.get_iteration() != getattr(
            self, "debug_last_iteration", None
        )
        if started_new_iteration:
            if self.next_iter_when_debug_should_be_run is None:
                debug = False
            else:
                debug = TEDebugState.get_iteration() >= self.next_iter_when_debug_should_be_run
        self.debug_last_iteration = TEDebugState.get_iteration()
        return debug

    def no_debug_features_active(self, quantizers):
        """
        Checks if any debug feature is active for this layer.
        """
        run_current = any_feature_enabled(quantizers)

        # Sometimes features inform that they will not be enabled for particular layer
        # for multiple next iterations.
        self.next_iter_when_debug_should_be_run = next_iter_when_debug_should_be_run(quantizers)

        if not run_current:
            return True

        if self.primary_weights_in_fp8:
            raise RuntimeError("FP8 weights are not supported in debug mode.")
        return False
1526

1527
1528
1529
1530
1531
1532
    def _validate_name(self):
        """
        Validate name passed to the module.
        This is invoked in the forward() method as module names are assigned after Model is initialized in Megatron-LM.
        If no name is assigned, it creates a default name with layer count as the variable.
        """
1533
1534
        if self.name is not None:
            return
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
        assert TEDebugState.debug_enabled
        import nvdlfw_inspect.api as debug_api

        if self.name is None:
            debug_api.log_message(
                "Names are not provided to debug modules. ",
                "Creating and using generic names. Pass names to debug modules for better"
                " insight. ",
                level=logging.WARNING,
            )
            self.name = f"Layer_{TEDebugState.get_layer_count()}"

1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
    def _check_weight_tensor_recipe_correspondence(self) -> None:
        """
        Verify that the weight tensor types match their corresponding recipe type.
        This is invoked in the forward().

        This establishes a 1:1 correspondence between recipe types and tensor types:
        - DelayedScaling → Float8Tensor
        - Float8CurrentScaling → Float8Tensor
        - MXFP8BlockScaling → MXFP8Tensor
        - Float8BlockScaling → Float8BlockTensor

        Example case to check: recipe is DelayedScaling (DelayedScaling is set in fp8_autocast()),
        but the weight tensor is MXFP8Tensor (MXFP8BlockScaling is set in fp8_model_init()).
        """
        if not self.fp8 and not self.fp8_calibration:
            return
        if not hasattr(self, "weight_names") or not self.weight_names:
            return

        recipe = self.fp8_meta["recipe"]
        weight_tensors = [getattr(self, name) for name in self.weight_names]
        for i, tensor in enumerate(weight_tensors):
            if isinstance(tensor, QuantizedTensorBase):
                quantizer = tensor._get_quantizer()
                if quantizer is None:
                    continue
                compatible_recipe_class = quantizer._get_compatible_recipe()
                if compatible_recipe_class is None:
                    continue
                if not isinstance(recipe, compatible_recipe_class):
                    raise RuntimeError(
                        f"Recipe mismatch for '{self.weight_names[i]}': tensor supports recipe"
                        f" {compatible_recipe_class.__name__}, but got {recipe.__class__.__name__}."
                        " Please check the recipes assigned during fp8_model_init() and"
                        " fp8_autocast() calls."
                    )