base.py 46.6 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.

"""Base modules and utilities for TransformerEngine PyTorch API"""
6
import io
7
8
9
10
import os
import pickle
import warnings
from abc import ABC, abstractmethod
11
from typing import Any, Dict, Generator, List, Optional, Set, Tuple, Union
12
from contextlib import contextmanager
13
from types import MethodType
14
15
16
17

import torch
import torch.nn.functional as F

18
import transformer_engine_torch as tex
19
20
from transformer_engine.common.recipe import Recipe

21
from ._common import _ParameterInitMeta
22
from ..fp8 import (
23
24
    MXFP8BlockScalingRecipeState,
    DelayedScalingRecipeState,
25
    Float8CurrentScalingRecipeState,
26
    FP8GlobalStateManager,
27
    RecipeState,
28
29
30
31
32
)
from ..distributed import (
    gather_along_first_dim,
    is_fp8_activation_recompute_enabled,
    in_fp8_activation_recompute_phase,
33
    _fsdp_gather_tensors,
34
35
)
from ..constants import dist_group_type
36
37
38
from ..tensor import QuantizedTensor, Quantizer
from ..tensor._internal.float8_tensor_base import Float8TensorBase
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
yuguo's avatar
yuguo committed
39
from torch.utils.cpp_extension import IS_HIP_EXTENSION
40

41
42
__all__ = ["initialize_ub", "destroy_ub"]

43
44
45
_2X_ACC_FPROP = False
_2X_ACC_DGRAD = True
_2X_ACC_WGRAD = True
46
_multi_stream_cublas_workspace = []
yuguo's avatar
yuguo committed
47
_multi_stream_cublas_batchgemm_workspace = []
48
49
_cublas_workspace = None
_ub_communicators = None
yuguo's avatar
yuguo committed
50
_NUM_MAX_UB_STREAMS = 2 if IS_HIP_EXTENSION else 3
51
_MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY = None, None
52
layers_atomic_ring_exchange = []
53
54
55
56


def get_cublas_workspace_size_bytes() -> None:
    """Return 32 MiB if using hopper, 4 MiB for all other architectures."""
yuguo's avatar
yuguo committed
57
58
    # Add env for control the padding for blaslt
    if IS_HIP_EXTENSION:
59
        return 134_217_728
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
    if torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 9:
        return 33_554_432
    return 4_194_304


def get_workspace() -> torch.Tensor:
    """Returns workspace for cublas."""
    global _cublas_workspace
    if _cublas_workspace is None:
        _cublas_workspace = torch.empty(
            get_cublas_workspace_size_bytes(), dtype=torch.uint8, device="cuda"
        )
    return _cublas_workspace


75
76
77
78
def get_multi_stream_cublas_workspace() -> List[torch.Tensor]:
    """Returns workspace for multi-stream cublas."""
    global _multi_stream_cublas_workspace
    if not _multi_stream_cublas_workspace:
79
        for _ in range(tex._num_cublas_streams):
80
81
82
83
84
            _multi_stream_cublas_workspace.append(
                torch.empty(get_cublas_workspace_size_bytes(), dtype=torch.uint8, device="cuda")
            )
    return _multi_stream_cublas_workspace

yuguo's avatar
yuguo committed
85
86
87
88
89
90
def get_multi_stream_cublas_batchgemm_workspace() -> List[torch.Tensor]:
    """Returns workspace for multi-stream cublas."""
    global _multi_stream_cublas_batchgemm_workspace
    if not _multi_stream_cublas_batchgemm_workspace:
        for _ in range(tex._num_cublas_batchgemm_streams):
            _multi_stream_cublas_batchgemm_workspace.append(
yuguo's avatar
yuguo committed
91
                torch.empty(128, dtype=torch.uint8, device="cuda")
yuguo's avatar
yuguo committed
92
93
94
            )
    return _multi_stream_cublas_batchgemm_workspace

yuguo's avatar
yuguo committed
95
96
97
98
if bool(int(os.getenv("NVTE_DISABLE_FC2_DGRAD_OVERLAP", "0"))):
    remove_ag_gemm_dgrad = ["fc2_dgrad"]
else:
    remove_ag_gemm_dgrad = []
99

100
101
def initialize_ub(
    shape: list,
102
    tp_size: int,
103
    use_fp8: bool = False,
104
    dtype: torch.dtype = torch.bfloat16,
105
    ub_cfgs: Optional[dict] = None,
106
    bootstrap_backend: Union[str, torch.distributed.Backend] = None,
107
) -> None:
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
    r"""
    Initialize the Userbuffers communicator for overlapping tensor-parallel communications with
    GEMM compute in te.Linear, te.LayerNormLinear and te.LayerNormMLP modules.

    Parameters
    ----------
    shape : list
            shape of the communication buffer, typically set to be the same as the global shape of
            the input tensor to a te.TransformerLayer forward pass, with the sequence and batch
            dimensions collapsed together -- i.e.: `(sequence_length * batch_size, hidden_size)`
    tp_size : int
              number of GPUs in the tensor-parallel process group
    use_fp8 : bool = False
              allocate the communication buffer for FP8 GEMM inputs/outputs
    dtype : torch.dtype = torch.bfloat16
            non-FP8 data type of the communication buffer when `use_fp8 = False`
    ub_cfgs: dict = None
             Configuration dictionary with the structure
             ```
             {
                <gemm_name> : {
                    "method": <"ring_exchange" or "pipeline">,
                    "is_reduce_scatter": bool,
                    "num_sm": int,
                    "cga_size": int,
                    "set_sm_margin": bool,
                    "num_splits": int,
                    "aggregate": bool,
                    "atomic_gemm": bool,
                    "use_ce": bool,
                    "fp8_buf": bool,
                }
             }
             ```
             for `te.TransformerLayer` GEMM layers in `["qkv_fprop", "qkv_dgrad", "qkv_wgrad",
             "proj_fprop", "proj_dgrad", "proj_wgrad", "fc1_fprop", "fc1_dgrad", "fc2_dgrad",
             "fc2_fprop", "fc2_dgrad"]`.
    bootstrap_backend : str = None
                        `torch.distributed` communication backend for the all-gather, broadcast and
                        barrier collectives during Userbuffers initialization. Not all backends are
                        valid for every cluster configuration and distributed launch method even if
                        they are available in PyTorch. When left unset, the initialization prefers
                        to use the MPI backend, falling back first on Gloo and then NCCL if MPI is
                        not available. Setting `NVTE_UB_WITH_MPI=1` when building TE overrides this
                        option and always initializes Userbuffers with direct MPI calls in C++,
                        which also requires `MPI_HOME=/path/to/mpi/root` to be set at compile time.
    """
155
    if not tex.device_supports_multicast():
156
        assert bool(int(os.getenv("UB_SKIPMC", "0"))), (
157
158
159
160
            "CUDA device, driver and/or toolkit version does not support comm+GEMM overlap with "
            + "CUDA Multicast. Launch app with UB_SKIPMC=1 to try CUDA IPC instead."
        )

161
162
163
    global _ub_communicators
    assert _ub_communicators is None, "UB communicators are already initialized."
    _ub_communicators = {}
164
165

    if tex.ubuf_built_with_mpi():
166
167
        # We're bootstrapping with direct calls to MPI in Userbuffers code so we need to force
        # an MPI_Init() here by creating a new MPI process group...
168
        assert torch.distributed.is_mpi_available()
169
170
        _ = torch.distributed.new_group(backend="mpi")
        helper = tex.CommOverlapHelper()
171
    else:
172
173
        # Bootstrapping with torch.distributed API, so check backend and construct
        # intra/inter-node process groups...
174
175
176
177
178
        assert (
            torch.distributed.is_initialized()
        ), "torch.distributed must be initialized before Userbuffers"
        if bootstrap_backend is None:
            bootstrap_backend = "nccl"
179
            if torch.distributed.is_mpi_available():
180
                bootstrap_backend = "mpi"
181
182
            elif torch.distributed.is_gloo_available():
                bootstrap_backend = "gloo"
183
        else:
184
185
186
187
188
189
190
191
192
            assert bootstrap_backend in [
                "gloo",
                "mpi",
                "nccl",
            ], "Invalid torch.distributed backend for bootstrapping Userbuffers!"
            assert torch.distributed.is_backend_available(bootstrap_backend), (
                f"PyTorch must be compiled with '{bootstrap_backend}' support in order to "
                f"bootstrap Userbuffers with '{bootstrap_backend}' collectives."
            )
193
194
195
196
197

        world_group = torch.distributed.new_group(backend=bootstrap_backend)
        world_rank = torch.distributed.get_rank(world_group)
        world_size = torch.distributed.get_world_size(world_group)

198
199
        num_domains = world_size // tp_size
        mydomain_idx = world_rank // tp_size
200
        if num_domains > 1:
201
202
203
204
            ranks_per_domain_list = [
                [i * tp_size + t for t in range(tp_size)] for i in range(num_domains)
            ]
            tp_domain_group, _ = torch.distributed.new_subgroups_by_enumeration(
205
206
                ranks_per_domain_list, backend=bootstrap_backend
            )
207
208
            local_rank = torch.distributed.get_rank(tp_domain_group)
            tp_domain_ranks = torch.distributed.get_process_group_ranks(tp_domain_group)
209

210
            helper = tex.CommOverlapHelper(world_group, tp_domain_group)
211
        else:
212
213
            # TP model on single NVLink domain, no replication, no data-parallelism
            mydomain_idx = 0
214
            local_rank = world_rank
215
            tp_domain_ranks = list(range(world_size))
216
217

            helper = tex.CommOverlapHelper(world_group)
218

219
        if world_rank == 0:
220
            print(f"!!! [UB] Number of TP domains: {num_domains}\n", end="", flush=True)
221
222
        if local_rank == 0:
            print(
223
                f"!!! [UB] Global ranks on TP domain {mydomain_idx}: {tp_domain_ranks}\n",
224
225
226
227
                end="",
                flush=True,
            )

228
229
230
231
232
    # Increase the workspace by the number of maximum concurrent streams
    global _cublas_workspace
    _cublas_workspace = get_workspace().repeat(_NUM_MAX_UB_STREAMS)

    # Default buffer precision: AllGather buffers use fp8 when using fp8 recipe
233
    layers_all_gather_overlap = [
234
235
236
237
238
239
        "qkv_fprop",
        "qkv_dgrad",
        "proj_dgrad",
        "fc1_fprop",
        "fc1_dgrad",
        "fc2_dgrad",
240
    ]
241
    layers_reduce_scatter_overlap = ["proj_fprop", "fc2_fprop", "qkv_wgrad", "fc1_wgrad"]
Jaemin Choi's avatar
Jaemin Choi committed
242
    dgrad_reduce_scatter_overlap = ["qkv_dgrad", "fc1_dgrad"]
243
    # Default overlap methods for layers
yuguo's avatar
yuguo committed
244
245
246
247
248
249
250
251
252
253
254
255
    if bool(int(os.getenv("NVTE_NO_PIPELINE_OVERLAP", "0"))):
        methods = {
            "ring_exchange": ["qkv_fprop", "fc1_fprop", "proj_dgrad", "fc2_dgrad", "proj_fprop", "fc2_fprop"],
            "pipeline": [],
            "bulk": ["qkv_dgrad", "qkv_wgrad", "fc1_dgrad", "fc1_wgrad"],
        }
    else:
        methods = {
            "ring_exchange": ["qkv_fprop", "fc1_fprop", "proj_dgrad", "fc2_dgrad"],
            "pipeline": ["proj_fprop", "fc2_fprop"],
            "bulk": ["qkv_dgrad", "qkv_wgrad", "fc1_dgrad", "fc1_wgrad"],
        }
256

257
    # AG-RS overlap pairs of layers forming a tensor-parallel block
258
259
    ag_rs_pairs = {"qkv_fprop": "proj_fprop", "fc1_fprop": "fc2_fprop"}
    rs_ag_pairs = {v: k for k, v in ag_rs_pairs.items()}
260
261
262
    global layers_atomic_ring_exchange
    layers_atomic_ring_exchange = []

263
264
265
266
267
268
    def get_method(name):
        for method, names in methods.items():
            if name in names:
                return method
        raise KeyError(f"Given layer name {name} does not exist.")

269
    def get_default_config(name):
270
        global _MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY
271
272
        method = get_method(name)
        is_reduce_scatter = name in layers_reduce_scatter_overlap
273
274
        if _MIN_STREAM_PRIORITY is None or _MAX_STREAM_PRIORITY is None:
            _MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY = tex.get_stream_priority_range()
275
276
277
        default_cfg = {
            "method": method,
            "is_reduce_scatter": is_reduce_scatter,
yuguo's avatar
yuguo committed
278
            "num_sm": 1 if method == "ring_exchange" else 8,
279
            "cga_size": 1 if method == "ring_exchange" else 2,
280
281
            "set_sm_margin": not method == "ring_exchange",
            "num_splits": tp_size if method == "ring_exchange" else 4,
282
283
284
285
            "aggregate": False,
            "atomic_gemm": False,
            "use_ce": True,
            "fp8_buf": name in layers_all_gather_overlap,
286
287
288
            "comm_priority": _MAX_STREAM_PRIORITY,
            "gemm_priority": _MIN_STREAM_PRIORITY,
            "pipeline_rs_overlap_first_gemm": False,
289
290
291
        }
        return default_cfg

292
293
294
    def add_ub(
        name: str,
        method: str,
295
        is_reduce_scatter: bool,
296
297
        num_sm: int = 16,
        cga_size: int = 2,
298
        set_sm_margin: bool = False,
299
        num_splits: int = 0,
300
301
        aggregate: bool = False,
        atomic_gemm: bool = False,
302
        use_ce: bool = True,
303
        fp8_buf: bool = False,
304
305
306
        comm_priority: int = 0,
        gemm_priority: int = 0,
        pipeline_rs_overlap_first_gemm: bool = False,
307
    ) -> None:
308
309
310
311
312
        if atomic_gemm:
            warnings.warn(
                "Atomic GEMM uses a beta API from cublas and is not tested for all use cases."
            )
            assert use_fp8, "Atomic GEMM overlap supported only for FP8 GEMM."
313
            if method == "bulk":
314
                warnings.warn(
315
                    f"At {name}, atoimic GEMM not is supported for a bulk overlap."
316
317
318
                    "Defaulting to `atomic_gemm=False`."
                )
                atomic_gemm = 0
319
        if not is_reduce_scatter and method == "pipeline":
320
            raise ValueError(
321
                f"At {name}, `pipeline` overlap method is not supported for AllGather."
322
            )
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
        # Check if both AG and RS overlaps use `atomic GEMM`` + `p2p ring-exchange`.
        # Using atomic GEMM + p2p ring-exchange in only one of the pair breaks functionality.
        global layers_atomic_ring_exchange
        if atomic_gemm and method == "ring_exchange" and name in ag_rs_pairs:
            layers_atomic_ring_exchange += [name, ag_rs_pairs[name]]
        if name in rs_ag_pairs:
            assert_message = (
                f"At {name}, atomic AG-GEMM overlap with `ring_exchange` shuffles GEMM chunk "
                "outputs, and  RS-GEMM overlap un-suffle them. When one of the GEMM-AG and "
                "GEMM-RS overlaps forming a TP block (e.g., qkv_fprop and proj_fprop) uses "
                "`atomic gemm` and `ring_exhcnage`, its pair must use the same overlap config "
                "for functionality."
            )
            if name in layers_atomic_ring_exchange:
                assert atomic_gemm and method == "ring_exchange", assert_message
            else:
                if atomic_gemm and method == "ring_exchange":
                    assert rs_ag_pairs[name] in layers_atomic_ring_exchange, assert_message

342
        buffer_dtype = torch.uint8 if (use_fp8 and fp8_buf) else dtype
343
        if method == "ring_exchange":
344
345
346
347
            ub_obj = tex.CommOverlapP2P(
                shape,  # Communication buffer shape
                buffer_dtype,  # Communication buffer data type
                helper,  # Helper for torch.distributed callbacks during bootstrapping
348
                tp_size,  # Tensor-parallel group size (may be different than local_size)
349
350
351
352
353
354
355
356
                tex.CommOverlapType.RS if is_reduce_scatter else tex.CommOverlapType.AG,
                num_max_streams=_NUM_MAX_UB_STREAMS,
                comm_cga_size=cga_size,
                num_comm_sm=num_sm,
                set_sm_margin=set_sm_margin,
                atomic_gemm=atomic_gemm,
                use_ce=use_ce,
                aggregate=aggregate,
357
358
                gemm_priority=gemm_priority,
                comm_priority=comm_priority,
359
            )
360
        else:
361
362
363
364
            ub_obj = tex.CommOverlap(
                shape,  # Communication buffer shape
                buffer_dtype,  # Communication buffer data type
                helper,  # Helper for torch.distributed callbacks during bootstrapping
365
                tp_size,  # Tensor-parallel group size (may be different than local_size)
366
                num_splits=num_splits,
yuguo's avatar
yuguo committed
367
                num_max_streams=_NUM_MAX_UB_STREAMS,
368
369
370
371
                comm_cga_size=cga_size,
                num_comm_sm=num_sm,
                set_sm_margin=set_sm_margin,
                atomic_gemm=atomic_gemm,
372
373
374
                gemm_priority=gemm_priority,
                comm_priority=comm_priority,
                rs_overlap_first_gemm=pipeline_rs_overlap_first_gemm,
375
            )
376
377
        _ub_communicators[name] = ub_obj

Jaemin Choi's avatar
Jaemin Choi committed
378
379
    if ub_cfgs is not None:
        for name in dgrad_reduce_scatter_overlap:
380
381
            if name in ub_cfgs and "method" in ub_cfgs[name] and ub_cfgs[name]["method"] != "bulk":
                wgrad_name = name.replace("dgrad", "wgrad")
Jaemin Choi's avatar
Jaemin Choi committed
382
383
                assert wgrad_name not in ub_cfgs
                layers_reduce_scatter_overlap.remove(wgrad_name)
384
                layers_all_gather_overlap.remove(name)
Jaemin Choi's avatar
Jaemin Choi committed
385
                layers_reduce_scatter_overlap.append(name)
386
387
388
                methods["bulk"].remove(name)
                new_method = ub_cfgs[name]["method"]
                methods[new_method].append(name)
Jaemin Choi's avatar
Jaemin Choi committed
389

390
    for name in methods["ring_exchange"] + methods["pipeline"] + methods["bulk"]:
yuguo's avatar
yuguo committed
391
392
        if name in remove_ag_gemm_dgrad:
            continue
393
        ub_cfg = get_default_config(name)
394
        if ub_cfgs is not None and name in ub_cfgs:
395
            fp8_buf = (name in layers_all_gather_overlap) or (
396
                ub_cfgs[name].get("fp8_buf", False) and name in methods["pipeline"]
397
            )
398
399
400
            ub_cfg.update(ub_cfgs[name])
            ub_cfg["fp8_buf"] = fp8_buf
        add_ub(name, **ub_cfg)
401
402
403
404
405


def get_ub(name: str):
    """Get userbuffer communicator corresponding to give key."""
    assert _ub_communicators is not None, "UB manager is not initialized."
yuguo's avatar
yuguo committed
406
407
408
    # assert name in _ub_communicators, f"UB for {name} is not registered."
    if name in remove_ag_gemm_dgrad:
        return None
409
410
    return _ub_communicators[name]

411

412
413
414
415
416
417
418
def destroy_ub():
    """Destroy all allocated userbuffer communicators."""
    global _ub_communicators
    _ub_communicators = None
    global layers_atomic_ring_exchange
    layers_atomic_ring_exchange = []

419
420
421
422
423
424
425
426
427
428
429

class TransformerEngineBaseModule(torch.nn.Module, ABC):
    """Base TE module."""

    def __init__(self) -> None:
        super().__init__()
        assert torch.cuda.is_available(), "TransformerEngine needs CUDA."
        self.fp8_initialized = False
        self.fp8 = False
        self.fp8_calibration = False
        self.fp8_meta = {}
430
        self.fp8_meta["fp8_checkpoint"] = False
431
432
        self.fp8_meta["fp8_group"] = None
        self.fp8_meta_tensors_initialized = False
433
        self.quantizers = {"scaling_fwd": {}, "scaling_bwd": {}}
434
435
436
        self.tp_group = None
        self.tp_size = 1
        self.sequence_parallel = False
437
438
        self.param_init_meta = {}
        self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters()
439
        self.preserve_high_precision_init_val = FP8GlobalStateManager.with_high_precision_init_val()
440
441
        self.fsdp_wrapped = False
        self.fsdp_group = None
442
        self._fp8_workspaces: Dict[str, QuantizedTensor] = {}
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
        self.activation_dtype: Optional[torch.dtype] = None

    # Names of attributes that can be set quickly (see __setattr__
    # method)
    _fast_setattr_names: Set[str] = {
        "activation_dtype",
        "fp8",
        "fp8_initialized",
        "fp8_calibration",
        "fp8_parameters",
    }

    def __setattr__(self, name: str, value: Any) -> None:
        if name in TransformerEngineBaseModule._fast_setattr_names:
            # torch.nn.Module has a custom __setattr__ that handles
            # modules, parameters, and buffers. This is unnecessary
            # overhead when setting plain attrs.
            self.__dict__[name] = value
        else:
            # Default case
            super().__setattr__(name, value)
464

465
    def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> None:
466
467
468
469
        """
        Delayed scaling only.

        Increase or decrease size of amax history based on given `length`.
470
471
472
473
474
475
476
477
478
479

        .. warning::
            This changes the underlying amax memory location.
        """
        if fwd is None:
            fp8_meta_tensor_keys = ("scaling_fwd", "scaling_bwd")
        else:
            fp8_meta_tensor_keys = ("scaling_fwd" if fwd else "scaling_bwd",)

        for meta_key in fp8_meta_tensor_keys:
480
481
482
            if meta_key not in self.fp8_meta:
                # Handles non-parameter FP8 modules, e.g. DPA.
                continue
483
484
485
486
487
            curr_len = self.fp8_meta[meta_key].amax_history.shape[0]
            if length == curr_len:
                continue
            if length < curr_len:
                self.fp8_meta[meta_key].amax_history = (
488
489
                    self.fp8_meta[meta_key].amax_history[:length].clone()
                )
490
491
492
493
494
495
            elif length > curr_len:
                extra_rows = length - curr_len
                self.fp8_meta[meta_key].amax_history = F.pad(
                    self.fp8_meta[meta_key].amax_history, pad=(0, 0, 0, extra_rows)
                )

496
497
498
            # Update quantizers with new amax pointers.
            self.quantizers[meta_key] = self.fp8_meta[meta_key].make_quantizers()

499
500
            # Update the global buffers with new amax and history pointers.
            if FP8GlobalStateManager.get_buffer_info() in self.fp8_meta:
501
502
503
                fwd_pos, fwd_key, bwd_pos, bwd_key = self.fp8_meta[
                    FP8GlobalStateManager.get_buffer_info()
                ]
504
505
506
507
508
                for pos, buffer_key in zip((fwd_pos, bwd_pos), (fwd_key, bwd_key)):
                    if buffer_key in FP8GlobalStateManager.global_amax_buffer:
                        assert (
                            buffer_key in FP8GlobalStateManager.global_amax_history_buffer
                        ), "TE internal error during amax history change."
509
510
511
                        FP8GlobalStateManager.global_amax_buffer[buffer_key][pos] = self.fp8_meta[
                            meta_key
                        ].amax_history[0]
512
                        FP8GlobalStateManager.global_amax_history_buffer[buffer_key][pos] = (
513
514
                            self.fp8_meta[meta_key].amax_history
                        )
515

516
    def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None:
517
518
519
        """Init scales and amaxes for fwd | bwd."""
        fp8_meta_tensor_key = "scaling_fwd" if fwd else "scaling_bwd"

520
        # Return early if recipe state matches recipe
521
        if self.fp8_meta_tensors_initialized:
522
523
524
525
526
527
            recipe_state = self.fp8_meta[fp8_meta_tensor_key]
            if recipe.delayed() and isinstance(recipe_state, DelayedScalingRecipeState):
                self.adjust_amax_history_length(recipe.amax_history_len, fwd=fwd)
                return
            if recipe.mxfp8() and isinstance(recipe_state, MXFP8BlockScalingRecipeState):
                return
528
529
530
531
            if recipe.float8_current_scaling() and isinstance(
                recipe_state, Float8CurrentScalingRecipeState
            ):
                return
532
533
534

        # Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and
        # 2 (grad_output and grad_input) for bwd
535
        num_fp8_tensors = self.fp8_meta["num_gemms"] * 3 if fwd else self.fp8_meta["num_gemms"] * 2
536

537
538
539
540
541
        # Initialize recipe state and quantizers
        recipe_state = RecipeState.create(
            recipe,
            mode=("forward" if fwd else "backward"),
            num_quantizers=num_fp8_tensors,
542
543
        )

544
545
546
547
        self.fp8_meta[fp8_meta_tensor_key] = recipe_state
        self.quantizers[fp8_meta_tensor_key] = recipe_state.make_quantizers()

    def init_fp8_meta_tensors(self, recipe: Recipe) -> None:
548
        """Init scales and amaxes."""
549
550
551
        self.set_meta_tensor(True, recipe)
        self.set_meta_tensor(False, recipe)

552
553
        self.fp8_meta_tensors_initialized = True

554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
    def get_fp8_meta_tensors(self) -> None:
        """Get scales and amaxes."""
        fwd_key, bwd_key = "scaling_fwd", "scaling_bwd"
        if fwd_key not in self.fp8_meta or bwd_key not in self.fp8_meta:
            return None

        fp8_meta_tensors = {fwd_key: [], bwd_key: []}
        with torch.no_grad():
            for key in (fwd_key, bwd_key):
                fp8_meta_tensors[key].append(self.fp8_meta[key].scale.clone())
                fp8_meta_tensors[key].append(self.fp8_meta[key].amax_history.clone())
        return fp8_meta_tensors

    def reset_fp8_meta_tensors(self, fp8_meta_tensors=None) -> None:
        """Reset scales and amaxes."""
569

570
571
572
573
574
        def reset(key):
            if key in self.fp8_meta:
                if fp8_meta_tensors is None:
                    self.fp8_meta[key].scale.copy_(torch.ones_like(self.fp8_meta[key].scale))
                    self.fp8_meta[key].amax_history.copy_(
575
576
                        torch.zeros_like(self.fp8_meta[key].amax_history)
                    )
577
578
579
                else:
                    assert key in fp8_meta_tensors, "Cannot reset fp8 tensors."
                    self.fp8_meta[key].scale.copy_(fp8_meta_tensors[key][0])
580
                    self.fp8_meta[key].amax_history.copy_(fp8_meta_tensors[key][1])
581

582
583
584
585
        with torch.no_grad():
            reset("scaling_fwd")
            reset("scaling_bwd")

586
587
    def get_extra_state(self) -> torch.Tensor:
        """Save before checkpointing."""
588

589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
        # This implementation is working around a few issues:
        #
        # (1) PyTorch's "extra state" infrastructure might be able to
        #     support any picklable type, but they make no guarantees.
        #     We have experienced problems (e.g. in ONNX export) with
        #     non-tensor extra state.
        # (2) PyTorch's checkpointing infrastructure does not remap
        #     devices for "extra state" like it does for "state dict".
        #     Thus, we want to avoid putting extra state on the GPU
        #     since it may be loaded on the wrong device.
        # (3) The extra state consists of many small tensors. If we
        #     want to copy them all to CPU, then we need to avoid the
        #     overhead of many GPU-CPU memory transfers.
        #
        # See: https://github.com/NVIDIA/TransformerEngine/pull/351
        # See: https://github.com/NVIDIA/TransformerEngine/pull/363

        def to_cpu(src: torch.Tensor) -> torch.Tensor:
            """Helper function to make CPU copy of tensor

            Memory transfer is asynchronous w.r.t. host, so GPU should
            be synchronized before using result.

            """
            dst = torch.empty_like(src, device="cpu")
            dst.copy_(src, non_blocking=True)
            return dst

        # Store FP8 state if needed
        state = None
619
        fp8_checkpoint = self.fp8_meta["fp8_checkpoint"] or self.fp8 or self.fp8_calibration
620
        if fp8_checkpoint:
621
622

            # Copy tensors to CPU and store
623
            state = {}
624
625
626
627
628
629
            state["recipe"] = self.fp8_meta["recipe"]
            if state["recipe"].delayed():
                state["scale_fwd"] = to_cpu(self.fp8_meta["scaling_fwd"].scale)
                state["amax_history_fwd"] = to_cpu(self.fp8_meta["scaling_fwd"].amax_history)
                state["scale_bwd"] = to_cpu(self.fp8_meta["scaling_bwd"].scale)
                state["amax_history_bwd"] = to_cpu(self.fp8_meta["scaling_bwd"].amax_history)
630
631

            # Store other pickelable values
632
633
            extra = {}
            for k, v in self.fp8_meta.items():
634
635
636
                if k != "buffer_index_and_autocast_key" and isinstance(
                    v, (bool, int, float, str, tuple, list)
                ):
637
638
639
                    extra[k] = v
            state["extra_fp8_variables"] = extra

640
641
642
643
        # Serialize state into byte tensor
        torch.cuda.synchronize()
        state_serialized = bytearray(pickle.dumps(state))
        state_serialized = torch.frombuffer(state_serialized, dtype=torch.uint8)
644
        return state_serialized
645
646
647
648
649
650

    def set_extra_state(self, state: torch.Tensor) -> None:
        """Load previous state."""
        if state is None:
            return

651
        # Load state
652
        if isinstance(state, torch.Tensor):
653
            # Default format: byte tensor with pickled data
654
            state = pickle.loads(state.detach().cpu().numpy().tobytes())
655
        elif isinstance(state, io.BytesIO):
656
            # Deprecated format with io.BytesIO
657
            state.seek(0)
658
            state = torch.load(state, map_location="cuda")
659
660
        else:
            raise RuntimeError("Unsupported checkpoint format.")
661
662
663

        if state is None:
            return
664

665
        # Load extra items
666
        self.fp8_meta.update(state["extra_fp8_variables"])
667
        self.fp8_meta["recipe"] = state["recipe"]
668
669
670
        if "global_fp8_buffer_pos_fwd_recompute" in self.fp8_meta:
            del self.fp8_meta["global_fp8_buffer_pos_fwd_recompute"]

671
        # Initialize before loading
672
        self.init_fp8_meta_tensors(self.fp8_meta["recipe"])
673
674
675
676
677
678
679
680
681
682
683

        def copy_tensor(src: torch.Tensor, dst: torch.Tensor) -> None:
            """Helper function to copy tensor from CPU

            Memory transfer is asynchronous w.r.t. host, so GPU should
            be synchronized before using result.

            """
            dst.copy_(src, non_blocking=True)

        # Load tensors
684
685
686
687
688
        if self.fp8_meta["recipe"].delayed():
            copy_tensor(state["scale_fwd"], self.fp8_meta["scaling_fwd"].scale)
            copy_tensor(state["amax_history_fwd"], self.fp8_meta["scaling_fwd"].amax_history)
            copy_tensor(state["scale_bwd"], self.fp8_meta["scaling_bwd"].scale)
            copy_tensor(state["amax_history_bwd"], self.fp8_meta["scaling_bwd"].amax_history)
689
        torch.cuda.synchronize()
690
691
692
693
694
695
696
697
698

    def set_activation_dtype(self, inp: torch.Tensor) -> None:
        """Get activation data type for AMP."""
        # Native AMP (`torch.autocast`) gets highest priority
        if torch.is_autocast_enabled():
            self.activation_dtype = torch.get_autocast_gpu_dtype()
            return

        # All checks after this have already been performed once, thus skip
699
        if self.activation_dtype == inp.dtype:
700
701
            return

702
703
704
705
706
707
708
709
        dtype = inp.dtype
        for name, param in self.named_parameters():
            if param is not None:
                assert dtype == param.dtype, (
                    "Data types for parameters must match when outside of autocasted region. "
                    f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}"
                )
        self.activation_dtype = dtype
710
711

    def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
712
713
714
715
716
717
718
719
720
        """
        Set the tensor parallel group for the given
        module before executing the forward pass.

        Parameters
        ----------
        tp_group : ProcessGroup, default = `None`
                  tensor parallel process group.
        """
721
722
723
        self.tp_group = tp_group
        self.tp_group_initialized = True

724
725
726
    def _get_fp8_params(self) -> Union[List[torch.Tensor], None]:
        """returns the FP8 weights."""
        fp8_params = []
727
        for param in self.parameters(recurse=False):
728
            if isinstance(param, QuantizedTensor) and param.requires_grad:
729
730
731
732
733
                fp8_params.append(param)
        if len(fp8_params) == 0:
            return None
        return fp8_params

734
735
    # This routine is shared across FP8 and FP8_calibration paths so should not actually
    # assume FP8 execution.
736
    def init_fp8_metadata(self, num_gemms: int = 1) -> None:
737
        """Initialize fp8 related metadata and tensors during fprop."""
738
        self.fp8_parameters = FP8GlobalStateManager.with_fp8_parameters()
739
740
        self.fp8 = FP8GlobalStateManager.is_fp8_enabled()
        self.fp8_calibration = FP8GlobalStateManager.is_fp8_calibration()
741
        fp8_enabled = self.fp8 or self.fp8_calibration
742
        self.fp8_meta["fp8_checkpoint"] = self.fp8 or self.fp8_calibration
743

744
        if self.fp8_parameters or fp8_enabled:
745
746
747
748
            if (
                self.fp8_initialized
                and FP8GlobalStateManager.get_fp8_recipe() == self.fp8_meta["recipe"]
            ):
749
                # FP8 init has already been run and recipe is the same, don't do anything.
750
                return
751
            self.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe()
752
753
754
755
756
757
758
759
760
761
762
        else:
            # If fp8 isn't enabled, turn off and return.
            self.fp8_initialized = False
            return

        if self.fp8_parameters and not self.fp8_initialized:
            self.fp8_meta["num_gemms"] = num_gemms
            self.init_fp8_meta_tensors(self.fp8_meta["recipe"])

        if fp8_enabled:
            # Set FP8 and other FP8 metadata
763
            self.fp8_meta["num_gemms"] = num_gemms
764
            self.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group()
765
766
767
768
769
770

            # Set FP8_MAX per tensor according to recipe
            self.fp8_meta["fp8_max_fwd"] = self.fp8_meta["recipe"].fp8_format.value.max_fwd
            self.fp8_meta["fp8_max_bwd"] = self.fp8_meta["recipe"].fp8_format.value.max_bwd

            # Allocate scales and amaxes
771
            self.init_fp8_meta_tensors(self.fp8_meta["recipe"])
772
            self.fp8_initialized = True
773
774

            self.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe()
775
776
777
778
779
780

    @contextmanager
    def prepare_forward(
        self,
        inp: torch.Tensor,
        num_gemms: int = 1,
781
        allow_non_contiguous: bool = False,
Jan Bielak's avatar
Jan Bielak committed
782
    ) -> Generator[torch.Tensor, None, None]:
783
784
785
786
787
788
789
790
        """Checks and prep for FWD.
        The context manager is needed because there isn't a way for a module to know
        if it's the last FP8 module in the forward autocast. It is useful
        to setup the forward aggregated amax reduction for every module
        just in case. The autocast exit will pick up the most recent one.
        """
        # Activation recomputation is used and this is the second forward phase.
        if self.fp8 and in_fp8_activation_recompute_phase():
791
            FP8GlobalStateManager.get_old_fp8_meta_tensors_for_recompute(self.fp8_meta)
792
793
794
795
796
797
798
        else:
            assert inp.is_cuda, "TransformerEngine needs CUDA."

            if self.tp_size > 1:
                assert self.tp_group_initialized, "TP group not initialized."

            self.set_activation_dtype(inp)
799
            self.init_fp8_metadata(num_gemms=num_gemms)
800

801
            if self.fp8 and self.sequence_parallel and self.fp8_meta["recipe"].delayed():
802
803
804
805
                assert self.fp8_meta["recipe"].reduce_amax, (
                    "Amax reduction across tensor parallel group is "
                    "necessary when using sequence parallelism with FP8."
                )
806

807
            if self.fp8 and not FP8GlobalStateManager.fp8_graph_capturing():
808
                FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(self.fp8_meta)
809
810

            # Activation recomputation is used and this is the first forward phase.
811
            if self.fp8 and self.training and is_fp8_activation_recompute_enabled():
812
                FP8GlobalStateManager.copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta)
813
814

        with torch.cuda.nvtx.range(self.__class__.__name__ + " forward"):
815
816
817
            if not allow_non_contiguous and not inp.is_contiguous():
                inp = inp.contiguous()
            yield inp
818
819

        if self.fp8 and in_fp8_activation_recompute_phase():
820
            FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta)
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840

    def set_nccl_overlap_warning_if_tp(self) -> None:
        """When using TP, the NCCL communication needs to be scheduled
        before the GEMM for there to be a guaranteed overlap. From the
        host side in TE, the comm calls are always launched first, but
        to ensure that the GEMM isn't scheduled first, the environment
        variable `CUDA_DEVICE_MAX_CONNECTIONS` needs to be set to 1 to
        force a single channel.
        """
        if self.tp_size == 1:
            return
        num_cuda_work_queues = int(os.getenv("CUDA_DEVICE_MAX_CONNECTIONS", "0"))
        if num_cuda_work_queues != 1:
            warnings.warn(
                "To guarantee overlapping TP and SP collectives with the backward"
                "GEMMs, set environment variable CUDA_DEVICE_MAX_CONNECTIONS = 1"
            )

    @staticmethod
    def grad_output_preprocess(
841
842
843
844
        ctx,
        grad_output: torch.Tensor,
        row_parallel_mode: bool,
        quantizer: Optional[Quantizer],
845
846
847
    ) -> Tuple[Union[torch.Tensor, None], ...]:
        """Utility function for backward.
        Returns tuple in order (all optional/None based on training precion/recipe):
848
849
            R1: gathered `grad_output`.
            R2: bias gradient on R1.
850
851

        """
852
853
        grad_output = grad_output.reshape((-1, grad_output.shape[-1]))
        grad_output = grad_output.contiguous()
854
855
        gather_grad_output = row_parallel_mode and ctx.sequence_parallel

856
        # Non-FP8 case: bgrad is fused with wgrad for this case.
857
858
        if not ctx.fp8:
            if gather_grad_output:
yuguo's avatar
yuguo committed
859
                if not ctx.ub_overlap_ag or ctx.ub_obj_gradout is None:
860
                    grad_output, _ = gather_along_first_dim(grad_output, ctx.tp_group)
861
                else:
862
863
864
865
866
867
868
                    ctx.ub_obj_gradout.copy_into_buffer(grad_output, quantizer, local_chunk=True)
                    grad_output = ctx.ub_obj_gradout.get_buffer(quantizer)
            return grad_output, None

        # FP8 with all-gather: unfused bgrad, fused cast + transpose
        if gather_grad_output:
            grad_bias = None
869
            if ctx.use_bias:
870
                grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0)
yuguo's avatar
yuguo committed
871
            if ctx.ub_overlap_ag and ctx.ub_obj_gradout is not None:
872
873
874
875
876
877
878
879
880
                # Quantize the gradient if needed
                if not isinstance(
                    grad_output, (QuantizedTensor, Float8TensorBase, MXFP8TensorBase)
                ):
                    grad_output = quantizer(grad_output)

                # Copy into communication buffer, and replace original gradient with it
                ctx.ub_obj_gradout.copy_into_buffer(grad_output, quantizer, local_chunk=True)
                grad_output = ctx.ub_obj_gradout.get_buffer(quantizer)
881
            else:
882
883
884
885
                grad_output, _ = gather_along_first_dim(
                    grad_output,
                    ctx.tp_group,
                    quantizer=quantizer,
886
                )
887
            return grad_output, grad_bias
888

889
890
        # FP8 without all-gather: fused bgrad + cast + transpose
        grad_bias = None
891
        if ctx.use_bias:
892
893
            if isinstance(grad_output, (QuantizedTensor, Float8TensorBase, MXFP8TensorBase)):
                grad_bias = grad_output.dequantize().view(-1, grad_output.shape[-1]).sum(dim=0)
894
            else:
895
896
897
898
                grad_bias, grad_output = tex.bgrad_quantize(grad_output, quantizer)
        if not isinstance(grad_output, (QuantizedTensor, Float8TensorBase, MXFP8TensorBase)):
            grad_output = quantizer(grad_output)
        return grad_output, grad_bias
899

900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
    def register_parameter(self, name, param, **kwargs):
        """
        Thin wrapper around PyTorch parameter registration to stash additional parameter
        metedata used in deferred initialization.
        """
        super().register_parameter(name, param)
        self.param_init_meta[name] = _ParameterInitMeta(**kwargs)

    def reset_parameters(self, defer_init: Optional[bool] = False) -> None:
        """
        Reset all module parameters to initial values. Unless deferred initialization
        is specified, all parameters on a 'meta' device are also materialized on a real cuda
        device before the values are reset to initial.
        """
        if defer_init:
            return

        for name, param in self.named_parameters(recurse=False):
            # Ensure parameter is on a real device
919
920
            if param.device == torch.device("meta"):
                param = torch.empty_like(param, device="cuda")
921
922
923
924
925
926
927

            # Initialize the parameter values on device
            init_fn = self.param_init_meta[name].init_fn
            get_rng_state_tracker = self.param_init_meta[name].get_rng_state_tracker
            if get_rng_state_tracker is None:
                init_fn(param)
            else:
928
929
930
931
932
933
                if hasattr(self, "rng_tracker_name") and self.rng_tracker_name:
                    with get_rng_state_tracker().fork(self.rng_tracker_name):
                        init_fn(param)
                else:
                    with get_rng_state_tracker().fork():
                        init_fn(param)
934

935
            # If primary weights are in fp8, wrap the parameter as FP8Tensor
936
            fp8_meta_index = self.param_init_meta[name].fp8_meta_index
937
            high_precision_init_val = None
938
            if self.primary_weights_in_fp8 and fp8_meta_index is not None:
939
940
941
                if self.preserve_high_precision_init_val:
                    high_precision_init_val = param.detach().cpu()

942
943
944
945
946
947
                quantizer = self.quantizers["scaling_fwd"][fp8_meta_index]
                assert (
                    quantizer is not None
                )  # to use primary fp8 weight one needs to use FP8 autocast with specific recipe.
                quantizer.internal = False
                param = quantizer(param)
948
949
950
951
952

            # Redo parameter wrap in case we broke it above
            # NOTE: Currently this can only be broken when primary weights are in Fp8 but
            #       re-applying the nn.Parameter() wrap is a no-op when the input is already
            #       a parameter so we always re-apply it just for extra safety.
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
            param = torch.nn.Parameter(param)
            if high_precision_init_val is not None:

                # - Master weights are initialized from model weights, if we use fp8 primary
                #   weights to initialize master weights, the numerical values of master weights
                #   are not consistent with the numerical values when we initialize them from
                #   bf16/fp16 weights.
                # - So we add a `_high_precision_init_val` attribute to each model weight to store
                #   the original bf16/fp16 weight on cpu before casting it to fp8. And users can
                #   use `get_high_precision_init_val` to get this cpu tensor.
                # - This cpu tensor is not needed once the master weight is initialized, so users
                #   should call `clear_high_precision_init_val` to remove it after master weight
                #   is initialized.

                def get(self):
                    if hasattr(self, "_high_precision_init_val"):
                        return self._high_precision_init_val
                    return None

                def clear(self):
                    if hasattr(self, "_high_precision_init_val"):
                        del self._high_precision_init_val

                param._high_precision_init_val = high_precision_init_val
                param.get_high_precision_init_val = MethodType(get, param)
                param.clear_high_precision_init_val = MethodType(clear, param)

            setattr(self, name, param)
981

982
983
984
    @abstractmethod
    def forward(self):
        """Needs override."""
985

986
    def get_weight_workspace(
987
        self,
988
989
        *,
        tensor: Optional[torch.Tensor] = None,
990
        quantizer: Optional[Quantizer] = None,
991
992
993
        cache_name: Optional[str] = None,
        update_workspace: bool = True,
        skip_update_flag: Optional[torch.Tensor] = None,
994
995
        fsdp_group: Optional[dist_group_type] = None,
    ) -> QuantizedTensor:
996
997
998
999
1000
1001
1002
1003
1004
        """Get FP8 workspace buffer and maybe update its values

        The workspace buffer may be cached for future function calls.

        Parameters
        ----------
        tensor : torch.Tensor, optional
            Values to copy into workspace. Required if the workspace
            is being constructed or updated.
1005
1006
1007
        quantizer: Quantizer, optional
            Quantizer used to cast the weights. Required if the
            workspace is being constructed or updated.
1008
1009
1010
1011
1012
1013
1014
        cache_name: str, optional
            Key for caching.
        update_workspace: bool, default = `True`
            Update workspace with values from `tensor`.
        skip_update_flag: torch.Tensor, optional
            GPU flag to skip updating the workspace. Take precedence
            over `update_workspace` if provided.
1015
1016
        fsdp_group: bool, default = None
            FSDP process group that the weights are distributed over.
1017
1018
        """

1019
1020
1021
1022
1023
1024
1025
1026
1027
        # FP8 primary weights
        if isinstance(tensor, QuantizedTensor):
            if update_workspace and quantizer is not None:
                tensor.update_usage(
                    rowwise_usage=quantizer.rowwise_usage,
                    columnwise_usage=quantizer.columnwise_usage,
                )
            return tensor

1028
        # Try getting workspace from cache
1029
1030
1031
        out = None
        if cache_name is not None:
            out = self._fp8_workspaces.get(cache_name, None)
1032
1033
1034
1035
1036
1037
1038
            if quantizer is not None and isinstance(out, MXFP8TensorBase):
                if quantizer.rowwise_usage and out._rowwise_data is None:
                    out = None
                    del self._fp8_workspaces[cache_name]
                elif quantizer.columnwise_usage and out._columnwise_data is None:
                    out = None
                    del self._fp8_workspaces[cache_name]
1039

1040
1041
1042
1043
1044
        # Gather cached Fp8 workspace if it's distributed
        # NOTE: FSDP sharding is supported only for Fp8 buffers and will not work
        #       for models initialized with Fp8 primary weights.
        if (
            out is not None
1045
            and tensor is not None
1046
            and fsdp_group is not None
1047
            and out.data.shape != tensor.data.shape
1048
1049
1050
1051
        ):
            _fsdp_gather_tensors(fsdp_group, [tensor.data.shape], out)

        # Construct workspace if needed
1052
        if out is None:
1053
            if tensor is None or quantizer is None:
1054
                raise ValueError(
1055
                    "tensor and quantizer kwargs must be provided to construct FP8 workspace"
1056
                )
1057
            out = quantizer(tensor)
1058
1059

            # Update cache
1060
1061
            if cache_name is not None:
                self._fp8_workspaces[cache_name] = out
1062
            return out
1063
1064
1065
1066
1067
1068

        # Update workspace if needed
        if skip_update_flag is not None:
            update_workspace = True
        if update_workspace:
            if tensor is None:
1069
                raise ValueError("tensor kwarg must be provided to update FP8 workspace")
1070
            if hasattr(out, "quantize_"):
1071
                out.quantize_(tensor, noop_flag=skip_update_flag)
1072
1073
            else:
                tex.quantize(tensor, quantizer, out, skip_update_flag)
1074
1075

        return out
1076

1077
1078
1079
    def _load_from_state_dict(
        self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
    ):
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
        """
        This function loads tensors and extra state including fp8 metadata.
        This metadata is essential for copying fp8 tensors, as the copy_ function
        uses the scale_inv parameter from fp8_meta to set the correct scaling factor
        for the new tensor.
        Hence, this extra state must be loaded before the tensor copying process,
        not after, as is typically done in _load_from_state_dict.
        Tensors are copied into fp8 tensors only when self.primary_weights_in_fp8=True,
        otherwise, this behavior is not required.
        """
        if self.primary_weights_in_fp8:
            extra_state_key = prefix + torch.nn.modules.module._EXTRA_STATE_KEY_SUFFIX
            if extra_state_key in state_dict:
                self.set_extra_state(state_dict[extra_state_key])
1094
1095
1096
        super()._load_from_state_dict(
            state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
        )