base.py 46.9 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.

"""Base modules and utilities for TransformerEngine PyTorch API"""
6
import io
7
8
9
10
import os
import pickle
import warnings
from abc import ABC, abstractmethod
11
from typing import Any, Dict, Generator, List, Optional, Set, Tuple, Union
12
from contextlib import contextmanager
13
from types import MethodType
14
15
16
17

import torch
import torch.nn.functional as F

18
import transformer_engine_torch as tex
19
20
from transformer_engine.common.recipe import Recipe

21
from ._common import _ParameterInitMeta
22
from ..fp8 import (
23
24
    MXFP8BlockScalingRecipeState,
    DelayedScalingRecipeState,
25
    Float8CurrentScalingRecipeState,
26
    FP8GlobalStateManager,
27
    RecipeState,
28
29
30
31
32
)
from ..distributed import (
    gather_along_first_dim,
    is_fp8_activation_recompute_enabled,
    in_fp8_activation_recompute_phase,
33
    _fsdp_gather_tensors,
34
35
)
from ..constants import dist_group_type
36
37
38
from ..tensor import QuantizedTensor, Quantizer
from ..tensor._internal.float8_tensor_base import Float8TensorBase
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
yuguo's avatar
yuguo committed
39
from torch.utils.cpp_extension import IS_HIP_EXTENSION
40

41
42
__all__ = ["initialize_ub", "destroy_ub"]

43
44
45
_2X_ACC_FPROP = False
_2X_ACC_DGRAD = True
_2X_ACC_WGRAD = True
46
_multi_stream_cublas_workspace = []
yuguo's avatar
yuguo committed
47
_multi_stream_cublas_batchgemm_workspace = []
48
49
_cublas_workspace = None
_ub_communicators = None
yuguo's avatar
yuguo committed
50
_NUM_MAX_UB_STREAMS = 2 if IS_HIP_EXTENSION else 3
51
_MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY = None, None
52
layers_atomic_ring_exchange = []
53
54
55
56


def get_cublas_workspace_size_bytes() -> None:
    """Return 32 MiB if using hopper, 4 MiB for all other architectures."""
yuguo's avatar
yuguo committed
57
58
    # Add env for control the padding for blaslt
    if IS_HIP_EXTENSION:
59
        return 134_217_728
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
    if torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 9:
        return 33_554_432
    return 4_194_304


def get_workspace() -> torch.Tensor:
    """Returns workspace for cublas."""
    global _cublas_workspace
    if _cublas_workspace is None:
        _cublas_workspace = torch.empty(
            get_cublas_workspace_size_bytes(), dtype=torch.uint8, device="cuda"
        )
    return _cublas_workspace


75
76
77
78
def get_multi_stream_cublas_workspace() -> List[torch.Tensor]:
    """Returns workspace for multi-stream cublas."""
    global _multi_stream_cublas_workspace
    if not _multi_stream_cublas_workspace:
79
        for _ in range(tex._num_cublas_streams):
80
81
82
83
84
            _multi_stream_cublas_workspace.append(
                torch.empty(get_cublas_workspace_size_bytes(), dtype=torch.uint8, device="cuda")
            )
    return _multi_stream_cublas_workspace

yuguo's avatar
yuguo committed
85
86
87
88
89
90
def get_multi_stream_cublas_batchgemm_workspace() -> List[torch.Tensor]:
    """Returns workspace for multi-stream cublas."""
    global _multi_stream_cublas_batchgemm_workspace
    if not _multi_stream_cublas_batchgemm_workspace:
        for _ in range(tex._num_cublas_batchgemm_streams):
            _multi_stream_cublas_batchgemm_workspace.append(
yuguo's avatar
yuguo committed
91
                torch.empty(128, dtype=torch.uint8, device="cuda")
yuguo's avatar
yuguo committed
92
93
94
            )
    return _multi_stream_cublas_batchgemm_workspace

yuguo's avatar
yuguo committed
95
96
97
98
if bool(int(os.getenv("NVTE_DISABLE_FC2_DGRAD_OVERLAP", "0"))):
    remove_ag_gemm_dgrad = ["fc2_dgrad"]
else:
    remove_ag_gemm_dgrad = []
99

100
101
def initialize_ub(
    shape: list,
102
    tp_size: int,
103
    use_fp8: bool = False,
104
    dtype: torch.dtype = torch.bfloat16,
105
    ub_cfgs: Optional[dict] = None,
106
    bootstrap_backend: Union[str, torch.distributed.Backend] = None,
107
) -> None:
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
    r"""
    Initialize the Userbuffers communicator for overlapping tensor-parallel communications with
    GEMM compute in te.Linear, te.LayerNormLinear and te.LayerNormMLP modules.

    Parameters
    ----------
    shape : list
            shape of the communication buffer, typically set to be the same as the global shape of
            the input tensor to a te.TransformerLayer forward pass, with the sequence and batch
            dimensions collapsed together -- i.e.: `(sequence_length * batch_size, hidden_size)`
    tp_size : int
              number of GPUs in the tensor-parallel process group
    use_fp8 : bool = False
              allocate the communication buffer for FP8 GEMM inputs/outputs
    dtype : torch.dtype = torch.bfloat16
            non-FP8 data type of the communication buffer when `use_fp8 = False`
    ub_cfgs: dict = None
             Configuration dictionary with the structure
             ```
             {
                <gemm_name> : {
                    "method": <"ring_exchange" or "pipeline">,
                    "is_reduce_scatter": bool,
                    "num_sm": int,
                    "cga_size": int,
                    "set_sm_margin": bool,
                    "num_splits": int,
                    "aggregate": bool,
                    "atomic_gemm": bool,
                    "use_ce": bool,
                    "fp8_buf": bool,
                }
             }
             ```
             for `te.TransformerLayer` GEMM layers in `["qkv_fprop", "qkv_dgrad", "qkv_wgrad",
             "proj_fprop", "proj_dgrad", "proj_wgrad", "fc1_fprop", "fc1_dgrad", "fc2_dgrad",
             "fc2_fprop", "fc2_dgrad"]`.
    bootstrap_backend : str = None
                        `torch.distributed` communication backend for the all-gather, broadcast and
                        barrier collectives during Userbuffers initialization. Not all backends are
                        valid for every cluster configuration and distributed launch method even if
                        they are available in PyTorch. When left unset, the initialization prefers
                        to use the MPI backend, falling back first on Gloo and then NCCL if MPI is
                        not available. Setting `NVTE_UB_WITH_MPI=1` when building TE overrides this
                        option and always initializes Userbuffers with direct MPI calls in C++,
                        which also requires `MPI_HOME=/path/to/mpi/root` to be set at compile time.
    """
155
    if not tex.device_supports_multicast():
156
        assert bool(int(os.getenv("UB_SKIPMC", "0"))), (
157
158
159
160
            "CUDA device, driver and/or toolkit version does not support comm+GEMM overlap with "
            + "CUDA Multicast. Launch app with UB_SKIPMC=1 to try CUDA IPC instead."
        )

161
162
163
    global _ub_communicators
    assert _ub_communicators is None, "UB communicators are already initialized."
    _ub_communicators = {}
164
165

    if tex.ubuf_built_with_mpi():
166
167
        # We're bootstrapping with direct calls to MPI in Userbuffers code so we need to force
        # an MPI_Init() here by creating a new MPI process group...
168
        assert torch.distributed.is_mpi_available()
169
170
        _ = torch.distributed.new_group(backend="mpi")
        helper = tex.CommOverlapHelper()
171
    else:
172
173
        # Bootstrapping with torch.distributed API, so check backend and construct
        # intra/inter-node process groups...
174
175
176
177
178
        assert (
            torch.distributed.is_initialized()
        ), "torch.distributed must be initialized before Userbuffers"
        if bootstrap_backend is None:
            bootstrap_backend = "nccl"
179
            if torch.distributed.is_mpi_available():
180
                bootstrap_backend = "mpi"
181
182
            elif torch.distributed.is_gloo_available():
                bootstrap_backend = "gloo"
183
        else:
184
185
186
187
188
189
190
191
192
            assert bootstrap_backend in [
                "gloo",
                "mpi",
                "nccl",
            ], "Invalid torch.distributed backend for bootstrapping Userbuffers!"
            assert torch.distributed.is_backend_available(bootstrap_backend), (
                f"PyTorch must be compiled with '{bootstrap_backend}' support in order to "
                f"bootstrap Userbuffers with '{bootstrap_backend}' collectives."
            )
193
194
195
196
197

        world_group = torch.distributed.new_group(backend=bootstrap_backend)
        world_rank = torch.distributed.get_rank(world_group)
        world_size = torch.distributed.get_world_size(world_group)

198
199
        num_domains = world_size // tp_size
        mydomain_idx = world_rank // tp_size
200
        if num_domains > 1:
201
202
203
204
            ranks_per_domain_list = [
                [i * tp_size + t for t in range(tp_size)] for i in range(num_domains)
            ]
            tp_domain_group, _ = torch.distributed.new_subgroups_by_enumeration(
205
206
                ranks_per_domain_list, backend=bootstrap_backend
            )
207
208
            local_rank = torch.distributed.get_rank(tp_domain_group)
            tp_domain_ranks = torch.distributed.get_process_group_ranks(tp_domain_group)
209

210
            helper = tex.CommOverlapHelper(world_group, tp_domain_group)
211
        else:
212
213
            # TP model on single NVLink domain, no replication, no data-parallelism
            mydomain_idx = 0
214
            local_rank = world_rank
215
            tp_domain_ranks = list(range(world_size))
216
217

            helper = tex.CommOverlapHelper(world_group)
218

219
        if world_rank == 0:
220
            print(f"!!! [UB] Number of TP domains: {num_domains}\n", end="", flush=True)
221
222
        if local_rank == 0:
            print(
223
                f"!!! [UB] Global ranks on TP domain {mydomain_idx}: {tp_domain_ranks}\n",
224
225
226
227
                end="",
                flush=True,
            )

228
    # Allocate cuBLAS workspace with expanded size for chunking in overlapping GEMM calls
229
    global _cublas_workspace
230
231
232
233
234
235
236
    if _cublas_workspace is None:
        _cublas_workspace = get_workspace().repeat(_NUM_MAX_UB_STREAMS)
    elif _cublas_workspace.numel() != get_cublas_workspace_size_bytes() * _NUM_MAX_UB_STREAMS:
        # This ensures we don't do `.repeat()` on an already expanded workspace
        _cublas_workspace = torch.empty(
            get_cublas_workspace_size_bytes(), dtype=torch.uint8, device="cuda"
        ).repeat(_NUM_MAX_UB_STREAMS)
237
238

    # Default buffer precision: AllGather buffers use fp8 when using fp8 recipe
239
    layers_all_gather_overlap = [
240
241
242
243
244
245
        "qkv_fprop",
        "qkv_dgrad",
        "proj_dgrad",
        "fc1_fprop",
        "fc1_dgrad",
        "fc2_dgrad",
246
    ]
247
    layers_reduce_scatter_overlap = ["proj_fprop", "fc2_fprop", "qkv_wgrad", "fc1_wgrad"]
Jaemin Choi's avatar
Jaemin Choi committed
248
    dgrad_reduce_scatter_overlap = ["qkv_dgrad", "fc1_dgrad"]
249
    # Default overlap methods for layers
yuguo's avatar
yuguo committed
250
251
252
253
254
255
256
257
258
259
260
261
    if bool(int(os.getenv("NVTE_NO_PIPELINE_OVERLAP", "0"))):
        methods = {
            "ring_exchange": ["qkv_fprop", "fc1_fprop", "proj_dgrad", "fc2_dgrad", "proj_fprop", "fc2_fprop"],
            "pipeline": [],
            "bulk": ["qkv_dgrad", "qkv_wgrad", "fc1_dgrad", "fc1_wgrad"],
        }
    else:
        methods = {
            "ring_exchange": ["qkv_fprop", "fc1_fprop", "proj_dgrad", "fc2_dgrad"],
            "pipeline": ["proj_fprop", "fc2_fprop"],
            "bulk": ["qkv_dgrad", "qkv_wgrad", "fc1_dgrad", "fc1_wgrad"],
        }
262

263
    # AG-RS overlap pairs of layers forming a tensor-parallel block
264
265
    ag_rs_pairs = {"qkv_fprop": "proj_fprop", "fc1_fprop": "fc2_fprop"}
    rs_ag_pairs = {v: k for k, v in ag_rs_pairs.items()}
266
267
268
    global layers_atomic_ring_exchange
    layers_atomic_ring_exchange = []

269
270
271
272
273
274
    def get_method(name):
        for method, names in methods.items():
            if name in names:
                return method
        raise KeyError(f"Given layer name {name} does not exist.")

275
    def get_default_config(name):
276
        global _MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY
277
278
        method = get_method(name)
        is_reduce_scatter = name in layers_reduce_scatter_overlap
279
280
        if _MIN_STREAM_PRIORITY is None or _MAX_STREAM_PRIORITY is None:
            _MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY = tex.get_stream_priority_range()
281
282
283
        default_cfg = {
            "method": method,
            "is_reduce_scatter": is_reduce_scatter,
yuguo's avatar
yuguo committed
284
            "num_sm": 1 if method == "ring_exchange" else 8,
285
            "cga_size": 1 if method == "ring_exchange" else 2,
286
287
            "set_sm_margin": not method == "ring_exchange",
            "num_splits": tp_size if method == "ring_exchange" else 4,
288
289
290
291
            "aggregate": False,
            "atomic_gemm": False,
            "use_ce": True,
            "fp8_buf": name in layers_all_gather_overlap,
292
293
294
            "comm_priority": _MAX_STREAM_PRIORITY,
            "gemm_priority": _MIN_STREAM_PRIORITY,
            "pipeline_rs_overlap_first_gemm": False,
295
296
297
        }
        return default_cfg

298
299
300
    def add_ub(
        name: str,
        method: str,
301
        is_reduce_scatter: bool,
302
303
        num_sm: int = 16,
        cga_size: int = 2,
304
        set_sm_margin: bool = False,
305
        num_splits: int = 0,
306
307
        aggregate: bool = False,
        atomic_gemm: bool = False,
308
        use_ce: bool = True,
309
        fp8_buf: bool = False,
310
311
312
        comm_priority: int = 0,
        gemm_priority: int = 0,
        pipeline_rs_overlap_first_gemm: bool = False,
313
    ) -> None:
314
315
316
317
318
        if atomic_gemm:
            warnings.warn(
                "Atomic GEMM uses a beta API from cublas and is not tested for all use cases."
            )
            assert use_fp8, "Atomic GEMM overlap supported only for FP8 GEMM."
319
            if method == "bulk":
320
                warnings.warn(
321
                    f"At {name}, atoimic GEMM not is supported for a bulk overlap."
322
323
324
                    "Defaulting to `atomic_gemm=False`."
                )
                atomic_gemm = 0
325
        if not is_reduce_scatter and method == "pipeline":
326
            raise ValueError(
327
                f"At {name}, `pipeline` overlap method is not supported for AllGather."
328
            )
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
        # Check if both AG and RS overlaps use `atomic GEMM`` + `p2p ring-exchange`.
        # Using atomic GEMM + p2p ring-exchange in only one of the pair breaks functionality.
        global layers_atomic_ring_exchange
        if atomic_gemm and method == "ring_exchange" and name in ag_rs_pairs:
            layers_atomic_ring_exchange += [name, ag_rs_pairs[name]]
        if name in rs_ag_pairs:
            assert_message = (
                f"At {name}, atomic AG-GEMM overlap with `ring_exchange` shuffles GEMM chunk "
                "outputs, and  RS-GEMM overlap un-suffle them. When one of the GEMM-AG and "
                "GEMM-RS overlaps forming a TP block (e.g., qkv_fprop and proj_fprop) uses "
                "`atomic gemm` and `ring_exhcnage`, its pair must use the same overlap config "
                "for functionality."
            )
            if name in layers_atomic_ring_exchange:
                assert atomic_gemm and method == "ring_exchange", assert_message
            else:
                if atomic_gemm and method == "ring_exchange":
                    assert rs_ag_pairs[name] in layers_atomic_ring_exchange, assert_message

348
        buffer_dtype = torch.uint8 if (use_fp8 and fp8_buf) else dtype
349
        if method == "ring_exchange":
350
351
352
353
            ub_obj = tex.CommOverlapP2P(
                shape,  # Communication buffer shape
                buffer_dtype,  # Communication buffer data type
                helper,  # Helper for torch.distributed callbacks during bootstrapping
354
                tp_size,  # Tensor-parallel group size (may be different than local_size)
355
356
357
358
359
360
361
362
                tex.CommOverlapType.RS if is_reduce_scatter else tex.CommOverlapType.AG,
                num_max_streams=_NUM_MAX_UB_STREAMS,
                comm_cga_size=cga_size,
                num_comm_sm=num_sm,
                set_sm_margin=set_sm_margin,
                atomic_gemm=atomic_gemm,
                use_ce=use_ce,
                aggregate=aggregate,
363
364
                gemm_priority=gemm_priority,
                comm_priority=comm_priority,
365
            )
366
        else:
367
368
369
370
            ub_obj = tex.CommOverlap(
                shape,  # Communication buffer shape
                buffer_dtype,  # Communication buffer data type
                helper,  # Helper for torch.distributed callbacks during bootstrapping
371
                tp_size,  # Tensor-parallel group size (may be different than local_size)
372
                num_splits=num_splits,
yuguo's avatar
yuguo committed
373
                num_max_streams=_NUM_MAX_UB_STREAMS,
374
375
376
377
                comm_cga_size=cga_size,
                num_comm_sm=num_sm,
                set_sm_margin=set_sm_margin,
                atomic_gemm=atomic_gemm,
378
379
380
                gemm_priority=gemm_priority,
                comm_priority=comm_priority,
                rs_overlap_first_gemm=pipeline_rs_overlap_first_gemm,
381
            )
382
383
        _ub_communicators[name] = ub_obj

Jaemin Choi's avatar
Jaemin Choi committed
384
385
    if ub_cfgs is not None:
        for name in dgrad_reduce_scatter_overlap:
386
387
            if name in ub_cfgs and "method" in ub_cfgs[name] and ub_cfgs[name]["method"] != "bulk":
                wgrad_name = name.replace("dgrad", "wgrad")
Jaemin Choi's avatar
Jaemin Choi committed
388
389
                assert wgrad_name not in ub_cfgs
                layers_reduce_scatter_overlap.remove(wgrad_name)
390
                layers_all_gather_overlap.remove(name)
Jaemin Choi's avatar
Jaemin Choi committed
391
                layers_reduce_scatter_overlap.append(name)
392
393
394
                methods["bulk"].remove(name)
                new_method = ub_cfgs[name]["method"]
                methods[new_method].append(name)
Jaemin Choi's avatar
Jaemin Choi committed
395

396
    for name in methods["ring_exchange"] + methods["pipeline"] + methods["bulk"]:
yuguo's avatar
yuguo committed
397
398
        if name in remove_ag_gemm_dgrad:
            continue
399
        ub_cfg = get_default_config(name)
400
        if ub_cfgs is not None and name in ub_cfgs:
401
            fp8_buf = (name in layers_all_gather_overlap) or (
402
                ub_cfgs[name].get("fp8_buf", False) and name in methods["pipeline"]
403
            )
404
405
406
            ub_cfg.update(ub_cfgs[name])
            ub_cfg["fp8_buf"] = fp8_buf
        add_ub(name, **ub_cfg)
407
408
409
410
411


def get_ub(name: str):
    """Get userbuffer communicator corresponding to give key."""
    assert _ub_communicators is not None, "UB manager is not initialized."
yuguo's avatar
yuguo committed
412
413
414
    # assert name in _ub_communicators, f"UB for {name} is not registered."
    if name in remove_ag_gemm_dgrad:
        return None
415
416
    return _ub_communicators[name]

417

418
419
420
421
422
423
424
def destroy_ub():
    """Destroy all allocated userbuffer communicators."""
    global _ub_communicators
    _ub_communicators = None
    global layers_atomic_ring_exchange
    layers_atomic_ring_exchange = []

425
426
427
428
429
430
431
432
433
434
435

class TransformerEngineBaseModule(torch.nn.Module, ABC):
    """Base TE module."""

    def __init__(self) -> None:
        super().__init__()
        assert torch.cuda.is_available(), "TransformerEngine needs CUDA."
        self.fp8_initialized = False
        self.fp8 = False
        self.fp8_calibration = False
        self.fp8_meta = {}
436
        self.fp8_meta["fp8_checkpoint"] = False
437
438
        self.fp8_meta["fp8_group"] = None
        self.fp8_meta_tensors_initialized = False
439
        self.quantizers = {"scaling_fwd": {}, "scaling_bwd": {}}
440
441
442
        self.tp_group = None
        self.tp_size = 1
        self.sequence_parallel = False
443
444
        self.param_init_meta = {}
        self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters()
445
        self.preserve_high_precision_init_val = FP8GlobalStateManager.with_high_precision_init_val()
446
447
        self.fsdp_wrapped = False
        self.fsdp_group = None
448
        self._fp8_workspaces: Dict[str, QuantizedTensor] = {}
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
        self.activation_dtype: Optional[torch.dtype] = None

    # Names of attributes that can be set quickly (see __setattr__
    # method)
    _fast_setattr_names: Set[str] = {
        "activation_dtype",
        "fp8",
        "fp8_initialized",
        "fp8_calibration",
        "fp8_parameters",
    }

    def __setattr__(self, name: str, value: Any) -> None:
        if name in TransformerEngineBaseModule._fast_setattr_names:
            # torch.nn.Module has a custom __setattr__ that handles
            # modules, parameters, and buffers. This is unnecessary
            # overhead when setting plain attrs.
            self.__dict__[name] = value
        else:
            # Default case
            super().__setattr__(name, value)
470

471
    def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> None:
472
473
474
475
        """
        Delayed scaling only.

        Increase or decrease size of amax history based on given `length`.
476
477
478
479
480
481
482
483
484
485

        .. warning::
            This changes the underlying amax memory location.
        """
        if fwd is None:
            fp8_meta_tensor_keys = ("scaling_fwd", "scaling_bwd")
        else:
            fp8_meta_tensor_keys = ("scaling_fwd" if fwd else "scaling_bwd",)

        for meta_key in fp8_meta_tensor_keys:
486
487
488
            if meta_key not in self.fp8_meta:
                # Handles non-parameter FP8 modules, e.g. DPA.
                continue
489
490
491
492
493
            curr_len = self.fp8_meta[meta_key].amax_history.shape[0]
            if length == curr_len:
                continue
            if length < curr_len:
                self.fp8_meta[meta_key].amax_history = (
494
495
                    self.fp8_meta[meta_key].amax_history[:length].clone()
                )
496
497
498
499
500
501
            elif length > curr_len:
                extra_rows = length - curr_len
                self.fp8_meta[meta_key].amax_history = F.pad(
                    self.fp8_meta[meta_key].amax_history, pad=(0, 0, 0, extra_rows)
                )

502
503
504
            # Update quantizers with new amax pointers.
            self.quantizers[meta_key] = self.fp8_meta[meta_key].make_quantizers()

505
506
            # Update the global buffers with new amax and history pointers.
            if FP8GlobalStateManager.get_buffer_info() in self.fp8_meta:
507
508
509
                fwd_pos, fwd_key, bwd_pos, bwd_key = self.fp8_meta[
                    FP8GlobalStateManager.get_buffer_info()
                ]
510
511
512
513
514
                for pos, buffer_key in zip((fwd_pos, bwd_pos), (fwd_key, bwd_key)):
                    if buffer_key in FP8GlobalStateManager.global_amax_buffer:
                        assert (
                            buffer_key in FP8GlobalStateManager.global_amax_history_buffer
                        ), "TE internal error during amax history change."
515
516
517
                        FP8GlobalStateManager.global_amax_buffer[buffer_key][pos] = self.fp8_meta[
                            meta_key
                        ].amax_history[0]
518
                        FP8GlobalStateManager.global_amax_history_buffer[buffer_key][pos] = (
519
520
                            self.fp8_meta[meta_key].amax_history
                        )
521

522
    def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None:
523
524
525
        """Init scales and amaxes for fwd | bwd."""
        fp8_meta_tensor_key = "scaling_fwd" if fwd else "scaling_bwd"

526
        # Return early if recipe state matches recipe
527
        if self.fp8_meta_tensors_initialized:
528
529
530
531
532
533
            recipe_state = self.fp8_meta[fp8_meta_tensor_key]
            if recipe.delayed() and isinstance(recipe_state, DelayedScalingRecipeState):
                self.adjust_amax_history_length(recipe.amax_history_len, fwd=fwd)
                return
            if recipe.mxfp8() and isinstance(recipe_state, MXFP8BlockScalingRecipeState):
                return
534
535
536
537
            if recipe.float8_current_scaling() and isinstance(
                recipe_state, Float8CurrentScalingRecipeState
            ):
                return
538
539
540

        # Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and
        # 2 (grad_output and grad_input) for bwd
541
        num_fp8_tensors = self.fp8_meta["num_gemms"] * 3 if fwd else self.fp8_meta["num_gemms"] * 2
542

543
544
545
546
547
        # Initialize recipe state and quantizers
        recipe_state = RecipeState.create(
            recipe,
            mode=("forward" if fwd else "backward"),
            num_quantizers=num_fp8_tensors,
548
549
        )

550
551
552
553
        self.fp8_meta[fp8_meta_tensor_key] = recipe_state
        self.quantizers[fp8_meta_tensor_key] = recipe_state.make_quantizers()

    def init_fp8_meta_tensors(self, recipe: Recipe) -> None:
554
        """Init scales and amaxes."""
555
556
557
        self.set_meta_tensor(True, recipe)
        self.set_meta_tensor(False, recipe)

558
559
        self.fp8_meta_tensors_initialized = True

560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
    def get_fp8_meta_tensors(self) -> None:
        """Get scales and amaxes."""
        fwd_key, bwd_key = "scaling_fwd", "scaling_bwd"
        if fwd_key not in self.fp8_meta or bwd_key not in self.fp8_meta:
            return None

        fp8_meta_tensors = {fwd_key: [], bwd_key: []}
        with torch.no_grad():
            for key in (fwd_key, bwd_key):
                fp8_meta_tensors[key].append(self.fp8_meta[key].scale.clone())
                fp8_meta_tensors[key].append(self.fp8_meta[key].amax_history.clone())
        return fp8_meta_tensors

    def reset_fp8_meta_tensors(self, fp8_meta_tensors=None) -> None:
        """Reset scales and amaxes."""
575

576
577
578
579
580
        def reset(key):
            if key in self.fp8_meta:
                if fp8_meta_tensors is None:
                    self.fp8_meta[key].scale.copy_(torch.ones_like(self.fp8_meta[key].scale))
                    self.fp8_meta[key].amax_history.copy_(
581
582
                        torch.zeros_like(self.fp8_meta[key].amax_history)
                    )
583
584
585
                else:
                    assert key in fp8_meta_tensors, "Cannot reset fp8 tensors."
                    self.fp8_meta[key].scale.copy_(fp8_meta_tensors[key][0])
586
                    self.fp8_meta[key].amax_history.copy_(fp8_meta_tensors[key][1])
587

588
589
590
591
        with torch.no_grad():
            reset("scaling_fwd")
            reset("scaling_bwd")

592
593
    def get_extra_state(self) -> torch.Tensor:
        """Save before checkpointing."""
594

595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
        # This implementation is working around a few issues:
        #
        # (1) PyTorch's "extra state" infrastructure might be able to
        #     support any picklable type, but they make no guarantees.
        #     We have experienced problems (e.g. in ONNX export) with
        #     non-tensor extra state.
        # (2) PyTorch's checkpointing infrastructure does not remap
        #     devices for "extra state" like it does for "state dict".
        #     Thus, we want to avoid putting extra state on the GPU
        #     since it may be loaded on the wrong device.
        # (3) The extra state consists of many small tensors. If we
        #     want to copy them all to CPU, then we need to avoid the
        #     overhead of many GPU-CPU memory transfers.
        #
        # See: https://github.com/NVIDIA/TransformerEngine/pull/351
        # See: https://github.com/NVIDIA/TransformerEngine/pull/363

        def to_cpu(src: torch.Tensor) -> torch.Tensor:
            """Helper function to make CPU copy of tensor

            Memory transfer is asynchronous w.r.t. host, so GPU should
            be synchronized before using result.

            """
            dst = torch.empty_like(src, device="cpu")
            dst.copy_(src, non_blocking=True)
            return dst

        # Store FP8 state if needed
        state = None
625
        fp8_checkpoint = self.fp8_meta["fp8_checkpoint"] or self.fp8 or self.fp8_calibration
626
        if fp8_checkpoint:
627
628

            # Copy tensors to CPU and store
629
            state = {}
630
631
632
633
634
635
            state["recipe"] = self.fp8_meta["recipe"]
            if state["recipe"].delayed():
                state["scale_fwd"] = to_cpu(self.fp8_meta["scaling_fwd"].scale)
                state["amax_history_fwd"] = to_cpu(self.fp8_meta["scaling_fwd"].amax_history)
                state["scale_bwd"] = to_cpu(self.fp8_meta["scaling_bwd"].scale)
                state["amax_history_bwd"] = to_cpu(self.fp8_meta["scaling_bwd"].amax_history)
636
637

            # Store other pickelable values
638
639
            extra = {}
            for k, v in self.fp8_meta.items():
640
641
642
                if k != "buffer_index_and_autocast_key" and isinstance(
                    v, (bool, int, float, str, tuple, list)
                ):
643
644
645
                    extra[k] = v
            state["extra_fp8_variables"] = extra

646
647
648
649
        # Serialize state into byte tensor
        torch.cuda.synchronize()
        state_serialized = bytearray(pickle.dumps(state))
        state_serialized = torch.frombuffer(state_serialized, dtype=torch.uint8)
650
        return state_serialized
651
652
653
654
655
656

    def set_extra_state(self, state: torch.Tensor) -> None:
        """Load previous state."""
        if state is None:
            return

657
        # Load state
658
        if isinstance(state, torch.Tensor):
659
            # Default format: byte tensor with pickled data
660
            state = pickle.loads(state.detach().cpu().numpy().tobytes())
661
        elif isinstance(state, io.BytesIO):
662
            # Deprecated format with io.BytesIO
663
            state.seek(0)
664
            state = torch.load(state, map_location="cuda")
665
666
        else:
            raise RuntimeError("Unsupported checkpoint format.")
667
668
669

        if state is None:
            return
670

671
        # Load extra items
672
        self.fp8_meta.update(state["extra_fp8_variables"])
673
        self.fp8_meta["recipe"] = state["recipe"]
674
675
676
        if "global_fp8_buffer_pos_fwd_recompute" in self.fp8_meta:
            del self.fp8_meta["global_fp8_buffer_pos_fwd_recompute"]

677
        # Initialize before loading
678
        self.init_fp8_meta_tensors(self.fp8_meta["recipe"])
679
680
681
682
683
684
685
686
687
688
689

        def copy_tensor(src: torch.Tensor, dst: torch.Tensor) -> None:
            """Helper function to copy tensor from CPU

            Memory transfer is asynchronous w.r.t. host, so GPU should
            be synchronized before using result.

            """
            dst.copy_(src, non_blocking=True)

        # Load tensors
690
691
692
693
694
        if self.fp8_meta["recipe"].delayed():
            copy_tensor(state["scale_fwd"], self.fp8_meta["scaling_fwd"].scale)
            copy_tensor(state["amax_history_fwd"], self.fp8_meta["scaling_fwd"].amax_history)
            copy_tensor(state["scale_bwd"], self.fp8_meta["scaling_bwd"].scale)
            copy_tensor(state["amax_history_bwd"], self.fp8_meta["scaling_bwd"].amax_history)
695
        torch.cuda.synchronize()
696
697
698
699
700
701
702
703
704

    def set_activation_dtype(self, inp: torch.Tensor) -> None:
        """Get activation data type for AMP."""
        # Native AMP (`torch.autocast`) gets highest priority
        if torch.is_autocast_enabled():
            self.activation_dtype = torch.get_autocast_gpu_dtype()
            return

        # All checks after this have already been performed once, thus skip
705
        if self.activation_dtype == inp.dtype:
706
707
            return

708
709
710
711
712
713
714
715
        dtype = inp.dtype
        for name, param in self.named_parameters():
            if param is not None:
                assert dtype == param.dtype, (
                    "Data types for parameters must match when outside of autocasted region. "
                    f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}"
                )
        self.activation_dtype = dtype
716
717

    def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
718
719
720
721
722
723
724
725
726
        """
        Set the tensor parallel group for the given
        module before executing the forward pass.

        Parameters
        ----------
        tp_group : ProcessGroup, default = `None`
                  tensor parallel process group.
        """
727
728
729
        self.tp_group = tp_group
        self.tp_group_initialized = True

730
731
732
    def _get_fp8_params(self) -> Union[List[torch.Tensor], None]:
        """returns the FP8 weights."""
        fp8_params = []
733
        for param in self.parameters(recurse=False):
734
            if isinstance(param, QuantizedTensor) and param.requires_grad:
735
736
737
738
739
                fp8_params.append(param)
        if len(fp8_params) == 0:
            return None
        return fp8_params

740
741
    # This routine is shared across FP8 and FP8_calibration paths so should not actually
    # assume FP8 execution.
742
    def init_fp8_metadata(self, num_gemms: int = 1) -> None:
743
        """Initialize fp8 related metadata and tensors during fprop."""
744
        self.fp8_parameters = FP8GlobalStateManager.with_fp8_parameters()
745
746
        self.fp8 = FP8GlobalStateManager.is_fp8_enabled()
        self.fp8_calibration = FP8GlobalStateManager.is_fp8_calibration()
747
        fp8_enabled = self.fp8 or self.fp8_calibration
748
        self.fp8_meta["fp8_checkpoint"] = self.fp8 or self.fp8_calibration
749

750
        if self.fp8_parameters or fp8_enabled:
751
752
753
754
            if (
                self.fp8_initialized
                and FP8GlobalStateManager.get_fp8_recipe() == self.fp8_meta["recipe"]
            ):
755
                # FP8 init has already been run and recipe is the same, don't do anything.
756
                return
757
            self.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe()
758
759
760
761
762
763
764
765
766
767
768
        else:
            # If fp8 isn't enabled, turn off and return.
            self.fp8_initialized = False
            return

        if self.fp8_parameters and not self.fp8_initialized:
            self.fp8_meta["num_gemms"] = num_gemms
            self.init_fp8_meta_tensors(self.fp8_meta["recipe"])

        if fp8_enabled:
            # Set FP8 and other FP8 metadata
769
            self.fp8_meta["num_gemms"] = num_gemms
770
            self.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group()
771
772
773
774
775
776

            # Set FP8_MAX per tensor according to recipe
            self.fp8_meta["fp8_max_fwd"] = self.fp8_meta["recipe"].fp8_format.value.max_fwd
            self.fp8_meta["fp8_max_bwd"] = self.fp8_meta["recipe"].fp8_format.value.max_bwd

            # Allocate scales and amaxes
777
            self.init_fp8_meta_tensors(self.fp8_meta["recipe"])
778
            self.fp8_initialized = True
779
780

            self.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe()
781
782
783
784
785
786

    @contextmanager
    def prepare_forward(
        self,
        inp: torch.Tensor,
        num_gemms: int = 1,
787
        allow_non_contiguous: bool = False,
Jan Bielak's avatar
Jan Bielak committed
788
    ) -> Generator[torch.Tensor, None, None]:
789
790
791
792
793
794
795
796
        """Checks and prep for FWD.
        The context manager is needed because there isn't a way for a module to know
        if it's the last FP8 module in the forward autocast. It is useful
        to setup the forward aggregated amax reduction for every module
        just in case. The autocast exit will pick up the most recent one.
        """
        # Activation recomputation is used and this is the second forward phase.
        if self.fp8 and in_fp8_activation_recompute_phase():
797
            FP8GlobalStateManager.get_old_fp8_meta_tensors_for_recompute(self.fp8_meta)
798
799
800
801
802
803
804
        else:
            assert inp.is_cuda, "TransformerEngine needs CUDA."

            if self.tp_size > 1:
                assert self.tp_group_initialized, "TP group not initialized."

            self.set_activation_dtype(inp)
805
            self.init_fp8_metadata(num_gemms=num_gemms)
806

807
            if self.fp8 and self.sequence_parallel and self.fp8_meta["recipe"].delayed():
808
809
810
811
                assert self.fp8_meta["recipe"].reduce_amax, (
                    "Amax reduction across tensor parallel group is "
                    "necessary when using sequence parallelism with FP8."
                )
812

813
            if self.fp8 and not FP8GlobalStateManager.fp8_graph_capturing():
814
                FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(self.fp8_meta)
815
816

            # Activation recomputation is used and this is the first forward phase.
817
            if self.fp8 and self.training and is_fp8_activation_recompute_enabled():
818
                FP8GlobalStateManager.copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta)
819
820

        with torch.cuda.nvtx.range(self.__class__.__name__ + " forward"):
821
822
823
            if not allow_non_contiguous and not inp.is_contiguous():
                inp = inp.contiguous()
            yield inp
824
825

        if self.fp8 and in_fp8_activation_recompute_phase():
826
            FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta)
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846

    def set_nccl_overlap_warning_if_tp(self) -> None:
        """When using TP, the NCCL communication needs to be scheduled
        before the GEMM for there to be a guaranteed overlap. From the
        host side in TE, the comm calls are always launched first, but
        to ensure that the GEMM isn't scheduled first, the environment
        variable `CUDA_DEVICE_MAX_CONNECTIONS` needs to be set to 1 to
        force a single channel.
        """
        if self.tp_size == 1:
            return
        num_cuda_work_queues = int(os.getenv("CUDA_DEVICE_MAX_CONNECTIONS", "0"))
        if num_cuda_work_queues != 1:
            warnings.warn(
                "To guarantee overlapping TP and SP collectives with the backward"
                "GEMMs, set environment variable CUDA_DEVICE_MAX_CONNECTIONS = 1"
            )

    @staticmethod
    def grad_output_preprocess(
847
848
849
850
        ctx,
        grad_output: torch.Tensor,
        row_parallel_mode: bool,
        quantizer: Optional[Quantizer],
851
852
853
    ) -> Tuple[Union[torch.Tensor, None], ...]:
        """Utility function for backward.
        Returns tuple in order (all optional/None based on training precion/recipe):
854
855
            R1: gathered `grad_output`.
            R2: bias gradient on R1.
856
857

        """
858
859
        grad_output = grad_output.reshape((-1, grad_output.shape[-1]))
        grad_output = grad_output.contiguous()
860
861
        gather_grad_output = row_parallel_mode and ctx.sequence_parallel

862
        # Non-FP8 case: bgrad is fused with wgrad for this case.
863
864
        if not ctx.fp8:
            if gather_grad_output:
yuguo's avatar
yuguo committed
865
                if not ctx.ub_overlap_ag or ctx.ub_obj_gradout is None:
866
                    grad_output, _ = gather_along_first_dim(grad_output, ctx.tp_group)
867
                else:
868
869
870
871
872
873
874
                    ctx.ub_obj_gradout.copy_into_buffer(grad_output, quantizer, local_chunk=True)
                    grad_output = ctx.ub_obj_gradout.get_buffer(quantizer)
            return grad_output, None

        # FP8 with all-gather: unfused bgrad, fused cast + transpose
        if gather_grad_output:
            grad_bias = None
875
            if ctx.use_bias:
876
                grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0)
yuguo's avatar
yuguo committed
877
            if ctx.ub_overlap_ag and ctx.ub_obj_gradout is not None:
878
879
880
881
882
883
884
885
886
                # Quantize the gradient if needed
                if not isinstance(
                    grad_output, (QuantizedTensor, Float8TensorBase, MXFP8TensorBase)
                ):
                    grad_output = quantizer(grad_output)

                # Copy into communication buffer, and replace original gradient with it
                ctx.ub_obj_gradout.copy_into_buffer(grad_output, quantizer, local_chunk=True)
                grad_output = ctx.ub_obj_gradout.get_buffer(quantizer)
887
            else:
888
889
890
891
                grad_output, _ = gather_along_first_dim(
                    grad_output,
                    ctx.tp_group,
                    quantizer=quantizer,
892
                )
893
            return grad_output, grad_bias
894

895
896
        # FP8 without all-gather: fused bgrad + cast + transpose
        grad_bias = None
897
        if ctx.use_bias:
898
899
            if isinstance(grad_output, (QuantizedTensor, Float8TensorBase, MXFP8TensorBase)):
                grad_bias = grad_output.dequantize().view(-1, grad_output.shape[-1]).sum(dim=0)
900
            else:
901
902
903
904
                grad_bias, grad_output = tex.bgrad_quantize(grad_output, quantizer)
        if not isinstance(grad_output, (QuantizedTensor, Float8TensorBase, MXFP8TensorBase)):
            grad_output = quantizer(grad_output)
        return grad_output, grad_bias
905

906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
    def register_parameter(self, name, param, **kwargs):
        """
        Thin wrapper around PyTorch parameter registration to stash additional parameter
        metedata used in deferred initialization.
        """
        super().register_parameter(name, param)
        self.param_init_meta[name] = _ParameterInitMeta(**kwargs)

    def reset_parameters(self, defer_init: Optional[bool] = False) -> None:
        """
        Reset all module parameters to initial values. Unless deferred initialization
        is specified, all parameters on a 'meta' device are also materialized on a real cuda
        device before the values are reset to initial.
        """
        if defer_init:
            return

        for name, param in self.named_parameters(recurse=False):
            # Ensure parameter is on a real device
925
926
            if param.device == torch.device("meta"):
                param = torch.empty_like(param, device="cuda")
927
928
929
930
931
932
933

            # Initialize the parameter values on device
            init_fn = self.param_init_meta[name].init_fn
            get_rng_state_tracker = self.param_init_meta[name].get_rng_state_tracker
            if get_rng_state_tracker is None:
                init_fn(param)
            else:
934
935
936
937
938
939
                if hasattr(self, "rng_tracker_name") and self.rng_tracker_name:
                    with get_rng_state_tracker().fork(self.rng_tracker_name):
                        init_fn(param)
                else:
                    with get_rng_state_tracker().fork():
                        init_fn(param)
940

941
            # If primary weights are in fp8, wrap the parameter as FP8Tensor
942
            fp8_meta_index = self.param_init_meta[name].fp8_meta_index
943
            high_precision_init_val = None
944
            if self.primary_weights_in_fp8 and fp8_meta_index is not None:
945
946
947
                if self.preserve_high_precision_init_val:
                    high_precision_init_val = param.detach().cpu()

948
949
950
951
952
953
                quantizer = self.quantizers["scaling_fwd"][fp8_meta_index]
                assert (
                    quantizer is not None
                )  # to use primary fp8 weight one needs to use FP8 autocast with specific recipe.
                quantizer.internal = False
                param = quantizer(param)
954
955
956
957
958

            # Redo parameter wrap in case we broke it above
            # NOTE: Currently this can only be broken when primary weights are in Fp8 but
            #       re-applying the nn.Parameter() wrap is a no-op when the input is already
            #       a parameter so we always re-apply it just for extra safety.
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
            param = torch.nn.Parameter(param)
            if high_precision_init_val is not None:

                # - Master weights are initialized from model weights, if we use fp8 primary
                #   weights to initialize master weights, the numerical values of master weights
                #   are not consistent with the numerical values when we initialize them from
                #   bf16/fp16 weights.
                # - So we add a `_high_precision_init_val` attribute to each model weight to store
                #   the original bf16/fp16 weight on cpu before casting it to fp8. And users can
                #   use `get_high_precision_init_val` to get this cpu tensor.
                # - This cpu tensor is not needed once the master weight is initialized, so users
                #   should call `clear_high_precision_init_val` to remove it after master weight
                #   is initialized.

                def get(self):
                    if hasattr(self, "_high_precision_init_val"):
                        return self._high_precision_init_val
                    return None

                def clear(self):
                    if hasattr(self, "_high_precision_init_val"):
                        del self._high_precision_init_val

                param._high_precision_init_val = high_precision_init_val
                param.get_high_precision_init_val = MethodType(get, param)
                param.clear_high_precision_init_val = MethodType(clear, param)

            setattr(self, name, param)
987

988
989
990
    @abstractmethod
    def forward(self):
        """Needs override."""
991

992
    def get_weight_workspace(
993
        self,
994
995
        *,
        tensor: Optional[torch.Tensor] = None,
996
        quantizer: Optional[Quantizer] = None,
997
998
999
        cache_name: Optional[str] = None,
        update_workspace: bool = True,
        skip_update_flag: Optional[torch.Tensor] = None,
1000
1001
        fsdp_group: Optional[dist_group_type] = None,
    ) -> QuantizedTensor:
1002
1003
1004
1005
1006
1007
1008
1009
1010
        """Get FP8 workspace buffer and maybe update its values

        The workspace buffer may be cached for future function calls.

        Parameters
        ----------
        tensor : torch.Tensor, optional
            Values to copy into workspace. Required if the workspace
            is being constructed or updated.
1011
1012
1013
        quantizer: Quantizer, optional
            Quantizer used to cast the weights. Required if the
            workspace is being constructed or updated.
1014
1015
1016
1017
1018
1019
1020
        cache_name: str, optional
            Key for caching.
        update_workspace: bool, default = `True`
            Update workspace with values from `tensor`.
        skip_update_flag: torch.Tensor, optional
            GPU flag to skip updating the workspace. Take precedence
            over `update_workspace` if provided.
1021
1022
        fsdp_group: bool, default = None
            FSDP process group that the weights are distributed over.
1023
1024
        """

1025
1026
1027
1028
1029
1030
1031
1032
1033
        # FP8 primary weights
        if isinstance(tensor, QuantizedTensor):
            if update_workspace and quantizer is not None:
                tensor.update_usage(
                    rowwise_usage=quantizer.rowwise_usage,
                    columnwise_usage=quantizer.columnwise_usage,
                )
            return tensor

1034
        # Try getting workspace from cache
1035
1036
1037
        out = None
        if cache_name is not None:
            out = self._fp8_workspaces.get(cache_name, None)
1038
1039
1040
1041
1042
1043
1044
            if quantizer is not None and isinstance(out, MXFP8TensorBase):
                if quantizer.rowwise_usage and out._rowwise_data is None:
                    out = None
                    del self._fp8_workspaces[cache_name]
                elif quantizer.columnwise_usage and out._columnwise_data is None:
                    out = None
                    del self._fp8_workspaces[cache_name]
1045

1046
1047
1048
1049
1050
        # Gather cached Fp8 workspace if it's distributed
        # NOTE: FSDP sharding is supported only for Fp8 buffers and will not work
        #       for models initialized with Fp8 primary weights.
        if (
            out is not None
1051
            and tensor is not None
1052
            and fsdp_group is not None
1053
            and out.data.shape != tensor.data.shape
1054
1055
1056
1057
        ):
            _fsdp_gather_tensors(fsdp_group, [tensor.data.shape], out)

        # Construct workspace if needed
1058
        if out is None:
1059
            if tensor is None or quantizer is None:
1060
                raise ValueError(
1061
                    "tensor and quantizer kwargs must be provided to construct FP8 workspace"
1062
                )
1063
            out = quantizer(tensor)
1064
1065

            # Update cache
1066
1067
            if cache_name is not None:
                self._fp8_workspaces[cache_name] = out
1068
            return out
1069
1070
1071
1072
1073
1074

        # Update workspace if needed
        if skip_update_flag is not None:
            update_workspace = True
        if update_workspace:
            if tensor is None:
1075
                raise ValueError("tensor kwarg must be provided to update FP8 workspace")
1076
            if hasattr(out, "quantize_"):
1077
                out.quantize_(tensor, noop_flag=skip_update_flag)
1078
1079
            else:
                tex.quantize(tensor, quantizer, out, skip_update_flag)
1080
1081

        return out
1082

1083
1084
1085
    def _load_from_state_dict(
        self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
    ):
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
        """
        This function loads tensors and extra state including fp8 metadata.
        This metadata is essential for copying fp8 tensors, as the copy_ function
        uses the scale_inv parameter from fp8_meta to set the correct scaling factor
        for the new tensor.
        Hence, this extra state must be loaded before the tensor copying process,
        not after, as is typically done in _load_from_state_dict.
        Tensors are copied into fp8 tensors only when self.primary_weights_in_fp8=True,
        otherwise, this behavior is not required.
        """
        if self.primary_weights_in_fp8:
            extra_state_key = prefix + torch.nn.modules.module._EXTRA_STATE_KEY_SUFFIX
            if extra_state_key in state_dict:
                self.set_extra_state(state_dict[extra_state_key])
1100
1101
1102
        super()._load_from_state_dict(
            state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
        )