base.py 46.7 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.

"""Base modules and utilities for TransformerEngine PyTorch API"""
6
import io
7
8
9
10
import os
import pickle
import warnings
from abc import ABC, abstractmethod
11
from typing import Any, Dict, Generator, List, Optional, Set, Tuple, Union
12
from contextlib import contextmanager
13
from types import MethodType
14
15
16
17

import torch
import torch.nn.functional as F

18
import transformer_engine_torch as tex
19
20
from transformer_engine.common.recipe import Recipe

21
from ._common import _ParameterInitMeta
22
from ..fp8 import (
23
24
    MXFP8BlockScalingRecipeState,
    DelayedScalingRecipeState,
25
    Float8CurrentScalingRecipeState,
26
    Float8BlockScalingRecipeState,
27
    FP8GlobalStateManager,
28
    RecipeState,
29
30
31
32
33
)
from ..distributed import (
    gather_along_first_dim,
    is_fp8_activation_recompute_enabled,
    in_fp8_activation_recompute_phase,
34
    _fsdp_gather_tensors,
35
36
)
from ..constants import dist_group_type
37
from ..tensor import QuantizedTensor, Quantizer
38
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
39
40
from ..tensor._internal.float8_tensor_base import Float8TensorBase
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
41
from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
42

43
44
__all__ = ["initialize_ub", "destroy_ub"]

45
46
47
_2X_ACC_FPROP = False
_2X_ACC_DGRAD = True
_2X_ACC_WGRAD = True
48
_multi_stream_cublas_workspace = []
49
_dummy_wgrads = {}
50
51
52
_cublas_workspace = None
_ub_communicators = None
_NUM_MAX_UB_STREAMS = 3
53
_MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY = None, None
54
layers_atomic_ring_exchange = []
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73


def get_cublas_workspace_size_bytes() -> None:
    """Return 32 MiB if using hopper, 4 MiB for all other architectures."""
    if torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 9:
        return 33_554_432
    return 4_194_304


def get_workspace() -> torch.Tensor:
    """Returns workspace for cublas."""
    global _cublas_workspace
    if _cublas_workspace is None:
        _cublas_workspace = torch.empty(
            get_cublas_workspace_size_bytes(), dtype=torch.uint8, device="cuda"
        )
    return _cublas_workspace


74
75
76
77
def get_multi_stream_cublas_workspace() -> List[torch.Tensor]:
    """Returns workspace for multi-stream cublas."""
    global _multi_stream_cublas_workspace
    if not _multi_stream_cublas_workspace:
78
        for _ in range(tex._num_cublas_streams):
79
80
81
82
83
84
            _multi_stream_cublas_workspace.append(
                torch.empty(get_cublas_workspace_size_bytes(), dtype=torch.uint8, device="cuda")
            )
    return _multi_stream_cublas_workspace


85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
def get_dummy_wgrad(shape: list, dtype: torch.dtype, zero=False) -> torch.Tensor:
    """Returns a dummy tensor of given shape."""
    assert len(shape) == 2
    global _dummy_wgrads
    if (shape[0], shape[1], dtype) not in _dummy_wgrads:
        _dummy_wgrads[(shape[0], shape[1], dtype)] = torch.empty(
            shape,
            dtype=dtype,
            device="cuda",
            requires_grad=False,
        )
    if zero:
        _dummy_wgrads[(shape[0], shape[1], dtype)].fill_(0)
    return _dummy_wgrads[(shape[0], shape[1], dtype)].detach()


101
102
def initialize_ub(
    shape: list,
103
    tp_size: int,
104
    use_fp8: bool = False,
105
    dtype: torch.dtype = torch.bfloat16,
106
    ub_cfgs: Optional[dict] = None,
107
    bootstrap_backend: Union[str, torch.distributed.Backend] = None,
108
) -> None:
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
    r"""
    Initialize the Userbuffers communicator for overlapping tensor-parallel communications with
    GEMM compute in te.Linear, te.LayerNormLinear and te.LayerNormMLP modules.

    Parameters
    ----------
    shape : list
            shape of the communication buffer, typically set to be the same as the global shape of
            the input tensor to a te.TransformerLayer forward pass, with the sequence and batch
            dimensions collapsed together -- i.e.: `(sequence_length * batch_size, hidden_size)`
    tp_size : int
              number of GPUs in the tensor-parallel process group
    use_fp8 : bool = False
              allocate the communication buffer for FP8 GEMM inputs/outputs
    dtype : torch.dtype = torch.bfloat16
            non-FP8 data type of the communication buffer when `use_fp8 = False`
    ub_cfgs: dict = None
             Configuration dictionary with the structure
             ```
             {
                <gemm_name> : {
                    "method": <"ring_exchange" or "pipeline">,
                    "is_reduce_scatter": bool,
                    "num_sm": int,
                    "cga_size": int,
                    "set_sm_margin": bool,
                    "num_splits": int,
                    "aggregate": bool,
                    "atomic_gemm": bool,
                    "use_ce": bool,
                    "fp8_buf": bool,
                }
             }
             ```
             for `te.TransformerLayer` GEMM layers in `["qkv_fprop", "qkv_dgrad", "qkv_wgrad",
             "proj_fprop", "proj_dgrad", "proj_wgrad", "fc1_fprop", "fc1_dgrad", "fc2_dgrad",
             "fc2_fprop", "fc2_dgrad"]`.
    bootstrap_backend : str = None
                        `torch.distributed` communication backend for the all-gather, broadcast and
                        barrier collectives during Userbuffers initialization. Not all backends are
                        valid for every cluster configuration and distributed launch method even if
                        they are available in PyTorch. When left unset, the initialization prefers
                        to use the MPI backend, falling back first on Gloo and then NCCL if MPI is
                        not available. Setting `NVTE_UB_WITH_MPI=1` when building TE overrides this
                        option and always initializes Userbuffers with direct MPI calls in C++,
                        which also requires `MPI_HOME=/path/to/mpi/root` to be set at compile time.
    """
156
    if not tex.device_supports_multicast():
157
        assert bool(int(os.getenv("UB_SKIPMC", "0"))), (
158
159
160
161
            "CUDA device, driver and/or toolkit version does not support comm+GEMM overlap with "
            + "CUDA Multicast. Launch app with UB_SKIPMC=1 to try CUDA IPC instead."
        )

162
163
164
    global _ub_communicators
    assert _ub_communicators is None, "UB communicators are already initialized."
    _ub_communicators = {}
165
166

    if tex.ubuf_built_with_mpi():
167
168
        # We're bootstrapping with direct calls to MPI in Userbuffers code so we need to force
        # an MPI_Init() here by creating a new MPI process group...
169
        assert torch.distributed.is_mpi_available()
170
171
        _ = torch.distributed.new_group(backend="mpi")
        helper = tex.CommOverlapHelper()
172
    else:
173
174
        # Bootstrapping with torch.distributed API, so check backend and construct
        # intra/inter-node process groups...
175
176
177
178
179
        assert (
            torch.distributed.is_initialized()
        ), "torch.distributed must be initialized before Userbuffers"
        if bootstrap_backend is None:
            bootstrap_backend = "nccl"
180
            if torch.distributed.is_mpi_available():
181
                bootstrap_backend = "mpi"
182
183
            elif torch.distributed.is_gloo_available():
                bootstrap_backend = "gloo"
184
        else:
185
186
187
188
189
190
191
192
193
            assert bootstrap_backend in [
                "gloo",
                "mpi",
                "nccl",
            ], "Invalid torch.distributed backend for bootstrapping Userbuffers!"
            assert torch.distributed.is_backend_available(bootstrap_backend), (
                f"PyTorch must be compiled with '{bootstrap_backend}' support in order to "
                f"bootstrap Userbuffers with '{bootstrap_backend}' collectives."
            )
194
195
196
197
198

        world_group = torch.distributed.new_group(backend=bootstrap_backend)
        world_rank = torch.distributed.get_rank(world_group)
        world_size = torch.distributed.get_world_size(world_group)

199
200
        num_domains = world_size // tp_size
        mydomain_idx = world_rank // tp_size
201
        if num_domains > 1:
202
203
204
205
            ranks_per_domain_list = [
                [i * tp_size + t for t in range(tp_size)] for i in range(num_domains)
            ]
            tp_domain_group, _ = torch.distributed.new_subgroups_by_enumeration(
206
207
                ranks_per_domain_list, backend=bootstrap_backend
            )
208
209
            local_rank = torch.distributed.get_rank(tp_domain_group)
            tp_domain_ranks = torch.distributed.get_process_group_ranks(tp_domain_group)
210

211
            helper = tex.CommOverlapHelper(world_group, tp_domain_group)
212
        else:
213
214
            # TP model on single NVLink domain, no replication, no data-parallelism
            mydomain_idx = 0
215
            local_rank = world_rank
216
            tp_domain_ranks = list(range(world_size))
217
218

            helper = tex.CommOverlapHelper(world_group)
219

220
        if world_rank == 0:
221
            print(f"!!! [UB] Number of TP domains: {num_domains}\n", end="", flush=True)
222
223
        if local_rank == 0:
            print(
224
                f"!!! [UB] Global ranks on TP domain {mydomain_idx}: {tp_domain_ranks}\n",
225
226
227
228
                end="",
                flush=True,
            )

229
230
231
232
233
    # Increase the workspace by the number of maximum concurrent streams
    global _cublas_workspace
    _cublas_workspace = get_workspace().repeat(_NUM_MAX_UB_STREAMS)

    # Default buffer precision: AllGather buffers use fp8 when using fp8 recipe
234
    layers_all_gather_overlap = [
235
236
237
238
239
240
        "qkv_fprop",
        "qkv_dgrad",
        "proj_dgrad",
        "fc1_fprop",
        "fc1_dgrad",
        "fc2_dgrad",
241
    ]
242
    layers_reduce_scatter_overlap = ["proj_fprop", "fc2_fprop", "qkv_wgrad", "fc1_wgrad"]
Jaemin Choi's avatar
Jaemin Choi committed
243
    dgrad_reduce_scatter_overlap = ["qkv_dgrad", "fc1_dgrad"]
244
245
    # Default overlap methods for layers
    methods = {
246
247
248
        "ring_exchange": ["qkv_fprop", "fc1_fprop", "proj_dgrad", "fc2_dgrad"],
        "pipeline": ["proj_fprop", "fc2_fprop"],
        "bulk": ["qkv_dgrad", "qkv_wgrad", "fc1_dgrad", "fc1_wgrad"],
249
250
    }

251
    # AG-RS overlap pairs of layers forming a tensor-parallel block
252
253
    ag_rs_pairs = {"qkv_fprop": "proj_fprop", "fc1_fprop": "fc2_fprop"}
    rs_ag_pairs = {v: k for k, v in ag_rs_pairs.items()}
254
255
256
    global layers_atomic_ring_exchange
    layers_atomic_ring_exchange = []

257
258
259
260
261
262
    def get_method(name):
        for method, names in methods.items():
            if name in names:
                return method
        raise KeyError(f"Given layer name {name} does not exist.")

263
    def get_default_config(name):
264
        global _MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY
265
266
        method = get_method(name)
        is_reduce_scatter = name in layers_reduce_scatter_overlap
267
268
        if _MIN_STREAM_PRIORITY is None or _MAX_STREAM_PRIORITY is None:
            _MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY = tex.get_stream_priority_range()
269
270
271
272
273
        default_cfg = {
            "method": method,
            "is_reduce_scatter": is_reduce_scatter,
            "num_sm": 1 if method == "ring_exchange" else 16,
            "cga_size": 1 if method == "ring_exchange" else 2,
274
275
            "set_sm_margin": not method == "ring_exchange",
            "num_splits": tp_size if method == "ring_exchange" else 4,
276
277
278
279
            "aggregate": False,
            "atomic_gemm": False,
            "use_ce": True,
            "fp8_buf": name in layers_all_gather_overlap,
280
281
282
            "comm_priority": _MAX_STREAM_PRIORITY,
            "gemm_priority": _MIN_STREAM_PRIORITY,
            "pipeline_rs_overlap_first_gemm": False,
283
284
285
        }
        return default_cfg

286
287
288
    def add_ub(
        name: str,
        method: str,
289
        is_reduce_scatter: bool,
290
291
        num_sm: int = 16,
        cga_size: int = 2,
292
        set_sm_margin: bool = False,
293
        num_splits: int = 0,
294
295
        aggregate: bool = False,
        atomic_gemm: bool = False,
296
        use_ce: bool = True,
297
        fp8_buf: bool = False,
298
299
300
        comm_priority: int = 0,
        gemm_priority: int = 0,
        pipeline_rs_overlap_first_gemm: bool = False,
301
    ) -> None:
302
303
304
305
306
        if atomic_gemm:
            warnings.warn(
                "Atomic GEMM uses a beta API from cublas and is not tested for all use cases."
            )
            assert use_fp8, "Atomic GEMM overlap supported only for FP8 GEMM."
307
            if method == "bulk":
308
                warnings.warn(
309
                    f"At {name}, atoimic GEMM not is supported for a bulk overlap."
310
311
312
                    "Defaulting to `atomic_gemm=False`."
                )
                atomic_gemm = 0
313
        if not is_reduce_scatter and method == "pipeline":
314
            raise ValueError(
315
                f"At {name}, `pipeline` overlap method is not supported for AllGather."
316
            )
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
        # Check if both AG and RS overlaps use `atomic GEMM`` + `p2p ring-exchange`.
        # Using atomic GEMM + p2p ring-exchange in only one of the pair breaks functionality.
        global layers_atomic_ring_exchange
        if atomic_gemm and method == "ring_exchange" and name in ag_rs_pairs:
            layers_atomic_ring_exchange += [name, ag_rs_pairs[name]]
        if name in rs_ag_pairs:
            assert_message = (
                f"At {name}, atomic AG-GEMM overlap with `ring_exchange` shuffles GEMM chunk "
                "outputs, and  RS-GEMM overlap un-suffle them. When one of the GEMM-AG and "
                "GEMM-RS overlaps forming a TP block (e.g., qkv_fprop and proj_fprop) uses "
                "`atomic gemm` and `ring_exhcnage`, its pair must use the same overlap config "
                "for functionality."
            )
            if name in layers_atomic_ring_exchange:
                assert atomic_gemm and method == "ring_exchange", assert_message
            else:
                if atomic_gemm and method == "ring_exchange":
                    assert rs_ag_pairs[name] in layers_atomic_ring_exchange, assert_message

336
        buffer_dtype = torch.uint8 if (use_fp8 and fp8_buf) else dtype
337
        if method == "ring_exchange":
338
339
340
341
            ub_obj = tex.CommOverlapP2P(
                shape,  # Communication buffer shape
                buffer_dtype,  # Communication buffer data type
                helper,  # Helper for torch.distributed callbacks during bootstrapping
342
                tp_size,  # Tensor-parallel group size (may be different than local_size)
343
344
345
346
347
348
349
350
                tex.CommOverlapType.RS if is_reduce_scatter else tex.CommOverlapType.AG,
                num_max_streams=_NUM_MAX_UB_STREAMS,
                comm_cga_size=cga_size,
                num_comm_sm=num_sm,
                set_sm_margin=set_sm_margin,
                atomic_gemm=atomic_gemm,
                use_ce=use_ce,
                aggregate=aggregate,
351
352
                gemm_priority=gemm_priority,
                comm_priority=comm_priority,
353
            )
354
        else:
355
356
357
358
            ub_obj = tex.CommOverlap(
                shape,  # Communication buffer shape
                buffer_dtype,  # Communication buffer data type
                helper,  # Helper for torch.distributed callbacks during bootstrapping
359
                tp_size,  # Tensor-parallel group size (may be different than local_size)
360
361
362
363
364
365
                num_splits=num_splits,
                num_max_streams=_NUM_MAX_UB_STREAMS,
                comm_cga_size=cga_size,
                num_comm_sm=num_sm,
                set_sm_margin=set_sm_margin,
                atomic_gemm=atomic_gemm,
366
367
368
                gemm_priority=gemm_priority,
                comm_priority=comm_priority,
                rs_overlap_first_gemm=pipeline_rs_overlap_first_gemm,
369
            )
370
371
        _ub_communicators[name] = ub_obj

Jaemin Choi's avatar
Jaemin Choi committed
372
373
    if ub_cfgs is not None:
        for name in dgrad_reduce_scatter_overlap:
374
375
            if name in ub_cfgs and "method" in ub_cfgs[name] and ub_cfgs[name]["method"] != "bulk":
                wgrad_name = name.replace("dgrad", "wgrad")
Jaemin Choi's avatar
Jaemin Choi committed
376
377
                assert wgrad_name not in ub_cfgs
                layers_reduce_scatter_overlap.remove(wgrad_name)
378
                layers_all_gather_overlap.remove(name)
Jaemin Choi's avatar
Jaemin Choi committed
379
                layers_reduce_scatter_overlap.append(name)
380
381
382
                methods["bulk"].remove(name)
                new_method = ub_cfgs[name]["method"]
                methods[new_method].append(name)
Jaemin Choi's avatar
Jaemin Choi committed
383

384
    for name in methods["ring_exchange"] + methods["pipeline"] + methods["bulk"]:
385
        ub_cfg = get_default_config(name)
386
        if ub_cfgs is not None and name in ub_cfgs:
387
            fp8_buf = (name in layers_all_gather_overlap) or (
388
                ub_cfgs[name].get("fp8_buf", False) and name in methods["pipeline"]
389
            )
390
391
392
            ub_cfg.update(ub_cfgs[name])
            ub_cfg["fp8_buf"] = fp8_buf
        add_ub(name, **ub_cfg)
393
394
395
396
397
398
399
400


def get_ub(name: str):
    """Get userbuffer communicator corresponding to give key."""
    assert _ub_communicators is not None, "UB manager is not initialized."
    assert name in _ub_communicators, f"UB for {name} is not registered."
    return _ub_communicators[name]

401

402
403
404
405
406
407
408
def destroy_ub():
    """Destroy all allocated userbuffer communicators."""
    global _ub_communicators
    _ub_communicators = None
    global layers_atomic_ring_exchange
    layers_atomic_ring_exchange = []

409
410
411
412
413
414
415
416
417
418
419

class TransformerEngineBaseModule(torch.nn.Module, ABC):
    """Base TE module."""

    def __init__(self) -> None:
        super().__init__()
        assert torch.cuda.is_available(), "TransformerEngine needs CUDA."
        self.fp8_initialized = False
        self.fp8 = False
        self.fp8_calibration = False
        self.fp8_meta = {}
420
        self.fp8_meta["fp8_checkpoint"] = False
421
422
        self.fp8_meta["fp8_group"] = None
        self.fp8_meta_tensors_initialized = False
423
        self.quantizers = {"scaling_fwd": {}, "scaling_bwd": {}}
424
425
426
        self.tp_group = None
        self.tp_size = 1
        self.sequence_parallel = False
427
428
        self.param_init_meta = {}
        self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters()
429
        self.preserve_high_precision_init_val = FP8GlobalStateManager.with_high_precision_init_val()
430
431
        self.fsdp_wrapped = False
        self.fsdp_group = None
432
        self._fp8_workspaces: Dict[str, QuantizedTensor] = {}
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
        self.activation_dtype: Optional[torch.dtype] = None

    # Names of attributes that can be set quickly (see __setattr__
    # method)
    _fast_setattr_names: Set[str] = {
        "activation_dtype",
        "fp8",
        "fp8_initialized",
        "fp8_calibration",
        "fp8_parameters",
    }

    def __setattr__(self, name: str, value: Any) -> None:
        if name in TransformerEngineBaseModule._fast_setattr_names:
            # torch.nn.Module has a custom __setattr__ that handles
            # modules, parameters, and buffers. This is unnecessary
            # overhead when setting plain attrs.
            self.__dict__[name] = value
        else:
            # Default case
            super().__setattr__(name, value)
454

455
    def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> None:
456
457
458
459
        """
        Delayed scaling only.

        Increase or decrease size of amax history based on given `length`.
460
461
462
463
464
465
466
467
468
469

        .. warning::
            This changes the underlying amax memory location.
        """
        if fwd is None:
            fp8_meta_tensor_keys = ("scaling_fwd", "scaling_bwd")
        else:
            fp8_meta_tensor_keys = ("scaling_fwd" if fwd else "scaling_bwd",)

        for meta_key in fp8_meta_tensor_keys:
470
471
472
            if meta_key not in self.fp8_meta:
                # Handles non-parameter FP8 modules, e.g. DPA.
                continue
473
474
475
476
477
            curr_len = self.fp8_meta[meta_key].amax_history.shape[0]
            if length == curr_len:
                continue
            if length < curr_len:
                self.fp8_meta[meta_key].amax_history = (
478
479
                    self.fp8_meta[meta_key].amax_history[:length].clone()
                )
480
481
482
483
484
485
            elif length > curr_len:
                extra_rows = length - curr_len
                self.fp8_meta[meta_key].amax_history = F.pad(
                    self.fp8_meta[meta_key].amax_history, pad=(0, 0, 0, extra_rows)
                )

486
487
488
            # Update quantizers with new amax pointers.
            self.quantizers[meta_key] = self.fp8_meta[meta_key].make_quantizers()

489
490
            # Update the global buffers with new amax and history pointers.
            if FP8GlobalStateManager.get_buffer_info() in self.fp8_meta:
491
492
493
                fwd_pos, fwd_key, bwd_pos, bwd_key = self.fp8_meta[
                    FP8GlobalStateManager.get_buffer_info()
                ]
494
495
496
497
498
                for pos, buffer_key in zip((fwd_pos, bwd_pos), (fwd_key, bwd_key)):
                    if buffer_key in FP8GlobalStateManager.global_amax_buffer:
                        assert (
                            buffer_key in FP8GlobalStateManager.global_amax_history_buffer
                        ), "TE internal error during amax history change."
499
500
501
                        FP8GlobalStateManager.global_amax_buffer[buffer_key][pos] = self.fp8_meta[
                            meta_key
                        ].amax_history[0]
502
                        FP8GlobalStateManager.global_amax_history_buffer[buffer_key][pos] = (
503
504
                            self.fp8_meta[meta_key].amax_history
                        )
505

506
    def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None:
507
508
509
        """Init scales and amaxes for fwd | bwd."""
        fp8_meta_tensor_key = "scaling_fwd" if fwd else "scaling_bwd"

510
        # Return early if recipe state matches recipe
511
        if self.fp8_meta_tensors_initialized:
512
513
514
515
516
517
            recipe_state = self.fp8_meta[fp8_meta_tensor_key]
            if recipe.delayed() and isinstance(recipe_state, DelayedScalingRecipeState):
                self.adjust_amax_history_length(recipe.amax_history_len, fwd=fwd)
                return
            if recipe.mxfp8() and isinstance(recipe_state, MXFP8BlockScalingRecipeState):
                return
518
519
520
521
            if recipe.float8_current_scaling() and isinstance(
                recipe_state, Float8CurrentScalingRecipeState
            ):
                return
522
523
524
525
            if recipe.float8_block_scaling() and isinstance(
                recipe_state, Float8BlockScalingRecipeState
            ):
                return
526
527
528

        # Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and
        # 2 (grad_output and grad_input) for bwd
529
        num_fp8_tensors = self.fp8_meta["num_gemms"] * 3 if fwd else self.fp8_meta["num_gemms"] * 2
530

531
532
533
534
535
        # Initialize recipe state and quantizers
        recipe_state = RecipeState.create(
            recipe,
            mode=("forward" if fwd else "backward"),
            num_quantizers=num_fp8_tensors,
536
537
        )

538
539
540
541
        self.fp8_meta[fp8_meta_tensor_key] = recipe_state
        self.quantizers[fp8_meta_tensor_key] = recipe_state.make_quantizers()

    def init_fp8_meta_tensors(self, recipe: Recipe) -> None:
542
        """Init scales and amaxes."""
543
544
545
        self.set_meta_tensor(True, recipe)
        self.set_meta_tensor(False, recipe)

546
547
        self.fp8_meta_tensors_initialized = True

548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
    def get_fp8_meta_tensors(self) -> None:
        """Get scales and amaxes."""
        fwd_key, bwd_key = "scaling_fwd", "scaling_bwd"
        if fwd_key not in self.fp8_meta or bwd_key not in self.fp8_meta:
            return None

        fp8_meta_tensors = {fwd_key: [], bwd_key: []}
        with torch.no_grad():
            for key in (fwd_key, bwd_key):
                fp8_meta_tensors[key].append(self.fp8_meta[key].scale.clone())
                fp8_meta_tensors[key].append(self.fp8_meta[key].amax_history.clone())
        return fp8_meta_tensors

    def reset_fp8_meta_tensors(self, fp8_meta_tensors=None) -> None:
        """Reset scales and amaxes."""
563

564
565
566
567
568
        def reset(key):
            if key in self.fp8_meta:
                if fp8_meta_tensors is None:
                    self.fp8_meta[key].scale.copy_(torch.ones_like(self.fp8_meta[key].scale))
                    self.fp8_meta[key].amax_history.copy_(
569
570
                        torch.zeros_like(self.fp8_meta[key].amax_history)
                    )
571
572
573
                else:
                    assert key in fp8_meta_tensors, "Cannot reset fp8 tensors."
                    self.fp8_meta[key].scale.copy_(fp8_meta_tensors[key][0])
574
                    self.fp8_meta[key].amax_history.copy_(fp8_meta_tensors[key][1])
575

576
577
578
579
        with torch.no_grad():
            reset("scaling_fwd")
            reset("scaling_bwd")

580
581
    def get_extra_state(self) -> torch.Tensor:
        """Save before checkpointing."""
582

583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
        # This implementation is working around a few issues:
        #
        # (1) PyTorch's "extra state" infrastructure might be able to
        #     support any picklable type, but they make no guarantees.
        #     We have experienced problems (e.g. in ONNX export) with
        #     non-tensor extra state.
        # (2) PyTorch's checkpointing infrastructure does not remap
        #     devices for "extra state" like it does for "state dict".
        #     Thus, we want to avoid putting extra state on the GPU
        #     since it may be loaded on the wrong device.
        # (3) The extra state consists of many small tensors. If we
        #     want to copy them all to CPU, then we need to avoid the
        #     overhead of many GPU-CPU memory transfers.
        #
        # See: https://github.com/NVIDIA/TransformerEngine/pull/351
        # See: https://github.com/NVIDIA/TransformerEngine/pull/363

        def to_cpu(src: torch.Tensor) -> torch.Tensor:
            """Helper function to make CPU copy of tensor

            Memory transfer is asynchronous w.r.t. host, so GPU should
            be synchronized before using result.

            """
            dst = torch.empty_like(src, device="cpu")
            dst.copy_(src, non_blocking=True)
            return dst

        # Store FP8 state if needed
        state = None
613
        fp8_checkpoint = self.fp8_meta["fp8_checkpoint"] or self.fp8 or self.fp8_calibration
614
        if fp8_checkpoint:
615
616

            # Copy tensors to CPU and store
617
            state = {}
618
619
620
621
622
623
            state["recipe"] = self.fp8_meta["recipe"]
            if state["recipe"].delayed():
                state["scale_fwd"] = to_cpu(self.fp8_meta["scaling_fwd"].scale)
                state["amax_history_fwd"] = to_cpu(self.fp8_meta["scaling_fwd"].amax_history)
                state["scale_bwd"] = to_cpu(self.fp8_meta["scaling_bwd"].scale)
                state["amax_history_bwd"] = to_cpu(self.fp8_meta["scaling_bwd"].amax_history)
624
625

            # Store other pickelable values
626
627
            extra = {}
            for k, v in self.fp8_meta.items():
628
629
630
                if k != "buffer_index_and_autocast_key" and isinstance(
                    v, (bool, int, float, str, tuple, list)
                ):
631
632
633
                    extra[k] = v
            state["extra_fp8_variables"] = extra

634
635
636
637
        # Serialize state into byte tensor
        torch.cuda.synchronize()
        state_serialized = bytearray(pickle.dumps(state))
        state_serialized = torch.frombuffer(state_serialized, dtype=torch.uint8)
638
        return state_serialized
639
640
641
642
643
644

    def set_extra_state(self, state: torch.Tensor) -> None:
        """Load previous state."""
        if state is None:
            return

645
        # Load state
646
        if isinstance(state, torch.Tensor):
647
            # Default format: byte tensor with pickled data
648
            state = pickle.loads(state.detach().cpu().numpy().tobytes())
649
        elif isinstance(state, io.BytesIO):
650
            # Deprecated format with io.BytesIO
651
            state.seek(0)
652
            state = torch.load(state, map_location="cuda")
653
654
        else:
            raise RuntimeError("Unsupported checkpoint format.")
655
656
657

        if state is None:
            return
658

659
        # Load extra items
660
        self.fp8_meta.update(state["extra_fp8_variables"])
661
        self.fp8_meta["recipe"] = state["recipe"]
662
663
664
        if "global_fp8_buffer_pos_fwd_recompute" in self.fp8_meta:
            del self.fp8_meta["global_fp8_buffer_pos_fwd_recompute"]

665
        # Initialize before loading
666
        self.init_fp8_meta_tensors(self.fp8_meta["recipe"])
667
668
669
670
671
672
673
674
675
676
677

        def copy_tensor(src: torch.Tensor, dst: torch.Tensor) -> None:
            """Helper function to copy tensor from CPU

            Memory transfer is asynchronous w.r.t. host, so GPU should
            be synchronized before using result.

            """
            dst.copy_(src, non_blocking=True)

        # Load tensors
678
679
680
681
682
        if self.fp8_meta["recipe"].delayed():
            copy_tensor(state["scale_fwd"], self.fp8_meta["scaling_fwd"].scale)
            copy_tensor(state["amax_history_fwd"], self.fp8_meta["scaling_fwd"].amax_history)
            copy_tensor(state["scale_bwd"], self.fp8_meta["scaling_bwd"].scale)
            copy_tensor(state["amax_history_bwd"], self.fp8_meta["scaling_bwd"].amax_history)
683
        torch.cuda.synchronize()
684
685
686
687
688
689
690
691
692

    def set_activation_dtype(self, inp: torch.Tensor) -> None:
        """Get activation data type for AMP."""
        # Native AMP (`torch.autocast`) gets highest priority
        if torch.is_autocast_enabled():
            self.activation_dtype = torch.get_autocast_gpu_dtype()
            return

        # All checks after this have already been performed once, thus skip
693
        if self.activation_dtype == inp.dtype:
694
695
            return

696
697
698
699
700
701
702
703
        dtype = inp.dtype
        for name, param in self.named_parameters():
            if param is not None:
                assert dtype == param.dtype, (
                    "Data types for parameters must match when outside of autocasted region. "
                    f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}"
                )
        self.activation_dtype = dtype
704
705

    def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
706
707
708
709
710
711
712
713
714
        """
        Set the tensor parallel group for the given
        module before executing the forward pass.

        Parameters
        ----------
        tp_group : ProcessGroup, default = `None`
                  tensor parallel process group.
        """
715
716
717
        self.tp_group = tp_group
        self.tp_group_initialized = True

718
719
720
    def _get_fp8_params(self) -> Union[List[torch.Tensor], None]:
        """returns the FP8 weights."""
        fp8_params = []
721
        for param in self.parameters(recurse=False):
722
            if isinstance(param, QuantizedTensor) and param.requires_grad:
723
724
725
726
727
                fp8_params.append(param)
        if len(fp8_params) == 0:
            return None
        return fp8_params

728
729
    # This routine is shared across FP8 and FP8_calibration paths so should not actually
    # assume FP8 execution.
730
    def init_fp8_metadata(self, num_gemms: int = 1) -> None:
731
        """Initialize fp8 related metadata and tensors during fprop."""
732
        self.fp8_parameters = FP8GlobalStateManager.with_fp8_parameters()
733
734
        self.fp8 = FP8GlobalStateManager.is_fp8_enabled()
        self.fp8_calibration = FP8GlobalStateManager.is_fp8_calibration()
735
        fp8_enabled = self.fp8 or self.fp8_calibration
736
        self.fp8_meta["fp8_checkpoint"] = self.fp8 or self.fp8_calibration
737

738
        if self.fp8_parameters or fp8_enabled:
739
740
741
742
            if (
                self.fp8_initialized
                and FP8GlobalStateManager.get_fp8_recipe() == self.fp8_meta["recipe"]
            ):
743
                # FP8 init has already been run and recipe is the same, don't do anything.
744
                return
745
            self.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe()
746
747
748
749
750
751
752
753
754
755
756
        else:
            # If fp8 isn't enabled, turn off and return.
            self.fp8_initialized = False
            return

        if self.fp8_parameters and not self.fp8_initialized:
            self.fp8_meta["num_gemms"] = num_gemms
            self.init_fp8_meta_tensors(self.fp8_meta["recipe"])

        if fp8_enabled:
            # Set FP8 and other FP8 metadata
757
            self.fp8_meta["num_gemms"] = num_gemms
758
            self.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group()
759
760
761
762
763
764

            # Set FP8_MAX per tensor according to recipe
            self.fp8_meta["fp8_max_fwd"] = self.fp8_meta["recipe"].fp8_format.value.max_fwd
            self.fp8_meta["fp8_max_bwd"] = self.fp8_meta["recipe"].fp8_format.value.max_bwd

            # Allocate scales and amaxes
765
            self.init_fp8_meta_tensors(self.fp8_meta["recipe"])
766
            self.fp8_initialized = True
767
768

            self.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe()
769
770
771
772
773
774

    @contextmanager
    def prepare_forward(
        self,
        inp: torch.Tensor,
        num_gemms: int = 1,
775
        allow_non_contiguous: bool = False,
Jan Bielak's avatar
Jan Bielak committed
776
    ) -> Generator[torch.Tensor, None, None]:
777
778
779
780
781
782
783
784
        """Checks and prep for FWD.
        The context manager is needed because there isn't a way for a module to know
        if it's the last FP8 module in the forward autocast. It is useful
        to setup the forward aggregated amax reduction for every module
        just in case. The autocast exit will pick up the most recent one.
        """
        # Activation recomputation is used and this is the second forward phase.
        if self.fp8 and in_fp8_activation_recompute_phase():
785
            FP8GlobalStateManager.get_old_fp8_meta_tensors_for_recompute(self.fp8_meta)
786
787
788
789
790
791
792
        else:
            assert inp.is_cuda, "TransformerEngine needs CUDA."

            if self.tp_size > 1:
                assert self.tp_group_initialized, "TP group not initialized."

            self.set_activation_dtype(inp)
793
            self.init_fp8_metadata(num_gemms=num_gemms)
794

795
            if self.fp8 and self.sequence_parallel and self.fp8_meta["recipe"].delayed():
796
797
798
799
                assert self.fp8_meta["recipe"].reduce_amax, (
                    "Amax reduction across tensor parallel group is "
                    "necessary when using sequence parallelism with FP8."
                )
800

801
            if self.fp8 and not FP8GlobalStateManager.fp8_graph_capturing():
802
                FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(self.fp8_meta)
803
804

            # Activation recomputation is used and this is the first forward phase.
805
            if self.fp8 and self.training and is_fp8_activation_recompute_enabled():
806
                FP8GlobalStateManager.copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta)
807
808

        with torch.cuda.nvtx.range(self.__class__.__name__ + " forward"):
809
810
811
            if not allow_non_contiguous and not inp.is_contiguous():
                inp = inp.contiguous()
            yield inp
812
813

        if self.fp8 and in_fp8_activation_recompute_phase():
814
            FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta)
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834

    def set_nccl_overlap_warning_if_tp(self) -> None:
        """When using TP, the NCCL communication needs to be scheduled
        before the GEMM for there to be a guaranteed overlap. From the
        host side in TE, the comm calls are always launched first, but
        to ensure that the GEMM isn't scheduled first, the environment
        variable `CUDA_DEVICE_MAX_CONNECTIONS` needs to be set to 1 to
        force a single channel.
        """
        if self.tp_size == 1:
            return
        num_cuda_work_queues = int(os.getenv("CUDA_DEVICE_MAX_CONNECTIONS", "0"))
        if num_cuda_work_queues != 1:
            warnings.warn(
                "To guarantee overlapping TP and SP collectives with the backward"
                "GEMMs, set environment variable CUDA_DEVICE_MAX_CONNECTIONS = 1"
            )

    @staticmethod
    def grad_output_preprocess(
835
836
837
838
        ctx,
        grad_output: torch.Tensor,
        row_parallel_mode: bool,
        quantizer: Optional[Quantizer],
839
840
841
    ) -> Tuple[Union[torch.Tensor, None], ...]:
        """Utility function for backward.
        Returns tuple in order (all optional/None based on training precion/recipe):
842
843
            R1: gathered `grad_output`.
            R2: bias gradient on R1.
844
845

        """
846
847
        grad_output = grad_output.reshape((-1, grad_output.shape[-1]))
        grad_output = grad_output.contiguous()
848
849
        gather_grad_output = row_parallel_mode and ctx.sequence_parallel

850
        # Non-FP8 case: bgrad is fused with wgrad for this case.
851
852
        if not ctx.fp8:
            if gather_grad_output:
853
                if not ctx.ub_overlap_ag:
854
                    grad_output, _ = gather_along_first_dim(grad_output, ctx.tp_group)
855
                else:
856
857
858
859
860
861
862
                    ctx.ub_obj_gradout.copy_into_buffer(grad_output, quantizer, local_chunk=True)
                    grad_output = ctx.ub_obj_gradout.get_buffer(quantizer)
            return grad_output, None

        # FP8 with all-gather: unfused bgrad, fused cast + transpose
        if gather_grad_output:
            grad_bias = None
863
            if ctx.use_bias:
864
                grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0)
865
            if ctx.ub_overlap_ag:
866
867
                # Quantize the gradient if needed
                if not isinstance(
868
869
870
871
872
873
874
                    grad_output,
                    (
                        QuantizedTensor,
                        Float8TensorBase,
                        MXFP8TensorBase,
                        Float8BlockwiseQTensorBase,
                    ),
875
876
877
878
879
880
                ):
                    grad_output = quantizer(grad_output)

                # Copy into communication buffer, and replace original gradient with it
                ctx.ub_obj_gradout.copy_into_buffer(grad_output, quantizer, local_chunk=True)
                grad_output = ctx.ub_obj_gradout.get_buffer(quantizer)
881
            else:
882
883
884
885
                grad_output, _ = gather_along_first_dim(
                    grad_output,
                    ctx.tp_group,
                    quantizer=quantizer,
886
                )
887
            return grad_output, grad_bias
888

889
890
        # FP8 without all-gather: fused bgrad + cast + transpose
        grad_bias = None
891
        if ctx.use_bias:
892
893
894
895
            if isinstance(
                grad_output,
                (QuantizedTensor, Float8TensorBase, MXFP8TensorBase, Float8BlockwiseQTensorBase),
            ):
896
                grad_bias = grad_output.dequantize().view(-1, grad_output.shape[-1]).sum(dim=0)
897
            else:
898
899
900
901
902
903
904
905
906
                if isinstance(quantizer, Float8BlockQuantizer):
                    # unfuse bgrad for now until cast_transpose + dgrad calculation is ready for Float8BlockQuantizer.
                    grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0)
                else:
                    grad_bias, grad_output = tex.bgrad_quantize(grad_output, quantizer)
        if not isinstance(
            grad_output,
            (QuantizedTensor, Float8TensorBase, MXFP8TensorBase, Float8BlockwiseQTensorBase),
        ):
907
908
            grad_output = quantizer(grad_output)
        return grad_output, grad_bias
909

910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
    def register_parameter(self, name, param, **kwargs):
        """
        Thin wrapper around PyTorch parameter registration to stash additional parameter
        metedata used in deferred initialization.
        """
        super().register_parameter(name, param)
        self.param_init_meta[name] = _ParameterInitMeta(**kwargs)

    def reset_parameters(self, defer_init: Optional[bool] = False) -> None:
        """
        Reset all module parameters to initial values. Unless deferred initialization
        is specified, all parameters on a 'meta' device are also materialized on a real cuda
        device before the values are reset to initial.
        """
        if defer_init:
            return

        for name, param in self.named_parameters(recurse=False):
            # Ensure parameter is on a real device
929
930
            if param.device == torch.device("meta"):
                param = torch.empty_like(param, device="cuda")
931
932
933
934
935
936
937

            # Initialize the parameter values on device
            init_fn = self.param_init_meta[name].init_fn
            get_rng_state_tracker = self.param_init_meta[name].get_rng_state_tracker
            if get_rng_state_tracker is None:
                init_fn(param)
            else:
938
939
940
941
942
943
                if hasattr(self, "rng_tracker_name") and self.rng_tracker_name:
                    with get_rng_state_tracker().fork(self.rng_tracker_name):
                        init_fn(param)
                else:
                    with get_rng_state_tracker().fork():
                        init_fn(param)
944

945
            # If primary weights are in fp8, wrap the parameter as FP8Tensor
946
            fp8_meta_index = self.param_init_meta[name].fp8_meta_index
947
            high_precision_init_val = None
948
            if self.primary_weights_in_fp8 and fp8_meta_index is not None:
949
950
951
                if self.preserve_high_precision_init_val:
                    high_precision_init_val = param.detach().cpu()

952
953
954
955
956
957
                quantizer = self.quantizers["scaling_fwd"][fp8_meta_index]
                assert (
                    quantizer is not None
                )  # to use primary fp8 weight one needs to use FP8 autocast with specific recipe.
                quantizer.internal = False
                param = quantizer(param)
958
959
960
961
962

            # Redo parameter wrap in case we broke it above
            # NOTE: Currently this can only be broken when primary weights are in Fp8 but
            #       re-applying the nn.Parameter() wrap is a no-op when the input is already
            #       a parameter so we always re-apply it just for extra safety.
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
            param = torch.nn.Parameter(param)
            if high_precision_init_val is not None:

                # - Master weights are initialized from model weights, if we use fp8 primary
                #   weights to initialize master weights, the numerical values of master weights
                #   are not consistent with the numerical values when we initialize them from
                #   bf16/fp16 weights.
                # - So we add a `_high_precision_init_val` attribute to each model weight to store
                #   the original bf16/fp16 weight on cpu before casting it to fp8. And users can
                #   use `get_high_precision_init_val` to get this cpu tensor.
                # - This cpu tensor is not needed once the master weight is initialized, so users
                #   should call `clear_high_precision_init_val` to remove it after master weight
                #   is initialized.

                def get(self):
                    if hasattr(self, "_high_precision_init_val"):
                        return self._high_precision_init_val
                    return None

                def clear(self):
                    if hasattr(self, "_high_precision_init_val"):
                        del self._high_precision_init_val

                param._high_precision_init_val = high_precision_init_val
                param.get_high_precision_init_val = MethodType(get, param)
                param.clear_high_precision_init_val = MethodType(clear, param)

            setattr(self, name, param)
991

992
993
994
    @abstractmethod
    def forward(self):
        """Needs override."""
995

996
    def get_weight_workspace(
997
        self,
998
999
        *,
        tensor: Optional[torch.Tensor] = None,
1000
        quantizer: Optional[Quantizer] = None,
1001
1002
1003
        cache_name: Optional[str] = None,
        update_workspace: bool = True,
        skip_update_flag: Optional[torch.Tensor] = None,
1004
1005
        fsdp_group: Optional[dist_group_type] = None,
    ) -> QuantizedTensor:
1006
1007
1008
1009
1010
1011
1012
1013
1014
        """Get FP8 workspace buffer and maybe update its values

        The workspace buffer may be cached for future function calls.

        Parameters
        ----------
        tensor : torch.Tensor, optional
            Values to copy into workspace. Required if the workspace
            is being constructed or updated.
1015
1016
1017
        quantizer: Quantizer, optional
            Quantizer used to cast the weights. Required if the
            workspace is being constructed or updated.
1018
1019
1020
1021
1022
1023
1024
        cache_name: str, optional
            Key for caching.
        update_workspace: bool, default = `True`
            Update workspace with values from `tensor`.
        skip_update_flag: torch.Tensor, optional
            GPU flag to skip updating the workspace. Take precedence
            over `update_workspace` if provided.
1025
1026
        fsdp_group: bool, default = None
            FSDP process group that the weights are distributed over.
1027
1028
        """

1029
1030
1031
1032
1033
1034
1035
1036
1037
        # FP8 primary weights
        if isinstance(tensor, QuantizedTensor):
            if update_workspace and quantizer is not None:
                tensor.update_usage(
                    rowwise_usage=quantizer.rowwise_usage,
                    columnwise_usage=quantizer.columnwise_usage,
                )
            return tensor

1038
        # Try getting workspace from cache
1039
1040
1041
        out = None
        if cache_name is not None:
            out = self._fp8_workspaces.get(cache_name, None)
1042
1043
1044
1045
1046
1047
1048
            if quantizer is not None and isinstance(out, MXFP8TensorBase):
                if quantizer.rowwise_usage and out._rowwise_data is None:
                    out = None
                    del self._fp8_workspaces[cache_name]
                elif quantizer.columnwise_usage and out._columnwise_data is None:
                    out = None
                    del self._fp8_workspaces[cache_name]
1049

1050
1051
1052
1053
1054
        # Gather cached Fp8 workspace if it's distributed
        # NOTE: FSDP sharding is supported only for Fp8 buffers and will not work
        #       for models initialized with Fp8 primary weights.
        if (
            out is not None
1055
            and tensor is not None
1056
            and fsdp_group is not None
1057
            and out.data.shape != tensor.data.shape
1058
1059
1060
1061
        ):
            _fsdp_gather_tensors(fsdp_group, [tensor.data.shape], out)

        # Construct workspace if needed
1062
        if out is None:
1063
            if tensor is None or quantizer is None:
1064
                raise ValueError(
1065
                    "tensor and quantizer kwargs must be provided to construct FP8 workspace"
1066
                )
1067
            out = quantizer(tensor)
1068
1069

            # Update cache
1070
1071
            if cache_name is not None:
                self._fp8_workspaces[cache_name] = out
1072
            return out
1073
1074
1075
1076
1077
1078

        # Update workspace if needed
        if skip_update_flag is not None:
            update_workspace = True
        if update_workspace:
            if tensor is None:
1079
                raise ValueError("tensor kwarg must be provided to update FP8 workspace")
1080
            if hasattr(out, "quantize_"):
1081
                out.quantize_(tensor, noop_flag=skip_update_flag)
1082
1083
            else:
                tex.quantize(tensor, quantizer, out, skip_update_flag)
1084
1085

        return out
1086

1087
1088
1089
    def _load_from_state_dict(
        self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
    ):
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
        """
        This function loads tensors and extra state including fp8 metadata.
        This metadata is essential for copying fp8 tensors, as the copy_ function
        uses the scale_inv parameter from fp8_meta to set the correct scaling factor
        for the new tensor.
        Hence, this extra state must be loaded before the tensor copying process,
        not after, as is typically done in _load_from_state_dict.
        Tensors are copied into fp8 tensors only when self.primary_weights_in_fp8=True,
        otherwise, this behavior is not required.
        """
        if self.primary_weights_in_fp8:
            extra_state_key = prefix + torch.nn.modules.module._EXTRA_STATE_KEY_SUFFIX
            if extra_state_key in state_dict:
                self.set_extra_state(state_dict[extra_state_key])
1104
1105
1106
        super()._load_from_state_dict(
            state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
        )