base.py 37.9 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.

"""Base modules and utilities for TransformerEngine PyTorch API"""
6
import io
7
8
9
10
import os
import pickle
import warnings
from abc import ABC, abstractmethod
11
from typing import Dict, Generator, List, Optional, Tuple, Union
12
13
14
15
16
from contextlib import contextmanager

import torch
import torch.nn.functional as F

17
import transformer_engine_torch as tex
18
from ._common import _ParameterInitMeta
19
from ..export import is_in_onnx_export_mode
20
21
22
from ..fp8 import (
    get_default_fp8_recipe,
    get_fp8_te_dtype,
23
    FP8GlobalStateManager,
24
25
26
27
28
)
from ..distributed import (
    gather_along_first_dim,
    is_fp8_activation_recompute_enabled,
    in_fp8_activation_recompute_phase,
29
    _fsdp_gather_tensors,
30
31
32
33
34
35
36
)
from ..cpp_extensions import (
    fp8_cast_transpose_fused,
    fp8_cast_transpose_bgrad_fused,
    cast_to_fp8,
)
from ..constants import dist_group_type
37
from ..float8_tensor import Float8Tensor
38
39
40
41
42
43
44

_2X_ACC_FPROP = False
_2X_ACC_DGRAD = True
_2X_ACC_WGRAD = True
_cublas_workspace = None
_ub_communicators = None
_NUM_MAX_UB_STREAMS = 3
45
layers_atomic_ring_exchange = []
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68


def get_cublas_workspace_size_bytes() -> None:
    """Return 32 MiB if using hopper, 4 MiB for all other architectures."""
    if torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 9:
        return 33_554_432
    return 4_194_304


def get_workspace() -> torch.Tensor:
    """Returns workspace for cublas."""
    global _cublas_workspace
    if _cublas_workspace is None:
        _cublas_workspace = torch.empty(
            get_cublas_workspace_size_bytes(), dtype=torch.uint8, device="cuda"
        )
    return _cublas_workspace


def initialize_ub(
    shape: list,
    tp_size: int,
    use_fp8: bool = False,
69
    dtype: torch.dtype = torch.bfloat16,
70
71
72
73
74
75
76
77
78
79
80
81
82
    ub_cfgs: Optional[dict] = None
) -> None:
    """Initialize communicators for TP comm overlap using userbuffers."""
    global _ub_communicators
    assert _ub_communicators is None, "UB communicators are already initialized."
    _ub_communicators = {}
    rank_id = torch.distributed.get_rank()

    # Increase the workspace by the number of maximum concurrent streams
    global _cublas_workspace
    _cublas_workspace = get_workspace().repeat(_NUM_MAX_UB_STREAMS)

    # Default buffer precision: AllGather buffers use fp8 when using fp8 recipe
83
    layers_all_gather_overlap = [
84
85
        "qkv_fprop", "qkv_dgrad", "proj_dgrad", "fc1_fprop", "fc1_dgrad", "fc2_dgrad"
    ]
86
    layers_reduce_scatter_overlap = ["proj_fprop", "fc2_fprop", "qkv_wgrad", "fc1_wgrad"]
Jaemin Choi's avatar
Jaemin Choi committed
87
    dgrad_reduce_scatter_overlap = ["qkv_dgrad", "fc1_dgrad"]
88
89
90
91
92
93
94
    # Default overlap methods for layers
    methods = {
        "ring_exchange":["qkv_fprop", "fc1_fprop", "proj_dgrad", "fc2_dgrad"],
        "pipeline":["proj_fprop", "fc2_fprop"],
        "bulk":["qkv_dgrad", "qkv_wgrad", "fc1_dgrad", "fc1_wgrad"],
    }

95
96
97
98
99
100
    # AG-RS overlap pairs of layers forming a tensor-parallel block
    ag_rs_pairs = {"qkv_fprop":"proj_fprop", "fc1_fprop":"fc2_fprop"}
    rs_ag_pairs = {v : k for k, v in ag_rs_pairs.items()}
    global layers_atomic_ring_exchange
    layers_atomic_ring_exchange = []

101
102
103
104
105
106
107
108
109
    def get_method(name):
        for method, names in methods.items():
            if name in names:
                return method
        raise KeyError(f"Given layer name {name} does not exist.")

    def add_ub(
        name: str,
        method: str,
110
        is_reduce_scatter: int,
111
112
113
        num_sm: int = 16,
        cga_size: int = 2,
        set_sm_margin: int = 0,
114
        num_splits: int = 0,
115
        aggregate: int = 0,
116
        atomic_gemm: int = 0,
117
        fp8_buf: bool = False,
118
    ) -> None:
119
120
121
122
123
124
125
        if atomic_gemm:
            warnings.warn(
                "Atomic GEMM uses a beta API from cublas and is not tested for all use cases."
            )
            assert use_fp8, "Atomic GEMM overlap supported only for FP8 GEMM."
            if method == 'bulk':
                warnings.warn(
126
                    f"At {name}, atoimic GEMM not is supported for a bulk overlap."
127
128
129
130
131
                    "Defaulting to `atomic_gemm=False`."
                )
                atomic_gemm = 0
        if not is_reduce_scatter and method == 'pipeline':
            raise ValueError(
132
                f"At {name}, `pipeline` overlap method is not supported for AllGather."
133
            )
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
        # Check if both AG and RS overlaps use `atomic GEMM`` + `p2p ring-exchange`.
        # Using atomic GEMM + p2p ring-exchange in only one of the pair breaks functionality.
        global layers_atomic_ring_exchange
        if atomic_gemm and method == "ring_exchange" and name in ag_rs_pairs:
            layers_atomic_ring_exchange += [name, ag_rs_pairs[name]]
        if name in rs_ag_pairs:
            assert_message = (
                f"At {name}, atomic AG-GEMM overlap with `ring_exchange` shuffles GEMM chunk "
                "outputs, and  RS-GEMM overlap un-suffle them. When one of the GEMM-AG and "
                "GEMM-RS overlaps forming a TP block (e.g., qkv_fprop and proj_fprop) uses "
                "`atomic gemm` and `ring_exhcnage`, its pair must use the same overlap config "
                "for functionality."
            )
            if name in layers_atomic_ring_exchange:
                assert atomic_gemm and method == "ring_exchange", assert_message
            else:
                if atomic_gemm and method == "ring_exchange":
                    assert rs_ag_pairs[name] in layers_atomic_ring_exchange, assert_message

153
154
        sample_buffer = torch.empty(
            shape,
155
            dtype=torch.uint8 if (use_fp8 and fp8_buf) else dtype,
156
            device='cuda')
157
158
159
160
161
        if method == 'ring_exchange':
            ub_obj = tex.UbufP2PCommOverlap(
                    sample_buffer,          # Sample userbuffer
                    rank_id,                # Rank id
                    tp_size,                # TP size
162
163
164
                    num_sm,                 # Number of communication SMs
                    cga_size,               # CGA cluster size
                    set_sm_margin,          # Set SM margin
165
166
                    aggregate,              # Aggregate 2X GEMM chunks
                    _NUM_MAX_UB_STREAMS,    # Max concurrent GEMM streams
167
168
                    is_reduce_scatter,      # overlap with reduce scatter
                    atomic_gemm,            # use a single GEMM with atomic-counters
169
                    torch.Tensor(),         # empty tensor to pass to counters
170
171
172
173
174
175
176
177
178
179
180
                )
        else:
            ub_obj = tex.UbufCommOverlap(
                    sample_buffer,          # Sample userbuffer
                    rank_id,                # Rank id
                    tp_size,                # TP size
                    num_sm,                 # Number of communication SMs
                    cga_size,               # CGA cluster size
                    num_splits,             # Number of communication splits
                    set_sm_margin,          # Set SM margin
                    _NUM_MAX_UB_STREAMS,    # Max concurrent GEMM streams
181
                    atomic_gemm,            # use a single GEMM with atomic-counters
182
                    torch.Tensor(),         # empty tensor to pass to counters
183
184
185
                )
        _ub_communicators[name] = ub_obj

Jaemin Choi's avatar
Jaemin Choi committed
186
187
188
189
190
191
192
193
    if ub_cfgs is not None:
        for name in dgrad_reduce_scatter_overlap:
            if name in ub_cfgs and 'method' in ub_cfgs[name] and ub_cfgs[name]['method'] != 'bulk':
                wgrad_name = name.replace('dgrad','wgrad')
                assert wgrad_name not in ub_cfgs
                layers_reduce_scatter_overlap.remove(wgrad_name)
                layers_reduce_scatter_overlap.append(name)

194
195
196
    for name in (methods["ring_exchange"]+methods["pipeline"]+methods["bulk"]):
        if ub_cfgs is not None and name in ub_cfgs:
            ub_cfg = ub_cfgs[name]
197
198
199
            method = ub_cfg.get("method", get_method(name))
            num_sm = ub_cfg.get("num_sm", 16)
            cga_size = ub_cfg.get("cga_size", 2)
200
            num_splits = ub_cfg.get("num_splits", 4 if method == "pipeline" else 0)
201
202
203
            set_sm_margin = ub_cfg.get("set_sm_margin", 0)
            aggregate = ub_cfg.get("aggregate", 0)
            atomic_gemm = ub_cfg.get("atomic_gemm", 0)
204
            is_reduce_scatter = 1 if name in layers_reduce_scatter_overlap else 0
205
206
207
            # Support FP8 userbuffer when (1) AllGather and (2) FP8-GEMM output ReduceScatter
            fp8_buf = ((name in layers_all_gather_overlap) or
                      (ub_cfg.get("fp8_buf", False) and name in methods["pipeline"]))
208
209
210
            add_ub(
                name,
                method,
211
                is_reduce_scatter,
212
213
214
215
                num_sm,
                cga_size,
                set_sm_margin,
                num_splits,
216
217
                aggregate,
                atomic_gemm,
218
                fp8_buf,
219
220
221
            )
        else:
            method = get_method(name)
222
223
224
225
226
227
228
            add_ub(
                name,
                method=method,
                is_reduce_scatter=1 if name in layers_reduce_scatter_overlap else 0,
                num_splits=4 if method == "pipeline" else 0,
                fp8_buf=name in layers_all_gather_overlap,
            )
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248


def get_ub(name: str):
    """Get userbuffer communicator corresponding to give key."""
    global _ub_communicators
    assert _ub_communicators is not None, "UB manager is not initialized."
    assert name in _ub_communicators, f"UB for {name} is not registered."
    return _ub_communicators[name]


class TransformerEngineBaseModule(torch.nn.Module, ABC):
    """Base TE module."""

    def __init__(self) -> None:
        super().__init__()
        assert torch.cuda.is_available(), "TransformerEngine needs CUDA."
        self.fp8_initialized = False
        self.fp8 = False
        self.fp8_calibration = False
        self.fp8_meta = {}
249
        self.fp8_meta["fp8_checkpoint"] = False
250
251
252
253
254
255
        self.fp8_meta["fp8_group"] = None
        self.fp8_meta["recipe"] = get_default_fp8_recipe()
        self.fp8_meta_tensors_initialized = False
        self.tp_group = None
        self.tp_size = 1
        self.sequence_parallel = False
256
257
        self.param_init_meta = {}
        self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters()
258
259
        self.fsdp_wrapped = False
        self.fsdp_group = None
260
        self._fp8_workspaces: Dict[str, Float8Tensor] = {}
261

262
263
264
265
266
267
268
269
270
271
272
273
    def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> None:
        """Increase or decrease size of amax history based on given `length`.

        .. warning::
            This changes the underlying amax memory location.
        """
        if fwd is None:
            fp8_meta_tensor_keys = ("scaling_fwd", "scaling_bwd")
        else:
            fp8_meta_tensor_keys = ("scaling_fwd" if fwd else "scaling_bwd",)

        for meta_key in fp8_meta_tensor_keys:
274
275
276
            if meta_key not in self.fp8_meta:
                # Handles non-parameter FP8 modules, e.g. DPA.
                continue
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
            curr_len = self.fp8_meta[meta_key].amax_history.shape[0]
            if length == curr_len:
                continue
            if length < curr_len:
                self.fp8_meta[meta_key].amax_history = (
                    self.fp8_meta[meta_key].amax_history[: length].clone())
            elif length > curr_len:
                extra_rows = length - curr_len
                self.fp8_meta[meta_key].amax_history = F.pad(
                    self.fp8_meta[meta_key].amax_history, pad=(0, 0, 0, extra_rows)
                )

            # Update the global buffers with new amax and history pointers.
            if FP8GlobalStateManager.get_buffer_info() in self.fp8_meta:
                fwd_pos, fwd_key, bwd_pos, bwd_key = (
                    self.fp8_meta[FP8GlobalStateManager.get_buffer_info()])
                for pos, buffer_key in zip((fwd_pos, bwd_pos), (fwd_key, bwd_key)):
                    if buffer_key in FP8GlobalStateManager.global_amax_buffer:
                        assert (
                            buffer_key in FP8GlobalStateManager.global_amax_history_buffer
                        ), "TE internal error during amax history change."
                        FP8GlobalStateManager.global_amax_buffer[buffer_key][pos] = (
                            self.fp8_meta[meta_key].amax_history[0])
                        FP8GlobalStateManager.global_amax_history_buffer[buffer_key][pos] = (
                            self.fp8_meta[meta_key].amax_history)

303
304
305
306
307
308
    def set_meta_tensor(self, fwd: bool) -> None:
        """Init scales and amaxes for fwd | bwd."""
        fp8_meta_tensor_key = "scaling_fwd" if fwd else "scaling_bwd"

        if self.fp8_meta_tensors_initialized:
            # Handle changed amax history size.
309
            self.adjust_amax_history_length(self.fp8_meta["recipe"].amax_history_len, fwd=fwd)
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
            return

        # Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and
        # 2 (grad_output and grad_input) for bwd
        num_fp8_tensors = (
            self.fp8_meta["num_gemms"] * 3 if fwd else self.fp8_meta["num_gemms"] * 2
        )

        self.fp8_meta[fp8_meta_tensor_key] = tex.FP8TensorMeta()
        self.fp8_meta[fp8_meta_tensor_key].scale = torch.ones(
            num_fp8_tensors, dtype=torch.float32, device="cuda"
        )
        self.fp8_meta[fp8_meta_tensor_key].scale_inv = torch.ones(
            num_fp8_tensors, dtype=torch.float32, device="cuda"
        )
        self.fp8_meta[fp8_meta_tensor_key].amax_history = torch.zeros(
            self.fp8_meta["recipe"].amax_history_len,
            num_fp8_tensors,
            dtype=torch.float32,
            device="cuda",
        )

    def init_fp8_meta_tensors(self) -> None:
        """Init scales and amaxes."""
        self.set_meta_tensor(True)
        self.set_meta_tensor(False)
        self.fp8_meta_tensors_initialized = True

338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
    def get_fp8_meta_tensors(self) -> None:
        """Get scales and amaxes."""
        fwd_key, bwd_key = "scaling_fwd", "scaling_bwd"
        if fwd_key not in self.fp8_meta or bwd_key not in self.fp8_meta:
            return None

        fp8_meta_tensors = {fwd_key: [], bwd_key: []}
        with torch.no_grad():
            for key in (fwd_key, bwd_key):
                fp8_meta_tensors[key].append(self.fp8_meta[key].scale.clone())
                fp8_meta_tensors[key].append(self.fp8_meta[key].scale_inv.clone())
                fp8_meta_tensors[key].append(self.fp8_meta[key].amax_history.clone())
        return fp8_meta_tensors

    def reset_fp8_meta_tensors(self, fp8_meta_tensors=None) -> None:
        """Reset scales and amaxes."""
        def reset(key):
            if key in self.fp8_meta:
                if fp8_meta_tensors is None:
                    self.fp8_meta[key].scale.copy_(torch.ones_like(self.fp8_meta[key].scale))
                    self.fp8_meta[key].scale_inv.copy_(
                        torch.ones_like(self.fp8_meta[key].scale_inv))
                    self.fp8_meta[key].amax_history.copy_(
                        torch.zeros_like(self.fp8_meta[key].amax_history))
                else:
                    assert key in fp8_meta_tensors, "Cannot reset fp8 tensors."
                    self.fp8_meta[key].scale.copy_(fp8_meta_tensors[key][0])
                    self.fp8_meta[key].scale_inv.copy_(fp8_meta_tensors[key][1])
                    self.fp8_meta[key].amax_history.copy_(fp8_meta_tensors[key][2])
        with torch.no_grad():
            reset("scaling_fwd")
            reset("scaling_bwd")

371
372
373
    def get_extra_state(self) -> torch.Tensor:
        """Save before checkpointing."""
        state = None
374

375
        fp8_checkpoint = self.fp8_meta["fp8_checkpoint"] or self.fp8 or self.fp8_calibration
376
377

        if fp8_checkpoint:
378
            state = {}
379
380
381
382
383
384
            state["scale_fwd"] = self.fp8_meta["scaling_fwd"].scale
            state["scale_inv_fwd"] = self.fp8_meta["scaling_fwd"].scale_inv
            state["amax_history_fwd"] = self.fp8_meta["scaling_fwd"].amax_history
            state["scale_bwd"] = self.fp8_meta["scaling_bwd"].scale
            state["scale_inv_bwd"] = self.fp8_meta["scaling_bwd"].scale_inv
            state["amax_history_bwd"] = self.fp8_meta["scaling_bwd"].amax_history
385
386
387
388

            # Store other pickelable values.
            extra = {}
            for k, v in self.fp8_meta.items():
389
                if isinstance(v, (bool, int, float, str, tuple, list)):
390
391
392
                    extra[k] = v
            state["extra_fp8_variables"] = extra

393
394
395
396
397
        if is_in_onnx_export_mode():
            state_serialized = torch.frombuffer(pickle.dumps(state), dtype=torch.uint8)
        else:
            state_serialized = io.BytesIO()
            torch.save(state, state_serialized)
398

399
        return state_serialized
400
401
402
403
404
405
406
407

    def set_extra_state(self, state: torch.Tensor) -> None:
        """Load previous state."""
        if state is None:
            return

        if isinstance(state, torch.Tensor):
            state = pickle.loads(state.detach().cpu().numpy().tobytes())
408
409
410
        elif isinstance(state, io.BytesIO):
            state.seek(0)
            state = torch.load(state, map_location='cuda')
411
412
        else:
            raise RuntimeError("Unsupported checkpoint format.")
413
414
415

        if state is None:
            return
416
417
418
419
420
421
422
423
424
425
426
427
428

        # Load extra items.
        self.fp8_meta.update(state["extra_fp8_variables"])
        self.fp8_meta["recipe"].amax_history_len = state["amax_history_fwd"].shape[0]
        if "global_fp8_buffer_pos_fwd_recompute" in self.fp8_meta:
            del self.fp8_meta["global_fp8_buffer_pos_fwd_recompute"]

        # Initialize before loading.
        self.init_fp8_meta_tensors()
        self.fp8_meta["scaling_fwd"].scale.copy_(state["scale_fwd"])
        self.fp8_meta["scaling_fwd"].amax_history.copy_(state["amax_history_fwd"])
        self.fp8_meta["scaling_bwd"].scale.copy_(state["scale_bwd"])
        self.fp8_meta["scaling_bwd"].amax_history.copy_(state["amax_history_bwd"])
429
430
        self.fp8_meta["scaling_fwd"].scale_inv.copy_(state["scale_inv_fwd"])
        self.fp8_meta["scaling_bwd"].scale_inv.copy_(state["scale_inv_bwd"])
431
432
433
434
435
436
437
438
439

    def set_activation_dtype(self, inp: torch.Tensor) -> None:
        """Get activation data type for AMP."""
        # Native AMP (`torch.autocast`) gets highest priority
        if torch.is_autocast_enabled():
            self.activation_dtype = torch.get_autocast_gpu_dtype()
            return

        # All checks after this have already been performed once, thus skip
440
        if hasattr(self, "activation_dtype") and self.activation_dtype == inp.dtype:
441
442
            return

443
444
445
446
447
448
449
450
451
452
453
454
455
456
        dtype = inp.dtype
        for name, param in self.named_parameters():
            if param is not None:
                assert dtype == param.dtype, (
                    "Data types for parameters must match when outside of autocasted region. "
                    f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}"
                )
        for name, buf in self.named_buffers():
            if buf is not None:
                assert dtype == buf.dtype, (
                    "Data types for buffers must match when outside of autocasted region. "
                    f" Found input dtype: {dtype} and {name!r} dtype: {buf.dtype}"
                )
        self.activation_dtype = dtype
457
458

    def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
459
460
461
462
463
464
465
466
467
        """
        Set the tensor parallel group for the given
        module before executing the forward pass.

        Parameters
        ----------
        tp_group : ProcessGroup, default = `None`
                  tensor parallel process group.
        """
468
469
470
        self.tp_group = tp_group
        self.tp_group_initialized = True

471
472
473
    def _get_fp8_params(self) -> Union[List[torch.Tensor], None]:
        """returns the FP8 weights."""
        fp8_params = []
474
        for param in self.parameters(recurse=False):
475
476
477
478
479
480
            if isinstance(param, Float8Tensor) and param.requires_grad:
                fp8_params.append(param)
        if len(fp8_params) == 0:
            return None
        return fp8_params

481
482
    # This routine is shared across FP8 and FP8_calibration paths so should not actually
    # assume FP8 execution.
483
    def init_fp8_metadata(self, num_gemms: int = 1) -> None:
484
        """Initialize fp8 related metadata and tensors during fprop."""
485
        self.fp8_parameters = FP8GlobalStateManager.with_fp8_parameters()
486
487
        self.fp8 = FP8GlobalStateManager.is_fp8_enabled()
        self.fp8_calibration = FP8GlobalStateManager.is_fp8_calibration()
488
        self.fp8_meta["fp8_checkpoint"] = self.fp8 or self.fp8_calibration
489

490
491
492
493
        if self.fp8_parameters and not self.fp8_initialized:
            self.fp8_meta["num_gemms"] = num_gemms
            self.init_fp8_meta_tensors()

494
495
        if self.fp8 or self.fp8_calibration:
            # FP8 init has already been run and recipe is the same, don't do anything.
496
497
            if (self.fp8_initialized
                and FP8GlobalStateManager.get_fp8_recipe() == self.fp8_meta["recipe"]):
498
499
500
                return

            # Set FP8, recipe, and other FP8 metadata
501
            self.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe()
502
            self.fp8_meta["num_gemms"] = num_gemms
503
            self.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group()
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520

            # Set FP8_MAX per tensor according to recipe
            self.fp8_meta["fp8_max_fwd"] = self.fp8_meta["recipe"].fp8_format.value.max_fwd
            self.fp8_meta["fp8_max_bwd"] = self.fp8_meta["recipe"].fp8_format.value.max_bwd

            # Allocate scales and amaxes
            self.init_fp8_meta_tensors()
            self.fp8_initialized = True
        else:
            # If fp8 isn't enabled, turn off and return.
            self.fp8_initialized = False
            return

    @contextmanager
    def prepare_forward(
        self,
        inp: torch.Tensor,
521
        is_first_microbatch: Union[bool, None],  # pylint: disable=unused-argument
522
        num_gemms: int = 1,
523
        allow_non_contiguous: bool = False,
Jan Bielak's avatar
Jan Bielak committed
524
    ) -> Generator[torch.Tensor, None, None]:
525
526
527
528
529
530
531
532
        """Checks and prep for FWD.
        The context manager is needed because there isn't a way for a module to know
        if it's the last FP8 module in the forward autocast. It is useful
        to setup the forward aggregated amax reduction for every module
        just in case. The autocast exit will pick up the most recent one.
        """
        # Activation recomputation is used and this is the second forward phase.
        if self.fp8 and in_fp8_activation_recompute_phase():
533
            FP8GlobalStateManager.get_old_fp8_meta_tensors_for_recompute(self.fp8_meta)
534
535
536
537
538
539
540
        else:
            assert inp.is_cuda, "TransformerEngine needs CUDA."

            if self.tp_size > 1:
                assert self.tp_group_initialized, "TP group not initialized."

            self.set_activation_dtype(inp)
541
            self.init_fp8_metadata(num_gemms=num_gemms)
542

543
544
545
546
547
            if self.fp8 and self.sequence_parallel:
                assert self.fp8_meta["recipe"].reduce_amax, \
                "Amax reduction across tensor parallel group is " \
                "necessary when using sequence parallelism with FP8."

548
549
550
            if self.fp8 and not FP8GlobalStateManager.fp8_graph_capturing():
                FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(
                    self.fp8_meta, fp8_weights=self._get_fp8_params())
551
552
553
554
555
556
557

            # Activation recomputation is used and this is the first forward phase.
            if (
                self.fp8
                and self.training
                and is_fp8_activation_recompute_enabled()
            ):
558
                FP8GlobalStateManager.copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta)
559
560

        with torch.cuda.nvtx.range(self.__class__.__name__ + " forward"):
561
562
563
564
            if not allow_non_contiguous:
                yield inp.contiguous()
            else:
                yield inp
565
566

        if self.fp8 and in_fp8_activation_recompute_phase():
567
            FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta)
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
            return

    def set_nccl_overlap_warning_if_tp(self) -> None:
        """When using TP, the NCCL communication needs to be scheduled
        before the GEMM for there to be a guaranteed overlap. From the
        host side in TE, the comm calls are always launched first, but
        to ensure that the GEMM isn't scheduled first, the environment
        variable `CUDA_DEVICE_MAX_CONNECTIONS` needs to be set to 1 to
        force a single channel.
        """
        if self.tp_size == 1:
            return
        num_cuda_work_queues = int(os.getenv("CUDA_DEVICE_MAX_CONNECTIONS", "0"))
        if num_cuda_work_queues != 1:
            warnings.warn(
                "To guarantee overlapping TP and SP collectives with the backward"
                "GEMMs, set environment variable CUDA_DEVICE_MAX_CONNECTIONS = 1"
            )

    @staticmethod
    def grad_output_preprocess(
        ctx, grad_output: torch.Tensor, row_parallel_mode: bool
    ) -> Tuple[Union[torch.Tensor, None], ...]:
        """Utility function for backward.
        Returns tuple in order (all optional/None based on training precion/recipe):
            R1: gathered `grad_output` in higher precision.
            R2: gathered `grad_output` in FP8.
            R3: R2 transposed.
            R4: bias gradient on R1.

        """
599
600
601
602
603
        if isinstance(grad_output, Float8Tensor):
            grad_output._data = grad_output._data.contiguous()
        else:
            grad_output = grad_output.contiguous()
        grad_output_mat = grad_output.view(-1, grad_output.shape[-1])
604
605
606
607
608
        gather_grad_output = row_parallel_mode and ctx.sequence_parallel

        # No-FP8 case: bgrad is fused with wgrad for this case.
        if not ctx.fp8:
            if gather_grad_output:
609
                if not ctx.ub_overlap_ag:
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
                    grad_output_mat, _ = gather_along_first_dim(
                        grad_output_mat, ctx.tp_group
                    )
                else:
                    ctx.ub_obj_gradout.copy_input_to_ubuf(grad_output, True)
                    grad_output_mat = ctx.ub_obj_gradout.get_ubuf_output(1)
            return grad_output_mat, None, None, None

        fp8_dtype_backward = get_fp8_te_dtype(
            ctx.fp8_meta["recipe"], fprop_tensor=False
        )

        # FP8 case with non-FP8 wgrad
        if (
            gather_grad_output
            and ctx.fp8_meta["recipe"].override_linear_precision.wgrad
        ):
            assert (
628
                not ctx.ub_overlap_ag
629
            ), "override_linear_precision.wgrad not supported with UB AG overlap"
630
631
632
633
634
635
636
            grad_output_mat, _ = gather_along_first_dim(grad_output_mat, ctx.tp_group)
        # FP8 case with gather: unfused bgrad, cast, transpose for efficient gather
        elif gather_grad_output:
            if ctx.use_bias:
                grad_bias = grad_output_mat.sum(dim=0)
            else:
                grad_bias = None
637
            if ctx.ub_overlap_ag:
638
639
640
                grad_output_c = ctx.ub_obj_gradout.get_ubuf_output(0)
            else:
                grad_output_c = torch.empty_like(grad_output_mat, dtype=torch.uint8)
641
642
643
644
645
646
647
648
649
            if not isinstance(grad_output_mat, Float8Tensor):
                cast_to_fp8(
                    grad_output_mat,
                    ctx.fp8_meta["scaling_bwd"],
                    tex.FP8BwdTensors.GRAD_OUTPUT1,
                    fp8_dtype_backward,
                    out=grad_output_c,
                )
            else:
650
                grad_output_c = grad_output_mat
651
            if not ctx.ub_overlap_ag:
652
                grad_output_c, _ = gather_along_first_dim(grad_output_c, ctx.tp_group)
653
654
655
656
                if not isinstance(grad_output_c, Float8Tensor):
                    grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward)
                else:
                    grad_output_t = grad_output_c.transpose_2d()
657
658
659
660
661
662
663
664
            else:
                grad_output_c = ctx.ub_obj_gradout.get_ubuf_output(1)
                grad_output_t = None

            return grad_output_mat, grad_output_c, grad_output_t, grad_bias

        # FP8 case without gather: cast, transpose, bgrad fused
        if ctx.use_bias:
665
666
667
            grad_output_mat_no_fp8 = grad_output_mat
            if isinstance(grad_output_mat, Float8Tensor):
                grad_output_mat_no_fp8 = grad_output_mat.from_float8(grad_output_mat.dtype)
668
            grad_bias, grad_output_c, grad_output_t = fp8_cast_transpose_bgrad_fused(
669
                grad_output_mat_no_fp8,
670
671
672
673
674
675
                ctx.fp8_meta["scaling_bwd"],
                tex.FP8BwdTensors.GRAD_OUTPUT1,
                fp8_dtype_backward,
            )
        else:
            if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
676
677
678
679
680
681
682
683
684
685
                if isinstance(grad_output_mat, Float8Tensor):
                    grad_output_c = grad_output_mat
                    grad_output_t = grad_output_c.transpose_2d()
                else:
                    grad_output_c, grad_output_t = fp8_cast_transpose_fused(
                        grad_output_mat,
                        ctx.fp8_meta["scaling_bwd"],
                        tex.FP8BwdTensors.GRAD_OUTPUT1,
                        fp8_dtype_backward,
                    )
686
687
            else:
                grad_output_t = None
688
689
690
691
692
693
694
695
696
                if not isinstance(grad_output_mat, Float8Tensor):
                    grad_output_c = cast_to_fp8(
                        grad_output_mat,
                        ctx.fp8_meta["scaling_bwd"],
                        tex.FP8BwdTensors.GRAD_OUTPUT1,
                        fp8_dtype_backward,
                    )
                else:
                    grad_output_c = grad_output_mat
697
698
699
700
            grad_bias = None

        return grad_output_mat, grad_output_c, grad_output_t, grad_bias

701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
    def register_parameter(self, name, param, **kwargs):
        """
        Thin wrapper around PyTorch parameter registration to stash additional parameter
        metedata used in deferred initialization.
        """
        super().register_parameter(name, param)
        self.param_init_meta[name] = _ParameterInitMeta(**kwargs)

    def reset_parameters(self, defer_init: Optional[bool] = False) -> None:
        """
        Reset all module parameters to initial values. Unless deferred initialization
        is specified, all parameters on a 'meta' device are also materialized on a real cuda
        device before the values are reset to initial.
        """
        if defer_init:
            return

        for name, param in self.named_parameters(recurse=False):
            # Ensure parameter is on a real device
            if param.device == torch.device('meta'):
721
                param = torch.empty_like(param, device='cuda')
722
723
724
725
726
727
728

            # Initialize the parameter values on device
            init_fn = self.param_init_meta[name].init_fn
            get_rng_state_tracker = self.param_init_meta[name].get_rng_state_tracker
            if get_rng_state_tracker is None:
                init_fn(param)
            else:
729
730
731
732
733
734
                if hasattr(self, "rng_tracker_name") and self.rng_tracker_name:
                    with get_rng_state_tracker().fork(self.rng_tracker_name):
                        init_fn(param)
                else:
                    with get_rng_state_tracker().fork():
                        init_fn(param)
735
736
737
738
739
740
741

            # If primary weights are in fp8, wrap the parameter as Float8Tensor
            fp8_meta_index = self.param_init_meta[name].fp8_meta_index
            if self.primary_weights_in_fp8 and fp8_meta_index is not None:
                param = Float8Tensor.to_float8(
                    param,
                    fp8_meta=self.fp8_meta,
742
743
                    fp8_meta_index=fp8_meta_index,
                    amax=torch.empty(1, device="cuda"),  # Dummy amax to avoid overwriting history.
744
745
746
747
748
749
750
751
                )

            # Redo parameter wrap in case we broke it above
            # NOTE: Currently this can only be broken when primary weights are in Fp8 but
            #       re-applying the nn.Parameter() wrap is a no-op when the input is already
            #       a parameter so we always re-apply it just for extra safety.
            setattr(self, name, torch.nn.Parameter(param))

752
753
754
    @abstractmethod
    def forward(self):
        """Needs override."""
755

756
    def get_fp8_workspace(
757
        self,
758
759
760
761
762
763
764
765
        *,
        tensor: Optional[torch.Tensor] = None,
        fp8_meta_forward: Optional[bool] = None,
        fp8_meta_index: Optional[int] = None,
        cache_name: Optional[str] = None,
        update_workspace: bool = True,
        skip_update_flag: Optional[torch.Tensor] = None,
        with_transpose: bool = False,
766
        fsdp_group: dist_group_type = None,
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
    ) -> Float8Tensor:
        """Get FP8 workspace buffer and maybe update its values

        The workspace buffer may be cached for future function calls.

        Parameters
        ----------
        tensor : torch.Tensor, optional
            Values to copy into workspace. Required if the workspace
            is being constructed or updated.
        fp8_meta_forward: bool, optional
            Whether to access FP8 meta tensors for the forward pass or
            backward pass. Required if the workspace is being
            constructed.
        fp8_meta_index: int, optional
            Index to access in FP8 meta tensors. Required if the
            workspace is being constructed.
        cache_name: str, optional
            Key for caching.
        update_workspace: bool, default = `True`
            Update workspace with values from `tensor`.
        skip_update_flag: torch.Tensor, optional
            GPU flag to skip updating the workspace. Take precedence
            over `update_workspace` if provided.
        with_transpose: bool, default = `False`
            Whether to initialize cached transpose in workspace.
793
794
        fsdp_group: bool, default = None
            FSDP process group that the weights are distributed over.
795
796
797
798
799
800
        """

        # Construct workspace if needed
        out = None
        if cache_name is not None:
            out = self._fp8_workspaces.get(cache_name, None)
801
802
803
804
805
806
807
808
            # Gather cached Fp8 workspace if it's distributed
            # NOTE: FSDP sharding is supported only for Fp8 buffers and will not work
            #       for models initialized with Fp8 primary weights.
            if (not isinstance(out, Float8Tensor) and
                fsdp_group is not None and
                out._data.shape != tensor.data.shape):
                _fsdp_gather_tensors(fsdp_group, [tensor.data.shape], out)

809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
        if out is None:
            if (
                tensor is None
                or fp8_meta_forward is None
                or fp8_meta_index is None
            ):
                raise ValueError(
                    "tensor, fp8_meta_forward, and fp8_meta_index kwargs "
                    "must be provided to construct FP8 workspace"
                )
            fp8_dtype = get_fp8_te_dtype(
                self.fp8_meta["recipe"],
                fprop_tensor=fp8_meta_forward,
            )
            scale_inv = torch.empty(
                [1],
                dtype=torch.float32,
                device=tensor.device
            )
            out = Float8Tensor(
                data=torch.empty_like(tensor, dtype=torch.uint8),
                fp8_meta=self.fp8_meta,
                fp8_meta_forward=fp8_meta_forward,
                fp8_meta_index=fp8_meta_index,
                fp8_dtype=fp8_dtype,
                fp8_scale_inv=scale_inv,
                dtype=tensor.dtype,
            )
            if cache_name is not None:
                self._fp8_workspaces[cache_name] = out
            update_workspace = True
            skip_update_flag = None

        # Update workspace if needed
        if skip_update_flag is not None:
            update_workspace = True
        if update_workspace:
            if tensor is None:
                raise ValueError(
                    "tensor kwarg must be provided to update FP8 workspace"
                )
            if with_transpose:
                out.cast_transpose_(
                    tensor,
                    noop_flag=skip_update_flag,
                )
            else:
                fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
                    forward=out._fp8_meta_forward,
                )
                fp8_meta = out._fp8_meta[fp8_meta_key]
                fp8_meta_index = out._fp8_meta_index
                cast_to_fp8(
                    tensor,
                    fp8_meta,
                    fp8_meta_index,
                    out._fp8_dtype,
                    out=out._data,
                )
                if is_in_onnx_export_mode():
                    # ONNX export expects FP8 scales can be
                    # represented with constant ops. However, copying
                    # into a buffer involves an expand op for array
                    # broadcasting. We work around this by filling the
                    # buffer instead.
                    out._scale_inv.fill_(fp8_meta.scale_inv[fp8_meta_index].item())
                else:
                    out._scale_inv.copy_(fp8_meta.scale_inv[fp8_meta_index])

        return out
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897

    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
                            missing_keys, unexpected_keys, error_msgs):
        """
        This function loads tensors and extra state including fp8 metadata.
        This metadata is essential for copying fp8 tensors, as the copy_ function
        uses the scale_inv parameter from fp8_meta to set the correct scaling factor
        for the new tensor.
        Hence, this extra state must be loaded before the tensor copying process,
        not after, as is typically done in _load_from_state_dict.
        Tensors are copied into fp8 tensors only when self.primary_weights_in_fp8=True,
        otherwise, this behavior is not required.
        """
        if self.primary_weights_in_fp8:
            extra_state_key = prefix + torch.nn.modules.module._EXTRA_STATE_KEY_SUFFIX
            if extra_state_key in state_dict:
                self.set_extra_state(state_dict[extra_state_key])
        super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
                            missing_keys, unexpected_keys, error_msgs)