base.py 64.5 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.

"""Base modules and utilities for TransformerEngine PyTorch API"""
6
import io
7
import math
8
9
10
11
import os
import pickle
import warnings
from abc import ABC, abstractmethod
12
from typing import Any, Dict, Generator, List, Optional, Set, Tuple, Union
13
from contextlib import contextmanager
14
import logging
15
from types import MethodType
16
17
18
19

import torch
import torch.nn.functional as F

20
import transformer_engine_torch as tex
21
22
from transformer_engine.common.recipe import Recipe

23
from ._common import _ParameterInitMeta, noop_cat
24
from ..fp8 import (
25
26
    MXFP8BlockScalingRecipeState,
    DelayedScalingRecipeState,
27
    Float8CurrentScalingRecipeState,
28
    Float8BlockScalingRecipeState,
29
    FP8GlobalStateManager,
30
    RecipeState,
31
32
33
34
35
)
from ..distributed import (
    gather_along_first_dim,
    is_fp8_activation_recompute_enabled,
    in_fp8_activation_recompute_phase,
36
    _fsdp_gather_tensors,
37
38
)
from ..constants import dist_group_type
39
40
41
from ..tensor.quantized_tensor import QuantizedTensor, QuantizedTensorBase, Quantizer
from ..tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer
42
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
43
44
from ..tensor._internal.float8_tensor_base import Float8TensorBase
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
45
from ..utils import is_non_tn_fp8_gemm_supported, torch_get_autocast_gpu_dtype
46
from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
47
from ...common.recipe import DelayedScaling, Recipe
48
49
from ...debug.pytorch.debug_state import TEDebugState
from ...debug.pytorch.debug_quantization import DebugQuantizer, DebugQuantizedTensor
50
from ...debug.pytorch.utils import next_iter_when_debug_should_be_run, any_feature_enabled
51

52
53
__all__ = ["initialize_ub", "destroy_ub"]

54
55
56
_2X_ACC_FPROP = False
_2X_ACC_DGRAD = True
_2X_ACC_WGRAD = True
57
_multi_stream_cublas_workspace = []
58
_dummy_wgrads = {}
59
60
61
_cublas_workspace = None
_ub_communicators = None
_NUM_MAX_UB_STREAMS = 3
62
_MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY = None, None
63
layers_atomic_ring_exchange = []
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82


def get_cublas_workspace_size_bytes() -> None:
    """Return 32 MiB if using hopper, 4 MiB for all other architectures."""
    if torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 9:
        return 33_554_432
    return 4_194_304


def get_workspace() -> torch.Tensor:
    """Returns workspace for cublas."""
    global _cublas_workspace
    if _cublas_workspace is None:
        _cublas_workspace = torch.empty(
            get_cublas_workspace_size_bytes(), dtype=torch.uint8, device="cuda"
        )
    return _cublas_workspace


83
84
85
86
def get_multi_stream_cublas_workspace() -> List[torch.Tensor]:
    """Returns workspace for multi-stream cublas."""
    global _multi_stream_cublas_workspace
    if not _multi_stream_cublas_workspace:
87
        for _ in range(tex.get_num_cublas_streams()):
88
89
90
91
92
93
            _multi_stream_cublas_workspace.append(
                torch.empty(get_cublas_workspace_size_bytes(), dtype=torch.uint8, device="cuda")
            )
    return _multi_stream_cublas_workspace


94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
def get_dummy_wgrad(shape: list, dtype: torch.dtype, zero=False) -> torch.Tensor:
    """Returns a dummy tensor of given shape."""
    assert len(shape) == 2
    global _dummy_wgrads
    if (shape[0], shape[1], dtype) not in _dummy_wgrads:
        _dummy_wgrads[(shape[0], shape[1], dtype)] = torch.empty(
            shape,
            dtype=dtype,
            device="cuda",
            requires_grad=False,
        )
    if zero:
        _dummy_wgrads[(shape[0], shape[1], dtype)].fill_(0)
    return _dummy_wgrads[(shape[0], shape[1], dtype)].detach()


110
111
def initialize_ub(
    shape: list,
112
    tp_size: int,
113
    use_fp8: bool = False,
114
    dtype: torch.dtype = torch.bfloat16,
115
    ub_cfgs: Optional[dict] = None,
116
    bootstrap_backend: Union[str, torch.distributed.Backend] = None,
117
) -> None:
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
    r"""
    Initialize the Userbuffers communicator for overlapping tensor-parallel communications with
    GEMM compute in te.Linear, te.LayerNormLinear and te.LayerNormMLP modules.

    Parameters
    ----------
    shape : list
            shape of the communication buffer, typically set to be the same as the global shape of
            the input tensor to a te.TransformerLayer forward pass, with the sequence and batch
            dimensions collapsed together -- i.e.: `(sequence_length * batch_size, hidden_size)`
    tp_size : int
              number of GPUs in the tensor-parallel process group
    use_fp8 : bool = False
              allocate the communication buffer for FP8 GEMM inputs/outputs
    dtype : torch.dtype = torch.bfloat16
            non-FP8 data type of the communication buffer when `use_fp8 = False`
    ub_cfgs: dict = None
             Configuration dictionary with the structure
             ```
             {
                <gemm_name> : {
                    "method": <"ring_exchange" or "pipeline">,
                    "is_reduce_scatter": bool,
                    "num_sm": int,
                    "cga_size": int,
                    "set_sm_margin": bool,
                    "num_splits": int,
                    "aggregate": bool,
                    "atomic_gemm": bool,
                    "use_ce": bool,
                    "fp8_buf": bool,
                }
             }
             ```
             for `te.TransformerLayer` GEMM layers in `["qkv_fprop", "qkv_dgrad", "qkv_wgrad",
             "proj_fprop", "proj_dgrad", "proj_wgrad", "fc1_fprop", "fc1_dgrad", "fc2_dgrad",
             "fc2_fprop", "fc2_dgrad"]`.
    bootstrap_backend : str = None
                        `torch.distributed` communication backend for the all-gather, broadcast and
                        barrier collectives during Userbuffers initialization. Not all backends are
                        valid for every cluster configuration and distributed launch method even if
                        they are available in PyTorch. When left unset, the initialization prefers
                        to use the MPI backend, falling back first on Gloo and then NCCL if MPI is
                        not available. Setting `NVTE_UB_WITH_MPI=1` when building TE overrides this
                        option and always initializes Userbuffers with direct MPI calls in C++,
                        which also requires `MPI_HOME=/path/to/mpi/root` to be set at compile time.
    """
165
    if not tex.device_supports_multicast():
166
        assert bool(int(os.getenv("UB_SKIPMC", "0"))), (
167
168
169
170
            "CUDA device, driver and/or toolkit version does not support comm+GEMM overlap with "
            + "CUDA Multicast. Launch app with UB_SKIPMC=1 to try CUDA IPC instead."
        )

171
172
173
    global _ub_communicators
    assert _ub_communicators is None, "UB communicators are already initialized."
    _ub_communicators = {}
174
175

    if tex.ubuf_built_with_mpi():
176
177
        # We're bootstrapping with direct calls to MPI in Userbuffers code so we need to force
        # an MPI_Init() here by creating a new MPI process group...
178
        assert torch.distributed.is_mpi_available()
179
180
        _ = torch.distributed.new_group(backend="mpi")
        helper = tex.CommOverlapHelper()
181
    else:
182
183
        # Bootstrapping with torch.distributed API, so check backend and construct
        # intra/inter-node process groups...
184
185
186
187
188
        assert (
            torch.distributed.is_initialized()
        ), "torch.distributed must be initialized before Userbuffers"
        if bootstrap_backend is None:
            bootstrap_backend = "nccl"
189
            if torch.distributed.is_mpi_available():
190
                bootstrap_backend = "mpi"
191
192
            elif torch.distributed.is_gloo_available():
                bootstrap_backend = "gloo"
193
        else:
194
195
196
197
198
199
200
201
202
            assert bootstrap_backend in [
                "gloo",
                "mpi",
                "nccl",
            ], "Invalid torch.distributed backend for bootstrapping Userbuffers!"
            assert torch.distributed.is_backend_available(bootstrap_backend), (
                f"PyTorch must be compiled with '{bootstrap_backend}' support in order to "
                f"bootstrap Userbuffers with '{bootstrap_backend}' collectives."
            )
203
204
205
206
207

        world_group = torch.distributed.new_group(backend=bootstrap_backend)
        world_rank = torch.distributed.get_rank(world_group)
        world_size = torch.distributed.get_world_size(world_group)

208
209
        num_domains = world_size // tp_size
        mydomain_idx = world_rank // tp_size
210
        if num_domains > 1:
211
212
213
214
            ranks_per_domain_list = [
                [i * tp_size + t for t in range(tp_size)] for i in range(num_domains)
            ]
            tp_domain_group, _ = torch.distributed.new_subgroups_by_enumeration(
215
216
                ranks_per_domain_list, backend=bootstrap_backend
            )
217
218
            local_rank = torch.distributed.get_rank(tp_domain_group)
            tp_domain_ranks = torch.distributed.get_process_group_ranks(tp_domain_group)
219

220
            helper = tex.CommOverlapHelper(world_group, tp_domain_group)
221
        else:
222
223
            # TP model on single NVLink domain, no replication, no data-parallelism
            mydomain_idx = 0
224
            local_rank = world_rank
225
            tp_domain_ranks = list(range(world_size))
226
227

            helper = tex.CommOverlapHelper(world_group)
228

229
        if world_rank == 0:
230
            print(f"!!! [UB] Number of TP domains: {num_domains}\n", end="", flush=True)
231
232
        if local_rank == 0:
            print(
233
                f"!!! [UB] Global ranks on TP domain {mydomain_idx}: {tp_domain_ranks}\n",
234
235
236
237
                end="",
                flush=True,
            )

238
    # Allocate cuBLAS workspace with expanded size for chunking in overlapping GEMM calls
239
    global _cublas_workspace
240
241
242
243
244
245
246
    if _cublas_workspace is None:
        _cublas_workspace = get_workspace().repeat(_NUM_MAX_UB_STREAMS)
    elif _cublas_workspace.numel() != get_cublas_workspace_size_bytes() * _NUM_MAX_UB_STREAMS:
        # This ensures we don't do `.repeat()` on an already expanded workspace
        _cublas_workspace = torch.empty(
            get_cublas_workspace_size_bytes(), dtype=torch.uint8, device="cuda"
        ).repeat(_NUM_MAX_UB_STREAMS)
247
248

    # Default buffer precision: AllGather buffers use fp8 when using fp8 recipe
249
    layers_all_gather_overlap = [
250
251
252
253
254
255
        "qkv_fprop",
        "qkv_dgrad",
        "proj_dgrad",
        "fc1_fprop",
        "fc1_dgrad",
        "fc2_dgrad",
256
    ]
257
    layers_reduce_scatter_overlap = ["proj_fprop", "fc2_fprop", "qkv_wgrad", "fc1_wgrad"]
Jaemin Choi's avatar
Jaemin Choi committed
258
    dgrad_reduce_scatter_overlap = ["qkv_dgrad", "fc1_dgrad"]
259
260
    # Default overlap methods for layers
    methods = {
261
262
263
        "ring_exchange": ["qkv_fprop", "fc1_fprop", "proj_dgrad", "fc2_dgrad"],
        "pipeline": ["proj_fprop", "fc2_fprop"],
        "bulk": ["qkv_dgrad", "qkv_wgrad", "fc1_dgrad", "fc1_wgrad"],
264
265
    }

266
    # AG-RS overlap pairs of layers forming a tensor-parallel block
267
268
    ag_rs_pairs = {"qkv_fprop": "proj_fprop", "fc1_fprop": "fc2_fprop"}
    rs_ag_pairs = {v: k for k, v in ag_rs_pairs.items()}
269
270
271
    global layers_atomic_ring_exchange
    layers_atomic_ring_exchange = []

272
273
274
275
276
277
    def get_method(name):
        for method, names in methods.items():
            if name in names:
                return method
        raise KeyError(f"Given layer name {name} does not exist.")

278
    def get_default_config(name):
279
        global _MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY
280
281
        method = get_method(name)
        is_reduce_scatter = name in layers_reduce_scatter_overlap
282
283
        if _MIN_STREAM_PRIORITY is None or _MAX_STREAM_PRIORITY is None:
            _MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY = tex.get_stream_priority_range()
284
285
286
287
288
        default_cfg = {
            "method": method,
            "is_reduce_scatter": is_reduce_scatter,
            "num_sm": 1 if method == "ring_exchange" else 16,
            "cga_size": 1 if method == "ring_exchange" else 2,
289
290
            "set_sm_margin": not method == "ring_exchange",
            "num_splits": tp_size if method == "ring_exchange" else 4,
291
292
293
294
            "aggregate": False,
            "atomic_gemm": False,
            "use_ce": True,
            "fp8_buf": name in layers_all_gather_overlap,
295
296
297
            "comm_priority": _MAX_STREAM_PRIORITY,
            "gemm_priority": _MIN_STREAM_PRIORITY,
            "pipeline_rs_overlap_first_gemm": False,
298
299
300
        }
        return default_cfg

301
302
303
    def add_ub(
        name: str,
        method: str,
304
        is_reduce_scatter: bool,
305
306
        num_sm: int = 16,
        cga_size: int = 2,
307
        set_sm_margin: bool = False,
308
        num_splits: int = 0,
309
310
        aggregate: bool = False,
        atomic_gemm: bool = False,
311
        use_ce: bool = True,
312
        fp8_buf: bool = False,
313
314
315
        comm_priority: int = 0,
        gemm_priority: int = 0,
        pipeline_rs_overlap_first_gemm: bool = False,
316
    ) -> None:
317
318
319
320
321
        if atomic_gemm:
            warnings.warn(
                "Atomic GEMM uses a beta API from cublas and is not tested for all use cases."
            )
            assert use_fp8, "Atomic GEMM overlap supported only for FP8 GEMM."
322
            if method == "bulk":
323
                warnings.warn(
324
                    f"At {name}, atoimic GEMM not is supported for a bulk overlap."
325
326
327
                    "Defaulting to `atomic_gemm=False`."
                )
                atomic_gemm = 0
328
        if not is_reduce_scatter and method == "pipeline":
329
            raise ValueError(
330
                f"At {name}, `pipeline` overlap method is not supported for AllGather."
331
            )
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
        # Check if both AG and RS overlaps use `atomic GEMM`` + `p2p ring-exchange`.
        # Using atomic GEMM + p2p ring-exchange in only one of the pair breaks functionality.
        global layers_atomic_ring_exchange
        if atomic_gemm and method == "ring_exchange" and name in ag_rs_pairs:
            layers_atomic_ring_exchange += [name, ag_rs_pairs[name]]
        if name in rs_ag_pairs:
            assert_message = (
                f"At {name}, atomic AG-GEMM overlap with `ring_exchange` shuffles GEMM chunk "
                "outputs, and  RS-GEMM overlap un-suffle them. When one of the GEMM-AG and "
                "GEMM-RS overlaps forming a TP block (e.g., qkv_fprop and proj_fprop) uses "
                "`atomic gemm` and `ring_exhcnage`, its pair must use the same overlap config "
                "for functionality."
            )
            if name in layers_atomic_ring_exchange:
                assert atomic_gemm and method == "ring_exchange", assert_message
            else:
                if atomic_gemm and method == "ring_exchange":
                    assert rs_ag_pairs[name] in layers_atomic_ring_exchange, assert_message

351
        buffer_dtype = torch.uint8 if (use_fp8 and fp8_buf) else dtype
352
        if method == "ring_exchange":
353
354
355
356
            ub_obj = tex.CommOverlapP2P(
                shape,  # Communication buffer shape
                buffer_dtype,  # Communication buffer data type
                helper,  # Helper for torch.distributed callbacks during bootstrapping
357
                tp_size,  # Tensor-parallel group size (may be different than local_size)
358
359
360
361
362
363
364
365
                tex.CommOverlapType.RS if is_reduce_scatter else tex.CommOverlapType.AG,
                num_max_streams=_NUM_MAX_UB_STREAMS,
                comm_cga_size=cga_size,
                num_comm_sm=num_sm,
                set_sm_margin=set_sm_margin,
                atomic_gemm=atomic_gemm,
                use_ce=use_ce,
                aggregate=aggregate,
366
367
                gemm_priority=gemm_priority,
                comm_priority=comm_priority,
368
            )
369
        else:
370
371
372
373
            ub_obj = tex.CommOverlap(
                shape,  # Communication buffer shape
                buffer_dtype,  # Communication buffer data type
                helper,  # Helper for torch.distributed callbacks during bootstrapping
374
                tp_size,  # Tensor-parallel group size (may be different than local_size)
375
376
377
378
379
380
                num_splits=num_splits,
                num_max_streams=_NUM_MAX_UB_STREAMS,
                comm_cga_size=cga_size,
                num_comm_sm=num_sm,
                set_sm_margin=set_sm_margin,
                atomic_gemm=atomic_gemm,
381
382
383
                gemm_priority=gemm_priority,
                comm_priority=comm_priority,
                rs_overlap_first_gemm=pipeline_rs_overlap_first_gemm,
384
            )
385
386
        _ub_communicators[name] = ub_obj

Jaemin Choi's avatar
Jaemin Choi committed
387
388
    if ub_cfgs is not None:
        for name in dgrad_reduce_scatter_overlap:
389
390
            if name in ub_cfgs and "method" in ub_cfgs[name] and ub_cfgs[name]["method"] != "bulk":
                wgrad_name = name.replace("dgrad", "wgrad")
Jaemin Choi's avatar
Jaemin Choi committed
391
392
                assert wgrad_name not in ub_cfgs
                layers_reduce_scatter_overlap.remove(wgrad_name)
393
                layers_all_gather_overlap.remove(name)
Jaemin Choi's avatar
Jaemin Choi committed
394
                layers_reduce_scatter_overlap.append(name)
395
396
397
                methods["bulk"].remove(name)
                new_method = ub_cfgs[name]["method"]
                methods[new_method].append(name)
Jaemin Choi's avatar
Jaemin Choi committed
398

399
    for name in methods["ring_exchange"] + methods["pipeline"] + methods["bulk"]:
400
        ub_cfg = get_default_config(name)
401
        if ub_cfgs is not None and name in ub_cfgs:
402
            fp8_buf = (name in layers_all_gather_overlap) or (
403
                ub_cfgs[name].get("fp8_buf", False) and name in methods["pipeline"]
404
            )
405
406
407
            ub_cfg.update(ub_cfgs[name])
            ub_cfg["fp8_buf"] = fp8_buf
        add_ub(name, **ub_cfg)
408
409
410
411
412
413
414
415


def get_ub(name: str):
    """Get userbuffer communicator corresponding to give key."""
    assert _ub_communicators is not None, "UB manager is not initialized."
    assert name in _ub_communicators, f"UB for {name} is not registered."
    return _ub_communicators[name]

416

417
418
419
420
421
422
423
def destroy_ub():
    """Destroy all allocated userbuffer communicators."""
    global _ub_communicators
    _ub_communicators = None
    global layers_atomic_ring_exchange
    layers_atomic_ring_exchange = []

424

425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
def fill_userbuffers_buffer_for_all_gather(
    comm,
    local_tensor: torch.Tensor,
    quantizer: Optional[Quantizer],
    process_group,
) -> tuple[torch.Tensor | QuantizedTensorBase, torch.Tensor | QuantizedTensorBase]:
    """Fill local shard of Userbuffers buffer with data for all-gather

    Returns the full tensor and the local shard, both using the
    Userbuffers buffer as their underlying data. These tensors should
    be used carefully (e.g. only immediately before and after a
    Userbuffers operation) since the underlying data may be
    overwritten by other Userbuffers operations.

    May perform blocking communication if needed for the gathered
    tensor's metadata, e.g. scaling factors.

    """

    # Tensor dimensions
    local_shape = local_tensor.size()
    if not local_shape:
        raise ValueError(f"Invalid local tensor (shape={tuple(local_shape)})")
    process_group_size = torch.distributed.get_world_size(process_group)
    global_shape = list(local_shape)
    global_shape[0] *= process_group_size

    # Unquantized data
    if quantizer is None:
        if isinstance(local_tensor, QuantizedTensorBase):
            local_tensor = local_tensor.dequantize()
        if comm.is_fp8_ubuf():
            raise RuntimeError(
                "Attempting to all-gather unquantized tensor, "
                "but Userbuffers is initialized with FP8 buffers"
            )
        comm.copy_into_buffer(local_tensor, local_chunk=True)
        global_tensor = comm.get_buffer(shape=global_shape)
        return global_tensor, local_tensor

    # FP8 data
    if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)):
        if not isinstance(local_tensor, Float8TensorBase):
            if isinstance(local_tensor, QuantizedTensorBase):
                local_tensor.dequantize()
            quantizer.set_usage(rowwise=True, columnwise=False)
            local_tensor = quantizer(local_tensor)
        if not comm.is_fp8_ubuf():
            raise RuntimeError(
                "Attempting to all-gather FP8 tensor, "
                "but Userbuffers is not initialized with FP8 buffers"
            )
        comm.copy_into_buffer(local_tensor._data, local_chunk=True)
        global_tensor_data = comm.get_buffer(shape=global_shape)
        global_tensor = Float8TensorBase(
            data=global_tensor_data,
            fp8_scale_inv=local_tensor._scale_inv,
            fp8_dtype=local_tensor._fp8_dtype,
            quantizer=quantizer,
        )
        return global_tensor, local_tensor

    # MXFP8 data
    if isinstance(quantizer, MXFP8Quantizer):

        # Cast to MXFP8 if needed
        if not isinstance(local_tensor, MXFP8TensorBase):
            if isinstance(local_tensor, QuantizedTensorBase):
                local_tensor.dequantize()
            local_tensor = quantizer(local_tensor)
        if not comm.is_fp8_ubuf():
            raise RuntimeError(
                "Attempting to all-gather MXFP8 tensor, "
                "but Userbuffers is not initialized with FP8 buffers"
            )

        # Check which MXFP8 buffer to communicate
        if quantizer.rowwise_usage == quantizer.columnwise_usage:
            raise ValueError(
                "Userbuffers can only communicate one MXFP8 buffer at a time, "
                f"but quantizer has rowwise_usage={quantizer.rowwise_usage}, "
                f"columnwise_usage={quantizer.columnwise_usage}"
            )
        with_rowwise_data = quantizer.rowwise_usage

        # Copy MXFP8 data to local chunk of Userbuffers buffer
        local_data = (
            local_tensor._rowwise_data if with_rowwise_data else local_tensor._columnwise_data
        )
        comm.copy_into_buffer(local_data, local_chunk=True)

        # Gather scaling-inverses
        if math.prod(local_shape[:-1]) % 128 != 0:
            raise ValueError(
                "Userbuffers requires MXFP8 tensor dims that are divisible by 128, "
                f"but got MXFP8 tensor with shape={tuple(local_shape)}"
            )
        local_scale_inv = (
            local_tensor._rowwise_scale_inv
            if with_rowwise_data
            else local_tensor._columnwise_scale_inv
        )
        local_scale_inv_size = list(local_scale_inv.size())
        global_scale_inv = torch.empty(
            [process_group_size * local_scale_inv_size[0]] + local_scale_inv_size[1:],
            dtype=local_scale_inv.dtype,
            device=local_scale_inv.device,
        )
        torch.distributed.all_gather_into_tensor(
            global_scale_inv,
            local_scale_inv,
            group=process_group,
        )

        # Construct MXFP8 tensor with Userbuffers buffer
        rowwise_data, rowwise_scale_inv = None, None
        columnwise_data, columnwise_scale_inv = None, None
        global_data = comm.get_buffer(shape=global_shape)
        if with_rowwise_data:
            rowwise_data, rowwise_scale_inv = global_data, global_scale_inv
        else:
            columnwise_data, columnwise_scale_inv = global_data, global_scale_inv
        global_tensor = MXFP8TensorBase(
            rowwise_data=rowwise_data,
            rowwise_scale_inv=rowwise_scale_inv,
            columnwise_data=columnwise_data,
            columnwise_scale_inv=columnwise_scale_inv,
            fp8_dtype=local_tensor._fp8_dtype,
            quantizer=quantizer,
        )
        return global_tensor, local_tensor

    # Unsupported data format
    raise ValueError(f"Unsupported quantizer for Userbuffers ({quantizer})")


561
562
563
564
565
566
class TransformerEngineBaseModule(torch.nn.Module, ABC):
    """Base TE module."""

    def __init__(self) -> None:
        super().__init__()
        assert torch.cuda.is_available(), "TransformerEngine needs CUDA."
567
        self.name = None
568
        self.next_iter_when_debug_should_be_run = 0
569
570
571
572
        self.fp8_initialized = False
        self.fp8 = False
        self.fp8_calibration = False
        self.fp8_meta = {}
573
        self.fp8_meta["fp8_checkpoint"] = False
574
575
        self.fp8_meta["fp8_group"] = None
        self.fp8_meta_tensors_initialized = False
576
        self.quantizers = {"scaling_fwd": {}, "scaling_bwd": {}}
577
578
579
        self.tp_group = None
        self.tp_size = 1
        self.sequence_parallel = False
580
581
        self.param_init_meta = {}
        self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters()
582
        self.preserve_high_precision_init_val = FP8GlobalStateManager.with_high_precision_init_val()
583
584
        self.fsdp_wrapped = False
        self.fsdp_group = None
585
        self._fp8_workspaces: Dict[str, QuantizedTensor] = {}
586
        self.activation_dtype: Optional[torch.dtype] = None
587
        self.wgrad_accumulation_and_reduce_hooks = []
588

589
590
591
        if not TEDebugState.debug_enabled:
            TEDebugState.initialize()

592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
    # Names of attributes that can be set quickly (see __setattr__
    # method)
    _fast_setattr_names: Set[str] = {
        "activation_dtype",
        "fp8",
        "fp8_initialized",
        "fp8_calibration",
        "fp8_parameters",
    }

    def __setattr__(self, name: str, value: Any) -> None:
        if name in TransformerEngineBaseModule._fast_setattr_names:
            # torch.nn.Module has a custom __setattr__ that handles
            # modules, parameters, and buffers. This is unnecessary
            # overhead when setting plain attrs.
            self.__dict__[name] = value
        else:
            # Default case
            super().__setattr__(name, value)
611

612
    def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> None:
613
614
615
616
        """
        Delayed scaling only.

        Increase or decrease size of amax history based on given `length`.
617
618
619
620
621
622
623
624
625
626

        .. warning::
            This changes the underlying amax memory location.
        """
        if fwd is None:
            fp8_meta_tensor_keys = ("scaling_fwd", "scaling_bwd")
        else:
            fp8_meta_tensor_keys = ("scaling_fwd" if fwd else "scaling_bwd",)

        for meta_key in fp8_meta_tensor_keys:
627
628
629
            if meta_key not in self.fp8_meta:
                # Handles non-parameter FP8 modules, e.g. DPA.
                continue
630
631
632
633
634
            curr_len = self.fp8_meta[meta_key].amax_history.shape[0]
            if length == curr_len:
                continue
            if length < curr_len:
                self.fp8_meta[meta_key].amax_history = (
635
636
                    self.fp8_meta[meta_key].amax_history[:length].clone()
                )
637
638
639
640
641
642
            elif length > curr_len:
                extra_rows = length - curr_len
                self.fp8_meta[meta_key].amax_history = F.pad(
                    self.fp8_meta[meta_key].amax_history, pad=(0, 0, 0, extra_rows)
                )

643
644
            # Update quantizers with new amax pointers.
            self.quantizers[meta_key] = self.fp8_meta[meta_key].make_quantizers()
645
646
            # Make sure weight tensors has correct quantizers
            self._update_weight_quantizers()
647

648
649
            # Update the global buffers with new amax and history pointers.
            if FP8GlobalStateManager.get_buffer_info() in self.fp8_meta:
650
651
652
                fwd_pos, fwd_key, bwd_pos, bwd_key = self.fp8_meta[
                    FP8GlobalStateManager.get_buffer_info()
                ]
653
654
655
656
657
                for pos, buffer_key in zip((fwd_pos, bwd_pos), (fwd_key, bwd_key)):
                    if buffer_key in FP8GlobalStateManager.global_amax_buffer:
                        assert (
                            buffer_key in FP8GlobalStateManager.global_amax_history_buffer
                        ), "TE internal error during amax history change."
658
659
660
                        FP8GlobalStateManager.global_amax_buffer[buffer_key][pos] = self.fp8_meta[
                            meta_key
                        ].amax_history[0]
661
                        FP8GlobalStateManager.global_amax_history_buffer[buffer_key][pos] = (
662
663
                            self.fp8_meta[meta_key].amax_history
                        )
664

665
    def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None:
666
667
668
        """Init scales and amaxes for fwd | bwd."""
        fp8_meta_tensor_key = "scaling_fwd" if fwd else "scaling_bwd"

669
        # Return early if recipe state matches recipe
670
        if self.fp8_meta_tensors_initialized:
671
672
673
674
675
676
            recipe_state = self.fp8_meta[fp8_meta_tensor_key]
            if recipe.delayed() and isinstance(recipe_state, DelayedScalingRecipeState):
                self.adjust_amax_history_length(recipe.amax_history_len, fwd=fwd)
                return
            if recipe.mxfp8() and isinstance(recipe_state, MXFP8BlockScalingRecipeState):
                return
677
678
679
680
            if recipe.float8_current_scaling() and isinstance(
                recipe_state, Float8CurrentScalingRecipeState
            ):
                return
681
682
683
684
            if recipe.float8_block_scaling() and isinstance(
                recipe_state, Float8BlockScalingRecipeState
            ):
                return
685
686
687

        # Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and
        # 2 (grad_output and grad_input) for bwd
688
        num_fp8_tensors = self.fp8_meta["num_gemms"] * 3 if fwd else self.fp8_meta["num_gemms"] * 2
689

690
691
692
693
694
        # Initialize recipe state and quantizers
        recipe_state = RecipeState.create(
            recipe,
            mode=("forward" if fwd else "backward"),
            num_quantizers=num_fp8_tensors,
695
696
        )

697
698
699
        self.fp8_meta[fp8_meta_tensor_key] = recipe_state
        self.quantizers[fp8_meta_tensor_key] = recipe_state.make_quantizers()

700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
    def _update_weight_quantizers(self) -> None:
        """Update the quantizers for the weight tensors."""
        weight_tensors = self._get_weight_tensors()
        weight_quantizers = self._get_weight_quantizers()
        assert len(weight_tensors) == len(weight_quantizers), (
            f"Number of weight tensors ({len(weight_tensors)}) and quantizers "
            f"({len(weight_quantizers)}) must match"
        )
        for weight, quantizer in zip(weight_tensors, weight_quantizers):
            if quantizer is not None and isinstance(weight, QuantizedTensorBase):
                weight.update_quantizer(quantizer)

    def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]:
        """Get the weight tensors of the module."""
        raise NotImplementedError(
            f"{self.__class__.__name__} class does not implement _get_weight_tensors function"
        )

    def _get_weight_quantizers(self) -> List[Quantizer]:
        """Get the weight quantizers of the module."""
        raise NotImplementedError(
            f"{self.__class__.__name__} class does not implement _get_weight_quantizers function"
        )

724
    def init_fp8_meta_tensors(self, recipe: Recipe) -> None:
725
        """Init scales and amaxes."""
726
727
728
        self.set_meta_tensor(True, recipe)
        self.set_meta_tensor(False, recipe)

729
730
        self.fp8_meta_tensors_initialized = True

731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
    def get_fp8_meta_tensors(self) -> None:
        """Get scales and amaxes."""
        fwd_key, bwd_key = "scaling_fwd", "scaling_bwd"
        if fwd_key not in self.fp8_meta or bwd_key not in self.fp8_meta:
            return None

        fp8_meta_tensors = {fwd_key: [], bwd_key: []}
        with torch.no_grad():
            for key in (fwd_key, bwd_key):
                fp8_meta_tensors[key].append(self.fp8_meta[key].scale.clone())
                fp8_meta_tensors[key].append(self.fp8_meta[key].amax_history.clone())
        return fp8_meta_tensors

    def reset_fp8_meta_tensors(self, fp8_meta_tensors=None) -> None:
        """Reset scales and amaxes."""
746

747
748
749
750
751
        def reset(key):
            if key in self.fp8_meta:
                if fp8_meta_tensors is None:
                    self.fp8_meta[key].scale.copy_(torch.ones_like(self.fp8_meta[key].scale))
                    self.fp8_meta[key].amax_history.copy_(
752
753
                        torch.zeros_like(self.fp8_meta[key].amax_history)
                    )
754
755
756
                else:
                    assert key in fp8_meta_tensors, "Cannot reset fp8 tensors."
                    self.fp8_meta[key].scale.copy_(fp8_meta_tensors[key][0])
757
                    self.fp8_meta[key].amax_history.copy_(fp8_meta_tensors[key][1])
758

759
760
761
762
        with torch.no_grad():
            reset("scaling_fwd")
            reset("scaling_bwd")

763
    def get_extra_state(self) -> torch.Tensor:
764
        """Save before checkpointing."""
765

766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
        # This implementation is working around a few issues:
        #
        # (1) PyTorch's "extra state" infrastructure might be able to
        #     support any picklable type, but they make no guarantees.
        #     We have experienced problems (e.g. in ONNX export) with
        #     non-tensor extra state.
        # (2) PyTorch's checkpointing infrastructure does not remap
        #     devices for "extra state" like it does for "state dict".
        #     Thus, we want to avoid putting extra state on the GPU
        #     since it may be loaded on the wrong device.
        # (3) The extra state consists of many small tensors. If we
        #     want to copy them all to CPU, then we need to avoid the
        #     overhead of many GPU-CPU memory transfers.
        #
        # See: https://github.com/NVIDIA/TransformerEngine/pull/351
        # See: https://github.com/NVIDIA/TransformerEngine/pull/363

        def to_cpu(src: torch.Tensor) -> torch.Tensor:
            """Helper function to make CPU copy of tensor

            Memory transfer is asynchronous w.r.t. host, so GPU should
            be synchronized before using result.

            """
            dst = torch.empty_like(src, device="cpu")
            dst.copy_(src, non_blocking=True)
            return dst

        # Store FP8 state if needed
        state = None
796
        fp8_checkpoint = self.fp8_meta["fp8_checkpoint"] or self.fp8 or self.fp8_calibration
797
        if not fp8_checkpoint:
798
            return torch.empty(0, dtype=torch.uint8)
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816

        # Copy tensors to CPU and store
        state = {}
        state["recipe"] = self.fp8_meta["recipe"]
        if state["recipe"].delayed():
            state["scale_fwd"] = to_cpu(self.fp8_meta["scaling_fwd"].scale)
            state["amax_history_fwd"] = to_cpu(self.fp8_meta["scaling_fwd"].amax_history)
            state["scale_bwd"] = to_cpu(self.fp8_meta["scaling_bwd"].scale)
            state["amax_history_bwd"] = to_cpu(self.fp8_meta["scaling_bwd"].amax_history)

        # Store other pickelable values
        extra = {}
        for k, v in self.fp8_meta.items():
            if k != "buffer_index_and_autocast_key" and isinstance(
                v, (bool, int, float, str, tuple, list)
            ):
                extra[k] = v
        state["extra_fp8_variables"] = extra
817

818
819
820
821
        # Serialize state into byte tensor
        torch.cuda.synchronize()
        state_serialized = bytearray(pickle.dumps(state))
        state_serialized = torch.frombuffer(state_serialized, dtype=torch.uint8)
822
        return state_serialized
823

824
    def set_extra_state(self, state: torch.Tensor) -> None:
825
        """Load previous state."""
826
827
828
829
830

        # Maintain backwards compatibility with older checkpoints.
        if state is None:
            return

831
        # Load state
832
        if isinstance(state, torch.Tensor):
833
834
835
            # No FP8 is indicated by an empty tensor we don't need to unpickle.
            if state.numel() == 0:
                return
836
            # Default format: byte tensor with pickled data
837
            state = pickle.loads(state.detach().cpu().numpy().tobytes())
838
        elif isinstance(state, io.BytesIO):
839
            # Deprecated format with io.BytesIO
840
            state.seek(0)
841
            state = torch.load(state, map_location="cuda")
842
843
        else:
            raise RuntimeError("Unsupported checkpoint format.")
844
845
846

        if state is None:
            return
847

848
849
850
851
852
853
854
855
        # TE 1.x checkpoint compatibility: add DelayedScaling recipe if missing
        if "recipe" not in state:
            # TE 1.x only supported delayed scaling, which was the default recipe
            state["recipe"] = DelayedScaling()
            # TE 1.x also saved scale_inv, which is not needed with Recipe object
            state.pop("scale_inv_fwd", None)
            state.pop("scale_inv_bwd", None)

856
        # Load extra items
857
        self.fp8_meta.update(state["extra_fp8_variables"])
858
        self.fp8_meta["recipe"] = state["recipe"]
859
860
861
        if "global_fp8_buffer_pos_fwd_recompute" in self.fp8_meta:
            del self.fp8_meta["global_fp8_buffer_pos_fwd_recompute"]

862
        # Initialize before loading
863
        self.init_fp8_meta_tensors(self.fp8_meta["recipe"])
864
865
866
867
868
869
870
871
872
873
874

        def copy_tensor(src: torch.Tensor, dst: torch.Tensor) -> None:
            """Helper function to copy tensor from CPU

            Memory transfer is asynchronous w.r.t. host, so GPU should
            be synchronized before using result.

            """
            dst.copy_(src, non_blocking=True)

        # Load tensors
875
876
877
878
879
        if self.fp8_meta["recipe"].delayed():
            copy_tensor(state["scale_fwd"], self.fp8_meta["scaling_fwd"].scale)
            copy_tensor(state["amax_history_fwd"], self.fp8_meta["scaling_fwd"].amax_history)
            copy_tensor(state["scale_bwd"], self.fp8_meta["scaling_bwd"].scale)
            copy_tensor(state["amax_history_bwd"], self.fp8_meta["scaling_bwd"].amax_history)
880
        torch.cuda.synchronize()
881
882
883
884
885

    def set_activation_dtype(self, inp: torch.Tensor) -> None:
        """Get activation data type for AMP."""
        # Native AMP (`torch.autocast`) gets highest priority
        if torch.is_autocast_enabled():
886
            self.activation_dtype = torch_get_autocast_gpu_dtype()
887
888
889
            return

        # All checks after this have already been performed once, thus skip
890
        if self.activation_dtype == inp.dtype:
891
892
            return

893
894
895
896
897
898
899
900
        dtype = inp.dtype
        for name, param in self.named_parameters():
            if param is not None:
                assert dtype == param.dtype, (
                    "Data types for parameters must match when outside of autocasted region. "
                    f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}"
                )
        self.activation_dtype = dtype
901
902

    def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
903
904
905
906
907
908
909
910
911
        """
        Set the tensor parallel group for the given
        module before executing the forward pass.

        Parameters
        ----------
        tp_group : ProcessGroup, default = `None`
                  tensor parallel process group.
        """
912
913
914
        self.tp_group = tp_group
        self.tp_group_initialized = True

915
916
917
    def _get_fp8_params(self) -> Union[List[torch.Tensor], None]:
        """returns the FP8 weights."""
        fp8_params = []
918
        for param in self.parameters(recurse=False):
919
            if isinstance(param, QuantizedTensor) and param.requires_grad:
920
921
922
923
924
                fp8_params.append(param)
        if len(fp8_params) == 0:
            return None
        return fp8_params

925
926
    # This routine is shared across FP8 and FP8_calibration paths so should not actually
    # assume FP8 execution.
927
    def init_fp8_metadata(self, num_gemms: int = 1) -> None:
928
        """Initialize fp8 related metadata and tensors during fprop."""
929
930
        _original_recipe = self.fp8_meta.get("recipe", None)

931
        self.fp8_parameters = FP8GlobalStateManager.with_fp8_parameters()
932
933
        self.fp8 = FP8GlobalStateManager.is_fp8_enabled()
        self.fp8_calibration = FP8GlobalStateManager.is_fp8_calibration()
934
        fp8_enabled = self.fp8 or self.fp8_calibration
935
        self.fp8_meta["fp8_checkpoint"] = self.fp8 or self.fp8_calibration
936

937
        if self.fp8_parameters or fp8_enabled:
938
939
940
941
            if (
                self.fp8_initialized
                and FP8GlobalStateManager.get_fp8_recipe() == self.fp8_meta["recipe"]
            ):
942
                # FP8 init has already been run and recipe is the same, don't do anything.
943
                return
944
            self.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe()
945
946
947
948
949
950
951
952
953
954
955
        else:
            # If fp8 isn't enabled, turn off and return.
            self.fp8_initialized = False
            return

        if self.fp8_parameters and not self.fp8_initialized:
            self.fp8_meta["num_gemms"] = num_gemms
            self.init_fp8_meta_tensors(self.fp8_meta["recipe"])

        if fp8_enabled:
            # Set FP8 and other FP8 metadata
956
            self.fp8_meta["num_gemms"] = num_gemms
957
            self.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group()
958
959
960
961
962
963

            # Set FP8_MAX per tensor according to recipe
            self.fp8_meta["fp8_max_fwd"] = self.fp8_meta["recipe"].fp8_format.value.max_fwd
            self.fp8_meta["fp8_max_bwd"] = self.fp8_meta["recipe"].fp8_format.value.max_bwd

            # Allocate scales and amaxes
964
            self.init_fp8_meta_tensors(self.fp8_meta["recipe"])
965
            self.fp8_initialized = True
966
967

            self.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe()
968

969
970
971
972
973
974
975
976
977
978
979
980
981
        _current_recipe = self.fp8_meta["recipe"]
        if _original_recipe is not None and not (
            issubclass(_current_recipe.__class__, _original_recipe.__class__)
            or issubclass(_original_recipe.__class__, _current_recipe.__class__)
        ):
            warnings.warn(
                f"Recipe type changed from {_original_recipe.__class__.__name__} "
                f"to {_current_recipe.__class__.__name__}. "
                "This may affect model behavior."
            )
            # Clear cached workspaces as they were created with the old recipe/quantizer type
            self._fp8_workspaces.clear()

982
983
984
985
986
    @contextmanager
    def prepare_forward(
        self,
        inp: torch.Tensor,
        num_gemms: int = 1,
987
        allow_non_contiguous: bool = False,
Jan Bielak's avatar
Jan Bielak committed
988
    ) -> Generator[torch.Tensor, None, None]:
989
990
991
992
993
994
        """Checks and prep for FWD.
        The context manager is needed because there isn't a way for a module to know
        if it's the last FP8 module in the forward autocast. It is useful
        to setup the forward aggregated amax reduction for every module
        just in case. The autocast exit will pick up the most recent one.
        """
995
        self.forwarded_at_least_once = True
996
997
        # Activation recomputation is used and this is the second forward phase.
        if self.fp8 and in_fp8_activation_recompute_phase():
998
            FP8GlobalStateManager.get_old_fp8_meta_tensors_for_recompute(self.fp8_meta)
999
1000
1001
1002
1003
1004
1005
        else:
            assert inp.is_cuda, "TransformerEngine needs CUDA."

            if self.tp_size > 1:
                assert self.tp_group_initialized, "TP group not initialized."

            self.set_activation_dtype(inp)
1006
            self.init_fp8_metadata(num_gemms=num_gemms)
1007
            self._check_weight_tensor_recipe_correspondence()
1008

1009
            if self.fp8 and self.sequence_parallel and self.fp8_meta["recipe"].delayed():
1010
1011
1012
1013
                assert self.fp8_meta["recipe"].reduce_amax, (
                    "Amax reduction across tensor parallel group is "
                    "necessary when using sequence parallelism with FP8."
                )
1014

1015
            if self.fp8 and not FP8GlobalStateManager.fp8_graph_capturing():
1016
                FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(self.fp8_meta)
1017
1018

            # Activation recomputation is used and this is the first forward phase.
1019
            if self.fp8 and self.training and is_fp8_activation_recompute_enabled():
1020
                FP8GlobalStateManager.copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta)
1021
1022

        with torch.cuda.nvtx.range(self.__class__.__name__ + " forward"):
1023
1024
1025
            if not allow_non_contiguous and not inp.is_contiguous():
                inp = inp.contiguous()
            yield inp
1026
1027

        if self.fp8 and in_fp8_activation_recompute_phase():
1028
            FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta)
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048

    def set_nccl_overlap_warning_if_tp(self) -> None:
        """When using TP, the NCCL communication needs to be scheduled
        before the GEMM for there to be a guaranteed overlap. From the
        host side in TE, the comm calls are always launched first, but
        to ensure that the GEMM isn't scheduled first, the environment
        variable `CUDA_DEVICE_MAX_CONNECTIONS` needs to be set to 1 to
        force a single channel.
        """
        if self.tp_size == 1:
            return
        num_cuda_work_queues = int(os.getenv("CUDA_DEVICE_MAX_CONNECTIONS", "0"))
        if num_cuda_work_queues != 1:
            warnings.warn(
                "To guarantee overlapping TP and SP collectives with the backward"
                "GEMMs, set environment variable CUDA_DEVICE_MAX_CONNECTIONS = 1"
            )

    @staticmethod
    def grad_output_preprocess(
1049
1050
1051
1052
        ctx,
        grad_output: torch.Tensor,
        row_parallel_mode: bool,
        quantizer: Optional[Quantizer],
1053
1054
1055
    ) -> Tuple[Union[torch.Tensor, None], ...]:
        """Utility function for backward.
        Returns tuple in order (all optional/None based on training precion/recipe):
1056
1057
            R1: gathered `grad_output`.
            R2: bias gradient on R1.
1058
1059

        """
1060
1061
        grad_output = grad_output.reshape((-1, grad_output.shape[-1]))
        grad_output = grad_output.contiguous()
1062
1063
        gather_grad_output = row_parallel_mode and ctx.sequence_parallel

1064
        # Non-FP8 case: bgrad is fused with wgrad for this case.
1065
        if not ctx.fp8 and not ctx.debug:
1066
            if gather_grad_output:
1067
                if not ctx.ub_overlap_ag:  # Perform NCCL all-gather
1068
                    grad_output, _ = gather_along_first_dim(grad_output, ctx.tp_group)
1069
1070
1071
1072
1073
1074
1075
                else:  # Initialize Userbuffers all-gather
                    grad_output, _ = fill_userbuffers_buffer_for_all_gather(
                        ctx.ub_obj_gradout,
                        grad_output,
                        None,
                        ctx.tp_group,
                    )
1076
1077
1078
            return grad_output, None

        # FP8 with all-gather: unfused bgrad, fused cast + transpose
1079
        # Also supports debug quantization, which is handled inside gather_along_first_dim.
1080
1081
        if gather_grad_output:
            grad_bias = None
1082
            if ctx.use_bias:
1083
                grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0)
1084
            if ctx.ub_overlap_ag:
1085
1086
                # Quantize the gradient if needed
                if not isinstance(
1087
1088
1089
1090
1091
1092
1093
                    grad_output,
                    (
                        QuantizedTensor,
                        Float8TensorBase,
                        MXFP8TensorBase,
                        Float8BlockwiseQTensorBase,
                    ),
1094
1095
1096
1097
                ):
                    grad_output = quantizer(grad_output)

                # Copy into communication buffer, and replace original gradient with it
1098
1099
1100
1101
1102
1103
                grad_output, _ = fill_userbuffers_buffer_for_all_gather(
                    ctx.ub_obj_gradout,
                    grad_output,
                    quantizer,
                    ctx.tp_group,
                )
1104
            else:
1105
1106
1107
1108
                grad_output, _ = gather_along_first_dim(
                    grad_output,
                    ctx.tp_group,
                    quantizer=quantizer,
1109
                )
1110
            return grad_output, grad_bias
1111

1112
1113
1114
1115
1116
1117
1118
        # Debug without all-gather: unfused cast and bgrad
        # bgrad only if wgrad is in FP8, otherwise it is fused with wgrad and we return None
        if ctx.debug:
            grad_output_ = quantizer(grad_output)
            if (
                isinstance(
                    grad_output_.get_tensor(True),
1119
1120
1121
1122
1123
1124
                    (
                        QuantizedTensor,
                        Float8TensorBase,
                        MXFP8TensorBase,
                        Float8BlockwiseQTensorBase,
                    ),
1125
1126
1127
1128
1129
1130
1131
1132
1133
                )
                and ctx.use_bias
            ):
                grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0)
            else:
                grad_bias = None
            grad_output = grad_output_
            return grad_output, grad_bias

1134
1135
        # FP8 without all-gather: fused bgrad + cast + transpose
        grad_bias = None
1136
        if ctx.use_bias:
1137
1138
1139
1140
            if isinstance(
                grad_output,
                (QuantizedTensor, Float8TensorBase, MXFP8TensorBase, Float8BlockwiseQTensorBase),
            ):
1141
                grad_bias = grad_output.dequantize().view(-1, grad_output.shape[-1]).sum(dim=0)
1142
            else:
1143
1144
1145
1146
1147
1148
1149
1150
1151
                if isinstance(quantizer, Float8BlockQuantizer):
                    # unfuse bgrad for now until cast_transpose + dgrad calculation is ready for Float8BlockQuantizer.
                    grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0)
                else:
                    grad_bias, grad_output = tex.bgrad_quantize(grad_output, quantizer)
        if not isinstance(
            grad_output,
            (QuantizedTensor, Float8TensorBase, MXFP8TensorBase, Float8BlockwiseQTensorBase),
        ):
1152
1153
            grad_output = quantizer(grad_output)
        return grad_output, grad_bias
1154

1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
    def register_parameter(self, name, param, **kwargs):
        """
        Thin wrapper around PyTorch parameter registration to stash additional parameter
        metedata used in deferred initialization.
        """
        super().register_parameter(name, param)
        self.param_init_meta[name] = _ParameterInitMeta(**kwargs)

    def reset_parameters(self, defer_init: Optional[bool] = False) -> None:
        """
        Reset all module parameters to initial values. Unless deferred initialization
        is specified, all parameters on a 'meta' device are also materialized on a real cuda
        device before the values are reset to initial.
        """
        if defer_init:
            return

        for name, param in self.named_parameters(recurse=False):
            # Ensure parameter is on a real device
1174
1175
            if param.device == torch.device("meta"):
                param = torch.empty_like(param, device="cuda")
1176
1177
1178
1179
1180
1181
1182

            # Initialize the parameter values on device
            init_fn = self.param_init_meta[name].init_fn
            get_rng_state_tracker = self.param_init_meta[name].get_rng_state_tracker
            if get_rng_state_tracker is None:
                init_fn(param)
            else:
1183
1184
1185
1186
1187
1188
                if hasattr(self, "rng_tracker_name") and self.rng_tracker_name:
                    with get_rng_state_tracker().fork(self.rng_tracker_name):
                        init_fn(param)
                else:
                    with get_rng_state_tracker().fork():
                        init_fn(param)
1189

1190
            # Wrap parameters in QuantizedTensor if needed
1191
            fp8_meta_index = self.param_init_meta[name].fp8_meta_index
1192
            high_precision_init_val = None
1193
            if self.primary_weights_in_fp8 and fp8_meta_index is not None:
1194
1195

                # Keep high-precision values on CPU if needed
1196
1197
1198
                if self.preserve_high_precision_init_val:
                    high_precision_init_val = param.detach().cpu()

1199
                # Configure quantizer
1200
                quantizer = self.quantizers["scaling_fwd"][fp8_meta_index]
1201
1202
1203
                if quantizer is None:
                    raise RuntimeError("Weight quantizer has not been initialized")
                quantizer.set_usage(rowwise=True, columnwise=torch.is_grad_enabled())
1204
                quantizer.internal = False
1205
1206

                # Quantize parameter
1207
                param = quantizer(param)
1208
1209
1210
1211
1212

            # Redo parameter wrap in case we broke it above
            # NOTE: Currently this can only be broken when primary weights are in Fp8 but
            #       re-applying the nn.Parameter() wrap is a no-op when the input is already
            #       a parameter so we always re-apply it just for extra safety.
1213
            param = torch.nn.Parameter(param)
1214
1215

            # Keep high-precision values on CPU if needed
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
            if high_precision_init_val is not None:

                # - Master weights are initialized from model weights, if we use fp8 primary
                #   weights to initialize master weights, the numerical values of master weights
                #   are not consistent with the numerical values when we initialize them from
                #   bf16/fp16 weights.
                # - So we add a `_high_precision_init_val` attribute to each model weight to store
                #   the original bf16/fp16 weight on cpu before casting it to fp8. And users can
                #   use `get_high_precision_init_val` to get this cpu tensor.
                # - This cpu tensor is not needed once the master weight is initialized, so users
                #   should call `clear_high_precision_init_val` to remove it after master weight
                #   is initialized.

                def get(self):
                    if hasattr(self, "_high_precision_init_val"):
                        return self._high_precision_init_val
                    return None

                def clear(self):
                    if hasattr(self, "_high_precision_init_val"):
                        del self._high_precision_init_val

                param._high_precision_init_val = high_precision_init_val
                param.get_high_precision_init_val = MethodType(get, param)
                param.clear_high_precision_init_val = MethodType(clear, param)

            setattr(self, name, param)
1243

1244
1245
1246
    @abstractmethod
    def forward(self):
        """Needs override."""
1247

1248
    def get_weight_workspace(
1249
        self,
1250
1251
        *,
        tensor: Optional[torch.Tensor] = None,
1252
        quantizer: Optional[Quantizer] = None,
1253
1254
1255
        cache_name: Optional[str] = None,
        update_workspace: bool = True,
        skip_update_flag: Optional[torch.Tensor] = None,
1256
        fsdp_group: Optional[dist_group_type] = None,
1257
        workspace_dtype: Optional[torch.dtype] = None,
1258
    ) -> QuantizedTensor:
1259
        """Get workspace buffer for weights and maybe update its values
1260
1261
1262
1263
1264
1265
1266
1267

        The workspace buffer may be cached for future function calls.

        Parameters
        ----------
        tensor : torch.Tensor, optional
            Values to copy into workspace. Required if the workspace
            is being constructed or updated.
1268
1269
1270
        quantizer: Quantizer, optional
            Quantizer used to cast the weights. Required if the
            workspace is being constructed or updated.
1271
1272
1273
1274
1275
1276
1277
        cache_name: str, optional
            Key for caching.
        update_workspace: bool, default = `True`
            Update workspace with values from `tensor`.
        skip_update_flag: torch.Tensor, optional
            GPU flag to skip updating the workspace. Take precedence
            over `update_workspace` if provided.
1278
1279
        fsdp_group: bool, default = None
            FSDP process group that the weights are distributed over.
1280
1281
1282
        workspace_dtype: torch.dtype, default = None
            If weight workspace contains high-precision tensor - for example
            for debug quantization, this is dtype of the tensor.
1283
1284
        """

1285
1286
1287
        # Handle case where weights are already quantized
        # Note: Make sure weights have required usages, but do not
        # destroy unnecessary usages since they may be used later.
1288
        if isinstance(tensor, QuantizedTensor):
1289
1290
1291
1292
1293
1294
            update_rowwise_usage = True if quantizer.rowwise_usage else None
            update_columnwise_usage = True if quantizer.columnwise_usage else None
            tensor.update_usage(
                rowwise_usage=update_rowwise_usage,
                columnwise_usage=update_columnwise_usage,
            )
1295
1296
            return tensor

1297
        # Try getting workspace from cache
1298
1299
1300
        out = None
        if cache_name is not None:
            out = self._fp8_workspaces.get(cache_name, None)
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312

        # Reset cache if workspace is invalid
        if out is not None and quantizer is not None:
            reset_cache = False
            if isinstance(out, Float8TensorBase):
                if (
                    not is_non_tn_fp8_gemm_supported()
                    and quantizer.columnwise_usage
                    and out._transpose is None
                ):
                    reset_cache = True
            elif isinstance(out, MXFP8TensorBase):
1313
                if quantizer.rowwise_usage and out._rowwise_data is None:
1314
                    reset_cache = True
1315
                elif quantizer.columnwise_usage and out._columnwise_data is None:
1316
1317
1318
1319
                    reset_cache = True
            if isinstance(out, DebugQuantizedTensor) != isinstance(quantizer, DebugQuantizer):
                reset_cache = True
            if reset_cache:
1320
                out = None
1321
                del self._fp8_workspaces[cache_name]
1322

1323
1324
1325
1326
1327
        # Gather cached Fp8 workspace if it's distributed
        # NOTE: FSDP sharding is supported only for Fp8 buffers and will not work
        #       for models initialized with Fp8 primary weights.
        if (
            out is not None
1328
            and tensor is not None
1329
            and fsdp_group is not None
1330
            and out.data.shape != tensor.data.shape
1331
1332
1333
1334
        ):
            _fsdp_gather_tensors(fsdp_group, [tensor.data.shape], out)

        # Construct workspace if needed
1335
        if out is None:
1336
            if tensor is None or quantizer is None:
1337
                raise ValueError(
1338
                    "tensor and quantizer kwargs must be provided to construct FP8 workspace"
1339
                )
1340
1341
1342
1343
1344
1345
1346

            if cache_name is not None:
                # Ensure the tensor in the cache is an instance of torch.Tensor,
                # as it persists beyond a single forward pass.
                # Setting internal=True would cause the data to be removed in prepare_for_saving(...).
                quantizer_internal = quantizer.internal
                quantizer.internal = False
1347
            out = quantizer.quantize(tensor, dtype=workspace_dtype)
1348
1349
            if cache_name is not None:
                quantizer.internal = quantizer_internal
1350
1351

            # Update cache
1352
1353
            if cache_name is not None:
                self._fp8_workspaces[cache_name] = out
1354
            return out
1355
1356
1357
1358
1359
1360

        # Update workspace if needed
        if skip_update_flag is not None:
            update_workspace = True
        if update_workspace:
            if tensor is None:
1361
                raise ValueError("tensor kwarg must be provided to update FP8 workspace")
1362
            if hasattr(out, "quantize_"):
1363
                out.quantize_(tensor, noop_flag=skip_update_flag)
1364
1365
            else:
                tex.quantize(tensor, quantizer, out, skip_update_flag)
1366
        return out
1367

1368
1369
1370
    def _load_from_state_dict(
        self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
    ):
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
        """
        This function loads tensors and extra state including fp8 metadata.
        This metadata is essential for copying fp8 tensors, as the copy_ function
        uses the scale_inv parameter from fp8_meta to set the correct scaling factor
        for the new tensor.
        Hence, this extra state must be loaded before the tensor copying process,
        not after, as is typically done in _load_from_state_dict.
        Tensors are copied into fp8 tensors only when self.primary_weights_in_fp8=True,
        otherwise, this behavior is not required.
        """
        if self.primary_weights_in_fp8:
            extra_state_key = prefix + torch.nn.modules.module._EXTRA_STATE_KEY_SUFFIX
            if extra_state_key in state_dict:
                self.set_extra_state(state_dict[extra_state_key])
1385
1386
1387
        super()._load_from_state_dict(
            state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
        )
1388

1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
    def register_wgrad_accumulation_and_reduce_hooks(self, wgrad_accumulation_and_reduce_hook):
        """
        This method is used to manually control the weight gradient accumulation and reduce.
        This method should be called before the backward() method.
        Set the skip_wgrad_accumulation_and_reduce to True to skip the weight gradient accumulation
        and reduce in backward();
        And register the wgrad_accumulation_and_reduce_func to be called in backward_dw() method.
        """
        self.wgrad_accumulation_and_reduce_hooks.append(wgrad_accumulation_and_reduce_hook)

1399
1400
1401
1402
1403
1404
1405
1406
    def backward_dw(self):
        """
        Execute the delayed weight gradient computation.
        This method is called after the main backward pass to compute weight gradients.
        """
        if self.wgrad_store is None or not self.wgrad_store.delay_wgrad_compute():
            return
        with torch.cuda.nvtx.range(f"_{self.__class__.__name__}_wgrad"):
1407
            (wgrad, bgrad), _ = self.wgrad_store.pop()
1408
            if not self.fuse_wgrad_accumulation:
1409
                weight_tensor = noop_cat(self._get_weight_tensors())
1410
1411
1412
1413
1414
                if weight_tensor.grad is None:
                    weight_tensor.grad = wgrad.to(weight_tensor.dtype)
            if self.use_bias:
                bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names])
                if bias_tensor.grad is None:
1415
                    bias_tensor.grad = bgrad.to(bias_tensor.dtype)
1416
1417
1418
1419
            del wgrad
            del bgrad
            for wgrad_accumulation_and_reduce_hook in self.wgrad_accumulation_and_reduce_hooks:
                wgrad_accumulation_and_reduce_hook()
1420

1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
    def is_debug_iter(self) -> bool:
        """
        This function checks if the debug should be enabled for this layer.
        """
        debug = TEDebugState.debug_enabled
        if not debug:
            return False
        self._validate_name()

        # If layer is run first time in new iteration,
        # we need to check if the debug should be enabled for this layer -
        # maybe in previous iterations debug features returned information
        # that no feature will be active for this layer for multiple next iterations.
        started_new_iteration = TEDebugState.get_iteration() != getattr(
            self, "debug_last_iteration", None
        )
        if started_new_iteration:
            if self.next_iter_when_debug_should_be_run is None:
                debug = False
            else:
                debug = TEDebugState.get_iteration() >= self.next_iter_when_debug_should_be_run
        self.debug_last_iteration = TEDebugState.get_iteration()
        return debug

    def no_debug_features_active(self, quantizers):
        """
        Checks if any debug feature is active for this layer.
        """
        run_current = any_feature_enabled(quantizers)

        # Sometimes features inform that they will not be enabled for particular layer
        # for multiple next iterations.
        self.next_iter_when_debug_should_be_run = next_iter_when_debug_should_be_run(quantizers)

        if not run_current:
            return True

        if self.primary_weights_in_fp8:
            raise RuntimeError("FP8 weights are not supported in debug mode.")
        return False

1462
1463
1464
1465
1466
1467
    def _validate_name(self):
        """
        Validate name passed to the module.
        This is invoked in the forward() method as module names are assigned after Model is initialized in Megatron-LM.
        If no name is assigned, it creates a default name with layer count as the variable.
        """
1468
1469
        if self.name is not None:
            return
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
        assert TEDebugState.debug_enabled
        import nvdlfw_inspect.api as debug_api

        if self.name is None:
            debug_api.log_message(
                "Names are not provided to debug modules. ",
                "Creating and using generic names. Pass names to debug modules for better"
                " insight. ",
                level=logging.WARNING,
            )
            self.name = f"Layer_{TEDebugState.get_layer_count()}"

1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
    def _check_weight_tensor_recipe_correspondence(self) -> None:
        """
        Verify that the weight tensor types match their corresponding recipe type.
        This is invoked in the forward().

        This establishes a 1:1 correspondence between recipe types and tensor types:
        - DelayedScaling → Float8Tensor
        - Float8CurrentScaling → Float8Tensor
        - MXFP8BlockScaling → MXFP8Tensor
        - Float8BlockScaling → Float8BlockTensor

        Example case to check: recipe is DelayedScaling (DelayedScaling is set in fp8_autocast()),
        but the weight tensor is MXFP8Tensor (MXFP8BlockScaling is set in fp8_model_init()).
        """
        if not self.fp8 and not self.fp8_calibration:
            return
        if not hasattr(self, "weight_names") or not self.weight_names:
            return

        recipe = self.fp8_meta["recipe"]
        weight_tensors = [getattr(self, name) for name in self.weight_names]
        for i, tensor in enumerate(weight_tensors):
            if isinstance(tensor, QuantizedTensorBase):
                quantizer = tensor._get_quantizer()
                if quantizer is None:
                    continue
                compatible_recipe_class = quantizer._get_compatible_recipe()
                if compatible_recipe_class is None:
                    continue
                if not isinstance(recipe, compatible_recipe_class):
                    raise RuntimeError(
                        f"Recipe mismatch for '{self.weight_names[i]}': tensor supports recipe"
                        f" {compatible_recipe_class.__name__}, but got {recipe.__class__.__name__}."
                        " Please check the recipes assigned during fp8_model_init() and"
                        " fp8_autocast() calls."
                    )