base.py 72.6 KB
Newer Older
1
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.

"""Base modules and utilities for TransformerEngine PyTorch API"""
6
import io
7
import math
8
9
10
import os
import pickle
import warnings
11
from enum import Enum
12
from abc import ABC, abstractmethod
13
from typing import Any, Dict, Generator, List, Optional, Tuple, Union
14
from contextlib import contextmanager
15
from types import MethodType
16
17
18

import torch
import torch.nn.functional as F
19
from torch.distributed.tensor import DTensor
20

21
import transformer_engine_torch as tex
22

23
from ._common import _ParameterInitMeta, noop_cat
24
from ..quantization import (
25
26
    MXFP8BlockScalingRecipeState,
    DelayedScalingRecipeState,
27
    Float8CurrentScalingRecipeState,
28
    Float8BlockScalingRecipeState,
29
    NVFP4BlockScalingRecipeState,
30
    FP8GlobalStateManager,
31
    RecipeState,
32
33
34
35
36
)
from ..distributed import (
    gather_along_first_dim,
    is_fp8_activation_recompute_enabled,
    in_fp8_activation_recompute_phase,
37
    _fsdp_gather_tensors,
38
39
)
from ..constants import dist_group_type
40
from ..cpp_extensions.gemm import _NUM_MAX_UB_STREAMS
41
from ..quantized_tensor import QuantizedTensor, QuantizedTensorStorage, Quantizer
42
43
from ..tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer
44
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
45
46
from ..tensor.storage.float8_tensor_storage import Float8TensorStorage
from ..tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage
47
from ..tensor.storage.nvfp4_tensor_storage import NVFP4TensorStorage
48
49
50
51
from ..utils import (
    is_non_tn_fp8_gemm_supported,
    torch_get_autocast_gpu_dtype,
    get_nvtx_range_context,
52
53
    nvtx_range_push,
    nvtx_range_pop,
54
)
55
from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage
56
from ...common.recipe import DelayedScaling, Recipe
57
58
from ...debug.pytorch.debug_state import TEDebugState
from ...debug.pytorch.debug_quantization import DebugQuantizer, DebugQuantizedTensor
59
from ...debug.pytorch.utils import next_iter_when_debug_should_be_run, any_feature_enabled
yuguo's avatar
yuguo committed
60
from torch.utils.cpp_extension import IS_HIP_EXTENSION
61

62
__all__ = ["initialize_ub", "destroy_ub", "UserBufferQuantizationMode"]
63

64
65
66
_2X_ACC_FPROP = False
_2X_ACC_DGRAD = True
_2X_ACC_WGRAD = True
67
_dummy_wgrads = {}
68
_multi_stream_cublas_batchgemm_workspace = []
69
_ub_communicators = None
70
_MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY = None, None
71
layers_atomic_ring_exchange = []
72
73


74
75
76
77
78
79
80
81
class UserBufferQuantizationMode(Enum):
    """
    UserBufferQuantizationMode is an enum that represents the quantization mode of the UserBuffer.
    """

    NONE = "none"
    FP8 = "fp8"

yuguo's avatar
yuguo committed
82
83
84
85
86
87
def get_multi_stream_cublas_batchgemm_workspace() -> List[torch.Tensor]:
    """Returns workspace for multi-stream cublas."""
    global _multi_stream_cublas_batchgemm_workspace
    if not _multi_stream_cublas_batchgemm_workspace:
        for _ in range(tex._num_cublas_batchgemm_streams):
            _multi_stream_cublas_batchgemm_workspace.append(
yuguo's avatar
yuguo committed
88
                torch.empty(128, dtype=torch.uint8, device="cuda")
yuguo's avatar
yuguo committed
89
90
91
            )
    return _multi_stream_cublas_batchgemm_workspace

92

yuguo's avatar
yuguo committed
93
94
95
96
if bool(int(os.getenv("NVTE_DISABLE_FC2_DGRAD_OVERLAP", "0"))):
    remove_ag_gemm_dgrad = ["fc2_dgrad"]
else:
    remove_ag_gemm_dgrad = []
97

98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
def get_dummy_wgrad(shape: list, dtype: torch.dtype, zero=False) -> torch.Tensor:
    """Returns a dummy tensor of given shape."""
    assert len(shape) == 2
    global _dummy_wgrads
    if (shape[0], shape[1], dtype) not in _dummy_wgrads:
        _dummy_wgrads[(shape[0], shape[1], dtype)] = torch.empty(
            shape,
            dtype=dtype,
            device="cuda",
            requires_grad=False,
        )
    if zero:
        _dummy_wgrads[(shape[0], shape[1], dtype)].fill_(0)
    return _dummy_wgrads[(shape[0], shape[1], dtype)].detach()

yuguo's avatar
yuguo committed
113
ub_comm_cu_nums = int(os.getenv("NVTE_UB_COMM_CU_NUMS", "8"))
114
115
def initialize_ub(
    shape: list,
116
    tp_size: int,
117
    use_fp8: bool = False,
118
    quantization_modes: List[UserBufferQuantizationMode] = None,
119
    dtype: torch.dtype = torch.bfloat16,
120
    ub_cfgs: Optional[Union[dict, List[dict]]] = None,
121
    bootstrap_backend: Union[str, torch.distributed.Backend] = None,
122
) -> None:
123
124
    r"""
    Initialize the Userbuffers communicator for overlapping tensor-parallel communications with
Paweł Gadziński's avatar
Paweł Gadziński committed
125
    GEMM compute in ``te.Linear``, ``te.LayerNormLinear`` and ``te.LayerNormMLP`` modules.
126
127
128
129
130

    Parameters
    ----------
    shape : list
            shape of the communication buffer, typically set to be the same as the global shape of
Paweł Gadziński's avatar
Paweł Gadziński committed
131
132
            the input tensor to a ``te.TransformerLayer`` forward pass, with the sequence and batch
            dimensions collapsed together -- i.e.: ``(sequence_length * batch_size, hidden_size)``
133
134
135
    tp_size : int
              number of GPUs in the tensor-parallel process group
    use_fp8 : bool = False
136
              allocate the communication buffer for FP8 GEMM inputs/outputs.
Paweł Gadziński's avatar
Paweł Gadziński committed
137
              DEPRECATED: Please use ``quantization_modes`` instead.
138
139
    quantization_modes : List[UserBufferQuantizationMode] = None
              if a list of UserBufferQuantizationMode is provided, a UB communicator is created for each quantization setting in the list.
Paweł Gadziński's avatar
Paweł Gadziński committed
140
              falls back to the legacy ``use_fp8`` parameter if ``None`` is provided.
141
    dtype : torch.dtype = torch.bfloat16
Paweł Gadziński's avatar
Paweł Gadziński committed
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
            non-FP8 data type of the communication buffer when ``use_fp8 = False``
    ub_cfgs : dict = None
             Configuration dictionary with the structure::

                 {
                    <gemm_name> : {
                        "method": <"ring_exchange" or "pipeline">,
                        "is_reduce_scatter": bool,
                        "num_sm": int,
                        "cga_size": int,
                        "set_sm_margin": bool,
                        "num_splits": int,
                        "aggregate": bool,
                        "atomic_gemm": bool,
                        "use_ce": bool,
                        "fp8_buf": bool,
                    }
                 }

             for ``te.TransformerLayer`` GEMM layers in ``["qkv_fprop", "qkv_dgrad", "qkv_wgrad",
162
             "proj_fprop", "proj_dgrad", "proj_wgrad", "fc1_fprop", "fc1_dgrad", "fc2_dgrad",
Paweł Gadziński's avatar
Paweł Gadziński committed
163
164
             "fc2_fprop", "fc2_wgrad"]``.
             a list may be provided to specify different overlap configurations for different the quantization settings in ``quantization_modes``
165
    bootstrap_backend : str = None
Paweł Gadziński's avatar
Paweł Gadziński committed
166
                        ``torch.distributed`` communication backend for the all-gather, broadcast and
167
168
169
170
                        barrier collectives during Userbuffers initialization. Not all backends are
                        valid for every cluster configuration and distributed launch method even if
                        they are available in PyTorch. When left unset, the initialization prefers
                        to use the MPI backend, falling back first on Gloo and then NCCL if MPI is
Paweł Gadziński's avatar
Paweł Gadziński committed
171
                        not available. Setting ``NVTE_UB_WITH_MPI=1`` when building TE overrides this
172
                        option and always initializes Userbuffers with direct MPI calls in C++,
Paweł Gadziński's avatar
Paweł Gadziński committed
173
                        which also requires ``MPI_HOME=/path/to/mpi/root`` to be set at compile time.
174
    """
175
    if not tex.device_supports_multicast():
yuguo's avatar
yuguo committed
176
        assert bool(int(os.getenv("UB_SKIPMC", "1"))), (
177
178
179
180
            "CUDA device, driver and/or toolkit version does not support comm+GEMM overlap with "
            + "CUDA Multicast. Launch app with UB_SKIPMC=1 to try CUDA IPC instead."
        )

181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
    if not quantization_modes:
        warnings.warn(
            "Initializing Userbuffers with use_fp8 is deprecated. Please use quantization_modes"
            " instead.",
            DeprecationWarning,
        )
        quantization_modes = [
            UserBufferQuantizationMode.FP8 if use_fp8 else UserBufferQuantizationMode.NONE
        ]
    else:
        assert isinstance(quantization_modes, list), "quantization_modes must be a list"
        assert all(
            isinstance(mode, UserBufferQuantizationMode) for mode in quantization_modes
        ), "quantization_modes must be a list of UserBufferQuantizationMode"

    if isinstance(ub_cfgs, dict) or ub_cfgs is None:
        ub_cfgs = [ub_cfgs] * len(quantization_modes)
    else:
        assert len(ub_cfgs) == len(
            quantization_modes
        ), "Number of ub_cfgs settings must match number of quantization configurations"

203
204
205
    global _ub_communicators
    assert _ub_communicators is None, "UB communicators are already initialized."
    _ub_communicators = {}
206
207

    if tex.ubuf_built_with_mpi():
208
209
        # We're bootstrapping with direct calls to MPI in Userbuffers code so we need to force
        # an MPI_Init() here by creating a new MPI process group...
210
        assert torch.distributed.is_mpi_available()
211
212
        _ = torch.distributed.new_group(backend="mpi")
        helper = tex.CommOverlapHelper()
213
    else:
214
215
        # Bootstrapping with torch.distributed API, so check backend and construct
        # intra/inter-node process groups...
216
217
218
219
220
        assert (
            torch.distributed.is_initialized()
        ), "torch.distributed must be initialized before Userbuffers"
        if bootstrap_backend is None:
            bootstrap_backend = "nccl"
221
            if torch.distributed.is_mpi_available():
222
                bootstrap_backend = "mpi"
223
224
            elif torch.distributed.is_gloo_available():
                bootstrap_backend = "gloo"
225
        else:
226
227
228
229
230
231
232
233
234
            assert bootstrap_backend in [
                "gloo",
                "mpi",
                "nccl",
            ], "Invalid torch.distributed backend for bootstrapping Userbuffers!"
            assert torch.distributed.is_backend_available(bootstrap_backend), (
                f"PyTorch must be compiled with '{bootstrap_backend}' support in order to "
                f"bootstrap Userbuffers with '{bootstrap_backend}' collectives."
            )
235
236
237
238
239

        world_group = torch.distributed.new_group(backend=bootstrap_backend)
        world_rank = torch.distributed.get_rank(world_group)
        world_size = torch.distributed.get_world_size(world_group)

240
241
        num_domains = world_size // tp_size
        mydomain_idx = world_rank // tp_size
242
        if num_domains > 1:
243
244
245
246
            ranks_per_domain_list = [
                [i * tp_size + t for t in range(tp_size)] for i in range(num_domains)
            ]
            tp_domain_group, _ = torch.distributed.new_subgroups_by_enumeration(
247
248
                ranks_per_domain_list, backend=bootstrap_backend
            )
249
250
            local_rank = torch.distributed.get_rank(tp_domain_group)
            tp_domain_ranks = torch.distributed.get_process_group_ranks(tp_domain_group)
251

252
            helper = tex.CommOverlapHelper(world_group, tp_domain_group)
253
        else:
254
255
            # TP model on single NVLink domain, no replication, no data-parallelism
            mydomain_idx = 0
256
            local_rank = world_rank
257
            tp_domain_ranks = list(range(world_size))
258
259

            helper = tex.CommOverlapHelper(world_group)
260

261
        if world_rank == 0:
262
            print(f"!!! [UB] Number of TP domains: {num_domains}\n", end="", flush=True)
263
264
        if local_rank == 0:
            print(
265
                f"!!! [UB] Global ranks on TP domain {mydomain_idx}: {tp_domain_ranks}\n",
266
267
268
269
                end="",
                flush=True,
            )

270
    # Default buffer precision: AllGather buffers use fp8 when using fp8 recipe
271
    layers_all_gather_overlap = [
272
273
274
        "qkv_fprop",
        "qkv_dgrad",
        "proj_dgrad",
275
        "proj_wgrad",
276
277
278
        "fc1_fprop",
        "fc1_dgrad",
        "fc2_dgrad",
279
        "fc2_wgrad",
280
    ]
281
    layers_reduce_scatter_overlap = ["proj_fprop", "fc2_fprop", "qkv_wgrad", "fc1_wgrad"]
Jaemin Choi's avatar
Jaemin Choi committed
282
    dgrad_reduce_scatter_overlap = ["qkv_dgrad", "fc1_dgrad"]
283
    # Default overlap methods for layers
yuguo's avatar
yuguo committed
284
    if bool(int(os.getenv("NVTE_PROJ_NO_PIPELINE_OVERLAP", "0"))) and bool(int(os.getenv("NVTE_FC2_NO_PIPELINE_OVERLAP", "0"))):
yuguo's avatar
yuguo committed
285
286
287
288
        methods = {
            "ring_exchange": ["qkv_fprop", "fc1_fprop", "proj_dgrad", "fc2_dgrad", "proj_fprop", "fc2_fprop"],
            "pipeline": [],
            "bulk": ["qkv_dgrad", "qkv_wgrad", "fc1_dgrad", "fc1_wgrad"],
289
            "external": ["proj_wgrad", "fc2_wgrad"],
yuguo's avatar
yuguo committed
290
        }
yuguo's avatar
yuguo committed
291
292
293
294
295
    elif bool(int(os.getenv("NVTE_PROJ_NO_PIPELINE_OVERLAP", "0"))):
        methods = {
            "ring_exchange": ["qkv_fprop", "fc1_fprop", "proj_dgrad", "fc2_dgrad", "proj_fprop"],
            "pipeline": ["fc2_fprop"],
            "bulk": ["qkv_dgrad", "qkv_wgrad", "fc1_dgrad", "fc1_wgrad"],
296
            "external": ["proj_wgrad", "fc2_wgrad"],
yuguo's avatar
yuguo committed
297
298
299
300
301
302
        }
    elif bool(int(os.getenv("NVTE_FC2_NO_PIPELINE_OVERLAP", "0"))):
        methods = {
            "ring_exchange": ["qkv_fprop", "fc1_fprop", "proj_dgrad", "fc2_dgrad", "fc2_fprop"],
            "pipeline": ["proj_fprop"],
            "bulk": ["qkv_dgrad", "qkv_wgrad", "fc1_dgrad", "fc1_wgrad"],
303
            "external": ["proj_wgrad", "fc2_wgrad"],
yuguo's avatar
yuguo committed
304
        }
yuguo's avatar
yuguo committed
305
306
307
308
309
    else:
        methods = {
            "ring_exchange": ["qkv_fprop", "fc1_fprop", "proj_dgrad", "fc2_dgrad"],
            "pipeline": ["proj_fprop", "fc2_fprop"],
            "bulk": ["qkv_dgrad", "qkv_wgrad", "fc1_dgrad", "fc1_wgrad"],
310
            "external": ["proj_wgrad", "fc2_wgrad"],
yuguo's avatar
yuguo committed
311
        }
312

313
    # AG-RS overlap pairs of layers forming a tensor-parallel block
314
315
    ag_rs_pairs = {"qkv_fprop": "proj_fprop", "fc1_fprop": "fc2_fprop"}
    rs_ag_pairs = {v: k for k, v in ag_rs_pairs.items()}
316
    external_gemm_to_overlap = {"proj_wgrad": "proj_dgrad", "fc2_wgrad": "fc2_dgrad"}
317
318
319
    global layers_atomic_ring_exchange
    layers_atomic_ring_exchange = []

320
321
322
323
324
325
    def get_method(name):
        for method, names in methods.items():
            if name in names:
                return method
        raise KeyError(f"Given layer name {name} does not exist.")

326
    def get_default_config(name):
327
        global _MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY
328
329
        method = get_method(name)
        is_reduce_scatter = name in layers_reduce_scatter_overlap
330
331
        if _MIN_STREAM_PRIORITY is None or _MAX_STREAM_PRIORITY is None:
            _MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY = tex.get_stream_priority_range()
332
333
334
        default_cfg = {
            "method": method,
            "is_reduce_scatter": is_reduce_scatter,
yuguo's avatar
yuguo committed
335
            "num_sm": 1 if method == "ring_exchange" else ub_comm_cu_nums,
336
            "cga_size": 1 if method == "ring_exchange" else 2,
337
338
            "set_sm_margin": not method == "ring_exchange",
            "num_splits": tp_size if method == "ring_exchange" else 4,
yuguo's avatar
yuguo committed
339
            "aggregate": bool(int(os.getenv("NVTE_TP_OVERLAP_AGGREGATE", "0"))),
340
341
342
            "atomic_gemm": False,
            "use_ce": True,
            "fp8_buf": name in layers_all_gather_overlap,
343
344
345
            "comm_priority": _MAX_STREAM_PRIORITY,
            "gemm_priority": _MIN_STREAM_PRIORITY,
            "pipeline_rs_overlap_first_gemm": False,
346
347
348
        }
        return default_cfg

349
350
    def add_ub(
        name: str,
351
        quantization_mode: UserBufferQuantizationMode,
352
        method: str,
353
        is_reduce_scatter: bool,
354
355
        num_sm: int = 16,
        cga_size: int = 2,
356
        set_sm_margin: bool = False,
357
        num_splits: int = 0,
358
359
        aggregate: bool = False,
        atomic_gemm: bool = False,
360
        use_ce: bool = True,
361
        fp8_buf: bool = False,
362
363
364
        comm_priority: int = 0,
        gemm_priority: int = 0,
        pipeline_rs_overlap_first_gemm: bool = False,
365
    ) -> None:
366
367
368
369
        if atomic_gemm:
            warnings.warn(
                "Atomic GEMM uses a beta API from cublas and is not tested for all use cases."
            )
370
371
372
            assert (
                quantization_mode == UserBufferQuantizationMode.FP8
            ), "Atomic GEMM overlap supported only for FP8 GEMM."
373
            if method in ("bulk", "external"):
374
                warnings.warn(
375
                    f"At {name}, atoimic GEMM not is supported for a bulk overlap."
376
377
378
                    "Defaulting to `atomic_gemm=False`."
                )
                atomic_gemm = 0
379
        if not is_reduce_scatter and method == "pipeline":
380
            raise ValueError(
381
                f"At {name}, `pipeline` overlap method is not supported for AllGather."
382
            )
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
        # Check if both AG and RS overlaps use `atomic GEMM`` + `p2p ring-exchange`.
        # Using atomic GEMM + p2p ring-exchange in only one of the pair breaks functionality.
        global layers_atomic_ring_exchange
        if atomic_gemm and method == "ring_exchange" and name in ag_rs_pairs:
            layers_atomic_ring_exchange += [name, ag_rs_pairs[name]]
        if name in rs_ag_pairs:
            assert_message = (
                f"At {name}, atomic AG-GEMM overlap with `ring_exchange` shuffles GEMM chunk "
                "outputs, and  RS-GEMM overlap un-suffle them. When one of the GEMM-AG and "
                "GEMM-RS overlaps forming a TP block (e.g., qkv_fprop and proj_fprop) uses "
                "`atomic gemm` and `ring_exhcnage`, its pair must use the same overlap config "
                "for functionality."
            )
            if name in layers_atomic_ring_exchange:
                assert atomic_gemm and method == "ring_exchange", assert_message
            else:
                if atomic_gemm and method == "ring_exchange":
                    assert rs_ag_pairs[name] in layers_atomic_ring_exchange, assert_message

402
403
404
405
406
407
408
409
410
411
        if name in external_gemm_to_overlap:
            assert method == "external", (
                f"At {name}, `external` overlap method is specified, but the selected method is"
                f" {method}"
            )
            assert external_gemm_to_overlap[name] in methods["ring_exchange"], (
                f"At {name}, `external` overlap method is specified, but the external gemm"
                f" {external_gemm_to_overlap[name]} is not using `ring_exchange` overlap method"
            )

412
413
414
415
416
        buffer_dtype = (
            torch.uint8
            if (quantization_mode == UserBufferQuantizationMode.FP8 and fp8_buf)
            else dtype
        )
417
        if method == "ring_exchange":
418
419
420
421
            ub_obj = tex.CommOverlapP2P(
                shape,  # Communication buffer shape
                buffer_dtype,  # Communication buffer data type
                helper,  # Helper for torch.distributed callbacks during bootstrapping
422
                tp_size,  # Tensor-parallel group size (may be different than local_size)
423
424
425
426
427
428
429
430
                tex.CommOverlapType.RS if is_reduce_scatter else tex.CommOverlapType.AG,
                num_max_streams=_NUM_MAX_UB_STREAMS,
                comm_cga_size=cga_size,
                num_comm_sm=num_sm,
                set_sm_margin=set_sm_margin,
                atomic_gemm=atomic_gemm,
                use_ce=use_ce,
                aggregate=aggregate,
431
432
                gemm_priority=gemm_priority,
                comm_priority=comm_priority,
433
            )
434
        else:
435
436
437
438
            ub_obj = tex.CommOverlap(
                shape,  # Communication buffer shape
                buffer_dtype,  # Communication buffer data type
                helper,  # Helper for torch.distributed callbacks during bootstrapping
439
                tp_size,  # Tensor-parallel group size (may be different than local_size)
440
                num_splits=num_splits,
yuguo's avatar
yuguo committed
441
                num_max_streams=_NUM_MAX_UB_STREAMS,
442
443
444
445
                comm_cga_size=cga_size,
                num_comm_sm=num_sm,
                set_sm_margin=set_sm_margin,
                atomic_gemm=atomic_gemm,
446
447
448
                gemm_priority=gemm_priority,
                comm_priority=comm_priority,
                rs_overlap_first_gemm=pipeline_rs_overlap_first_gemm,
449
            )
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
        _ub_communicators[(name, quantization_mode)] = ub_obj

    for quantization_mode, user_ub_cfg in zip(quantization_modes, ub_cfgs):
        if user_ub_cfg is not None:
            for name in dgrad_reduce_scatter_overlap:
                if (
                    name in user_ub_cfg
                    and "method" in user_ub_cfg[name]
                    and user_ub_cfg[name]["method"] != "bulk"
                ):
                    wgrad_name = name.replace("dgrad", "wgrad")
                    assert wgrad_name not in user_ub_cfg
                    layers_reduce_scatter_overlap.remove(wgrad_name)
                    layers_all_gather_overlap.remove(name)
                    layers_reduce_scatter_overlap.append(name)
                    methods["bulk"].remove(name)
                    new_method = user_ub_cfg[name]["method"]
                    methods[new_method].append(name)

        for name in (
            methods["ring_exchange"] + methods["pipeline"] + methods["bulk"] + methods["external"]
        ):
wenjh's avatar
wenjh committed
472
473
            if name in remove_ag_gemm_dgrad:
                continue
474
475
476
477
478
            ub_cfg = get_default_config(name)
            if user_ub_cfg is not None and name in user_ub_cfg:
                fp8_buf = (name in layers_all_gather_overlap) or (
                    user_ub_cfg[name].get("fp8_buf", False) and name in methods["pipeline"]
                )
479
                ub_cfg.update(user_ub_cfg[name])
480
481
                ub_cfg["fp8_buf"] = fp8_buf
            add_ub(name, quantization_mode, **ub_cfg)
482
483


484
def get_ub(name: str, use_fp8: bool):
485
    """Get userbuffer communicator corresponding to give key."""
486
487
488
489
    # For now use `use_fp8` boolean input as it matches the current design in the modules
    # So favour simplicity until the correct design becomes clear.
    # This is mainly an internal API so we don't need to worry about future changes
    key = (name, UserBufferQuantizationMode.FP8 if use_fp8 else UserBufferQuantizationMode.NONE)
490
    assert _ub_communicators is not None, "UB manager is not initialized."
491
    assert key in _ub_communicators, f"UB for {name} with use_fp8={use_fp8} is not registered."
yuguo's avatar
yuguo committed
492
493
494
    # assert name in _ub_communicators, f"UB for {name} is not registered."
    if name in remove_ag_gemm_dgrad:
        return None
495
    return _ub_communicators[key]
496

497

498
499
500
501
502
503
504
def destroy_ub():
    """Destroy all allocated userbuffer communicators."""
    global _ub_communicators
    _ub_communicators = None
    global layers_atomic_ring_exchange
    layers_atomic_ring_exchange = []

505

506
507
508
509
510
def fill_userbuffers_buffer_for_all_gather(
    comm,
    local_tensor: torch.Tensor,
    quantizer: Optional[Quantizer],
    process_group,
511
) -> tuple[torch.Tensor | QuantizedTensorStorage, torch.Tensor | QuantizedTensorStorage]:
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
    """Fill local shard of Userbuffers buffer with data for all-gather

    Returns the full tensor and the local shard, both using the
    Userbuffers buffer as their underlying data. These tensors should
    be used carefully (e.g. only immediately before and after a
    Userbuffers operation) since the underlying data may be
    overwritten by other Userbuffers operations.

    May perform blocking communication if needed for the gathered
    tensor's metadata, e.g. scaling factors.

    """

    # Tensor dimensions
    local_shape = local_tensor.size()
    if not local_shape:
        raise ValueError(f"Invalid local tensor (shape={tuple(local_shape)})")
    process_group_size = torch.distributed.get_world_size(process_group)
    global_shape = list(local_shape)
    global_shape[0] *= process_group_size

    # Unquantized data
    if quantizer is None:
535
        if isinstance(local_tensor, QuantizedTensorStorage):
536
537
538
539
540
541
542
543
544
545
546
547
            local_tensor = local_tensor.dequantize()
        if comm.is_fp8_ubuf():
            raise RuntimeError(
                "Attempting to all-gather unquantized tensor, "
                "but Userbuffers is initialized with FP8 buffers"
            )
        comm.copy_into_buffer(local_tensor, local_chunk=True)
        global_tensor = comm.get_buffer(shape=global_shape)
        return global_tensor, local_tensor

    # FP8 data
    if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)):
548
549
        if not isinstance(local_tensor, Float8TensorStorage):
            if isinstance(local_tensor, QuantizedTensorStorage):
550
551
552
553
554
555
556
557
558
559
                local_tensor.dequantize()
            quantizer.set_usage(rowwise=True, columnwise=False)
            local_tensor = quantizer(local_tensor)
        if not comm.is_fp8_ubuf():
            raise RuntimeError(
                "Attempting to all-gather FP8 tensor, "
                "but Userbuffers is not initialized with FP8 buffers"
            )
        comm.copy_into_buffer(local_tensor._data, local_chunk=True)
        global_tensor_data = comm.get_buffer(shape=global_shape)
560
        global_tensor = Float8TensorStorage(
561
562
563
564
565
566
567
568
569
570
571
            data=global_tensor_data,
            fp8_scale_inv=local_tensor._scale_inv,
            fp8_dtype=local_tensor._fp8_dtype,
            quantizer=quantizer,
        )
        return global_tensor, local_tensor

    # MXFP8 data
    if isinstance(quantizer, MXFP8Quantizer):

        # Cast to MXFP8 if needed
572
573
        if not isinstance(local_tensor, MXFP8TensorStorage):
            if isinstance(local_tensor, QuantizedTensorStorage):
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
                local_tensor.dequantize()
            local_tensor = quantizer(local_tensor)
        if not comm.is_fp8_ubuf():
            raise RuntimeError(
                "Attempting to all-gather MXFP8 tensor, "
                "but Userbuffers is not initialized with FP8 buffers"
            )

        # Check which MXFP8 buffer to communicate
        if quantizer.rowwise_usage == quantizer.columnwise_usage:
            raise ValueError(
                "Userbuffers can only communicate one MXFP8 buffer at a time, "
                f"but quantizer has rowwise_usage={quantizer.rowwise_usage}, "
                f"columnwise_usage={quantizer.columnwise_usage}"
            )
        with_rowwise_data = quantizer.rowwise_usage

        # Copy MXFP8 data to local chunk of Userbuffers buffer
        local_data = (
            local_tensor._rowwise_data if with_rowwise_data else local_tensor._columnwise_data
        )
        comm.copy_into_buffer(local_data, local_chunk=True)

        # Gather scaling-inverses
        if math.prod(local_shape[:-1]) % 128 != 0:
            raise ValueError(
                "Userbuffers requires MXFP8 tensor dims that are divisible by 128, "
                f"but got MXFP8 tensor with shape={tuple(local_shape)}"
            )
603
604
        if local_tensor._with_gemm_swizzled_scales:
            raise ValueError("Userbuffers assumes MXFP8 tensors have unswizzled scales")
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
        local_scale_inv = (
            local_tensor._rowwise_scale_inv
            if with_rowwise_data
            else local_tensor._columnwise_scale_inv
        )
        local_scale_inv_size = list(local_scale_inv.size())
        global_scale_inv = torch.empty(
            [process_group_size * local_scale_inv_size[0]] + local_scale_inv_size[1:],
            dtype=local_scale_inv.dtype,
            device=local_scale_inv.device,
        )
        torch.distributed.all_gather_into_tensor(
            global_scale_inv,
            local_scale_inv,
            group=process_group,
        )

        # Construct MXFP8 tensor with Userbuffers buffer
        rowwise_data, rowwise_scale_inv = None, None
        columnwise_data, columnwise_scale_inv = None, None
        global_data = comm.get_buffer(shape=global_shape)
        if with_rowwise_data:
            rowwise_data, rowwise_scale_inv = global_data, global_scale_inv
        else:
            columnwise_data, columnwise_scale_inv = global_data, global_scale_inv
630
        global_tensor = MXFP8TensorStorage(
631
632
633
634
635
636
            rowwise_data=rowwise_data,
            rowwise_scale_inv=rowwise_scale_inv,
            columnwise_data=columnwise_data,
            columnwise_scale_inv=columnwise_scale_inv,
            fp8_dtype=local_tensor._fp8_dtype,
            quantizer=quantizer,
637
            with_gemm_swizzled_scales=False,
638
639
640
641
642
643
644
        )
        return global_tensor, local_tensor

    # Unsupported data format
    raise ValueError(f"Unsupported quantizer for Userbuffers ({quantizer})")


645
646
647
class TransformerEngineBaseModule(torch.nn.Module, ABC):
    """Base TE module."""

648
    def __init__(self, name: Optional[str] = None) -> None:
649
650
        super().__init__()
        assert torch.cuda.is_available(), "TransformerEngine needs CUDA."
651
        self.name = name
652
        self.next_iter_when_debug_should_be_run = 0
653
654
655
656
        self.fp8_initialized = False
        self.fp8 = False
        self.fp8_calibration = False
        self.fp8_meta = {}
657
        self.fp8_meta["fp8_checkpoint"] = False
658
659
        self.fp8_meta["fp8_group"] = None
        self.fp8_meta_tensors_initialized = False
660
        self.quantizers = {"scaling_fwd": {}, "scaling_bwd": {}}
661
662
663
        self.tp_group = None
        self.tp_size = 1
        self.sequence_parallel = False
664
665
        self.param_init_meta = {}
        self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters()
666
        self.preserve_high_precision_init_val = FP8GlobalStateManager.with_high_precision_init_val()
667
668
        self.fsdp_wrapped = False
        self.fsdp_group = None
669
        self._fp8_workspaces: Dict[str, QuantizedTensor] = {}
670
        self.activation_dtype: Optional[torch.dtype] = None
671
        self.wgrad_accumulation_and_reduce_hooks = []
672
        self.wgrad_store = None
673

674
675
        if not TEDebugState.debug_enabled:
            TEDebugState.initialize()
676
        self._validate_name()
677

678
679
680
681
682
683
684
685
686
687
688
689
690
691
    def fast_setattr(self, name: str, value: Any) -> None:
        """
        Fast version of the Module's set attribute function.
        Should be used for regular attributes, but not properties nor parameters/buffers.
        """
        self.__dict__[name] = value

    def module_setattr(self, name: str, value: Any) -> None:
        """
        Regular version of the Module's set attribute function.
        Should be used only when the fast version cannot be used - for the properties,
        parameters and buffers.
        """
        super().__setattr__(name, value)
692

693
    def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> None:
694
695
696
697
        """
        Delayed scaling only.

        Increase or decrease size of amax history based on given `length`.
698
699
700
701
702
703
704
705
706
707

        .. warning::
            This changes the underlying amax memory location.
        """
        if fwd is None:
            fp8_meta_tensor_keys = ("scaling_fwd", "scaling_bwd")
        else:
            fp8_meta_tensor_keys = ("scaling_fwd" if fwd else "scaling_bwd",)

        for meta_key in fp8_meta_tensor_keys:
708
709
710
            if meta_key not in self.fp8_meta:
                # Handles non-parameter FP8 modules, e.g. DPA.
                continue
711
712
713
714
715
            curr_len = self.fp8_meta[meta_key].amax_history.shape[0]
            if length == curr_len:
                continue
            if length < curr_len:
                self.fp8_meta[meta_key].amax_history = (
716
717
                    self.fp8_meta[meta_key].amax_history[:length].clone()
                )
718
719
720
721
722
723
            elif length > curr_len:
                extra_rows = length - curr_len
                self.fp8_meta[meta_key].amax_history = F.pad(
                    self.fp8_meta[meta_key].amax_history, pad=(0, 0, 0, extra_rows)
                )

724
725
            # Update quantizers with new amax pointers.
            self.quantizers[meta_key] = self.fp8_meta[meta_key].make_quantizers()
726
727
            # Make sure weight tensors has correct quantizers
            self._update_weight_quantizers()
728

729
730
            # Update the global buffers with new amax and history pointers.
            if FP8GlobalStateManager.get_buffer_info() in self.fp8_meta:
731
732
733
                fwd_pos, fwd_key, bwd_pos, bwd_key = self.fp8_meta[
                    FP8GlobalStateManager.get_buffer_info()
                ]
734
735
736
737
738
                for pos, buffer_key in zip((fwd_pos, bwd_pos), (fwd_key, bwd_key)):
                    if buffer_key in FP8GlobalStateManager.global_amax_buffer:
                        assert (
                            buffer_key in FP8GlobalStateManager.global_amax_history_buffer
                        ), "TE internal error during amax history change."
739
740
741
                        FP8GlobalStateManager.global_amax_buffer[buffer_key][pos] = self.fp8_meta[
                            meta_key
                        ].amax_history[0]
742
                        FP8GlobalStateManager.global_amax_history_buffer[buffer_key][pos] = (
743
744
                            self.fp8_meta[meta_key].amax_history
                        )
745

746
    def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None:
747
748
749
        """Init scales and amaxes for fwd | bwd."""
        fp8_meta_tensor_key = "scaling_fwd" if fwd else "scaling_bwd"

750
        # Return early if recipe state matches recipe
751
        if self.fp8_meta_tensors_initialized:
752
753
754
755
756
757
            recipe_state = self.fp8_meta[fp8_meta_tensor_key]
            if recipe.delayed() and isinstance(recipe_state, DelayedScalingRecipeState):
                self.adjust_amax_history_length(recipe.amax_history_len, fwd=fwd)
                return
            if recipe.mxfp8() and isinstance(recipe_state, MXFP8BlockScalingRecipeState):
                return
758
759
760
761
            if recipe.float8_current_scaling() and isinstance(
                recipe_state, Float8CurrentScalingRecipeState
            ):
                return
762
763
764
765
            if recipe.float8_block_scaling() and isinstance(
                recipe_state, Float8BlockScalingRecipeState
            ):
                return
766
767
            if recipe.nvfp4() and isinstance(recipe_state, NVFP4BlockScalingRecipeState):
                return
768
769
770

        # Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and
        # 2 (grad_output and grad_input) for bwd
771
        num_fp8_tensors = self.fp8_meta["num_gemms"] * 3 if fwd else self.fp8_meta["num_gemms"] * 2
772

773
774
775
776
777
        # Initialize recipe state and quantizers
        recipe_state = RecipeState.create(
            recipe,
            mode=("forward" if fwd else "backward"),
            num_quantizers=num_fp8_tensors,
778
779
        )

780
781
782
        self.fp8_meta[fp8_meta_tensor_key] = recipe_state
        self.quantizers[fp8_meta_tensor_key] = recipe_state.make_quantizers()

783
784
785
786
787
788
789
790
791
    def _update_weight_quantizers(self) -> None:
        """Update the quantizers for the weight tensors."""
        weight_tensors = self._get_weight_tensors()
        weight_quantizers = self._get_weight_quantizers()
        assert len(weight_tensors) == len(weight_quantizers), (
            f"Number of weight tensors ({len(weight_tensors)}) and quantizers "
            f"({len(weight_quantizers)}) must match"
        )
        for weight, quantizer in zip(weight_tensors, weight_quantizers):
792
            if quantizer is not None and isinstance(weight, QuantizedTensorStorage):
793
794
                weight.update_quantizer(quantizer)

795
    def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorStorage]]:
796
797
798
799
800
801
802
803
804
805
806
        """Get the weight tensors of the module."""
        raise NotImplementedError(
            f"{self.__class__.__name__} class does not implement _get_weight_tensors function"
        )

    def _get_weight_quantizers(self) -> List[Quantizer]:
        """Get the weight quantizers of the module."""
        raise NotImplementedError(
            f"{self.__class__.__name__} class does not implement _get_weight_quantizers function"
        )

807
    def init_fp8_meta_tensors(self, recipe: Recipe) -> None:
808
        """Init scales and amaxes."""
809
810
811
        self.set_meta_tensor(True, recipe)
        self.set_meta_tensor(False, recipe)

812
        self.fast_setattr("fp8_meta_tensors_initialized", True)
813

814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
    def get_fp8_meta_tensors(self) -> None:
        """Get scales and amaxes."""
        fwd_key, bwd_key = "scaling_fwd", "scaling_bwd"
        if fwd_key not in self.fp8_meta or bwd_key not in self.fp8_meta:
            return None

        fp8_meta_tensors = {fwd_key: [], bwd_key: []}
        with torch.no_grad():
            for key in (fwd_key, bwd_key):
                fp8_meta_tensors[key].append(self.fp8_meta[key].scale.clone())
                fp8_meta_tensors[key].append(self.fp8_meta[key].amax_history.clone())
        return fp8_meta_tensors

    def reset_fp8_meta_tensors(self, fp8_meta_tensors=None) -> None:
        """Reset scales and amaxes."""
829

830
831
832
833
834
        def reset(key):
            if key in self.fp8_meta:
                if fp8_meta_tensors is None:
                    self.fp8_meta[key].scale.copy_(torch.ones_like(self.fp8_meta[key].scale))
                    self.fp8_meta[key].amax_history.copy_(
835
836
                        torch.zeros_like(self.fp8_meta[key].amax_history)
                    )
837
838
839
                else:
                    assert key in fp8_meta_tensors, "Cannot reset fp8 tensors."
                    self.fp8_meta[key].scale.copy_(fp8_meta_tensors[key][0])
840
                    self.fp8_meta[key].amax_history.copy_(fp8_meta_tensors[key][1])
841

842
843
844
845
        with torch.no_grad():
            reset("scaling_fwd")
            reset("scaling_bwd")

846
    def get_extra_state(self) -> torch.Tensor:
847
        """Save before checkpointing."""
848

849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
        # This implementation is working around a few issues:
        #
        # (1) PyTorch's "extra state" infrastructure might be able to
        #     support any picklable type, but they make no guarantees.
        #     We have experienced problems (e.g. in ONNX export) with
        #     non-tensor extra state.
        # (2) PyTorch's checkpointing infrastructure does not remap
        #     devices for "extra state" like it does for "state dict".
        #     Thus, we want to avoid putting extra state on the GPU
        #     since it may be loaded on the wrong device.
        # (3) The extra state consists of many small tensors. If we
        #     want to copy them all to CPU, then we need to avoid the
        #     overhead of many GPU-CPU memory transfers.
        #
        # See: https://github.com/NVIDIA/TransformerEngine/pull/351
        # See: https://github.com/NVIDIA/TransformerEngine/pull/363

        def to_cpu(src: torch.Tensor) -> torch.Tensor:
            """Helper function to make CPU copy of tensor

            Memory transfer is asynchronous w.r.t. host, so GPU should
            be synchronized before using result.

            """
            dst = torch.empty_like(src, device="cpu")
            dst.copy_(src, non_blocking=True)
            return dst

        # Store FP8 state if needed
        state = None
879
        fp8_checkpoint = self.fp8_meta["fp8_checkpoint"] or self.fp8 or self.fp8_calibration
880
        if not fp8_checkpoint:
881
            return torch.empty(0, dtype=torch.uint8)
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899

        # Copy tensors to CPU and store
        state = {}
        state["recipe"] = self.fp8_meta["recipe"]
        if state["recipe"].delayed():
            state["scale_fwd"] = to_cpu(self.fp8_meta["scaling_fwd"].scale)
            state["amax_history_fwd"] = to_cpu(self.fp8_meta["scaling_fwd"].amax_history)
            state["scale_bwd"] = to_cpu(self.fp8_meta["scaling_bwd"].scale)
            state["amax_history_bwd"] = to_cpu(self.fp8_meta["scaling_bwd"].amax_history)

        # Store other pickelable values
        extra = {}
        for k, v in self.fp8_meta.items():
            if k != "buffer_index_and_autocast_key" and isinstance(
                v, (bool, int, float, str, tuple, list)
            ):
                extra[k] = v
        state["extra_fp8_variables"] = extra
900

901
902
903
904
        # Serialize state into byte tensor
        torch.cuda.synchronize()
        state_serialized = bytearray(pickle.dumps(state))
        state_serialized = torch.frombuffer(state_serialized, dtype=torch.uint8)
905
        return state_serialized
906

907
    def set_extra_state(self, state: torch.Tensor) -> None:
908
        """Load previous state."""
909
910

        # Maintain backwards compatibility with older checkpoints.
911
912
913
        if state is None:
            return

914
        # Load state
915
        if isinstance(state, torch.Tensor):
916
917
918
            # No FP8 is indicated by an empty tensor we don't need to unpickle.
            if state.numel() == 0:
                return
919
            # Default format: byte tensor with pickled data
920
            state = pickle.loads(state.detach().cpu().numpy().tobytes())
921
        elif isinstance(state, io.BytesIO):
922
            # Deprecated format with io.BytesIO
923
            state.seek(0)
924
            state = torch.load(state, map_location="cuda")
925
926
        else:
            raise RuntimeError("Unsupported checkpoint format.")
927
928
929

        if state is None:
            return
930

931
932
933
934
935
936
937
938
        # TE 1.x checkpoint compatibility: add DelayedScaling recipe if missing
        if "recipe" not in state:
            # TE 1.x only supported delayed scaling, which was the default recipe
            state["recipe"] = DelayedScaling()
            # TE 1.x also saved scale_inv, which is not needed with Recipe object
            state.pop("scale_inv_fwd", None)
            state.pop("scale_inv_bwd", None)

939
        # Load extra items
940
        self.fp8_meta.update(state["extra_fp8_variables"])
941
        self.fp8_meta["recipe"] = state["recipe"]
942
943
944
        if "global_fp8_buffer_pos_fwd_recompute" in self.fp8_meta:
            del self.fp8_meta["global_fp8_buffer_pos_fwd_recompute"]

945
        # Initialize before loading
946
        self.init_fp8_meta_tensors(self.fp8_meta["recipe"])
947
948
949
950
951
952
953
954
955
956
957

        def copy_tensor(src: torch.Tensor, dst: torch.Tensor) -> None:
            """Helper function to copy tensor from CPU

            Memory transfer is asynchronous w.r.t. host, so GPU should
            be synchronized before using result.

            """
            dst.copy_(src, non_blocking=True)

        # Load tensors
958
959
960
961
962
        if self.fp8_meta["recipe"].delayed():
            copy_tensor(state["scale_fwd"], self.fp8_meta["scaling_fwd"].scale)
            copy_tensor(state["amax_history_fwd"], self.fp8_meta["scaling_fwd"].amax_history)
            copy_tensor(state["scale_bwd"], self.fp8_meta["scaling_bwd"].scale)
            copy_tensor(state["amax_history_bwd"], self.fp8_meta["scaling_bwd"].amax_history)
963
        torch.cuda.synchronize()
964
965
966
967
968

    def set_activation_dtype(self, inp: torch.Tensor) -> None:
        """Get activation data type for AMP."""
        # Native AMP (`torch.autocast`) gets highest priority
        if torch.is_autocast_enabled():
969
            self.fast_setattr("activation_dtype", torch_get_autocast_gpu_dtype())
970
971
972
            return

        # All checks after this have already been performed once, thus skip
973
        if self.activation_dtype == inp.dtype:
974
975
            return

976
        dtype = inp.dtype
977
978
979
980
981
982
983
        if not self.allow_different_data_and_param_types:
            for name, param in self.named_parameters():
                if param is not None:
                    assert dtype == param.dtype, (
                        "Data types for parameters must match when outside of autocasted region. "
                        f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}"
                    )
984
        self.fast_setattr("activation_dtype", dtype)
985
986

    def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
987
988
989
990
991
992
        """
        Set the tensor parallel group for the given
        module before executing the forward pass.

        Parameters
        ----------
Paweł Gadziński's avatar
Paweł Gadziński committed
993
        tp_group : ProcessGroup, default = None
994
995
                  tensor parallel process group.
        """
996
997
        self.fast_setattr("tp_group", tp_group)
        self.fast_setattr("tp_group_initialized", True)
998

999
1000
1001
    def _get_fp8_params(self) -> Union[List[torch.Tensor], None]:
        """returns the FP8 weights."""
        fp8_params = []
1002
        for param in self.parameters(recurse=False):
1003
            if isinstance(param, QuantizedTensor) and param.requires_grad:
1004
1005
1006
1007
1008
                fp8_params.append(param)
        if len(fp8_params) == 0:
            return None
        return fp8_params

1009
1010
    # This routine is shared across FP8 and FP8_calibration paths so should not actually
    # assume FP8 execution.
1011
    def init_fp8_metadata(self, num_gemms: int = 1) -> None:
1012
        """Initialize fp8 related metadata and tensors during fprop."""
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
        meta = self.fp8_meta

        fp8 = FP8GlobalStateManager.is_fp8_enabled()
        fp8_parameters = FP8GlobalStateManager.with_fp8_parameters()
        fp8_calibration = FP8GlobalStateManager.is_fp8_calibration()
        self.fast_setattr("fp8_parameters", fp8_parameters)
        self.fast_setattr("fp8", fp8)
        self.fast_setattr("fp8_calibration", fp8_calibration)
        fp8_enabled = fp8 or fp8_calibration
        meta["fp8_checkpoint"] = fp8_enabled

        _original_recipe = None

        if fp8_parameters or fp8_enabled:
            _original_recipe = meta.get("recipe", None)
            if self.fp8_initialized and FP8GlobalStateManager.get_fp8_recipe() == _original_recipe:
1029
                # FP8 init has already been run and recipe is the same, don't do anything.
1030
                return
1031
            meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe()
1032
1033
        else:
            # If fp8 isn't enabled, turn off and return.
1034
            self.fast_setattr("fp8_initialized", False)
1035
1036
            return

1037
1038
1039
        if fp8_parameters and not self.fp8_initialized:
            meta["num_gemms"] = num_gemms
            self.init_fp8_meta_tensors(meta["recipe"])
1040
1041
1042

        if fp8_enabled:
            # Set FP8 and other FP8 metadata
1043
1044
            meta["num_gemms"] = num_gemms
            meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group()
1045
1046

            # Set FP8_MAX per tensor according to recipe
1047
1048
1049
            if hasattr(meta["recipe"], "fp8_format"):
                meta["fp8_max_fwd"] = meta["recipe"].fp8_format.value.max_fwd
                meta["fp8_max_bwd"] = meta["recipe"].fp8_format.value.max_bwd
1050
1051

            # Allocate scales and amaxes
1052
1053
            self.init_fp8_meta_tensors(meta["recipe"])
            self.fast_setattr("fp8_initialized", True)
1054

1055
            meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe()
1056

1057
        _current_recipe = meta["recipe"]
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
        if _original_recipe is not None and not (
            issubclass(_current_recipe.__class__, _original_recipe.__class__)
            or issubclass(_original_recipe.__class__, _current_recipe.__class__)
        ):
            warnings.warn(
                f"Recipe type changed from {_original_recipe.__class__.__name__} "
                f"to {_current_recipe.__class__.__name__}. "
                "This may affect model behavior."
            )
            # Clear cached workspaces as they were created with the old recipe/quantizer type
            self._fp8_workspaces.clear()

1070
1071
1072
1073
    def prepare_forward(
        self,
        inp: torch.Tensor,
        num_gemms: int = 1,
1074
        allow_non_contiguous: bool = False,
1075
        allow_different_data_and_param_types: bool = False,
1076
1077
1078
1079
1080
1081
    ) -> torch.Tensor:
        """Checks and prepares for FWD execution."""
        self.fast_setattr(
            "allow_different_data_and_param_types", allow_different_data_and_param_types
        )
        self.fast_setattr("forwarded_at_least_once", True)
1082

1083
1084
        # Activation recomputation is used and this is the second forward phase.
        if self.fp8 and in_fp8_activation_recompute_phase():
1085
            delayed_scaling_recipe = self.fp8_meta["recipe"].delayed()
1086
            FP8GlobalStateManager.get_old_fp8_meta_tensors_for_recompute(self.fp8_meta)
1087
1088
1089
1090
1091
1092
1093
        else:
            assert inp.is_cuda, "TransformerEngine needs CUDA."

            if self.tp_size > 1:
                assert self.tp_group_initialized, "TP group not initialized."

            self.set_activation_dtype(inp)
1094
            self.init_fp8_metadata(num_gemms=num_gemms)
1095
            self._check_weight_tensor_recipe_correspondence()
1096

1097
1098
1099
1100
1101
1102
1103
            delayed_scaling_recipe = self.fp8 and self.fp8_meta["recipe"].delayed()
            if delayed_scaling_recipe:
                if self.sequence_parallel:
                    assert self.fp8_meta["recipe"].reduce_amax, (
                        "Amax reduction across tensor parallel group is "
                        "necessary when using sequence parallelism with FP8."
                    )
1104

1105
1106
                if not FP8GlobalStateManager.fp8_graph_capturing():
                    FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(self.fp8_meta)
1107

1108
1109
1110
                # Activation recomputation is used and this is the first forward phase.
                if self.training and is_fp8_activation_recompute_enabled():
                    FP8GlobalStateManager.copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta)
1111

1112
1113
1114
1115
        nvtx_range_push(self.__class__.__name__ + " forward")
        if not allow_non_contiguous and not inp.is_contiguous():
            inp = inp.contiguous()
        return inp
1116

1117
1118
1119
1120
1121
1122
    def end_forward(self):
        """
        Required to be called at the end of the forward function to properly handle
        DelayedScaling metadata handling and the NVTX ranges.
        """
        delayed_scaling_recipe = self.fp8 and self.fp8_meta["recipe"].delayed()
1123
        if delayed_scaling_recipe and self.fp8 and in_fp8_activation_recompute_phase():
1124
            FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta)
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
        nvtx_range_pop()

    @contextmanager
    def prepare_forward_ctx(
        self,
        inp: torch.Tensor,
        num_gemms: int = 1,
        allow_non_contiguous: bool = False,
        allow_different_data_and_param_types: bool = False,
    ) -> Generator[torch.Tensor, None, None]:
        """Checks and prepares for FWD execution."""
        inp = self.prepare_forward(
            inp, num_gemms, allow_non_contiguous, allow_different_data_and_param_types
        )
        try:
            yield inp
        finally:
            self.end_forward()
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162

    def set_nccl_overlap_warning_if_tp(self) -> None:
        """When using TP, the NCCL communication needs to be scheduled
        before the GEMM for there to be a guaranteed overlap. From the
        host side in TE, the comm calls are always launched first, but
        to ensure that the GEMM isn't scheduled first, the environment
        variable `CUDA_DEVICE_MAX_CONNECTIONS` needs to be set to 1 to
        force a single channel.
        """
        if self.tp_size == 1:
            return
        num_cuda_work_queues = int(os.getenv("CUDA_DEVICE_MAX_CONNECTIONS", "0"))
        if num_cuda_work_queues != 1:
            warnings.warn(
                "To guarantee overlapping TP and SP collectives with the backward"
                "GEMMs, set environment variable CUDA_DEVICE_MAX_CONNECTIONS = 1"
            )

    @staticmethod
    def grad_output_preprocess(
1163
1164
1165
1166
        ctx,
        grad_output: torch.Tensor,
        row_parallel_mode: bool,
        quantizer: Optional[Quantizer],
1167
1168
1169
    ) -> Tuple[Union[torch.Tensor, None], ...]:
        """Utility function for backward.
        Returns tuple in order (all optional/None based on training precion/recipe):
1170
1171
            R1: gathered `grad_output`.
            R2: bias gradient on R1.
1172
1173

        """
1174
1175
        grad_output = grad_output.reshape((-1, grad_output.shape[-1]))
        grad_output = grad_output.contiguous()
1176
1177
        gather_grad_output = row_parallel_mode and ctx.sequence_parallel

1178
        # Non-FP8 case: bgrad is fused with wgrad for this case.
1179
        if not ctx.fp8 and not ctx.debug:
1180
            if gather_grad_output:
1181
                if not ctx.ub_overlap_ag or ctx.ub_obj_gradout is None:  # Perform NCCL all-gather
1182
                    grad_output, _ = gather_along_first_dim(grad_output, ctx.tp_group)
1183
1184
1185
1186
1187
1188
1189
                else:  # Initialize Userbuffers all-gather
                    grad_output, _ = fill_userbuffers_buffer_for_all_gather(
                        ctx.ub_obj_gradout,
                        grad_output,
                        None,
                        ctx.tp_group,
                    )
1190
1191
1192
            return grad_output, None

        # FP8 with all-gather: unfused bgrad, fused cast + transpose
1193
        # Also supports debug quantization, which is handled inside gather_along_first_dim.
1194
1195
        if gather_grad_output:
            grad_bias = None
1196
            if ctx.use_bias:
1197
                grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0)
yuguo's avatar
yuguo committed
1198
            if ctx.ub_overlap_ag and ctx.ub_obj_gradout is not None:
1199
1200
                # Quantize the gradient if needed
                if not isinstance(
1201
1202
1203
                    grad_output,
                    (
                        QuantizedTensor,
1204
1205
1206
                        Float8TensorStorage,
                        MXFP8TensorStorage,
                        Float8BlockwiseQTensorStorage,
1207
                    ),
1208
1209
1210
1211
                ):
                    grad_output = quantizer(grad_output)

                # Copy into communication buffer, and replace original gradient with it
1212
1213
1214
1215
1216
1217
                grad_output, _ = fill_userbuffers_buffer_for_all_gather(
                    ctx.ub_obj_gradout,
                    grad_output,
                    quantizer,
                    ctx.tp_group,
                )
1218
            else:
1219
1220
1221
1222
                grad_output, _ = gather_along_first_dim(
                    grad_output,
                    ctx.tp_group,
                    quantizer=quantizer,
1223
                )
1224
            return grad_output, grad_bias
1225

1226
1227
1228
1229
        # Debug without all-gather: unfused cast and bgrad
        # bgrad only if wgrad is in FP8, otherwise it is fused with wgrad and we return None
        if ctx.debug:
            grad_output_ = quantizer(grad_output)
1230
            if ctx.use_bias:
1231
1232
1233
1234
1235
1236
                grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0)
            else:
                grad_bias = None
            grad_output = grad_output_
            return grad_output, grad_bias

1237
1238
        # FP8 without all-gather: fused bgrad + cast + transpose
        grad_bias = None
1239
        if ctx.use_bias:
1240
1241
            if isinstance(
                grad_output,
1242
1243
1244
1245
1246
1247
                (
                    QuantizedTensor,
                    Float8TensorStorage,
                    MXFP8TensorStorage,
                    Float8BlockwiseQTensorStorage,
                ),
1248
            ):
1249
                grad_bias = grad_output.dequantize().view(-1, grad_output.shape[-1]).sum(dim=0)
1250
            else:
yuguo's avatar
yuguo committed
1251
                if isinstance(quantizer, Float8BlockQuantizer) or (isinstance(quantizer, Float8CurrentScalingQuantizer) and IS_HIP_EXTENSION):
1252
1253
1254
1255
                    # unfuse bgrad for now until cast_transpose + dgrad calculation is ready for Float8BlockQuantizer.
                    grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0)
                else:
                    grad_bias, grad_output = tex.bgrad_quantize(grad_output, quantizer)
1256
        if not isinstance(grad_output, QuantizedTensorStorage):
1257
1258
            grad_output = quantizer(grad_output)
        return grad_output, grad_bias
1259

1260
1261
1262
1263
1264
1265
    def register_parameter(self, name, param, **kwargs):
        """
        Thin wrapper around PyTorch parameter registration to stash additional parameter
        metedata used in deferred initialization.
        """
        super().register_parameter(name, param)
1266
1267
1268
1269
1270
1271
        # Initialize param_init_meta exactly once during the init. FSDP2 can call
        # register parameter again to change parameters to DTensors. And it calls
        # it without custom fp8 specific kwargs that we need. And so we dont want
        # to reset/loose our fp8 init attributes.
        if hasattr(self, "param_init_meta") and name not in self.param_init_meta:
            self.param_init_meta[name] = _ParameterInitMeta(**kwargs)
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282

    def reset_parameters(self, defer_init: Optional[bool] = False) -> None:
        """
        Reset all module parameters to initial values. Unless deferred initialization
        is specified, all parameters on a 'meta' device are also materialized on a real cuda
        device before the values are reset to initial.
        """
        if defer_init:
            return

        for name, param in self.named_parameters(recurse=False):
1283
1284
1285
1286
1287
            # Check if parameter is a DTensor (FSDP2) or regular tensor
            is_dtensor = isinstance(param, DTensor)
            dtensor_param = param if is_dtensor else None
            # Need to update/quantize local tensor in case of DTensor
            param = param._local_tensor if is_dtensor else param
1288
            # Ensure parameter is on a real device
1289
1290
            if param.device == torch.device("meta"):
                param = torch.empty_like(param, device="cuda")
1291
1292
1293
1294
1295
1296
            # Initialize the parameter values on device
            init_fn = self.param_init_meta[name].init_fn
            get_rng_state_tracker = self.param_init_meta[name].get_rng_state_tracker
            if get_rng_state_tracker is None:
                init_fn(param)
            else:
1297
1298
1299
1300
1301
1302
                if hasattr(self, "rng_tracker_name") and self.rng_tracker_name:
                    with get_rng_state_tracker().fork(self.rng_tracker_name):
                        init_fn(param)
                else:
                    with get_rng_state_tracker().fork():
                        init_fn(param)
1303

1304
            # Wrap parameters in QuantizedTensor if needed
1305
            fp8_meta_index = self.param_init_meta[name].fp8_meta_index
1306
            high_precision_init_val = None
1307
            if self.primary_weights_in_fp8 and fp8_meta_index is not None:
1308
1309

                # Keep high-precision values on CPU if needed
1310
1311
1312
                if self.preserve_high_precision_init_val:
                    high_precision_init_val = param.detach().cpu()

1313
                # Configure quantizer
1314
                quantizer = self.quantizers["scaling_fwd"][fp8_meta_index]
1315
1316
1317
                if quantizer is None:
                    raise RuntimeError("Weight quantizer has not been initialized")
                quantizer.set_usage(rowwise=True, columnwise=torch.is_grad_enabled())
1318
                quantizer.internal = False
1319
1320
1321
1322
1323
1324
1325
1326
1327
                if is_dtensor and isinstance(quantizer, Float8CurrentScalingQuantizer):
                    device_mesh = dtensor_param.device_mesh
                    amax_reduction_group = (
                        device_mesh.get_group(mesh_dim="shard")
                        if device_mesh.ndim > 1
                        else device_mesh.get_group()
                    )
                    quantizer.amax_reduction_group = amax_reduction_group
                    quantizer.with_amax_reduction = True
1328
                # Quantize parameter
1329
                param = quantizer(param)
1330
1331
1332
1333
1334

            # Redo parameter wrap in case we broke it above
            # NOTE: Currently this can only be broken when primary weights are in Fp8 but
            #       re-applying the nn.Parameter() wrap is a no-op when the input is already
            #       a parameter so we always re-apply it just for extra safety.
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
            if is_dtensor:
                # recreate the DTensor from the parameter.
                dtensor_param = DTensor.from_local(
                    param,
                    device_mesh=dtensor_param.device_mesh,
                    placements=dtensor_param.placements,
                    shape=dtensor_param.size(),
                    stride=dtensor_param.stride(),
                )
                dtensor_param = torch.nn.Parameter(dtensor_param)
            else:
                param = torch.nn.Parameter(param)
1347
1348

            # Keep high-precision values on CPU if needed
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
            if high_precision_init_val is not None:

                # - Master weights are initialized from model weights, if we use fp8 primary
                #   weights to initialize master weights, the numerical values of master weights
                #   are not consistent with the numerical values when we initialize them from
                #   bf16/fp16 weights.
                # - So we add a `_high_precision_init_val` attribute to each model weight to store
                #   the original bf16/fp16 weight on cpu before casting it to fp8. And users can
                #   use `get_high_precision_init_val` to get this cpu tensor.
                # - This cpu tensor is not needed once the master weight is initialized, so users
                #   should call `clear_high_precision_init_val` to remove it after master weight
                #   is initialized.

                def get(self):
                    if hasattr(self, "_high_precision_init_val"):
                        return self._high_precision_init_val
                    return None

                def clear(self):
                    if hasattr(self, "_high_precision_init_val"):
                        del self._high_precision_init_val

                param._high_precision_init_val = high_precision_init_val
                param.get_high_precision_init_val = MethodType(get, param)
                param.clear_high_precision_init_val = MethodType(clear, param)
1374
                # Update the parameter based on its type
1375

1376
            if not is_dtensor:
1377
                self.module_setattr(name, param)
1378
            else:
1379
                self.module_setattr(name, dtensor_param)
1380

1381
1382
1383
    @abstractmethod
    def forward(self):
        """Needs override."""
1384

1385
    def get_weight_workspace(
1386
        self,
1387
1388
        *,
        tensor: Optional[torch.Tensor] = None,
1389
        quantizer: Optional[Quantizer] = None,
1390
1391
1392
        cache_name: Optional[str] = None,
        update_workspace: bool = True,
        skip_update_flag: Optional[torch.Tensor] = None,
1393
        fsdp_group: Optional[dist_group_type] = None,
1394
        workspace_dtype: Optional[torch.dtype] = None,
1395
    ) -> QuantizedTensor:
1396
        """Get workspace buffer for weights and maybe update its values
1397
1398
1399
1400
1401
1402
1403
1404

        The workspace buffer may be cached for future function calls.

        Parameters
        ----------
        tensor : torch.Tensor, optional
            Values to copy into workspace. Required if the workspace
            is being constructed or updated.
1405
1406
1407
        quantizer: Quantizer, optional
            Quantizer used to cast the weights. Required if the
            workspace is being constructed or updated.
1408
1409
        cache_name: str, optional
            Key for caching.
Paweł Gadziński's avatar
Paweł Gadziński committed
1410
        update_workspace: bool, default = True
1411
1412
1413
1414
            Update workspace with values from `tensor`.
        skip_update_flag: torch.Tensor, optional
            GPU flag to skip updating the workspace. Take precedence
            over `update_workspace` if provided.
1415
1416
        fsdp_group: bool, default = None
            FSDP process group that the weights are distributed over.
1417
1418
1419
        workspace_dtype: torch.dtype, default = None
            If weight workspace contains high-precision tensor - for example
            for debug quantization, this is dtype of the tensor.
1420
1421
        """

1422
1423
1424
        # Handle case where weights are already quantized
        # Note: Make sure weights have required usages, but do not
        # destroy unnecessary usages since they may be used later.
1425
        if isinstance(tensor, QuantizedTensor):
1426
1427
1428
1429
1430
1431
            update_rowwise_usage = True if quantizer.rowwise_usage else None
            update_columnwise_usage = True if quantizer.columnwise_usage else None
            tensor.update_usage(
                rowwise_usage=update_rowwise_usage,
                columnwise_usage=update_columnwise_usage,
            )
1432
1433
            return tensor

1434
        # Try getting workspace from cache
1435
1436
1437
        out = None
        if cache_name is not None:
            out = self._fp8_workspaces.get(cache_name, None)
1438
1439
1440
1441

        # Reset cache if workspace is invalid
        if out is not None and quantizer is not None:
            reset_cache = False
1442
            if isinstance(out, Float8TensorStorage):
1443
1444
1445
1446
1447
1448
                if (
                    not is_non_tn_fp8_gemm_supported()
                    and quantizer.columnwise_usage
                    and out._transpose is None
                ):
                    reset_cache = True
1449
            elif isinstance(out, MXFP8TensorStorage):
1450
                if quantizer.rowwise_usage and out._rowwise_data is None:
1451
                    reset_cache = True
1452
                elif quantizer.columnwise_usage and out._columnwise_data is None:
1453
                    reset_cache = True
1454
1455
1456
1457
1458
            elif isinstance(out, NVFP4TensorStorage):
                if quantizer.rowwise_usage and out._rowwise_data is None:
                    reset_cache = True
                elif quantizer.columnwise_usage and out._columnwise_data is None:
                    reset_cache = True
1459
1460
1461
            if isinstance(out, DebugQuantizedTensor) != isinstance(quantizer, DebugQuantizer):
                reset_cache = True
            if reset_cache:
1462
                out = None
1463
                del self._fp8_workspaces[cache_name]
1464

1465
1466
1467
1468
1469
        # Gather cached Fp8 workspace if it's distributed
        # NOTE: FSDP sharding is supported only for Fp8 buffers and will not work
        #       for models initialized with Fp8 primary weights.
        if (
            out is not None
1470
            and tensor is not None
1471
            and fsdp_group is not None
1472
            and out.data.shape != tensor.data.shape
1473
1474
1475
1476
        ):
            _fsdp_gather_tensors(fsdp_group, [tensor.data.shape], out)

        # Construct workspace if needed
1477
        if out is None:
1478
            if tensor is None or quantizer is None:
1479
                raise ValueError(
1480
                    "tensor and quantizer kwargs must be provided to construct FP8 workspace"
1481
                )
1482
1483
1484
1485
1486
1487
1488

            if cache_name is not None:
                # Ensure the tensor in the cache is an instance of torch.Tensor,
                # as it persists beyond a single forward pass.
                # Setting internal=True would cause the data to be removed in prepare_for_saving(...).
                quantizer_internal = quantizer.internal
                quantizer.internal = False
1489
            out = quantizer.quantize(tensor, dtype=workspace_dtype)
1490
1491
            if cache_name is not None:
                quantizer.internal = quantizer_internal
1492
1493

            # Update cache
1494
1495
            if cache_name is not None:
                self._fp8_workspaces[cache_name] = out
1496
            return out
1497
1498
1499
1500
1501
1502

        # Update workspace if needed
        if skip_update_flag is not None:
            update_workspace = True
        if update_workspace:
            if tensor is None:
1503
                raise ValueError("tensor kwarg must be provided to update FP8 workspace")
1504
            if hasattr(out, "quantize_"):
1505
                out.quantize_(tensor, noop_flag=skip_update_flag)
1506
1507
            else:
                tex.quantize(tensor, quantizer, out, skip_update_flag)
1508
        return out
1509

1510
1511
1512
    def _load_from_state_dict(
        self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
    ):
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
        """
        This function loads tensors and extra state including fp8 metadata.
        This metadata is essential for copying fp8 tensors, as the copy_ function
        uses the scale_inv parameter from fp8_meta to set the correct scaling factor
        for the new tensor.
        Hence, this extra state must be loaded before the tensor copying process,
        not after, as is typically done in _load_from_state_dict.
        Tensors are copied into fp8 tensors only when self.primary_weights_in_fp8=True,
        otherwise, this behavior is not required.
        """
        if self.primary_weights_in_fp8:
            extra_state_key = prefix + torch.nn.modules.module._EXTRA_STATE_KEY_SUFFIX
            if extra_state_key in state_dict:
                self.set_extra_state(state_dict[extra_state_key])
1527
1528
1529
        super()._load_from_state_dict(
            state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
        )
1530

1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
    def register_wgrad_accumulation_and_reduce_hooks(self, wgrad_accumulation_and_reduce_hook):
        """
        This method is used to manually control the weight gradient accumulation and reduce.
        This method should be called before the backward() method.
        Set the skip_wgrad_accumulation_and_reduce to True to skip the weight gradient accumulation
        and reduce in backward();
        And register the wgrad_accumulation_and_reduce_func to be called in backward_dw() method.
        """
        self.wgrad_accumulation_and_reduce_hooks.append(wgrad_accumulation_and_reduce_hook)

1541
1542
1543
1544
1545
1546
1547
1548
1549
    def need_backward_dw(self):
        """
        Check if this module needs to execute the delayed weight gradient computation.
        This method should be used at the beginning of self.backward_dw() to determine if it
        should actually be executed or just return without doing anything.
        User can also manually call this method to check that before calling into backward_dw().
        """
        return self.wgrad_store is not None and self.wgrad_store.delay_wgrad_compute()

1550
1551
1552
1553
1554
    def backward_dw(self):
        """
        Execute the delayed weight gradient computation.
        This method is called after the main backward pass to compute weight gradients.
        """
1555
        if not self.need_backward_dw():
1556
            return
1557
        with get_nvtx_range_context(f"_{self.__class__.__name__}_wgrad"):
1558
            (wgrad, bgrad), _ = self.wgrad_store.pop()
1559
            if not self.fuse_wgrad_accumulation:
1560
                weight_tensor = noop_cat(self._get_weight_tensors())
1561
                weight_tensor.grad = wgrad.to(weight_tensor.dtype)
1562
1563
1564
            if self.use_bias:
                bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names])
                if bias_tensor.grad is None:
1565
                    bias_tensor.grad = bgrad.to(bias_tensor.dtype)
1566
1567
1568
1569
            del wgrad
            del bgrad
            for wgrad_accumulation_and_reduce_hook in self.wgrad_accumulation_and_reduce_hooks:
                wgrad_accumulation_and_reduce_hook()
1570

1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
    def is_debug_iter(self) -> bool:
        """
        This function checks if the debug should be enabled for this layer.
        """
        debug = TEDebugState.debug_enabled
        if not debug:
            return False

        # If layer is run first time in new iteration,
        # we need to check if the debug should be enabled for this layer -
        # maybe in previous iterations debug features returned information
        # that no feature will be active for this layer for multiple next iterations.
        started_new_iteration = TEDebugState.get_iteration() != getattr(
            self, "debug_last_iteration", None
        )
        if started_new_iteration:
            if self.next_iter_when_debug_should_be_run is None:
                debug = False
            else:
                debug = TEDebugState.get_iteration() >= self.next_iter_when_debug_should_be_run
1591
1592
            self.fast_setattr("debug_last_iteration", TEDebugState.get_iteration())
            self.fast_setattr("debug_enabled_in_this_iteration", debug)
1593
1594
1595
1596
1597
        else:
            # If this is the same iteration as previous invocation of the module,
            # we use the debug value from the first invocation in the iteration.
            debug = self.debug_enabled_in_this_iteration

1598
        self.fast_setattr("debug_last_iteration", TEDebugState.get_iteration())
1599
1600
1601
1602
1603

        if self.wgrad_store is not None:
            if debug and self.wgrad_store.delay_wgrad_compute():
                raise RuntimeError("Delayed wgrad compute is not supported in debug mode.")

1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
        return debug

    def no_debug_features_active(self, quantizers):
        """
        Checks if any debug feature is active for this layer.
        """
        run_current = any_feature_enabled(quantizers)

        # Sometimes features inform that they will not be enabled for particular layer
        # for multiple next iterations.
1614
1615
1616
        self.fast_setattr(
            "next_iter_when_debug_should_be_run", next_iter_when_debug_should_be_run(quantizers)
        )
1617
1618
1619
1620
1621
1622
1623

        if not run_current:
            return True

        if self.primary_weights_in_fp8:
            raise RuntimeError("FP8 weights are not supported in debug mode.")
        return False
1624

1625
1626
1627
    def _validate_name(self):
        """
        Validate name passed to the module.
1628
1629
        It creates a default name with layer count as the variable
        which may be changed by the user of the module.
1630
        """
1631
1632
        if self.name is not None:
            return
1633
1634

        self.name = f"Layer_{TEDebugState.get_layer_count()}"
1635

1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
    def _check_weight_tensor_recipe_correspondence(self) -> None:
        """
        Verify that the weight tensor types match their corresponding recipe type.
        This is invoked in the forward().

        This establishes a 1:1 correspondence between recipe types and tensor types:
        - DelayedScaling → Float8Tensor
        - Float8CurrentScaling → Float8Tensor
        - MXFP8BlockScaling → MXFP8Tensor
        - Float8BlockScaling → Float8BlockTensor

1647
1648
        Example case to check: recipe is DelayedScaling (DelayedScaling is set in autocast()),
        but the weight tensor is MXFP8Tensor (MXFP8BlockScaling is set in quantized_model_init()).
1649
1650
1651
        """
        if not self.fp8 and not self.fp8_calibration:
            return
1652
1653
        if not self.primary_weights_in_fp8:
            return
1654
1655
1656
1657
1658
1659
        if not hasattr(self, "weight_names") or not self.weight_names:
            return

        recipe = self.fp8_meta["recipe"]
        weight_tensors = [getattr(self, name) for name in self.weight_names]
        for i, tensor in enumerate(weight_tensors):
1660
            if isinstance(tensor, QuantizedTensorStorage):
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
                quantizer = tensor._get_quantizer()
                if quantizer is None:
                    continue
                compatible_recipe_class = quantizer._get_compatible_recipe()
                if compatible_recipe_class is None:
                    continue
                if not isinstance(recipe, compatible_recipe_class):
                    raise RuntimeError(
                        f"Recipe mismatch for '{self.weight_names[i]}': tensor supports recipe"
                        f" {compatible_recipe_class.__name__}, but got {recipe.__class__.__name__}."
1671
1672
                        " Please check the recipes assigned during quantized_model_init() and"
                        " autocast() calls."
1673
                    )