base.py 46.7 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.

"""Base modules and utilities for TransformerEngine PyTorch API"""
6
import io
7
8
9
import os
import pickle
import warnings
10
11
12
import socket
import fcntl
import struct
13
from abc import ABC, abstractmethod
14
from typing import Any, Dict, Generator, List, Optional, Set, Tuple, Union
15
16
17
18
19
from contextlib import contextmanager

import torch
import torch.nn.functional as F

20
import transformer_engine_torch as tex
21
from ._common import _ParameterInitMeta
22
from ..export import is_in_onnx_export_mode
23
24
25
from ..fp8 import (
    get_default_fp8_recipe,
    get_fp8_te_dtype,
26
    FP8GlobalStateManager,
27
28
29
30
31
)
from ..distributed import (
    gather_along_first_dim,
    is_fp8_activation_recompute_enabled,
    in_fp8_activation_recompute_phase,
32
    _fsdp_gather_tensors,
33
34
35
36
37
38
39
)
from ..cpp_extensions import (
    fp8_cast_transpose_fused,
    fp8_cast_transpose_bgrad_fused,
    cast_to_fp8,
)
from ..constants import dist_group_type
40
from ..float8_tensor import Float8Tensor
41

42
43
__all__ = ["initialize_ub", "destroy_ub"]

44
45
46
_2X_ACC_FPROP = False
_2X_ACC_DGRAD = True
_2X_ACC_WGRAD = True
47
_multi_stream_cublas_workspace = []
48
49
50
_cublas_workspace = None
_ub_communicators = None
_NUM_MAX_UB_STREAMS = 3
51
layers_atomic_ring_exchange = []
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70


def get_cublas_workspace_size_bytes() -> None:
    """Return 32 MiB if using hopper, 4 MiB for all other architectures."""
    if torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 9:
        return 33_554_432
    return 4_194_304


def get_workspace() -> torch.Tensor:
    """Returns workspace for cublas."""
    global _cublas_workspace
    if _cublas_workspace is None:
        _cublas_workspace = torch.empty(
            get_cublas_workspace_size_bytes(), dtype=torch.uint8, device="cuda"
        )
    return _cublas_workspace


71
72
73
74
def get_multi_stream_cublas_workspace() -> List[torch.Tensor]:
    """Returns workspace for multi-stream cublas."""
    global _multi_stream_cublas_workspace
    if not _multi_stream_cublas_workspace:
75
        for _ in range(tex._num_cublas_streams):
76
77
78
79
80
81
            _multi_stream_cublas_workspace.append(
                torch.empty(get_cublas_workspace_size_bytes(), dtype=torch.uint8, device="cuda")
            )
    return _multi_stream_cublas_workspace


82
83
def initialize_ub(
    shape: list,
84
    tp_size: int,
85
    use_fp8: bool = False,
86
    dtype: torch.dtype = torch.bfloat16,
87
    ub_cfgs: Optional[dict] = None,
88
    bootstrap_backend: Union[str, torch.distributed.Backend] = None,
89
) -> None:
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
    r"""
    Initialize the Userbuffers communicator for overlapping tensor-parallel communications with
    GEMM compute in te.Linear, te.LayerNormLinear and te.LayerNormMLP modules.

    Parameters
    ----------
    shape : list
            shape of the communication buffer, typically set to be the same as the global shape of
            the input tensor to a te.TransformerLayer forward pass, with the sequence and batch
            dimensions collapsed together -- i.e.: `(sequence_length * batch_size, hidden_size)`
    tp_size : int
              number of GPUs in the tensor-parallel process group
    use_fp8 : bool = False
              allocate the communication buffer for FP8 GEMM inputs/outputs
    dtype : torch.dtype = torch.bfloat16
            non-FP8 data type of the communication buffer when `use_fp8 = False`
    ub_cfgs: dict = None
             Configuration dictionary with the structure
             ```
             {
                <gemm_name> : {
                    "method": <"ring_exchange" or "pipeline">,
                    "is_reduce_scatter": bool,
                    "num_sm": int,
                    "cga_size": int,
                    "set_sm_margin": bool,
                    "num_splits": int,
                    "aggregate": bool,
                    "atomic_gemm": bool,
                    "use_ce": bool,
                    "fp8_buf": bool,
                }
             }
             ```
             for `te.TransformerLayer` GEMM layers in `["qkv_fprop", "qkv_dgrad", "qkv_wgrad",
             "proj_fprop", "proj_dgrad", "proj_wgrad", "fc1_fprop", "fc1_dgrad", "fc2_dgrad",
             "fc2_fprop", "fc2_dgrad"]`.
    bootstrap_backend : str = None
                        `torch.distributed` communication backend for the all-gather, broadcast and
                        barrier collectives during Userbuffers initialization. Not all backends are
                        valid for every cluster configuration and distributed launch method even if
                        they are available in PyTorch. When left unset, the initialization prefers
                        to use the MPI backend, falling back first on Gloo and then NCCL if MPI is
                        not available. Setting `NVTE_UB_WITH_MPI=1` when building TE overrides this
                        option and always initializes Userbuffers with direct MPI calls in C++,
                        which also requires `MPI_HOME=/path/to/mpi/root` to be set at compile time.
    """
137
    if not tex.device_supports_multicast():
138
        assert bool(int(os.getenv("UB_SKIPMC", "0"))), (
139
140
141
142
            "CUDA device, driver and/or toolkit version does not support comm+GEMM overlap with "
            + "CUDA Multicast. Launch app with UB_SKIPMC=1 to try CUDA IPC instead."
        )

143
144
145
    global _ub_communicators
    assert _ub_communicators is None, "UB communicators are already initialized."
    _ub_communicators = {}
146
147

    if tex.ubuf_built_with_mpi():
148
149
        # We're bootstrapping with direct calls to MPI in Userbuffers code so we need to force
        # an MPI_Init() here by creating a new MPI process group...
150
        assert torch.distributed.is_mpi_available()
151
152
        _ = torch.distributed.new_group(backend="mpi")
        helper = tex.CommOverlapHelper()
153
    else:
154
155
        # Bootstrapping with torch.distributed API, so check backend and construct
        # intra/inter-node process groups...
156
157
158
159
160
        assert (
            torch.distributed.is_initialized()
        ), "torch.distributed must be initialized before Userbuffers"
        if bootstrap_backend is None:
            bootstrap_backend = "nccl"
161
            if torch.distributed.is_mpi_available():
162
                bootstrap_backend = "mpi"
163
164
            elif torch.distributed.is_gloo_available():
                bootstrap_backend = "gloo"
165
        else:
166
167
168
169
170
171
172
173
174
            assert bootstrap_backend in [
                "gloo",
                "mpi",
                "nccl",
            ], "Invalid torch.distributed backend for bootstrapping Userbuffers!"
            assert torch.distributed.is_backend_available(bootstrap_backend), (
                f"PyTorch must be compiled with '{bootstrap_backend}' support in order to "
                f"bootstrap Userbuffers with '{bootstrap_backend}' collectives."
            )
175
176
177
178
179

        world_group = torch.distributed.new_group(backend=bootstrap_backend)
        world_rank = torch.distributed.get_rank(world_group)
        world_size = torch.distributed.get_world_size(world_group)

180
181
182
183
184
        # We have single-node NVLink so we can color based on physical node hostnames.
        # NOTE: Prefer a network interface defined via the NVTE_UB_SOCKET_IFNAME variable, and
        #       otherwise fall back on NCCL_SOCKET_IFNAME or GLOO_SOCKET_IFNAME depending on
        #       the chosen bootstrap backend.
        mydomain = socket.gethostname()
185
        ifname = os.getenv(
186
            "NVTE_UB_SOCKET_IFNAME", os.getenv(f"{bootstrap_backend.upper()}_SOCKET_IFNAME")
187
188
        )
        if ifname is not None:
189
190
191
192
            # Make sure the ifname found in the environment is a valid network interface
            if ifname in [name for _, name in socket.if_nameindex()]:
                s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
                try:
193
                    mydomain = socket.inet_ntoa(
194
195
196
197
198
199
200
201
202
203
204
                        fcntl.ioctl(
                            s.fileno(), 0x8915, struct.pack("256s", ifname[:15].encode("UTF-8"))
                        )[20:24]
                    )
                except OSError as err:
                    raise OSError(f"Invalid network interface: {ifname}") from err
                finally:
                    s.close()
            else:
                ifname_warning = (
                    f"'{ifname}' is not a valid network interface! `te.initialize_ub()` will"
205
206
207
208
209
                    + " attempt to detect ranks on the same node by matching "
                    + "'socket.gethostname()', which is known to fail on virtual clusters like "
                    + "Kubernetes. If Userbuffers initialization fails, please set the "
                    + "'NVTE_UB_SOCKET_IFNAME' variable in your environment to the correct network "
                    + "interface."
210
                )
211
                warnings.warn(ifname_warning, UserWarning)
212

213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
        # Allgather the domain colors across ranks and reduce to a list of unique domains
        domain_per_rank_list = [None for _ in range(world_size)]
        torch.distributed.all_gather_object(domain_per_rank_list, mydomain, world_group)
        unique_domains = []
        for domain in domain_per_rank_list:
            if domain not in unique_domains:
                unique_domains.append(domain)
        num_domains = len(unique_domains)

        if num_domains > 1:
            # DP/TP model replicated on multiple NVLink domains
            ranks_per_domain_list = [[] for _ in range(num_domains)]
            mydomain_idx = -1
            for i, domain in enumerate(domain_per_rank_list):
                domain_idx = unique_domains.index(domain)
                ranks_per_domain_list[domain_idx].append(i)
                if domain == mydomain:
                    mydomain_idx = domain_idx
            assert mydomain_idx >= 0, "Internal TE error!"

            intra_domain_group, _ = torch.distributed.new_subgroups_by_enumeration(
                ranks_per_domain_list, backend=bootstrap_backend
            )
            local_rank = torch.distributed.get_rank(intra_domain_group)

            inter_domain_group, _ = torch.distributed.new_subgroups_by_enumeration(
                [list(ranks) for ranks in zip(*ranks_per_domain_list)],
                backend=bootstrap_backend,
241
            )
242
243

            helper = tex.CommOverlapHelper(world_group, intra_domain_group, inter_domain_group)
244
245

        else:
246
247
            # TP model on single NVLink domain, no replication, no data-parallelism
            mydomain_idx = 0
248
            local_rank = world_rank
249
250
251
            intra_domain_ranks = list(range(world_size))

            helper = tex.CommOverlapHelper(world_group)
252

253
        if world_rank == 0:
254
            print(f"!!! [UB] Number of NVLink domains: {num_domains}\n", end="", flush=True)
255
256
        if local_rank == 0:
            print(
257
                f"!!! [UB] Global ranks on domain {mydomain_idx}: {intra_domain_ranks}\n",
258
259
260
261
                end="",
                flush=True,
            )

262
263
264
265
266
    # Increase the workspace by the number of maximum concurrent streams
    global _cublas_workspace
    _cublas_workspace = get_workspace().repeat(_NUM_MAX_UB_STREAMS)

    # Default buffer precision: AllGather buffers use fp8 when using fp8 recipe
267
    layers_all_gather_overlap = [
268
269
270
271
272
273
        "qkv_fprop",
        "qkv_dgrad",
        "proj_dgrad",
        "fc1_fprop",
        "fc1_dgrad",
        "fc2_dgrad",
274
    ]
275
    layers_reduce_scatter_overlap = ["proj_fprop", "fc2_fprop", "qkv_wgrad", "fc1_wgrad"]
Jaemin Choi's avatar
Jaemin Choi committed
276
    dgrad_reduce_scatter_overlap = ["qkv_dgrad", "fc1_dgrad"]
277
278
    # Default overlap methods for layers
    methods = {
279
280
281
        "ring_exchange": ["qkv_fprop", "fc1_fprop", "proj_dgrad", "fc2_dgrad"],
        "pipeline": ["proj_fprop", "fc2_fprop"],
        "bulk": ["qkv_dgrad", "qkv_wgrad", "fc1_dgrad", "fc1_wgrad"],
282
283
    }

284
    # AG-RS overlap pairs of layers forming a tensor-parallel block
285
286
    ag_rs_pairs = {"qkv_fprop": "proj_fprop", "fc1_fprop": "fc2_fprop"}
    rs_ag_pairs = {v: k for k, v in ag_rs_pairs.items()}
287
288
289
    global layers_atomic_ring_exchange
    layers_atomic_ring_exchange = []

290
291
292
293
294
295
    def get_method(name):
        for method, names in methods.items():
            if name in names:
                return method
        raise KeyError(f"Given layer name {name} does not exist.")

296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
    def get_default_config(name):
        method = get_method(name)
        is_reduce_scatter = name in layers_reduce_scatter_overlap
        default_cfg = {
            "method": method,
            "is_reduce_scatter": is_reduce_scatter,
            "num_sm": 1 if method == "ring_exchange" else 16,
            "cga_size": 1 if method == "ring_exchange" else 2,
            "set_sm_margin": False,
            "num_splits": 4 if method == "pipeline" else tp_size,
            "aggregate": False,
            "atomic_gemm": False,
            "use_ce": True,
            "fp8_buf": name in layers_all_gather_overlap,
        }
        return default_cfg

313
314
315
    def add_ub(
        name: str,
        method: str,
316
        is_reduce_scatter: int,
317
318
319
        num_sm: int = 16,
        cga_size: int = 2,
        set_sm_margin: int = 0,
320
        num_splits: int = 0,
321
        aggregate: int = 0,
322
        atomic_gemm: int = 0,
323
        use_ce: bool = True,
324
        fp8_buf: bool = False,
325
    ) -> None:
326
327
328
329
330
        if atomic_gemm:
            warnings.warn(
                "Atomic GEMM uses a beta API from cublas and is not tested for all use cases."
            )
            assert use_fp8, "Atomic GEMM overlap supported only for FP8 GEMM."
331
            if method == "bulk":
332
                warnings.warn(
333
                    f"At {name}, atoimic GEMM not is supported for a bulk overlap."
334
335
336
                    "Defaulting to `atomic_gemm=False`."
                )
                atomic_gemm = 0
337
        if not is_reduce_scatter and method == "pipeline":
338
            raise ValueError(
339
                f"At {name}, `pipeline` overlap method is not supported for AllGather."
340
            )
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
        # Check if both AG and RS overlaps use `atomic GEMM`` + `p2p ring-exchange`.
        # Using atomic GEMM + p2p ring-exchange in only one of the pair breaks functionality.
        global layers_atomic_ring_exchange
        if atomic_gemm and method == "ring_exchange" and name in ag_rs_pairs:
            layers_atomic_ring_exchange += [name, ag_rs_pairs[name]]
        if name in rs_ag_pairs:
            assert_message = (
                f"At {name}, atomic AG-GEMM overlap with `ring_exchange` shuffles GEMM chunk "
                "outputs, and  RS-GEMM overlap un-suffle them. When one of the GEMM-AG and "
                "GEMM-RS overlaps forming a TP block (e.g., qkv_fprop and proj_fprop) uses "
                "`atomic gemm` and `ring_exhcnage`, its pair must use the same overlap config "
                "for functionality."
            )
            if name in layers_atomic_ring_exchange:
                assert atomic_gemm and method == "ring_exchange", assert_message
            else:
                if atomic_gemm and method == "ring_exchange":
                    assert rs_ag_pairs[name] in layers_atomic_ring_exchange, assert_message

360
        buffer_dtype = torch.uint8 if (use_fp8 and fp8_buf) else dtype
361
        if method == "ring_exchange":
362
363
364
365
            ub_obj = tex.CommOverlapP2P(
                shape,  # Communication buffer shape
                buffer_dtype,  # Communication buffer data type
                helper,  # Helper for torch.distributed callbacks during bootstrapping
366
                tp_size,  # Tensor-parallel group size (may be different than local_size)
367
368
369
370
371
372
373
374
                tex.CommOverlapType.RS if is_reduce_scatter else tex.CommOverlapType.AG,
                num_max_streams=_NUM_MAX_UB_STREAMS,
                comm_cga_size=cga_size,
                num_comm_sm=num_sm,
                set_sm_margin=set_sm_margin,
                atomic_gemm=atomic_gemm,
                use_ce=use_ce,
                aggregate=aggregate,
375
            )
376
        else:
377
378
379
380
            ub_obj = tex.CommOverlap(
                shape,  # Communication buffer shape
                buffer_dtype,  # Communication buffer data type
                helper,  # Helper for torch.distributed callbacks during bootstrapping
381
                tp_size,  # Tensor-parallel group size (may be different than local_size)
382
383
384
385
386
387
                num_splits=num_splits,
                num_max_streams=_NUM_MAX_UB_STREAMS,
                comm_cga_size=cga_size,
                num_comm_sm=num_sm,
                set_sm_margin=set_sm_margin,
                atomic_gemm=atomic_gemm,
388
            )
389
390
        _ub_communicators[name] = ub_obj

Jaemin Choi's avatar
Jaemin Choi committed
391
392
    if ub_cfgs is not None:
        for name in dgrad_reduce_scatter_overlap:
393
394
            if name in ub_cfgs and "method" in ub_cfgs[name] and ub_cfgs[name]["method"] != "bulk":
                wgrad_name = name.replace("dgrad", "wgrad")
Jaemin Choi's avatar
Jaemin Choi committed
395
396
                assert wgrad_name not in ub_cfgs
                layers_reduce_scatter_overlap.remove(wgrad_name)
397
                layers_all_gather_overlap.remove(name)
Jaemin Choi's avatar
Jaemin Choi committed
398
                layers_reduce_scatter_overlap.append(name)
399
400
401
                methods["bulk"].remove(name)
                new_method = ub_cfgs[name]["method"]
                methods[new_method].append(name)
Jaemin Choi's avatar
Jaemin Choi committed
402

403
    for name in methods["ring_exchange"] + methods["pipeline"] + methods["bulk"]:
404
        ub_cfg = get_default_config(name)
405
        if ub_cfgs is not None and name in ub_cfgs:
406
            fp8_buf = (name in layers_all_gather_overlap) or (
407
                ub_cfgs[name].get("fp8_buf", False) and name in methods["pipeline"]
408
            )
409
410
411
            ub_cfg.update(ub_cfgs[name])
            ub_cfg["fp8_buf"] = fp8_buf
        add_ub(name, **ub_cfg)
412
413
414
415
416
417
418
419


def get_ub(name: str):
    """Get userbuffer communicator corresponding to give key."""
    assert _ub_communicators is not None, "UB manager is not initialized."
    assert name in _ub_communicators, f"UB for {name} is not registered."
    return _ub_communicators[name]

420

421
422
423
424
425
426
427
def destroy_ub():
    """Destroy all allocated userbuffer communicators."""
    global _ub_communicators
    _ub_communicators = None
    global layers_atomic_ring_exchange
    layers_atomic_ring_exchange = []

428
429
430
431
432
433
434
435
436
437
438

class TransformerEngineBaseModule(torch.nn.Module, ABC):
    """Base TE module."""

    def __init__(self) -> None:
        super().__init__()
        assert torch.cuda.is_available(), "TransformerEngine needs CUDA."
        self.fp8_initialized = False
        self.fp8 = False
        self.fp8_calibration = False
        self.fp8_meta = {}
439
        self.fp8_meta["fp8_checkpoint"] = False
440
441
442
443
444
445
        self.fp8_meta["fp8_group"] = None
        self.fp8_meta["recipe"] = get_default_fp8_recipe()
        self.fp8_meta_tensors_initialized = False
        self.tp_group = None
        self.tp_size = 1
        self.sequence_parallel = False
446
447
        self.param_init_meta = {}
        self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters()
448
449
        self.fsdp_wrapped = False
        self.fsdp_group = None
450
        self._fp8_workspaces: Dict[str, Float8Tensor] = {}
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
        self.activation_dtype: Optional[torch.dtype] = None

    # Names of attributes that can be set quickly (see __setattr__
    # method)
    _fast_setattr_names: Set[str] = {
        "activation_dtype",
        "fp8",
        "fp8_initialized",
        "fp8_calibration",
        "fp8_parameters",
    }

    def __setattr__(self, name: str, value: Any) -> None:
        if name in TransformerEngineBaseModule._fast_setattr_names:
            # torch.nn.Module has a custom __setattr__ that handles
            # modules, parameters, and buffers. This is unnecessary
            # overhead when setting plain attrs.
            self.__dict__[name] = value
        else:
            # Default case
            super().__setattr__(name, value)
472

473
474
475
476
477
478
479
480
481
482
483
484
    def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> None:
        """Increase or decrease size of amax history based on given `length`.

        .. warning::
            This changes the underlying amax memory location.
        """
        if fwd is None:
            fp8_meta_tensor_keys = ("scaling_fwd", "scaling_bwd")
        else:
            fp8_meta_tensor_keys = ("scaling_fwd" if fwd else "scaling_bwd",)

        for meta_key in fp8_meta_tensor_keys:
485
486
487
            if meta_key not in self.fp8_meta:
                # Handles non-parameter FP8 modules, e.g. DPA.
                continue
488
489
490
491
492
            curr_len = self.fp8_meta[meta_key].amax_history.shape[0]
            if length == curr_len:
                continue
            if length < curr_len:
                self.fp8_meta[meta_key].amax_history = (
493
494
                    self.fp8_meta[meta_key].amax_history[:length].clone()
                )
495
496
497
498
499
500
501
502
            elif length > curr_len:
                extra_rows = length - curr_len
                self.fp8_meta[meta_key].amax_history = F.pad(
                    self.fp8_meta[meta_key].amax_history, pad=(0, 0, 0, extra_rows)
                )

            # Update the global buffers with new amax and history pointers.
            if FP8GlobalStateManager.get_buffer_info() in self.fp8_meta:
503
504
505
                fwd_pos, fwd_key, bwd_pos, bwd_key = self.fp8_meta[
                    FP8GlobalStateManager.get_buffer_info()
                ]
506
507
508
509
510
                for pos, buffer_key in zip((fwd_pos, bwd_pos), (fwd_key, bwd_key)):
                    if buffer_key in FP8GlobalStateManager.global_amax_buffer:
                        assert (
                            buffer_key in FP8GlobalStateManager.global_amax_history_buffer
                        ), "TE internal error during amax history change."
511
512
513
                        FP8GlobalStateManager.global_amax_buffer[buffer_key][pos] = self.fp8_meta[
                            meta_key
                        ].amax_history[0]
514
                        FP8GlobalStateManager.global_amax_history_buffer[buffer_key][pos] = (
515
516
                            self.fp8_meta[meta_key].amax_history
                        )
517

518
519
520
521
522
523
    def set_meta_tensor(self, fwd: bool) -> None:
        """Init scales and amaxes for fwd | bwd."""
        fp8_meta_tensor_key = "scaling_fwd" if fwd else "scaling_bwd"

        if self.fp8_meta_tensors_initialized:
            # Handle changed amax history size.
524
            self.adjust_amax_history_length(self.fp8_meta["recipe"].amax_history_len, fwd=fwd)
525
526
527
528
            return

        # Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and
        # 2 (grad_output and grad_input) for bwd
529
        num_fp8_tensors = self.fp8_meta["num_gemms"] * 3 if fwd else self.fp8_meta["num_gemms"] * 2
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550

        self.fp8_meta[fp8_meta_tensor_key] = tex.FP8TensorMeta()
        self.fp8_meta[fp8_meta_tensor_key].scale = torch.ones(
            num_fp8_tensors, dtype=torch.float32, device="cuda"
        )
        self.fp8_meta[fp8_meta_tensor_key].scale_inv = torch.ones(
            num_fp8_tensors, dtype=torch.float32, device="cuda"
        )
        self.fp8_meta[fp8_meta_tensor_key].amax_history = torch.zeros(
            self.fp8_meta["recipe"].amax_history_len,
            num_fp8_tensors,
            dtype=torch.float32,
            device="cuda",
        )

    def init_fp8_meta_tensors(self) -> None:
        """Init scales and amaxes."""
        self.set_meta_tensor(True)
        self.set_meta_tensor(False)
        self.fp8_meta_tensors_initialized = True

551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
    def get_fp8_meta_tensors(self) -> None:
        """Get scales and amaxes."""
        fwd_key, bwd_key = "scaling_fwd", "scaling_bwd"
        if fwd_key not in self.fp8_meta or bwd_key not in self.fp8_meta:
            return None

        fp8_meta_tensors = {fwd_key: [], bwd_key: []}
        with torch.no_grad():
            for key in (fwd_key, bwd_key):
                fp8_meta_tensors[key].append(self.fp8_meta[key].scale.clone())
                fp8_meta_tensors[key].append(self.fp8_meta[key].scale_inv.clone())
                fp8_meta_tensors[key].append(self.fp8_meta[key].amax_history.clone())
        return fp8_meta_tensors

    def reset_fp8_meta_tensors(self, fp8_meta_tensors=None) -> None:
        """Reset scales and amaxes."""
567

568
569
570
571
572
        def reset(key):
            if key in self.fp8_meta:
                if fp8_meta_tensors is None:
                    self.fp8_meta[key].scale.copy_(torch.ones_like(self.fp8_meta[key].scale))
                    self.fp8_meta[key].scale_inv.copy_(
573
574
                        torch.ones_like(self.fp8_meta[key].scale_inv)
                    )
575
                    self.fp8_meta[key].amax_history.copy_(
576
577
                        torch.zeros_like(self.fp8_meta[key].amax_history)
                    )
578
579
580
581
582
                else:
                    assert key in fp8_meta_tensors, "Cannot reset fp8 tensors."
                    self.fp8_meta[key].scale.copy_(fp8_meta_tensors[key][0])
                    self.fp8_meta[key].scale_inv.copy_(fp8_meta_tensors[key][1])
                    self.fp8_meta[key].amax_history.copy_(fp8_meta_tensors[key][2])
583

584
585
586
587
        with torch.no_grad():
            reset("scaling_fwd")
            reset("scaling_bwd")

588
589
590
    def get_extra_state(self) -> torch.Tensor:
        """Save before checkpointing."""
        state = None
591

592
        fp8_checkpoint = self.fp8_meta["fp8_checkpoint"] or self.fp8 or self.fp8_calibration
593
594

        if fp8_checkpoint:
595
            state = {}
596
597
598
599
600
601
            state["scale_fwd"] = self.fp8_meta["scaling_fwd"].scale
            state["scale_inv_fwd"] = self.fp8_meta["scaling_fwd"].scale_inv
            state["amax_history_fwd"] = self.fp8_meta["scaling_fwd"].amax_history
            state["scale_bwd"] = self.fp8_meta["scaling_bwd"].scale
            state["scale_inv_bwd"] = self.fp8_meta["scaling_bwd"].scale_inv
            state["amax_history_bwd"] = self.fp8_meta["scaling_bwd"].amax_history
602
603
604
605

            # Store other pickelable values.
            extra = {}
            for k, v in self.fp8_meta.items():
606
607
608
                if k != "buffer_index_and_autocast_key" and isinstance(
                    v, (bool, int, float, str, tuple, list)
                ):
609
610
611
                    extra[k] = v
            state["extra_fp8_variables"] = extra

612
613
614
615
616
        if is_in_onnx_export_mode():
            state_serialized = torch.frombuffer(pickle.dumps(state), dtype=torch.uint8)
        else:
            state_serialized = io.BytesIO()
            torch.save(state, state_serialized)
617

618
        return state_serialized
619
620
621
622
623
624
625
626

    def set_extra_state(self, state: torch.Tensor) -> None:
        """Load previous state."""
        if state is None:
            return

        if isinstance(state, torch.Tensor):
            state = pickle.loads(state.detach().cpu().numpy().tobytes())
627
628
        elif isinstance(state, io.BytesIO):
            state.seek(0)
629
            state = torch.load(state, map_location="cuda")
630
631
        else:
            raise RuntimeError("Unsupported checkpoint format.")
632
633
634

        if state is None:
            return
635
636
637
638
639
640
641
642
643
644
645
646
647

        # Load extra items.
        self.fp8_meta.update(state["extra_fp8_variables"])
        self.fp8_meta["recipe"].amax_history_len = state["amax_history_fwd"].shape[0]
        if "global_fp8_buffer_pos_fwd_recompute" in self.fp8_meta:
            del self.fp8_meta["global_fp8_buffer_pos_fwd_recompute"]

        # Initialize before loading.
        self.init_fp8_meta_tensors()
        self.fp8_meta["scaling_fwd"].scale.copy_(state["scale_fwd"])
        self.fp8_meta["scaling_fwd"].amax_history.copy_(state["amax_history_fwd"])
        self.fp8_meta["scaling_bwd"].scale.copy_(state["scale_bwd"])
        self.fp8_meta["scaling_bwd"].amax_history.copy_(state["amax_history_bwd"])
648
649
        self.fp8_meta["scaling_fwd"].scale_inv.copy_(state["scale_inv_fwd"])
        self.fp8_meta["scaling_bwd"].scale_inv.copy_(state["scale_inv_bwd"])
650
651
652
653
654
655
656
657
658

    def set_activation_dtype(self, inp: torch.Tensor) -> None:
        """Get activation data type for AMP."""
        # Native AMP (`torch.autocast`) gets highest priority
        if torch.is_autocast_enabled():
            self.activation_dtype = torch.get_autocast_gpu_dtype()
            return

        # All checks after this have already been performed once, thus skip
659
        if self.activation_dtype == inp.dtype:
660
661
            return

662
663
664
665
666
667
668
669
        dtype = inp.dtype
        for name, param in self.named_parameters():
            if param is not None:
                assert dtype == param.dtype, (
                    "Data types for parameters must match when outside of autocasted region. "
                    f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}"
                )
        self.activation_dtype = dtype
670
671

    def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
672
673
674
675
676
677
678
679
680
        """
        Set the tensor parallel group for the given
        module before executing the forward pass.

        Parameters
        ----------
        tp_group : ProcessGroup, default = `None`
                  tensor parallel process group.
        """
681
682
683
        self.tp_group = tp_group
        self.tp_group_initialized = True

684
685
686
    def _get_fp8_params(self) -> Union[List[torch.Tensor], None]:
        """returns the FP8 weights."""
        fp8_params = []
687
        for param in self.parameters(recurse=False):
688
689
690
691
692
693
            if isinstance(param, Float8Tensor) and param.requires_grad:
                fp8_params.append(param)
        if len(fp8_params) == 0:
            return None
        return fp8_params

694
695
    # This routine is shared across FP8 and FP8_calibration paths so should not actually
    # assume FP8 execution.
696
    def init_fp8_metadata(self, num_gemms: int = 1) -> None:
697
        """Initialize fp8 related metadata and tensors during fprop."""
698
        self.fp8_parameters = FP8GlobalStateManager.with_fp8_parameters()
699
700
        self.fp8 = FP8GlobalStateManager.is_fp8_enabled()
        self.fp8_calibration = FP8GlobalStateManager.is_fp8_calibration()
701
        self.fp8_meta["fp8_checkpoint"] = self.fp8 or self.fp8_calibration
702

703
704
705
706
        if self.fp8_parameters and not self.fp8_initialized:
            self.fp8_meta["num_gemms"] = num_gemms
            self.init_fp8_meta_tensors()

707
708
        if self.fp8 or self.fp8_calibration:
            # FP8 init has already been run and recipe is the same, don't do anything.
709
710
711
712
            if (
                self.fp8_initialized
                and FP8GlobalStateManager.get_fp8_recipe() == self.fp8_meta["recipe"]
            ):
713
714
715
                return

            # Set FP8, recipe, and other FP8 metadata
716
            self.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe()
717
            self.fp8_meta["num_gemms"] = num_gemms
718
            self.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group()
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734

            # Set FP8_MAX per tensor according to recipe
            self.fp8_meta["fp8_max_fwd"] = self.fp8_meta["recipe"].fp8_format.value.max_fwd
            self.fp8_meta["fp8_max_bwd"] = self.fp8_meta["recipe"].fp8_format.value.max_bwd

            # Allocate scales and amaxes
            self.init_fp8_meta_tensors()
            self.fp8_initialized = True
        else:
            # If fp8 isn't enabled, turn off and return.
            self.fp8_initialized = False

    @contextmanager
    def prepare_forward(
        self,
        inp: torch.Tensor,
735
        is_first_microbatch: Union[bool, None],  # pylint: disable=unused-argument
736
        num_gemms: int = 1,
737
        allow_non_contiguous: bool = False,
Jan Bielak's avatar
Jan Bielak committed
738
    ) -> Generator[torch.Tensor, None, None]:
739
740
741
742
743
744
745
746
        """Checks and prep for FWD.
        The context manager is needed because there isn't a way for a module to know
        if it's the last FP8 module in the forward autocast. It is useful
        to setup the forward aggregated amax reduction for every module
        just in case. The autocast exit will pick up the most recent one.
        """
        # Activation recomputation is used and this is the second forward phase.
        if self.fp8 and in_fp8_activation_recompute_phase():
747
            FP8GlobalStateManager.get_old_fp8_meta_tensors_for_recompute(self.fp8_meta)
748
749
750
751
752
753
754
        else:
            assert inp.is_cuda, "TransformerEngine needs CUDA."

            if self.tp_size > 1:
                assert self.tp_group_initialized, "TP group not initialized."

            self.set_activation_dtype(inp)
755
            self.init_fp8_metadata(num_gemms=num_gemms)
756

757
            if self.fp8 and self.sequence_parallel:
758
759
760
761
                assert self.fp8_meta["recipe"].reduce_amax, (
                    "Amax reduction across tensor parallel group is "
                    "necessary when using sequence parallelism with FP8."
                )
762

763
764
            if self.fp8 and not FP8GlobalStateManager.fp8_graph_capturing():
                FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(
765
766
                    self.fp8_meta, fp8_weights=self._get_fp8_params()
                )
767
768

            # Activation recomputation is used and this is the first forward phase.
769
            if self.fp8 and self.training and is_fp8_activation_recompute_enabled():
770
                FP8GlobalStateManager.copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta)
771
772

        with torch.cuda.nvtx.range(self.__class__.__name__ + " forward"):
773
774
775
            if not allow_non_contiguous and not inp.is_contiguous():
                inp = inp.contiguous()
            yield inp
776
777

        if self.fp8 and in_fp8_activation_recompute_phase():
778
            FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta)
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808

    def set_nccl_overlap_warning_if_tp(self) -> None:
        """When using TP, the NCCL communication needs to be scheduled
        before the GEMM for there to be a guaranteed overlap. From the
        host side in TE, the comm calls are always launched first, but
        to ensure that the GEMM isn't scheduled first, the environment
        variable `CUDA_DEVICE_MAX_CONNECTIONS` needs to be set to 1 to
        force a single channel.
        """
        if self.tp_size == 1:
            return
        num_cuda_work_queues = int(os.getenv("CUDA_DEVICE_MAX_CONNECTIONS", "0"))
        if num_cuda_work_queues != 1:
            warnings.warn(
                "To guarantee overlapping TP and SP collectives with the backward"
                "GEMMs, set environment variable CUDA_DEVICE_MAX_CONNECTIONS = 1"
            )

    @staticmethod
    def grad_output_preprocess(
        ctx, grad_output: torch.Tensor, row_parallel_mode: bool
    ) -> Tuple[Union[torch.Tensor, None], ...]:
        """Utility function for backward.
        Returns tuple in order (all optional/None based on training precion/recipe):
            R1: gathered `grad_output` in higher precision.
            R2: gathered `grad_output` in FP8.
            R3: R2 transposed.
            R4: bias gradient on R1.

        """
809
810
811
812
813
        if isinstance(grad_output, Float8Tensor):
            grad_output._data = grad_output._data.contiguous()
        else:
            grad_output = grad_output.contiguous()
        grad_output_mat = grad_output.view(-1, grad_output.shape[-1])
814
815
816
817
818
        gather_grad_output = row_parallel_mode and ctx.sequence_parallel

        # No-FP8 case: bgrad is fused with wgrad for this case.
        if not ctx.fp8:
            if gather_grad_output:
819
                if not ctx.ub_overlap_ag:
820
                    grad_output_mat, _ = gather_along_first_dim(grad_output_mat, ctx.tp_group)
821
822
823
824
825
                else:
                    ctx.ub_obj_gradout.copy_input_to_ubuf(grad_output, True)
                    grad_output_mat = ctx.ub_obj_gradout.get_ubuf_output(1)
            return grad_output_mat, None, None, None

826
        fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False)
827
828

        # FP8 case with non-FP8 wgrad
829
        if gather_grad_output and ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
830
            assert (
831
                not ctx.ub_overlap_ag
832
            ), "override_linear_precision.wgrad not supported with UB AG overlap"
833
834
835
836
837
838
839
            grad_output_mat, _ = gather_along_first_dim(grad_output_mat, ctx.tp_group)
        # FP8 case with gather: unfused bgrad, cast, transpose for efficient gather
        elif gather_grad_output:
            if ctx.use_bias:
                grad_bias = grad_output_mat.sum(dim=0)
            else:
                grad_bias = None
840
            if ctx.ub_overlap_ag:
841
842
843
                grad_output_c = ctx.ub_obj_gradout.get_ubuf_output(0)
            else:
                grad_output_c = torch.empty_like(grad_output_mat, dtype=torch.uint8)
844
845
846
847
848
849
850
851
852
            if not isinstance(grad_output_mat, Float8Tensor):
                cast_to_fp8(
                    grad_output_mat,
                    ctx.fp8_meta["scaling_bwd"],
                    tex.FP8BwdTensors.GRAD_OUTPUT1,
                    fp8_dtype_backward,
                    out=grad_output_c,
                )
            else:
853
                grad_output_c = grad_output_mat
854
            if not ctx.ub_overlap_ag:
855
                grad_output_c, _ = gather_along_first_dim(grad_output_c, ctx.tp_group)
856
857
858
859
                if not isinstance(grad_output_c, Float8Tensor):
                    grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward)
                else:
                    grad_output_t = grad_output_c.transpose_2d()
860
861
862
863
864
865
866
867
            else:
                grad_output_c = ctx.ub_obj_gradout.get_ubuf_output(1)
                grad_output_t = None

            return grad_output_mat, grad_output_c, grad_output_t, grad_bias

        # FP8 case without gather: cast, transpose, bgrad fused
        if ctx.use_bias:
868
869
870
            grad_output_mat_no_fp8 = grad_output_mat
            if isinstance(grad_output_mat, Float8Tensor):
                grad_output_mat_no_fp8 = grad_output_mat.from_float8(grad_output_mat.dtype)
871
            grad_bias, grad_output_c, grad_output_t = fp8_cast_transpose_bgrad_fused(
872
                grad_output_mat_no_fp8,
873
874
875
876
877
878
                ctx.fp8_meta["scaling_bwd"],
                tex.FP8BwdTensors.GRAD_OUTPUT1,
                fp8_dtype_backward,
            )
        else:
            if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
879
880
881
882
883
884
885
886
887
888
                if isinstance(grad_output_mat, Float8Tensor):
                    grad_output_c = grad_output_mat
                    grad_output_t = grad_output_c.transpose_2d()
                else:
                    grad_output_c, grad_output_t = fp8_cast_transpose_fused(
                        grad_output_mat,
                        ctx.fp8_meta["scaling_bwd"],
                        tex.FP8BwdTensors.GRAD_OUTPUT1,
                        fp8_dtype_backward,
                    )
889
890
            else:
                grad_output_t = None
891
892
893
894
895
896
897
898
899
                if not isinstance(grad_output_mat, Float8Tensor):
                    grad_output_c = cast_to_fp8(
                        grad_output_mat,
                        ctx.fp8_meta["scaling_bwd"],
                        tex.FP8BwdTensors.GRAD_OUTPUT1,
                        fp8_dtype_backward,
                    )
                else:
                    grad_output_c = grad_output_mat
900
901
902
903
            grad_bias = None

        return grad_output_mat, grad_output_c, grad_output_t, grad_bias

904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
    def register_parameter(self, name, param, **kwargs):
        """
        Thin wrapper around PyTorch parameter registration to stash additional parameter
        metedata used in deferred initialization.
        """
        super().register_parameter(name, param)
        self.param_init_meta[name] = _ParameterInitMeta(**kwargs)

    def reset_parameters(self, defer_init: Optional[bool] = False) -> None:
        """
        Reset all module parameters to initial values. Unless deferred initialization
        is specified, all parameters on a 'meta' device are also materialized on a real cuda
        device before the values are reset to initial.
        """
        if defer_init:
            return

        for name, param in self.named_parameters(recurse=False):
            # Ensure parameter is on a real device
923
924
            if param.device == torch.device("meta"):
                param = torch.empty_like(param, device="cuda")
925
926
927
928
929
930
931

            # Initialize the parameter values on device
            init_fn = self.param_init_meta[name].init_fn
            get_rng_state_tracker = self.param_init_meta[name].get_rng_state_tracker
            if get_rng_state_tracker is None:
                init_fn(param)
            else:
932
933
934
935
936
937
                if hasattr(self, "rng_tracker_name") and self.rng_tracker_name:
                    with get_rng_state_tracker().fork(self.rng_tracker_name):
                        init_fn(param)
                else:
                    with get_rng_state_tracker().fork():
                        init_fn(param)
938
939
940
941

            # If primary weights are in fp8, wrap the parameter as Float8Tensor
            fp8_meta_index = self.param_init_meta[name].fp8_meta_index
            if self.primary_weights_in_fp8 and fp8_meta_index is not None:
942
943
944
945
946
                dummy_amax = torch.empty(
                    (1, 1),
                    dtype=torch.float32,
                    device=param.device,
                )  # Dummy buffer to avoid overwriting amax history
947
948
949
                param = Float8Tensor.to_float8(
                    param,
                    fp8_meta=self.fp8_meta,
950
                    fp8_meta_index=fp8_meta_index,
951
952
                    amax=dummy_amax,
                    with_transpose_cache=torch.is_grad_enabled(),
953
954
955
956
957
958
959
960
                )

            # Redo parameter wrap in case we broke it above
            # NOTE: Currently this can only be broken when primary weights are in Fp8 but
            #       re-applying the nn.Parameter() wrap is a no-op when the input is already
            #       a parameter so we always re-apply it just for extra safety.
            setattr(self, name, torch.nn.Parameter(param))

961
962
963
    @abstractmethod
    def forward(self):
        """Needs override."""
964

965
    def get_fp8_workspace(
966
        self,
967
968
969
970
971
972
973
        *,
        tensor: Optional[torch.Tensor] = None,
        fp8_meta_forward: Optional[bool] = None,
        fp8_meta_index: Optional[int] = None,
        cache_name: Optional[str] = None,
        update_workspace: bool = True,
        skip_update_flag: Optional[torch.Tensor] = None,
974
        fsdp_group: dist_group_type = None,
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
    ) -> Float8Tensor:
        """Get FP8 workspace buffer and maybe update its values

        The workspace buffer may be cached for future function calls.

        Parameters
        ----------
        tensor : torch.Tensor, optional
            Values to copy into workspace. Required if the workspace
            is being constructed or updated.
        fp8_meta_forward: bool, optional
            Whether to access FP8 meta tensors for the forward pass or
            backward pass. Required if the workspace is being
            constructed.
        fp8_meta_index: int, optional
            Index to access in FP8 meta tensors. Required if the
            workspace is being constructed.
        cache_name: str, optional
            Key for caching.
        update_workspace: bool, default = `True`
            Update workspace with values from `tensor`.
        skip_update_flag: torch.Tensor, optional
            GPU flag to skip updating the workspace. Take precedence
            over `update_workspace` if provided.
999
1000
        fsdp_group: bool, default = None
            FSDP process group that the weights are distributed over.
1001
1002
        """

1003
        # Try getting workspace from cache
1004
1005
1006
        out = None
        if cache_name is not None:
            out = self._fp8_workspaces.get(cache_name, None)
1007

1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
        # Gather cached Fp8 workspace if it's distributed
        # NOTE: FSDP sharding is supported only for Fp8 buffers and will not work
        #       for models initialized with Fp8 primary weights.
        if (
            out is not None
            and not isinstance(out, Float8Tensor)
            and fsdp_group is not None
            and out._data.shape != tensor.data.shape
        ):
            _fsdp_gather_tensors(fsdp_group, [tensor.data.shape], out)

        # Construct workspace if needed
1020
        if out is None:
1021
1022

            # FP8 data
1023
            if tensor is None or fp8_meta_forward is None or fp8_meta_index is None:
1024
1025
1026
1027
1028
1029
1030
1031
                raise ValueError(
                    "tensor, fp8_meta_forward, and fp8_meta_index kwargs "
                    "must be provided to construct FP8 workspace"
                )
            fp8_dtype = get_fp8_te_dtype(
                self.fp8_meta["recipe"],
                fprop_tensor=fp8_meta_forward,
            )
1032
            data = torch.empty_like(tensor, dtype=torch.uint8)
1033
            scale_inv = torch.empty([1], dtype=torch.float32, device=tensor.device)
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051

            # Transpose cache
            with_transpose_cache = torch.is_grad_enabled()
            if (
                not with_transpose_cache
                and is_fp8_activation_recompute_enabled()
                and not in_fp8_activation_recompute_phase()
            ):
                with_transpose_cache = True
            data_transpose = None
            if with_transpose_cache:
                data_transpose = torch.empty(
                    (tensor.size(-1), tensor.numel() // tensor.size(-1)),
                    dtype=torch.uint8,
                    device=tensor.device,
                )

            # Construct FP8 tensor
1052
            out = Float8Tensor(
1053
                data=data,
1054
1055
1056
1057
1058
1059
                fp8_meta=self.fp8_meta,
                fp8_meta_forward=fp8_meta_forward,
                fp8_meta_index=fp8_meta_index,
                fp8_dtype=fp8_dtype,
                fp8_scale_inv=scale_inv,
                dtype=tensor.dtype,
1060
                data_transpose=data_transpose,
1061
            )
1062
1063

            # Update cache
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
            if cache_name is not None:
                self._fp8_workspaces[cache_name] = out
            update_workspace = True
            skip_update_flag = None

        # Update workspace if needed
        if skip_update_flag is not None:
            update_workspace = True
        if update_workspace:
            if tensor is None:
1074
                raise ValueError("tensor kwarg must be provided to update FP8 workspace")
1075
1076
1077
1078
1079
1080
1081
1082
1083
            if is_in_onnx_export_mode():
                # ONNX export does not support fused cast-transpose
                # kernel and requires that FP8 scales can be
                # represented with constant ops.
                transpose_cache = out._transpose
                out._transpose = None
                out.quantize_(tensor)
                out._scale_inv.fill_(out._scale_inv.item())
                out._transpose = transpose_cache
1084
            else:
1085
                out.quantize_(tensor, noop_flag=skip_update_flag)
1086
1087

        return out
1088

1089
1090
1091
    def _load_from_state_dict(
        self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
    ):
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
        """
        This function loads tensors and extra state including fp8 metadata.
        This metadata is essential for copying fp8 tensors, as the copy_ function
        uses the scale_inv parameter from fp8_meta to set the correct scaling factor
        for the new tensor.
        Hence, this extra state must be loaded before the tensor copying process,
        not after, as is typically done in _load_from_state_dict.
        Tensors are copied into fp8 tensors only when self.primary_weights_in_fp8=True,
        otherwise, this behavior is not required.
        """
        if self.primary_weights_in_fp8:
            extra_state_key = prefix + torch.nn.modules.module._EXTRA_STATE_KEY_SUFFIX
            if extra_state_key in state_dict:
                self.set_extra_state(state_dict[extra_state_key])
1106
1107
1108
        super()._load_from_state_dict(
            state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
        )