"vscode:/vscode.git/clone" did not exist on "e5464ee484450c2671dd0226516c99c60ce70d9d"
module.py 109 KB
Newer Older
1
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
5
6
#
# See LICENSE for license information.

"""Top level Transformer Engine PyTorch modules"""
import os
7
import pickle
Przemek Tredak's avatar
Przemek Tredak committed
8
9
import warnings
from abc import ABC, abstractmethod
10
from typing import Union, Optional, Callable, Tuple, Dict, Any, Mapping
Przemek Tredak's avatar
Przemek Tredak committed
11
from functools import partial
12
from contextlib import contextmanager
Przemek Tredak's avatar
Przemek Tredak committed
13

14
import numpy as np
Przemek Tredak's avatar
Przemek Tredak committed
15
16
17
18
19
20
21
import torch
from torch.nn.parameter import Parameter
from torch.nn import init

import transformer_engine_extensions as tex
from .fp8 import (
    is_fp8_enabled,
schetlur-nv's avatar
schetlur-nv committed
22
    is_fp8_calibration,
Przemek Tredak's avatar
Przemek Tredak committed
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
    get_fp8_recipe,
    get_fp8_group,
    get_default_fp8_recipe,
    get_fp8_te_dtype,
    is_first_fp8_module,
    new_fp8_context_id,
    get_fp8_context_id,
    set_fp8_context_id,
    add_amax_to_global_buffer,
    copy_amax_from_global_buffer,
    global_amax_reduction,
    setup_amax_forward_global_reduce_func,
    amax_and_scale_update,
    get_global_fp8_buffer,
    set_global_fp8_buffer,
    set_amax_buffer_key_deletion,
    delete_key_from_amax_buffer,
40
41
42
    copy_forward_fp8_meta_tensors_for_recompute,
    get_old_fp8_meta_tensors_for_recompute,
    restore_fp8_meta_tensors,
Przemek Tredak's avatar
Przemek Tredak committed
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
)
from .jit import (
    bias_gelu_fused,
    bgrad_dgelu_fused,
    set_jit_fusion_options,
    warmup_jit_bias_gelu_all_dtypes,
)
from .utils import (
    divide,
    get_default_init_method,
    cast_if_needed,
)
from .distributed import (
    set_tensor_model_parallel_attributes,
    get_distributed_world_size,
    allreduce,
    initialize_affine_weight_gpu,
    reduce_scatter_along_first_dim,
    gather_along_first_dim,
    gather_along_last_dim,
63
64
    is_fp8_activation_recompute_enabled,
    in_fp8_activation_recompute_phase,
Przemek Tredak's avatar
Przemek Tredak committed
65
66
67
68
69
70
71
72
73
)
from .cpp_extensions import (
    fp8_gemm,
    gemm,
    fp8_cast_transpose_fused,
    fp8_cast_transpose_bgrad_fused,
    fp8_gelu,
    fp8_cast_transpose_bgrad_dgelu_fused,
    layernorm_fwd_fp8,
74
75
    layernorm_fwd_fp8_inf,
    layernorm_fwd_inf,
Przemek Tredak's avatar
Przemek Tredak committed
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
    cast_to_fp8,
    cast_from_fp8,
)
from .constants import GemmParallelModes, dist_group_type, TE_DType

_2X_ACC_FPROP = False
_2X_ACC_DGRAD = True
_2X_ACC_WGRAD = True
_cublas_workspace = None


def get_cublas_workspace_size_bytes() -> None:
    """Return 32 MiB if using hopper, 4 MiB for all other architectures."""
    if torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 9:
        return 33_554_432
    return 4_194_304


def get_workspace() -> torch.Tensor:
    """Returns workspace for cublas."""
    global _cublas_workspace
    if _cublas_workspace is None:
        _cublas_workspace = torch.empty(
            get_cublas_workspace_size_bytes(), dtype=torch.int8, device="cuda"
        )
    return _cublas_workspace

103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
@contextmanager
def _prepare_backward(fp8: bool,
                      fp8_meta: Dict[str, Any],
                      reduce_amax_across_tp_group: bool,
                      tp_group: Union[dist_group_type, None],
                      name: str = ""):
    """Checks and prep for BWD."""
    if fp8:
        # Update amax and scale; Skip all setup for global amax reduction
        if not fp8_meta["recipe"].reduce_amax:
            amax_and_scale_update(fp8_meta, False)
        else:
            # From previous iteration
            copy_amax_from_global_buffer(fp8_meta, forward=False)
            amax_and_scale_update(fp8_meta, False)
            set_amax_buffer_key_deletion(fp8_meta, forward=False)

            # Get new backward key.
121
            fp8_meta["autocast_id_bwd"] = fp8_meta["autocast_id_fwd_stack"].pop(0)
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136

            add_amax_to_global_buffer(fp8_meta, forward=False)

    with torch.cuda.nvtx.range(name + " backward"):
        yield

    if not fp8 or not fp8_meta["recipe"].reduce_amax:
        return

    if fp8_meta["first_module"]:
        global_amax_reduction(
            fp8_meta, reduce_amax_across_tp_group, tp_group, forward=False
        )
        delete_key_from_amax_buffer(forward=False)

Przemek Tredak's avatar
Przemek Tredak committed
137
138
139
140
141
142
143

class TransformerEngineBaseModule(torch.nn.Module, ABC):
    """Base TE module."""

    def __init__(self) -> None:
        super().__init__()
        assert torch.cuda.is_available(), "TransformerEngine needs CUDA."
schetlur-nv's avatar
schetlur-nv committed
144
        self.fp8_initialized = False
Przemek Tredak's avatar
Przemek Tredak committed
145
        self.fp8 = False
schetlur-nv's avatar
schetlur-nv committed
146
        self.fp8_calibration = False
Przemek Tredak's avatar
Przemek Tredak committed
147
148
149
150
151
152
153
154
155
        self.fp8_meta = {}
        self.fp8_meta["fp8_group"] = None
        self.fp8_meta["recipe"] = get_default_fp8_recipe()
        self.fp8_meta_tensors_initialized = False
        self.tp_group = None
        self.tp_group_initialized = False
        self.tp_size = 1
        self.sequence_parallel = False
        self.fp8_weight_shapes = []
156
        self.fp8_meta["autocast_id_fwd_stack"] = []
Przemek Tredak's avatar
Przemek Tredak committed
157
158
159
160

    def set_meta_tensor(self, fwd: bool) -> None:
        """Init scales and amaxes for fwd | bwd."""
        fp8_meta_tensor_key = "scaling_fwd" if fwd else "scaling_bwd"
161
162
        # Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and
        # 2 (grad_output and grad_input) for bwd
Przemek Tredak's avatar
Przemek Tredak committed
163
        num_fp8_tensors = (
164
            self.fp8_meta["num_gemms"] * 3 if fwd else self.fp8_meta["num_gemms"] * 2
Przemek Tredak's avatar
Przemek Tredak committed
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
        )

        self.fp8_meta[fp8_meta_tensor_key] = tex.FP8TensorMeta()
        self.fp8_meta[fp8_meta_tensor_key].scale = torch.ones(
            num_fp8_tensors, dtype=torch.float32, device="cuda"
        )
        self.fp8_meta[fp8_meta_tensor_key].scale_inv = torch.ones(
            num_fp8_tensors, dtype=torch.float32, device="cuda"
        )
        self.fp8_meta[fp8_meta_tensor_key].amax_history = torch.zeros(
            self.fp8_meta["recipe"].amax_history_len,
            num_fp8_tensors,
            dtype=torch.float32,
            device="cuda",
        )

181
182
183
        # Needed for calculation of scale inverses to
        # preserve scale_inv when caching FP8 weights
        if fwd:
184
            # [True, False, True]: -> [input, weight, output]
185
            self.fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"] = torch.BoolTensor(
186
                [True, False, True] * self.fp8_meta["num_gemms"]
187
188
            ).cuda()
        else:
189
            # [True, True]: -> [grad_output, grad_input]
190
            self.fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"] = torch.BoolTensor(
191
                [True, True] * self.fp8_meta["num_gemms"]
192
193
            ).cuda()

Przemek Tredak's avatar
Przemek Tredak committed
194
195
196
197
198
199
200
201
202
    def init_fp8_meta_tensors(self) -> None:
        """Init scales and amaxes."""
        # Checkpoint loaded
        if self.fp8_meta_tensors_initialized:
            return

        self.set_meta_tensor(True)
        self.set_meta_tensor(False)

203
    def get_extra_state(self) -> torch.Tensor:
Przemek Tredak's avatar
Przemek Tredak committed
204
        """Save before checkpointing."""
205
        state = None
schetlur-nv's avatar
schetlur-nv committed
206
        if self.fp8 or self.fp8_calibration:
207
208
209
210
211
212
213
214
215
216
217
218
219
220
            state = {}
            state["scale_fwd"] = self.fp8_meta["scaling_fwd"].scale
            state["amax_history_fwd"] = self.fp8_meta["scaling_fwd"].amax_history
            state["scale_bwd"] = self.fp8_meta["scaling_bwd"].scale
            state["amax_history_bwd"] = self.fp8_meta["scaling_bwd"].amax_history
            state["global_fp8_buffer"] = get_global_fp8_buffer()

            # Store other pickelable values.
            extra = {}
            for k, v in self.fp8_meta.items():
                if isinstance(v, (bool, int, float, str)):
                    extra[k] = v
            state["extra_fp8_variables"] = extra

221
222
        state_serialized = pickle.dumps(state)
        state_tensor = torch.tensor(np.frombuffer(state_serialized, dtype=np.uint8))
Przemek Tredak's avatar
Przemek Tredak committed
223

224
225
226
        return state_tensor

    def set_extra_state(self, state: torch.Tensor) -> None:
Przemek Tredak's avatar
Przemek Tredak committed
227
228
229
230
        """Load previous state."""
        if state is None:
            return

231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
        # Maintain backward compatibility with v0.2.0 and older.
        if isinstance(state, list):
            warnings.warn(
                "This checkpoint format is deprecated and will be"
                "removed in a future release of Transformer Engine"
            )

            # Retrieve checkpointed items.
            scale_fwd = state[0]
            amax_history_fwd = state[1]
            scale_bwd = state[2]
            amax_history_bwd = state[3]
            self.fp8_meta["recipe"].amax_history_len = amax_history_fwd.shape[0]
            self.fp8_meta["num_gemms"] = (
                amax_history_fwd.shape[1] // 2
            )  # Two FWD tensors per GEMM

            # Initialize before loading
            self.init_fp8_meta_tensors()
            self.fp8_meta["scaling_fwd"].scale.copy_(scale_fwd)
            self.fp8_meta["scaling_fwd"].amax_history.copy_(amax_history_fwd)
            self.fp8_meta["scaling_bwd"].scale.copy_(scale_bwd)
            self.fp8_meta["scaling_bwd"].amax_history.copy_(amax_history_bwd)
            self.fp8_meta_tensors_initialized = True

            # Restore global FP8 buffer state.
            set_global_fp8_buffer(state[4])
            self.fp8_meta["update_amax_and_scale_fwd"] = state[5]
            self.fp8_meta["global_fp8_buffer_pos_fwd"] = state[6]
            self.fp8_meta["global_fp8_buffer_pos_bwd"] = state[7]
            self.fp8_meta["autocast_id_fwd"] = state[8]
            self.fp8_meta["autocast_id_bwd"] = state[9]
            return

265
266
267
268
269
        if isinstance(state, torch.Tensor):
            state = pickle.loads(state.detach().numpy().tobytes())
            if state is None:
                return

270
271
272
273
274
        # Restore global FP8 buffer states.
        set_global_fp8_buffer(state["global_fp8_buffer"])
        # Load extra items.
        self.fp8_meta.update(state["extra_fp8_variables"])
        self.fp8_meta["recipe"].amax_history_len = state["amax_history_fwd"].shape[0]
275
276
        if "global_fp8_buffer_pos_fwd_recompute" in self.fp8_meta:
            del self.fp8_meta["global_fp8_buffer_pos_fwd_recompute"]
277
278

        # Initialize before loading.
Przemek Tredak's avatar
Przemek Tredak committed
279
        self.init_fp8_meta_tensors()
280
281
282
283
        self.fp8_meta["scaling_fwd"].scale.copy_(state["scale_fwd"])
        self.fp8_meta["scaling_fwd"].amax_history.copy_(state["amax_history_fwd"])
        self.fp8_meta["scaling_bwd"].scale.copy_(state["scale_bwd"])
        self.fp8_meta["scaling_bwd"].amax_history.copy_(state["amax_history_bwd"])
Przemek Tredak's avatar
Przemek Tredak committed
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
        self.fp8_meta_tensors_initialized = True

    def set_activation_dtype(self, inp: torch.Tensor) -> None:
        """Get activation data type for AMP."""
        # Native AMP (`torch.autocast`) gets highest priority
        if torch.is_autocast_enabled():
            self.activation_dtype = torch.get_autocast_gpu_dtype()
            return

        # All checks after this have already been performed once, thus skip
        # We assume that user doesn't change input types across iterations
        if hasattr(self, "activation_dtype"):
            return

        assert all(
            (
                (inp.dtype == param.dtype) if param is not None else True
                for param in self.parameters()
            )
        ), (
            "Data type for activations and weights must "
            "match when outside of autocasted region"
        )
        assert all(
            (
                (inp.dtype == buf.dtype) if buf is not None else True
                for buf in self.buffers()
            )
        ), (
            "Data type for activations and buffers must "
            "match when outside of autocasted region"
        )
        self.activation_dtype = inp.dtype

    def set_fp8_weights(self) -> None:
        """Initializes FP8 weights for the module as class attributes. These
        are not parameters or buffers since we do not want functions such as
        `.to(dtype)` or `.to(device)` to effect them. These also do not need
        to be checkpointed. During `init` phase of the module, the attribute
        `fp8_weight_shapes` must be populated with the tensor shapes for FP8
        weights. This function will iterate over those shapes and initialize
        respective attributed named `weight1_fp8`, `weight2_fp8`, ...
        """
327
328
329
        if not self.fp8:
            return

Przemek Tredak's avatar
Przemek Tredak committed
330
331
332
        for i, shape in enumerate(self.fp8_weight_shapes, start=1):
            weight_cast_attr = f"weight{i}_fp8"
            weight_transpose_attr = f"weight{i}_t_fp8"
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358

            if (
                hasattr(self, weight_cast_attr)
                and getattr(self, weight_cast_attr).shape == shape
            ):
                return

            setattr(
                self,
                weight_cast_attr,
                torch.empty(
                    shape,
                    device=torch.cuda.current_device(),
                    dtype=torch.int8,
                ),
            )
            setattr(
                self,
                weight_transpose_attr,
                torch.empty(
                    shape[1],
                    shape[0],
                    device=torch.cuda.current_device(),
                    dtype=torch.int8,
                ),
            )
Przemek Tredak's avatar
Przemek Tredak committed
359
360
361
362
363
364

    def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
        """Set TP group."""
        self.tp_group = tp_group
        self.tp_group_initialized = True

schetlur-nv's avatar
schetlur-nv committed
365
366
    # This routine is shared across FP8 and FP8_calibration paths so should not actually
    # assume FP8 execution.
Przemek Tredak's avatar
Przemek Tredak committed
367
368
    def fp8_init(self, num_gemms: int = 1) -> None:
        """Initialize fp8 related metadata and tensors during fprop."""
schetlur-nv's avatar
schetlur-nv committed
369
370
371
372
        if is_fp8_enabled() or is_fp8_calibration():
            # FP8 init has already been run and recipe is the same, don't do anything.
            if self.fp8_initialized and get_fp8_recipe() == self.fp8_meta["recipe"]:
                return
Przemek Tredak's avatar
Przemek Tredak committed
373

schetlur-nv's avatar
schetlur-nv committed
374
375
376
377
378
379
            # Set FP8, recipe, and other FP8 metadata
            self.fp8 = is_fp8_enabled()
            self.fp8_calibration = is_fp8_calibration()
            self.fp8_meta["recipe"] = get_fp8_recipe()
            self.fp8_meta["num_gemms"] = num_gemms
            self.fp8_meta["fp8_group"] = get_fp8_group()
Przemek Tredak's avatar
Przemek Tredak committed
380

schetlur-nv's avatar
schetlur-nv committed
381
382
383
            # Set FP8_MAX per tensor according to recipe
            self.fp8_meta["fp8_max_fwd"] = self.fp8_meta["recipe"].fp8_format.value.max_fwd
            self.fp8_meta["fp8_max_bwd"] = self.fp8_meta["recipe"].fp8_format.value.max_bwd
Przemek Tredak's avatar
Przemek Tredak committed
384

schetlur-nv's avatar
schetlur-nv committed
385
386
387
388
389
390
391
            # Allocate scales and amaxes
            self.init_fp8_meta_tensors()
            self.fp8_initialized = True
        else:
            # If fp8 isn't enabled, turn off and return.
            self.fp8_initialized = False
            return
Przemek Tredak's avatar
Przemek Tredak committed
392

393
    @contextmanager
394
395
396
397
398
399
    def prepare_forward(
        self,
        inp: torch.Tensor,
        is_first_microbatch: Union[bool, None],
        num_gemms: int = 1,
    ) -> None:
400
401
402
403
404
405
406
        """Checks and prep for FWD.
        The context manager is needed because there isn't a way for a module to know
        if it's the last FP8 module in the forward autocast. It is useful
        to setup the forward aggregated amax reduction for every module
        just in case. The autocast exit will pick up the most recent one.
        """

407
408
409
        # Activation recomputation is used and this is the second forward phase.
        if self.fp8 and in_fp8_activation_recompute_phase():
            get_old_fp8_meta_tensors_for_recompute(self.fp8_meta)
410
411
        else:
            assert inp.is_cuda, "TransformerEngine needs CUDA."
Przemek Tredak's avatar
Przemek Tredak committed
412

413
414
            if self.tp_size > 1:
                assert self.tp_group_initialized, "TP group not initialized."
Przemek Tredak's avatar
Przemek Tredak committed
415

416
417
418
            self.set_activation_dtype(inp)
            self.fp8_init(num_gemms=num_gemms)
            self.set_fp8_weights()
Przemek Tredak's avatar
Przemek Tredak committed
419

420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
            update_weight_scale_inv = is_first_microbatch is None or is_first_microbatch

            # Previous iteration was grad_enabled
            if self.fp8_meta.get("update_amax_and_scale_fwd", False):
                if self.fp8_meta["recipe"].reduce_amax:
                    copy_amax_from_global_buffer(self.fp8_meta, forward=True)
                    amax_and_scale_update(
                        self.fp8_meta, True, update_weight_scale_inv=update_weight_scale_inv
                    )
                    set_amax_buffer_key_deletion(self.fp8_meta, forward=True)
                else:
                    amax_and_scale_update(
                        self.fp8_meta, True, update_weight_scale_inv=update_weight_scale_inv
                    )

            if self.fp8 and self.training:
436
437
438
439
440
441
442
443
                # Setup for amax reduction
                if self.fp8_meta["recipe"].reduce_amax:
                    self.fp8_meta["first_module"] = is_first_fp8_module()
                    if self.fp8_meta["first_module"]:
                        self.fp8_meta["autocast_id_fwd"] = new_fp8_context_id()
                        set_fp8_context_id(self.fp8_meta["autocast_id_fwd"])
                    else:
                        self.fp8_meta["autocast_id_fwd"] = get_fp8_context_id()
444
445
446
                    self.fp8_meta["autocast_id_fwd_stack"].append(
                        self.fp8_meta["autocast_id_fwd"]
                    )
447
                    add_amax_to_global_buffer(self.fp8_meta, forward=True)
448
449
450
                self.fp8_meta["update_amax_and_scale_fwd"] = True
            else:
                self.fp8_meta["update_amax_and_scale_fwd"] = False
451

452
453
454
            # Activation recomputation is used and this is the first forward phase.
            if (
                self.fp8
455
                and self.training
456
457
458
459
                and is_fp8_activation_recompute_enabled()
                and not in_fp8_activation_recompute_phase()
            ):
                copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta)
460

461
462
        with torch.cuda.nvtx.range(self.__class__.__name__ + " forward"):
            yield inp.contiguous()
Przemek Tredak's avatar
Przemek Tredak committed
463

464
465
466
467
        if self.fp8 and in_fp8_activation_recompute_phase():
            restore_fp8_meta_tensors(self.fp8_meta)
            return

468
        if self.fp8 and self.training and self.fp8_meta["recipe"].reduce_amax:
Przemek Tredak's avatar
Przemek Tredak committed
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
            set_fp8_context_id(self.fp8_meta["autocast_id_fwd"])
            reduce_func = partial(
                global_amax_reduction,
                self.fp8_meta,
                self.sequence_parallel,
                self.tp_group,
                forward=True,
            )
            setup_amax_forward_global_reduce_func(reduce_func)

    def set_nccl_overlap_warning_if_tp(self) -> None:
        """When using TP, the NCCL communication needs to be scheduled
        before the GEMM for there to be a guaranteed overlap. From the
        host side in TE, the comm calls are always launched first, but
        to ensure that the GEMM isn't scheduled first, the environment
        variable `CUDA_DEVICE_MAX_CONNECTIONS` needs to be set to 1 to
        force a single channel.
        """
        if self.tp_size == 1:
            return
        num_cuda_work_queues = int(os.getenv("CUDA_DEVICE_MAX_CONNECTIONS", "0"))
        if num_cuda_work_queues != 1:
            warnings.warn(
                "To guarantee overlapping TP and SP collectives with the backward"
                "GEMMs, set environment variable CUDA_DEVICE_MAX_CONNECTIONS = 1"
            )

    @staticmethod
    def grad_output_preprocess(
        ctx, grad_output: torch.Tensor, row_parallel_mode: bool
    ) -> Tuple[Union[torch.Tensor, None], ...]:
        """Utility function for backward.
        Returns tuple in order (all optional/None based on training precion/recipe):
            R1: gathered `grad_output` in higher precision.
            R2: gathered `grad_output` in FP8.
            R3: R2 transposed.
            R4: bias gradient on R1.

        """
        grad_output = grad_output.contiguous()
        grad_output_mat = grad_output.view((-1, grad_output.shape[-1]))
        gather_grad_output = row_parallel_mode and ctx.sequence_parallel

        # No-FP8 case: bgrad is fused with wgrad for this case.
        if not ctx.fp8:
            if gather_grad_output:
                grad_output_mat, _ = gather_along_first_dim(
                    grad_output_mat, ctx.tp_group
                )
            return grad_output_mat, None, None, None

        fp8_dtype_backward = get_fp8_te_dtype(
            ctx.fp8_meta["recipe"], fprop_tensor=False
        )

        # FP8 case with non-FP8 wgrad
        if (
            gather_grad_output
            and ctx.fp8_meta["recipe"].override_linear_precision.wgrad
        ):
            grad_output_mat, _ = gather_along_first_dim(grad_output_mat, ctx.tp_group)
        # FP8 case with gather: unfused bgrad, cast, transpose for efficient gather
        elif gather_grad_output:
            if ctx.use_bias:
                grad_bias = grad_output_mat.sum(dim=0)
            else:
                grad_bias = None
            grad_output_c = cast_to_fp8(
                grad_output_mat,
                ctx.fp8_meta["scaling_bwd"],
                tex.FP8BwdTensors.GRAD_OUTPUT1,
                fp8_dtype_backward,
            )
            grad_output_c, _ = gather_along_first_dim(grad_output_c, ctx.tp_group)
            grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward)

            return grad_output_mat, grad_output_c, grad_output_t, grad_bias

        # FP8 case without gather: cast, transpose, bgrad fused
        if ctx.use_bias:
            grad_bias, grad_output_c, grad_output_t = fp8_cast_transpose_bgrad_fused(
                grad_output_mat,
                ctx.fp8_meta["scaling_bwd"],
                tex.FP8BwdTensors.GRAD_OUTPUT1,
                fp8_dtype_backward,
            )
        else:
            if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
                grad_output_c, grad_output_t = fp8_cast_transpose_fused(
                    grad_output_mat,
                    ctx.fp8_meta["scaling_bwd"],
                    tex.FP8BwdTensors.GRAD_OUTPUT1,
                    fp8_dtype_backward,
                )
            else:
564
                grad_output_t = None
Przemek Tredak's avatar
Przemek Tredak committed
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
                grad_output_c = cast_to_fp8(
                    grad_output_mat,
                    ctx.fp8_meta["scaling_bwd"],
                    tex.FP8BwdTensors.GRAD_OUTPUT1,
                    fp8_dtype_backward,
                )
            grad_bias = None

        return grad_output_mat, grad_output_c, grad_output_t, grad_bias

    @abstractmethod
    def forward(self):
        """Needs override."""


580

Przemek Tredak's avatar
Przemek Tredak committed
581
582
583
584
585
586
587
588
589
590
591
592
class _LayerNormLinear(torch.autograd.Function):
    """LayerNormLinear semi-top level module
    Calls custom cuda extensions.
    """

    @staticmethod
    def forward(
        ctx,
        inp: torch.Tensor,
        ln_weight: torch.Tensor,
        ln_bias: torch.Tensor,
        weight: torch.Tensor,
593
594
        weight_fp8: Union[torch.Tensor, None],
        weight_t_fp8: Union[torch.Tensor, None],
Przemek Tredak's avatar
Przemek Tredak committed
595
596
597
598
599
        bias: torch.Tensor,
        use_bias: bool,
        eps: float,
        is_first_microbatch: Union[bool, None],
        fp8: bool,
schetlur-nv's avatar
schetlur-nv committed
600
        fp8_calibration: bool,
Przemek Tredak's avatar
Przemek Tredak committed
601
602
603
604
        fp8_meta: Dict[str, Any],
        fuse_wgrad_accumulation: bool,
        tp_group: Union[dist_group_type, None],
        sequence_parallel: bool,
605
        tensor_parallel: bool,
Przemek Tredak's avatar
Przemek Tredak committed
606
607
608
        activation_dtype: torch.dtype,
        parallel_mode: Union[str, None],
        return_layernorm_output: bool,
609
610
611
        is_training: bool,
        fwd_ln_sm_margin: int,
        bwd_ln_sm_margin: int,
Przemek Tredak's avatar
Przemek Tredak committed
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
    ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
        # Make sure input dimensions are compatible
        in_features = ln_weight.numel()
        assert inp.shape[-1] == in_features, "GEMM not possible"
        inputmat = inp.view((-1, in_features))

        update_fp8_weights = is_first_microbatch is None or is_first_microbatch

        # Cast for native AMP
        inputmat = cast_if_needed(inputmat, activation_dtype)
        ln_weight = cast_if_needed(ln_weight, activation_dtype)
        ln_bias = cast_if_needed(ln_bias, activation_dtype)

        # If residual connection is after LN, we need `ln_out`
        # tensor in higher precision, this comes at the cost
        # of an extra fp8 cast.
        if fp8:
            fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)

            if not return_layernorm_output:
632
633
634
635
636
637
638
639
640
                if is_training:
                    ln_out, mu, rsigma = layernorm_fwd_fp8(
                        inputmat,
                        ln_weight,
                        ln_bias,
                        eps,
                        fp8_meta["scaling_fwd"],
                        tex.FP8FwdTensors.GEMM1_INPUT,
                        fp8_dtype_forward,
641
                        fwd_ln_sm_margin,
642
643
644
645
646
647
648
649
650
651
652
653
                    )
                else:
                    mu = rsigma = None
                    ln_out = layernorm_fwd_fp8_inf(
                        inputmat,
                        ln_weight,
                        ln_bias,
                        eps,
                        fp8_meta["scaling_fwd"],
                        tex.FP8FwdTensors.GEMM1_INPUT,
                        fp8_dtype_forward,
                    )
Przemek Tredak's avatar
Przemek Tredak committed
654
            else:
655
656
                if is_training:
                    ln_out_return, mu, rsigma = tex.layernorm_fwd(
657
                        inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin
658
659
660
661
662
663
                    )
                else:
                    ln_out_return, mu, rsigma = layernorm_fwd_inf(
                        inputmat, ln_weight, ln_bias, eps
                    ), None, None

Przemek Tredak's avatar
Przemek Tredak committed
664
665
666
667
668
669
670
                ln_out = cast_to_fp8(
                    ln_out_return,
                    fp8_meta["scaling_fwd"],
                    tex.FP8FwdTensors.GEMM1_INPUT,
                    fp8_dtype_forward,
                )
        else:
671
            if is_training:
672
673
674
                ln_out, mu, rsigma = tex.layernorm_fwd(
                    inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin
                )
675
676
677
678
            else:
                ln_out, mu, rsigma = layernorm_fwd_inf(
                        inputmat, ln_weight, ln_bias, eps
                ), None, None
Przemek Tredak's avatar
Przemek Tredak committed
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
            ln_out_return = ln_out

        # Column Parallel Linear
        if parallel_mode == "column" and sequence_parallel:
            ln_out_total, _ = gather_along_first_dim(ln_out, tp_group)
        else:
            ln_out_total = ln_out

        if fp8:
            bias_dtype = (
                torch.bfloat16
                if activation_dtype == torch.float32
                else activation_dtype
            )
            bias = cast_if_needed(bias, bias_dtype) if use_bias else bias

            if update_fp8_weights:
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
                if is_training:
                    fp8_cast_transpose_fused(
                        weight,
                        fp8_meta["scaling_fwd"],
                        tex.FP8FwdTensors.GEMM1_WEIGHT,
                        fp8_dtype_forward,
                        cast_out=weight_fp8,
                        transpose_out=weight_t_fp8,
                    )
                else:
                    weight_t_fp8 = None
                    weight_fp8 = cast_to_fp8(
                        weight,
                        fp8_meta["scaling_fwd"],
                        tex.FP8FwdTensors.GEMM1_WEIGHT,
                        fp8_dtype_forward)
Przemek Tredak's avatar
Przemek Tredak committed
712
713
714

            out = fp8_gemm(
                weight_fp8,
715
716
                fp8_meta["scaling_fwd"].scale_inv,
                tex.FP8FwdTensors.GEMM1_WEIGHT,
Przemek Tredak's avatar
Przemek Tredak committed
717
718
                fp8_dtype_forward,
                ln_out_total,
719
720
                fp8_meta["scaling_fwd"].scale_inv,
                tex.FP8FwdTensors.GEMM1_INPUT,
Przemek Tredak's avatar
Przemek Tredak committed
721
722
723
724
725
726
727
728
729
730
731
732
                fp8_dtype_forward,
                activation_dtype,
                get_workspace(),
                bias=bias,
                use_bias=use_bias,
                use_split_accumulator=_2X_ACC_FPROP,
            )
        else:
            # Cast for native AMP
            weight = cast_if_needed(weight, activation_dtype)
            bias = cast_if_needed(bias, activation_dtype) if use_bias else bias

schetlur-nv's avatar
schetlur-nv committed
733
734
735
736
737
738
739
740
            if fp8_calibration:
                # amax of input
                fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_INPUT] = \
                    torch.amax(ln_out_total).float()
                # amax of weight
                fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_WEIGHT] = \
                    torch.amax(weight).float()

Przemek Tredak's avatar
Przemek Tredak committed
741
742
743
744
745
746
747
748
749
            out, _, _ = gemm(
                weight,
                ln_out_total,
                activation_dtype,
                get_workspace(),
                bias=bias,
                use_bias=use_bias,
            )

750
751
752
753
754
755
756
757
758
759
760
        if is_training:
            ctx.save_for_backward(
                inputmat,
                ln_weight,
                mu,
                rsigma,
                weight,
                weight_t_fp8,
                ln_out,
                fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None,
            )
Przemek Tredak's avatar
Przemek Tredak committed
761

762
763
764
765
766
767
768
769
770
771
772
773
            ctx.activation_dtype = activation_dtype
            ctx.fp8 = fp8
            ctx.fp8_meta = fp8_meta
            ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
            ctx.is_first_microbatch = is_first_microbatch
            ctx.use_bias = use_bias
            ctx.sequence_parallel = sequence_parallel
            ctx.tensor_parallel = tensor_parallel
            ctx.inp_shape = inp.shape
            ctx.parallel_mode = parallel_mode
            ctx.tp_group = tp_group
            ctx.return_layernorm_output = return_layernorm_output
774
            ctx.bwd_ln_sm_margin = bwd_ln_sm_margin
Przemek Tredak's avatar
Przemek Tredak committed
775
776
777
778

        # Row Parallel Linear
        if parallel_mode == "row" and sequence_parallel:
            out, _ = reduce_scatter_along_first_dim(out, tp_group)
779
        elif parallel_mode == "row" and tensor_parallel:
Przemek Tredak's avatar
Przemek Tredak committed
780
781
782
783
784
785
786
787
788
            out, _ = allreduce(out, tp_group)

        # [*, in_features] -> [*, out_features] except first dimension changes for SP
        out = out.view(-1, *inp.shape[1:-1], out.shape[-1])

        if return_layernorm_output:
            return out, ln_out_return.view_as(inp)
        return out

789

Przemek Tredak's avatar
Przemek Tredak committed
790
791
792
793
    @staticmethod
    def backward(
        ctx, *grad_outputs: Tuple[torch.Tensor, ...]
    ) -> Tuple[Union[torch.Tensor, None], ...]:
794
795
796
797
798
799
800
801
802
803
804
805
        with _prepare_backward(ctx.fp8, ctx.fp8_meta, ctx.sequence_parallel, ctx.tp_group,
                               name="_LayerNormLinear"):
            (
                inputmat,
                ln_weight,
                mu,
                rsigma,
                weight,
                weight_t_fp8,
                ln_out,
                fwd_scale_inverses,
            ) = ctx.saved_tensors
Przemek Tredak's avatar
Przemek Tredak committed
806

807
808
809
810
811
812
813
            (
                grad_output,
                grad_output_c,
                grad_output_t,
                grad_bias,
            ) = TransformerEngineBaseModule.grad_output_preprocess(
                ctx, grad_outputs[0], ctx.parallel_mode == "row"
Przemek Tredak's avatar
Przemek Tredak committed
814
815
            )

816
817
818
819
820
821
822
823
            # Column Parallel Linear
            # Overlap input AG with dgrad
            if ctx.parallel_mode == "column" and ctx.sequence_parallel:
                ln_out_total, handle = gather_along_first_dim(
                    ln_out, ctx.tp_group, async_op=True
                )
            else:
                ln_out_total = ln_out
Przemek Tredak's avatar
Przemek Tredak committed
824

825
826
827
828
829
830
            if ctx.is_first_microbatch is not None:
                accumulate_wgrad_into_param_main_grad = (
                    ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch
                )
            else:
                accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation
Przemek Tredak's avatar
Przemek Tredak committed
831

832
833
834
835
836
837
838
            if ctx.fp8:
                fp8_dtype_forward = get_fp8_te_dtype(
                    ctx.fp8_meta["recipe"], fprop_tensor=True
                )
                fp8_dtype_backward = get_fp8_te_dtype(
                    ctx.fp8_meta["recipe"], fprop_tensor=False
                )
Przemek Tredak's avatar
Przemek Tredak committed
839

840
841
842
                # DGRAD: Evaluated unconditionally to feed into Linear backward
                dgrad = fp8_gemm(
                    weight_t_fp8,
843
844
                    fwd_scale_inverses,
                    tex.FP8FwdTensors.GEMM1_WEIGHT,
845
846
                    fp8_dtype_forward,
                    grad_output_c,
847
848
                    ctx.fp8_meta["scaling_bwd"].scale_inv,
                    tex.FP8BwdTensors.GRAD_OUTPUT1,
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
                    fp8_dtype_backward,
                    ctx.activation_dtype,
                    get_workspace(),
                    use_split_accumulator=_2X_ACC_DGRAD,
                )
            else:
                # DGRAD: Evaluated unconditionally to feed into Linear backward
                dgrad, _, _ = gemm(
                    weight,
                    grad_output,
                    ctx.activation_dtype,
                    get_workspace(),
                    layout="NN",
                    grad=True,
                )
Przemek Tredak's avatar
Przemek Tredak committed
864

865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
            # Overlap dgrad-RS/AR with wgrad
            if ctx.parallel_mode == "column" and ctx.sequence_parallel:
                handle.wait()
                dgrad, handle = reduce_scatter_along_first_dim(
                    dgrad, ctx.tp_group, async_op=True
                )
            elif ctx.parallel_mode == "column" and ctx.tensor_parallel:
                dgrad, handle = allreduce(dgrad, ctx.tp_group, async_op=True)

            if weight.requires_grad:
                if ctx.fp8:
                    # WGRAD
                    if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
                        ln_out_total_t = tex.fp8_transpose(ln_out_total, fp8_dtype_forward)
                        wgrad = fp8_gemm(
                            ln_out_total_t,
881
882
                            fwd_scale_inverses,
                            tex.FP8FwdTensors.GEMM1_INPUT,
883
884
                            fp8_dtype_forward,
                            grad_output_t,
885
886
                            ctx.fp8_meta["scaling_bwd"].scale_inv,
                            tex.FP8BwdTensors.GRAD_OUTPUT1,
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
                            fp8_dtype_backward,
                            ctx.activation_dtype,
                            get_workspace(),
                            accumulate=accumulate_wgrad_into_param_main_grad,
                            fp32_output=ctx.fuse_wgrad_accumulation,
                            out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
                            use_split_accumulator=_2X_ACC_WGRAD,
                        )
                    else:
                        ln_out_total_c = cast_from_fp8(
                            ln_out_total,
                            ctx.fp8_meta["scaling_fwd"],
                            tex.FP8FwdTensors.GEMM1_INPUT,
                            fp8_dtype_forward,
                            TE_DType[ctx.activation_dtype],
                        )
                        wgrad, _, _ = gemm(
                            ln_out_total_c,
                            grad_output,
                            ctx.activation_dtype,
                            get_workspace(),
                            layout="NT",
                            grad=True,
                            accumulate=accumulate_wgrad_into_param_main_grad,
                            fp32_output=ctx.fuse_wgrad_accumulation,
                            out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
                        )
schetlur-nv's avatar
schetlur-nv committed
914
                else:
915
916
                    # WGRAD
                    wgrad, grad_bias, _ = gemm(
schetlur-nv's avatar
schetlur-nv committed
917
918
919
920
921
922
                        ln_out_total,
                        grad_output,
                        ctx.activation_dtype,
                        get_workspace(),
                        layout="NT",
                        grad=True,
923
                        use_bias=ctx.use_bias,
schetlur-nv's avatar
schetlur-nv committed
924
925
926
927
                        accumulate=accumulate_wgrad_into_param_main_grad,
                        fp32_output=ctx.fuse_wgrad_accumulation,
                        out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
                    )
Przemek Tredak's avatar
Przemek Tredak committed
928

929
930
931
            # Column Parallel Linear
            if ctx.parallel_mode == "column" and ctx.tensor_parallel and handle is not None:
                handle.wait()
Przemek Tredak's avatar
Przemek Tredak committed
932

933
934
            # LayerNorm gradient
            d_ln_out = dgrad.view(inputmat.shape)
Przemek Tredak's avatar
Przemek Tredak committed
935

936
937
938
            # Residual gradient
            if ctx.return_layernorm_output:
                d_ln_out = d_ln_out + grad_outputs[1].view_as(d_ln_out)
Przemek Tredak's avatar
Przemek Tredak committed
939

940
            dxmat, dgamma, dbeta = tex.layernorm_bwd(
941
                d_ln_out, inputmat, mu, rsigma, ln_weight, ctx.bwd_ln_sm_margin
942
            )
Przemek Tredak's avatar
Przemek Tredak committed
943

944
945
            if not ctx.use_bias:
                grad_bias = None
Przemek Tredak's avatar
Przemek Tredak committed
946
947
948
949
950

        return (
            dxmat.view(ctx.inp_shape),
            dgamma,
            dbeta,
schetlur-nv's avatar
schetlur-nv committed
951
            wgrad if weight.requires_grad else None,
Przemek Tredak's avatar
Przemek Tredak committed
952
953
954
955
956
957
958
959
960
961
962
963
964
965
            None,
            None,
            grad_bias,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
966
            None,
967
            None,
968
969
            None,
            None,
970
            None,
Przemek Tredak's avatar
Przemek Tredak committed
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
        )


class LayerNormLinear(TransformerEngineBaseModule):
    """
    Applies layer normalization followed by linear transformation to the incoming data.

    Parameters
    ----------
    in_features : int
                 size of each input sample.
    out_features : int
                  size of each output sample.
    eps : float, default = 1e-5
         a value added to the denominator of layer normalization for numerical stability.
    bias : bool, default = `True`
          if set to `False`, the layer will not learn an additive bias.
    init_method : Callable, default = `None`
                 used for initializing weights in the following way: `init_method(weight)`.
                 When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`.
    return_layernorm_output : bool, default = `False`
                             if set to `True`, output of layernorm is returned from the forward
                             together with the output of the linear transformation.
                             Example use case: residual connection for transformer module is
                             taken post layernorm.

    Parallelism parameters
    ----------------------
    sequence_parallel : bool, default = `False`
                       if set to `True`, uses sequence parallelism.
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
    tp_size : int, default = 1
             used as TP (tensor parallel) world size when TP groups are not formed during
             initialization. In this case, users must call the
             `set_tensor_parallel_group(tp_group)` method on the initialized module before the
             forward pass to supply the tensor parallel group needed for tensor and sequence
             parallel collectives.
    parallel_mode : {None, 'Column', 'Row'}, default = `None`
                   used to decide whether this Linear layer is Column Parallel Linear or Row
                   Parallel Linear as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
                   When set to `None`, no communication is performed.
    skip_weight_param_allocation: bool, default = `False`
                                 if set to `True`, weight parameter is not allocated and must be
                                 passed as a keyword argument `weight` during the forward pass.

    Optimization parameters
    -----------------------
    fuse_wgrad_accumulation : bool, default = 'False'
                             if set to `True`, enables fusing of creation and accumulation of
                             the weight gradient.
    return_bias : bool, default = `False`
                 when set to `True`, this module will not apply the additive bias itself, but
                 instead return the bias value during the forward pass together with the
                 output of the linear transformation :math:`y = xA^T`. This is useful when
                 the bias addition can be fused to subsequent operations.
    params_dtype : torch.dtype, default = `torch.float32`
                  it controls the type used to allocate the initial parameters. Useful when
                  the model is trained with lower precision and the original FP32 parameters
                  would not fit in GPU memory.
    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        eps: float = 1e-5,
        sequence_parallel: bool = False,
        fuse_wgrad_accumulation: bool = False,
        tp_group: Optional[dist_group_type] = None,
        tp_size: int = 1,
        get_rng_state_tracker: Optional[Callable] = None,
        init_method: Optional[Callable] = None,
        bias: bool = True,
        return_bias: bool = False,
        params_dtype: torch.dtype = torch.float32,
        parallel_mode: Optional[str] = None,
        return_layernorm_output: bool = False,
        skip_weight_param_allocation: bool = False,
    ) -> None:
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.fuse_wgrad_accumulation = fuse_wgrad_accumulation
        self.use_bias = bias
        self.return_bias = return_bias
        self.return_layernorm_output = return_layernorm_output
        self.skip_weight_param_allocation = skip_weight_param_allocation

        if tp_group is None:
            self.tp_size = tp_size
            if tp_size == 1:
                self.set_tensor_parallel_group(tp_group)
        else:
            self.tp_size = get_distributed_world_size(tp_group)
            self.set_tensor_parallel_group(tp_group)
        self.set_nccl_overlap_warning_if_tp()

        self.parallel_mode = parallel_mode
        assert (
            self.parallel_mode in GemmParallelModes
        ), f"parallel_mode {parallel_mode} not supported"

        if self.parallel_mode == "column":
            self.out_features = divide(self.out_features, self.tp_size)
        elif self.parallel_mode == "row":
            self.in_features = divide(self.in_features, self.tp_size)

        if init_method is None:
            init_method = get_default_init_method()

        self.sequence_parallel = (self.tp_size > 1) and sequence_parallel

        self.eps = eps
        self.layer_norm_weight = Parameter(
            torch.empty(
                in_features,
                device=torch.cuda.current_device(),
                dtype=params_dtype,
            )
        )
        self.layer_norm_bias = Parameter(
            torch.empty(
                in_features,
                device=torch.cuda.current_device(),
                dtype=params_dtype,
            )
        )
        setattr(self.layer_norm_weight, "sequence_parallel", self.sequence_parallel)
        setattr(self.layer_norm_bias, "sequence_parallel", self.sequence_parallel)
        self.reset_layer_norm_parameters()

        if not skip_weight_param_allocation:
            self.weight = Parameter(
                torch.empty(
                    self.out_features,
                    self.in_features,
                    device=torch.cuda.current_device(),
                    dtype=params_dtype,
                )
            )

            initialize_affine_weight_gpu(
                self.weight,
                init_method,
                get_rng_state_tracker,
                partition_dim=1 if self.parallel_mode == "row" else 0,
                stride=1,
            )

            if self.use_bias or self.return_bias:
                self.bias = Parameter(
                    torch.empty(
                        self.out_features,
                        device=torch.cuda.current_device(),
                        dtype=params_dtype,
                    )
                )
                if self.parallel_mode == "column":
                    set_tensor_model_parallel_attributes(self.bias, True, 0, 1)
            else:
1132
                self.register_buffer("bias", torch.Tensor().type(params_dtype), persistent=False)
Przemek Tredak's avatar
Przemek Tredak committed
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146

            with torch.no_grad():
                self.bias.zero_()

        self.fp8_weight_shapes.append(torch.Size((self.out_features, self.in_features)))

        # For RPL, bias has to be added after TP collectives
        # So it cannot be fused with the GEMM
        if self.parallel_mode == "row" and self.use_bias:
            self.gemm_bias_unfused_add = True
            self.use_bias = False
        else:
            self.gemm_bias_unfused_add = False

1147
1148
1149
1150
1151
1152
1153
        # These many SMs are subtracted from the total SM count when calling forward
        # and backward LayerNorm C APIs. These envvars can be used to prevent the LN
        # kernels from using all SMs in the device. This is useful for cases such as
        # communication overlap with LN.
        self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))
        self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))

Przemek Tredak's avatar
Przemek Tredak committed
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
    def reset_layer_norm_parameters(self) -> None:
        """Init LN params"""
        init.ones_(self.layer_norm_weight)
        init.zeros_(self.layer_norm_bias)

    def forward(
        self,
        inp: torch.Tensor,
        weight: Optional[torch.Tensor] = None,
        bias: Optional[torch.Tensor] = None,
        is_first_microbatch: Optional[bool] = None,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
        """
        Apply layer normalization to the input followed by a linear transformation.

        Parameters
        ----------
        inp : torch.Tensor
             Input tensor.
        weight : torch.Tensor, default = None
                An optional weight tensor for the module. This argument is compulsory if module
                is initialized with `skip_weight_param_allocation=True`
        bias : torch.Tensor, default = None
              An optional bias tensor for the module. This argument is compulsory if module
              is initialized with `skip_weight_param_allocation=True` and one of `use_bias`
              or `return_bias`
        is_first_microbatch : {True, False, None}, default = None
                             During training using either gradient accumulation or
                             pipeline parallelism a minibatch of data is further split
                             into microbatches. Between the microbatches of the same minibatch
                             the model weights are not updated. Setting this parameter indicates
                             whether the current microbatch is the first in a minibatch or not.
                             When set, this parameter enables additional optimizations:

                             * during FP8 training, it allows caching of the FP8 versions of
                               the weights
                             * it also allows skipping gradient accumulation during the
                               first microbatch (since it is the first gradient being
                               produced)
        """

1195
        with self.prepare_forward(inp, is_first_microbatch) as inp:
1196
1197
            bias_tensor = bias if bias is not None else self.bias

1198
1199
1200
1201
1202
1203
1204
            if self.training:
                fwd_fn = _LayerNormLinear.apply
                args = []
            else:
                fwd_fn = _LayerNormLinear.forward
                args = [None]
            args += (
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
                inp,
                self.layer_norm_weight,
                self.layer_norm_bias,
                weight if weight is not None else self.weight,
                self.weight1_fp8 if self.fp8 else None,
                self.weight1_t_fp8 if self.fp8 else None,
                bias_tensor,
                self.use_bias,
                self.eps,
                is_first_microbatch,
                self.fp8,
schetlur-nv's avatar
schetlur-nv committed
1216
                self.fp8_calibration,
1217
1218
1219
1220
1221
1222
1223
1224
                self.fp8_meta,
                self.fuse_wgrad_accumulation,
                self.tp_group,
                self.sequence_parallel,
                self.tp_size > 1,
                self.activation_dtype,
                self.parallel_mode,
                self.return_layernorm_output,
1225
                self.training,
1226
1227
                self.fwd_ln_sm_margin,
                self.bwd_ln_sm_margin,
1228
            )
1229
            out = fwd_fn(*args)
Przemek Tredak's avatar
Przemek Tredak committed
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253

        if self.return_layernorm_output:
            out, ln_out = out

        if self.gemm_bias_unfused_add:
            out = out + cast_if_needed(bias_tensor, self.activation_dtype)

        if self.return_bias:
            if self.return_layernorm_output:
                return out, cast_if_needed(bias_tensor, self.activation_dtype), ln_out
            return out, cast_if_needed(bias_tensor, self.activation_dtype)
        if self.return_layernorm_output:
            return out, ln_out
        return out

class _Linear(torch.autograd.Function):
    """Linear semi-top level module
    Calls custom cuda extensions.
    """

    @staticmethod
    def forward(
        ctx,
        weight: torch.Tensor,
1254
1255
        weight_fp8: Union[torch.Tensor, None],
        weight_t_fp8: Union[torch.Tensor, None],
Przemek Tredak's avatar
Przemek Tredak committed
1256
1257
1258
1259
1260
        inp: torch.Tensor,
        bias: torch.Tensor,
        use_bias: bool,
        is_first_microbatch: Union[bool, None],
        fp8: bool,
schetlur-nv's avatar
schetlur-nv committed
1261
        fp8_calibration: bool,
Przemek Tredak's avatar
Przemek Tredak committed
1262
1263
1264
1265
        fp8_meta: Dict[str, Any],
        fuse_wgrad_accumulation: bool,
        tp_group: Union[dist_group_type, None],
        sequence_parallel: bool,
1266
        tensor_parallel: bool,
Przemek Tredak's avatar
Przemek Tredak committed
1267
1268
        activation_dtype: torch.dtype,
        parallel_mode: Union[str, None],
1269
        is_training: bool,
Przemek Tredak's avatar
Przemek Tredak committed
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
    ) -> torch.Tensor:
        # Make sure input dimensions are compatible
        in_features = weight.shape[-1]
        assert inp.shape[-1] == in_features, "GEMM not possible"
        inputmat = inp.view((-1, in_features))

        update_fp8_weights = is_first_microbatch is None or is_first_microbatch

        # Cast for native AMP
        inputmat = cast_if_needed(inputmat, activation_dtype)
        inputmat_no_fp8 = inputmat

        if fp8:
            fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)

            if not fp8_meta["recipe"].override_linear_precision.wgrad:
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
                if is_training:
                    inputmat, inputmat_t = fp8_cast_transpose_fused(
                        inputmat,
                        fp8_meta["scaling_fwd"],
                        tex.FP8FwdTensors.GEMM1_INPUT,
                        fp8_dtype_forward,
                    )
                else:
                    inputmat = cast_to_fp8(
                        inputmat,
                        fp8_meta["scaling_fwd"],
                        tex.FP8FwdTensors.GEMM1_INPUT,
                        fp8_dtype_forward,
                    )
Przemek Tredak's avatar
Przemek Tredak committed
1300
            else:
1301
                inputmat, inputmat_t = cast_to_fp8(
Przemek Tredak's avatar
Przemek Tredak committed
1302
1303
1304
1305
                    inputmat,
                    fp8_meta["scaling_fwd"],
                    tex.FP8FwdTensors.GEMM1_INPUT,
                    fp8_dtype_forward,
1306
                ), None
Przemek Tredak's avatar
Przemek Tredak committed
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322

        # Column Parallel Linear
        if parallel_mode == "column" and sequence_parallel:
            inputmat_total, _ = gather_along_first_dim(inputmat, tp_group)
        else:
            inputmat_total = inputmat

        if fp8:
            bias_dtype = (
                torch.bfloat16
                if activation_dtype == torch.float32
                else activation_dtype
            )
            bias = cast_if_needed(bias, bias_dtype) if use_bias else bias

            if update_fp8_weights:
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
                if is_training:
                    fp8_cast_transpose_fused(
                        weight,
                        fp8_meta["scaling_fwd"],
                        tex.FP8FwdTensors.GEMM1_WEIGHT,
                        fp8_dtype_forward,
                        cast_out=weight_fp8,
                        transpose_out=weight_t_fp8,
                    )
                else:
                    weight_t_fp8 = None
                    weight_fp8 = cast_to_fp8(
                        weight,
                        fp8_meta["scaling_fwd"],
                        tex.FP8FwdTensors.GEMM1_WEIGHT,
                        fp8_dtype_forward,
                    )
Przemek Tredak's avatar
Przemek Tredak committed
1340
1341
1342

            out = fp8_gemm(
                weight_fp8,
1343
1344
                fp8_meta["scaling_fwd"].scale_inv,
                tex.FP8FwdTensors.GEMM1_WEIGHT,
Przemek Tredak's avatar
Przemek Tredak committed
1345
1346
                fp8_dtype_forward,
                inputmat,
1347
1348
                fp8_meta["scaling_fwd"].scale_inv,
                tex.FP8FwdTensors.GEMM1_INPUT,
Przemek Tredak's avatar
Przemek Tredak committed
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
                fp8_dtype_forward,
                activation_dtype,
                get_workspace(),
                bias=bias,
                use_bias=use_bias,
                use_split_accumulator=_2X_ACC_FPROP,
            )
        else:
            # Cast for native AMP
            weight = cast_if_needed(weight, activation_dtype)
            bias = cast_if_needed(bias, activation_dtype) if use_bias else bias

schetlur-nv's avatar
schetlur-nv committed
1361
1362
1363
1364
1365
1366
1367
1368
            if fp8_calibration:
                # amax of input
                fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_INPUT] = \
                    torch.amax(inputmat_total).float()
                # amax of weight
                fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_WEIGHT] = \
                    torch.amax(weight).float()

Przemek Tredak's avatar
Przemek Tredak committed
1369
1370
1371
1372
1373
1374
1375
1376
1377
            out, _, _ = gemm(
                weight,
                inputmat_total,
                activation_dtype,
                get_workspace(),
                bias=bias,
                use_bias=use_bias,
            )

1378
        if is_training:
schetlur-nv's avatar
schetlur-nv committed
1379
            fp8_wgrad = fp8 and not fp8_meta["recipe"].override_linear_precision.wgrad
1380
            ctx.save_for_backward(
1381
1382
1383
1384
                inputmat_no_fp8 if weight.requires_grad and not fp8_wgrad else None,
                inputmat_t if weight.requires_grad and fp8_wgrad else None,
                weight,
                weight_t_fp8 if fp8 else None,
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
                fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None,
            )
            ctx.activation_dtype = activation_dtype
            ctx.fp8 = fp8
            ctx.fp8_meta = fp8_meta
            ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
            ctx.is_first_microbatch = is_first_microbatch
            ctx.use_bias = use_bias
            ctx.sequence_parallel = sequence_parallel
            ctx.tensor_parallel = tensor_parallel
            ctx.inp_shape = inp.shape
            ctx.parallel_mode = parallel_mode
            ctx.tp_group = tp_group
schetlur-nv's avatar
schetlur-nv committed
1398
            ctx.requires_wgrad = weight.requires_grad
Przemek Tredak's avatar
Przemek Tredak committed
1399
1400
1401
1402

        # Row Parallel Linear
        if parallel_mode == "row" and sequence_parallel:
            out, _ = reduce_scatter_along_first_dim(out, tp_group)
1403
        elif parallel_mode == "row" and tensor_parallel:
Przemek Tredak's avatar
Przemek Tredak committed
1404
1405
1406
1407
1408
            out, _ = allreduce(out, tp_group)

        # [*, in_features] -> [*, out_features] except first dimension changes for SP
        return out.view(-1, *inp.shape[1:-1], out.shape[-1])

1409

Przemek Tredak's avatar
Przemek Tredak committed
1410
1411
1412
1413
    @staticmethod
    def backward(
        ctx, grad_output: torch.Tensor
    ) -> Tuple[Union[torch.Tensor, None], ...]:
1414
1415
1416
1417
1418
1419
1420
1421
1422
        with _prepare_backward(ctx.fp8, ctx.fp8_meta, ctx.sequence_parallel, ctx.tp_group,
                               name="_Linear"):
            (
                inputmat,
                inputmat_t,
                weight,
                weight_t_fp8,
                fwd_scale_inverses,
            ) = ctx.saved_tensors
Przemek Tredak's avatar
Przemek Tredak committed
1423

1424
1425
1426
1427
1428
1429
1430
1431
            (
                grad_output,
                grad_output_c,
                grad_output_t,
                grad_bias,
            ) = TransformerEngineBaseModule.grad_output_preprocess(
                ctx, grad_output, ctx.parallel_mode == "row"
            )
Przemek Tredak's avatar
Przemek Tredak committed
1432

1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
            # Column Parallel Linear
            # Overlap input AG with dgrad
            if ctx.parallel_mode == "column" and ctx.sequence_parallel:
                if ctx.fp8 and not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
                    inputmat_t_total, handle = gather_along_last_dim(
                        inputmat_t, ctx.tp_group, async_op=True
                    )
                else:
                    inputmat_total, handle = gather_along_first_dim(
                        inputmat, ctx.tp_group, async_op=True
                    )
            else:
                inputmat_t_total = inputmat_t
                inputmat_total = inputmat
Przemek Tredak's avatar
Przemek Tredak committed
1447

1448
1449
1450
            if ctx.is_first_microbatch is not None:
                accumulate_wgrad_into_param_main_grad = (
                    ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch
Przemek Tredak's avatar
Przemek Tredak committed
1451
1452
                )
            else:
1453
                accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation
Przemek Tredak's avatar
Przemek Tredak committed
1454

1455
1456
1457
1458
1459
1460
1461
            if ctx.fp8:
                fp8_dtype_forward = get_fp8_te_dtype(
                    ctx.fp8_meta["recipe"], fprop_tensor=True
                )
                fp8_dtype_backward = get_fp8_te_dtype(
                    ctx.fp8_meta["recipe"], fprop_tensor=False
                )
Przemek Tredak's avatar
Przemek Tredak committed
1462

1463
                # DGRAD
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
                dgrad = fp8_gemm(
                    weight_t_fp8,
                    fwd_scale_inverses,
                    tex.FP8FwdTensors.GEMM1_WEIGHT,
                    fp8_dtype_forward,
                    grad_output_c,
                    ctx.fp8_meta["scaling_bwd"].scale_inv,
                    tex.FP8BwdTensors.GRAD_OUTPUT1,
                    fp8_dtype_backward,
                    ctx.activation_dtype,
                    get_workspace(),
                    use_split_accumulator=_2X_ACC_DGRAD,
                )
            else:
                # DGRAD
                dgrad, _, _ = gemm(
                    weight,
                    grad_output,
                    ctx.activation_dtype,
                    get_workspace(),
                    layout="NN",
                    grad=True,
                )
Przemek Tredak's avatar
Przemek Tredak committed
1487

1488
1489
1490
1491
1492
1493
1494
1495
1496
            # Overlap dgrad-RS/AR with wgrad
            if ctx.parallel_mode == "column" and ctx.sequence_parallel:
                handle.wait()
                dgrad, handle = reduce_scatter_along_first_dim(
                    dgrad, ctx.tp_group, async_op=True
                )
            elif ctx.parallel_mode == "column" and ctx.tensor_parallel:
                dgrad, handle = allreduce(dgrad, ctx.tp_group, async_op=True)

schetlur-nv's avatar
schetlur-nv committed
1497
            if ctx.requires_wgrad:
1498
1499
1500
1501
1502
                if ctx.fp8:
                    # WGRAD
                    if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
                        wgrad = fp8_gemm(
                            inputmat_t_total,
1503
1504
                            fwd_scale_inverses,
                            tex.FP8FwdTensors.GEMM1_INPUT,
1505
1506
                            fp8_dtype_forward,
                            grad_output_t,
1507
1508
                            ctx.fp8_meta["scaling_bwd"].scale_inv,
                            tex.FP8BwdTensors.GRAD_OUTPUT1,
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
                            fp8_dtype_backward,
                            ctx.activation_dtype,
                            get_workspace(),
                            accumulate=accumulate_wgrad_into_param_main_grad,
                            fp32_output=ctx.fuse_wgrad_accumulation,
                            out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
                            use_split_accumulator=_2X_ACC_WGRAD,
                        )
                    else:
                        wgrad, _, _ = gemm(
                            inputmat_total,
                            grad_output,
                            ctx.activation_dtype,
                            get_workspace(),
                            layout="NT",
                            grad=True,
                            accumulate=accumulate_wgrad_into_param_main_grad,
                            fp32_output=ctx.fuse_wgrad_accumulation,
                            out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
                        )
schetlur-nv's avatar
schetlur-nv committed
1529
                else:
1530
1531
                    # WGRAD
                    wgrad, grad_bias, _ = gemm(
schetlur-nv's avatar
schetlur-nv committed
1532
1533
1534
1535
1536
1537
                        inputmat_total,
                        grad_output,
                        ctx.activation_dtype,
                        get_workspace(),
                        layout="NT",
                        grad=True,
1538
                        use_bias=ctx.use_bias,
schetlur-nv's avatar
schetlur-nv committed
1539
1540
1541
1542
                        accumulate=accumulate_wgrad_into_param_main_grad,
                        fp32_output=ctx.fuse_wgrad_accumulation,
                        out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
                    )
Przemek Tredak's avatar
Przemek Tredak committed
1543

1544
1545
1546
            # Column Parallel Linear
            if ctx.parallel_mode == "column" and ctx.tensor_parallel and handle is not None:
                handle.wait()
Przemek Tredak's avatar
Przemek Tredak committed
1547

1548
1549
            if not ctx.use_bias:
                grad_bias = None
Przemek Tredak's avatar
Przemek Tredak committed
1550
1551

        return (
schetlur-nv's avatar
schetlur-nv committed
1552
            wgrad if ctx.requires_wgrad else None,
Przemek Tredak's avatar
Przemek Tredak committed
1553
1554
            None,
            None,
1555
            dgrad.view(ctx.inp_shape),
Przemek Tredak's avatar
Przemek Tredak committed
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
            grad_bias,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
1566
            None,
1567
            None,
schetlur-nv's avatar
schetlur-nv committed
1568
            None,
Przemek Tredak's avatar
Przemek Tredak committed
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
        )


class Linear(TransformerEngineBaseModule):
    """
    Applies a linear transformation to the incoming data :math:`y = xA^T + b`

    On NVIDIA GPUs it is a drop-in replacement for `torch.nn.Linear`.

    Parameters
    ----------
    in_features : int
                 size of each input sample.
    out_features : int
                  size of each output sample.
    bias : bool, default = `True`
          if set to `False`, the layer will not learn an additive bias.
    init_method : Callable, default = `None`
                 used for initializing weights in the following way: `init_method(weight)`.
                 When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`.

    Parallelism parameters
    ----------------------
    sequence_parallel : bool, default = `False`
                       if set to `True`, uses sequence parallelism.
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
    tp_size : int, default = 1
             used as TP (tensor parallel) world size when TP groups are not formed during
             initialization. In this case, users must call the
             `set_tensor_parallel_group(tp_group)` method on the initialized module before the
             forward pass to supply the tensor parallel group needed for tensor and sequence
             parallel collectives.
    parallel_mode : {None, 'Column', 'Row'}, default = `None`
                   used to decide whether this Linear layer is Column Parallel Linear or Row
                   Parallel Linear as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
                   When set to `None`, no communication is performed.
    skip_weight_param_allocation: bool, default = `False`
                                 if set to `True`, weight parameter is not allocated and must be
                                 passed as a keyword argument `weight` during the forward pass.

    Optimization parameters
    -----------------------
    fuse_wgrad_accumulation : bool, default = 'False'
                             if set to `True`, enables fusing of creation and accumulation of
                             the weight gradient. When enabled, it is assumed that the weights
                             have an additional `main_grad` attribute (used instead of the
                             regular `grad`) which is a pre-allocated buffer of the correct
                             size to accumulate gradients in.
    return_bias : bool, default = `False`
                 when set to `True`, this module will not apply the additive bias itself, but
                 instead return the bias value during the forward pass together with the
                 output of the linear transformation :math:`y = xA^T`. This is useful when
                 the bias addition can be fused to subsequent operations.
    params_dtype : torch.dtype, default = `torch.float32`
                  it controls the type used to allocate the initial parameters. Useful when
                  the model is trained with lower precision and the original FP32 parameters
                  would not fit in GPU memory.
    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        sequence_parallel: bool = False,
        fuse_wgrad_accumulation: bool = False,
        tp_group: Optional[dist_group_type] = None,
        tp_size: int = 1,
        get_rng_state_tracker: Optional[Callable] = None,
        init_method: Optional[Callable] = None,
        bias: bool = True,
        return_bias: bool = False,
        params_dtype: torch.dtype = torch.float32,
        parallel_mode: Optional[str] = None,
        skip_weight_param_allocation: bool = False,
    ) -> None:
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.fuse_wgrad_accumulation = fuse_wgrad_accumulation
        self.use_bias = bias
        self.return_bias = return_bias
        self.skip_weight_param_allocation = skip_weight_param_allocation

        if tp_group is None:
            self.tp_size = tp_size
            if tp_size == 1:
                self.set_tensor_parallel_group(tp_group)
        else:
            self.tp_size = get_distributed_world_size(tp_group)
            self.set_tensor_parallel_group(tp_group)
        self.set_nccl_overlap_warning_if_tp()

        self.parallel_mode = parallel_mode
        assert (
            self.parallel_mode in GemmParallelModes
        ), f"parallel_mode {parallel_mode} not supported"

        if self.parallel_mode == "column":
            self.out_features = divide(self.out_features, self.tp_size)
        elif self.parallel_mode == "row":
            self.in_features = divide(self.in_features, self.tp_size)

        if init_method is None:
            init_method = get_default_init_method()

        self.sequence_parallel = (self.tp_size > 1) and sequence_parallel

        if not skip_weight_param_allocation:
            self.weight = Parameter(
                torch.empty(
                    self.out_features,
                    self.in_features,
                    device=torch.cuda.current_device(),
                    dtype=params_dtype,
                )
            )

            initialize_affine_weight_gpu(
                self.weight,
                init_method,
                get_rng_state_tracker,
                partition_dim=1 if self.parallel_mode == "row" else 0,
                stride=1,
            )

            if self.use_bias or self.return_bias:
                self.bias = Parameter(
                    torch.empty(
                        self.out_features,
                        device=torch.cuda.current_device(),
                        dtype=params_dtype,
                    )
                )
                if self.parallel_mode == "column":
                    set_tensor_model_parallel_attributes(self.bias, True, 0, 1)
            else:
1706
                self.register_buffer("bias", torch.Tensor().type(params_dtype), persistent=False)
Przemek Tredak's avatar
Przemek Tredak committed
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756

            with torch.no_grad():
                self.bias.zero_()

        self.fp8_weight_shapes.append(torch.Size((self.out_features, self.in_features)))

        # For RPL, bias has to be added after TP collectives
        # So it cannot be fused with the GEMM
        if self.parallel_mode == "row" and self.use_bias:
            self.gemm_bias_unfused_add = True
            self.use_bias = False
        else:
            self.gemm_bias_unfused_add = False

    def forward(
        self,
        inp: torch.Tensor,
        weight: Optional[torch.Tensor] = None,
        bias: Optional[torch.Tensor] = None,
        is_first_microbatch: Optional[bool] = None,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
        """
        Apply the linear transformation to the input.

        Parameters
        ----------
        inp : torch.Tensor
             Input tensor.
        weight : torch.Tensor, default = None
                An optional weight tensor for the module. This argument is compulsory if module
                is initialized with `skip_weight_param_allocation=True`
        bias : torch.Tensor, default = None
              An optional bias tensor for the module. This argument is compulsory if module
              is initialized with `skip_weight_param_allocation=True` and one of `use_bias`
              or `return_bias`
        is_first_microbatch : {True, False, None}, default = None
                             During training using either gradient accumulation or
                             pipeline parallelism a minibatch of data is further split
                             into microbatches. Between the microbatches of the same minibatch
                             the model weights are not updated. Setting this parameter indicates
                             whether the current microbatch is the first in a minibatch or not.
                             When set, this parameter enables additional optimizations:

                             * during FP8 training, it allows caching of the FP8 versions of
                               the weights
                             * it also allows skipping gradient accumulation during the
                               first microbatch (since it is the first gradient being
                               produced)
        """

1757
        with self.prepare_forward(inp, is_first_microbatch) as inp:
1758
1759
            bias_tensor = bias if bias is not None else self.bias

1760
1761
1762
1763
1764
1765
1766
            if self.training:
                linear_fn = _Linear.apply
                args = []
            else:
                linear_fn = _Linear.forward
                args = [None]
            args += (
1767
1768
1769
1770
1771
1772
1773
1774
                weight if weight is not None else self.weight,
                self.weight1_fp8 if self.fp8 else None,
                self.weight1_t_fp8 if self.fp8 else None,
                inp,
                bias_tensor,
                self.use_bias,
                is_first_microbatch,
                self.fp8,
schetlur-nv's avatar
schetlur-nv committed
1775
                self.fp8_calibration,
1776
1777
1778
1779
1780
1781
1782
                self.fp8_meta,
                self.fuse_wgrad_accumulation,
                self.tp_group,
                self.sequence_parallel,
                self.tp_size > 1,
                self.activation_dtype,
                self.parallel_mode,
1783
                self.training,
1784
            )
1785
            out = linear_fn(*args)
Przemek Tredak's avatar
Przemek Tredak committed
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806

        if self.gemm_bias_unfused_add:
            out = out + cast_if_needed(bias_tensor, self.activation_dtype)

        if self.return_bias:
            return out, cast_if_needed(bias_tensor, self.activation_dtype)
        return out


class _LayerNormMLP(torch.autograd.Function):
    """LayerNormMLP semi-top level module
    Calls custom cuda extensions.
    """

    @staticmethod
    def forward(
        ctx,
        inp: torch.Tensor,
        ln_weight: torch.Tensor,
        ln_bias: torch.Tensor,
        fc1_weight: torch.Tensor,
1807
1808
        fc1_weight_fp8: Union[torch.Tensor, None],
        fc1_weight_t_fp8: Union[torch.Tensor, None],
Przemek Tredak's avatar
Przemek Tredak committed
1809
1810
        fc1_bias: torch.Tensor,
        fc2_weight: torch.Tensor,
1811
1812
        fc2_weight_fp8: Union[torch.Tensor, None],
        fc2_weight_t_fp8: Union[torch.Tensor, None],
Przemek Tredak's avatar
Przemek Tredak committed
1813
1814
1815
1816
1817
        fc2_bias: torch.Tensor,
        use_bias: bool,
        eps: float,
        is_first_microbatch: Union[bool, None],
        fp8: bool,
schetlur-nv's avatar
schetlur-nv committed
1818
        fp8_calibration: bool,
Przemek Tredak's avatar
Przemek Tredak committed
1819
1820
1821
1822
        fp8_meta: Dict[str, Any],
        fuse_wgrad_accumulation: bool,
        tp_group: Union[dist_group_type, None],
        sequence_parallel: bool,
1823
        tensor_parallel: bool,
Przemek Tredak's avatar
Przemek Tredak committed
1824
1825
1826
1827
        activation_dtype: torch.dtype,
        return_layernorm_output: bool,
        bias_gelu_nvfusion: bool,
        set_parallel_mode: bool,
1828
1829
1830
        is_training: bool,
        fwd_ln_sm_margin: int,
        bwd_ln_sm_margin: int,
Przemek Tredak's avatar
Przemek Tredak committed
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
    ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
        # Make sure input dimensions are compatible
        in_features = ln_weight.numel()
        assert inp.shape[-1] == in_features, "GEMM not possible"
        inputmat = inp.view((-1, in_features))

        update_fp8_weights = is_first_microbatch is None or is_first_microbatch

        # Cast for native AMP
        inputmat = cast_if_needed(inputmat, activation_dtype)
        ln_weight = cast_if_needed(ln_weight, activation_dtype)
        ln_bias = cast_if_needed(ln_bias, activation_dtype)

        # If residual connection is after LN, we need `ln_out`
        # tensor in higher precision, this comes at the cost
        # of an extra fp8 cast.
        if fp8:
            fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
            if not return_layernorm_output:
1850
1851
1852
1853
1854
1855
1856
1857
1858
                if is_training:
                    ln_out, mu, rsigma = layernorm_fwd_fp8(
                        inputmat,
                        ln_weight,
                        ln_bias,
                        eps,
                        fp8_meta["scaling_fwd"],
                        tex.FP8FwdTensors.GEMM1_INPUT,
                        fp8_dtype_forward,
1859
                        fwd_ln_sm_margin,
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
                    )
                else:
                    ln_out = layernorm_fwd_fp8_inf(
                        inputmat,
                        ln_weight,
                        ln_bias,
                        eps,
                        fp8_meta["scaling_fwd"],
                        tex.FP8FwdTensors.GEMM1_INPUT,
                        fp8_dtype_forward,
                    )
Przemek Tredak's avatar
Przemek Tredak committed
1871
1872
            else:
                ln_out_return, mu, rsigma = tex.layernorm_fwd(
1873
                    inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin
Przemek Tredak's avatar
Przemek Tredak committed
1874
1875
1876
1877
1878
1879
1880
1881
                )
                ln_out = cast_to_fp8(
                    ln_out_return,
                    fp8_meta["scaling_fwd"],
                    tex.FP8FwdTensors.GEMM1_INPUT,
                    fp8_dtype_forward,
                )
        else:
1882
            if is_training:
1883
1884
1885
                ln_out, mu, rsigma = tex.layernorm_fwd(
                    inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin
                )
1886
1887
1888
1889
            else:
                ln_out, mu, rsigma = layernorm_fwd_inf(
                        inputmat, ln_weight, ln_bias, eps
                        ), None, None
Przemek Tredak's avatar
Przemek Tredak committed
1890

1891
            ln_out_return = ln_out
Przemek Tredak's avatar
Przemek Tredak committed
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
        # Column Parallel Linear
        if set_parallel_mode and sequence_parallel:
            ln_out_total, _ = gather_along_first_dim(ln_out, tp_group)
        else:
            ln_out_total = ln_out

        if fp8:
            bias_dtype = (
                torch.bfloat16
                if activation_dtype == torch.float32
                else activation_dtype
            )
            fc1_bias = cast_if_needed(fc1_bias, bias_dtype)
            fc2_bias = cast_if_needed(fc2_bias, bias_dtype) if use_bias else fc2_bias

            if update_fp8_weights:
1908
1909
1910
1911
1912
1913
1914
1915
1916
                if is_training:
                    fp8_cast_transpose_fused(
                        fc1_weight,
                        fp8_meta["scaling_fwd"],
                        tex.FP8FwdTensors.GEMM1_WEIGHT,
                        fp8_dtype_forward,
                        cast_out=fc1_weight_fp8,
                        transpose_out=fc1_weight_t_fp8,
                    )
Przemek Tredak's avatar
Przemek Tredak committed
1917

1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
                    fp8_cast_transpose_fused(
                        fc2_weight,
                        fp8_meta["scaling_fwd"],
                        tex.FP8FwdTensors.GEMM2_WEIGHT,
                        fp8_dtype_forward,
                        cast_out=fc2_weight_fp8,
                        transpose_out=fc2_weight_t_fp8,
                    )
                else:
                    fc1_weight_t_fp8 = None
                    fc1_weight_fp8 = cast_to_fp8(
                        fc1_weight,
                        fp8_meta["scaling_fwd"],
                        tex.FP8FwdTensors.GEMM1_WEIGHT,
                        fp8_dtype_forward,
                    )
                    fc2_weight_t_fp8 = None
                    fc2_weight_fp8 = cast_to_fp8(
                        fc2_weight,
                        fp8_meta["scaling_fwd"],
                        tex.FP8FwdTensors.GEMM2_WEIGHT,
                        fp8_dtype_forward,
                    )
Przemek Tredak's avatar
Przemek Tredak committed
1941
1942
1943

            fc1_out = fp8_gemm(
                fc1_weight_fp8,
1944
1945
                fp8_meta["scaling_fwd"].scale_inv,
                tex.FP8FwdTensors.GEMM1_WEIGHT,
Przemek Tredak's avatar
Przemek Tredak committed
1946
1947
                fp8_dtype_forward,
                ln_out_total,
1948
1949
                fp8_meta["scaling_fwd"].scale_inv,
                tex.FP8FwdTensors.GEMM1_INPUT,
Przemek Tredak's avatar
Przemek Tredak committed
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
                fp8_dtype_forward,
                activation_dtype,
                get_workspace(),
                bias=fc1_bias,
                use_bias=True,
                use_split_accumulator=_2X_ACC_FPROP,
            )

            gelu_out = fp8_gelu(
                fc1_out,
                fp8_meta["scaling_fwd"],
                tex.FP8FwdTensors.GEMM2_INPUT,
                fp8_dtype_forward,
            )

            fc2_out = fp8_gemm(
                fc2_weight_fp8,
1967
1968
                fp8_meta["scaling_fwd"].scale_inv,
                tex.FP8FwdTensors.GEMM2_WEIGHT,
Przemek Tredak's avatar
Przemek Tredak committed
1969
1970
                fp8_dtype_forward,
                gelu_out,
1971
1972
                fp8_meta["scaling_fwd"].scale_inv,
                tex.FP8FwdTensors.GEMM2_INPUT,
Przemek Tredak's avatar
Przemek Tredak committed
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
                fp8_dtype_forward,
                activation_dtype,
                get_workspace(),
                bias=fc2_bias,
                use_bias=use_bias,
                use_split_accumulator=_2X_ACC_FPROP,
            )
        else:
            # Cast for native AMP
            fc1_weight = cast_if_needed(fc1_weight, activation_dtype)
            fc2_weight = cast_if_needed(fc2_weight, activation_dtype)
            fc1_bias = cast_if_needed(fc1_bias, activation_dtype)
            fc2_bias = (
                cast_if_needed(fc2_bias, activation_dtype) if use_bias else fc2_bias
            )

schetlur-nv's avatar
schetlur-nv committed
1989
1990
1991
1992
1993
1994
1995
1996
            if fp8_calibration:
                # amax of fc1 input
                fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_INPUT] = \
                    torch.amax(ln_out_total).float()
                # amax of fc1 weight
                fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_WEIGHT] = \
                    torch.amax(fc1_weight).float()

Przemek Tredak's avatar
Przemek Tredak committed
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
            fc1_outputs = gemm(
                fc1_weight,
                ln_out_total,
                activation_dtype,
                get_workspace(),
                bias=fc1_bias,
                use_bias=not bias_gelu_nvfusion,
                gelu=not bias_gelu_nvfusion,
            )

2007
            if bias_gelu_nvfusion and is_training:
Przemek Tredak's avatar
Przemek Tredak committed
2008
2009
2010
2011
2012
                fc1_out, _, _ = fc1_outputs
                gelu_out = bias_gelu_fused(fc1_out, fc1_bias)
            else:
                gelu_out, _, fc1_out = fc1_outputs

schetlur-nv's avatar
schetlur-nv committed
2013
2014
2015
2016
2017
2018
2019
2020
            if fp8_calibration:
                # amax of fc2 input
                fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM2_INPUT] = \
                    torch.amax(gelu_out).float()
                # amax of fc2 weight
                fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM2_WEIGHT] = \
                    torch.amax(fc2_weight).float()

Przemek Tredak's avatar
Przemek Tredak committed
2021
2022
2023
2024
2025
2026
2027
2028
            fc2_out, _, _ = gemm(
                fc2_weight,
                gelu_out,
                activation_dtype,
                get_workspace(),
                bias=fc2_bias,
                use_bias=use_bias,
            )
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
        if is_training:
            ctx.save_for_backward(
                inputmat,
                ln_weight,
                mu,
                rsigma,
                ln_out,
                fc1_out,
                gelu_out,
                fc1_weight,
                fc1_weight_t_fp8,
                fc2_weight,
                fc2_weight_t_fp8,
                fc1_bias,
                fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None,
            )
            ctx.activation_dtype = activation_dtype
            ctx.fp8 = fp8
            ctx.fp8_meta = fp8_meta
            ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
            ctx.is_first_microbatch = is_first_microbatch
            ctx.use_bias = use_bias
            ctx.sequence_parallel = sequence_parallel
            ctx.tensor_parallel = tensor_parallel
            ctx.inp_shape = inp.shape
            ctx.tp_group = tp_group
            ctx.bias_gelu_nvfusion = bias_gelu_nvfusion
            ctx.return_layernorm_output = return_layernorm_output
            ctx.set_parallel_mode = set_parallel_mode
2058
            ctx.bwd_ln_sm_margin = bwd_ln_sm_margin
Przemek Tredak's avatar
Przemek Tredak committed
2059
2060
2061
2062

        # Row Parallel Linear
        if set_parallel_mode and sequence_parallel:
            fc2_out, _ = reduce_scatter_along_first_dim(fc2_out, tp_group)
2063
        elif set_parallel_mode and tensor_parallel:
Przemek Tredak's avatar
Przemek Tredak committed
2064
2065
2066
2067
2068
2069
2070
2071
2072
            fc2_out, _ = allreduce(fc2_out, tp_group)

        # [*, in_features] -> [*, out_features] except first dimension changes for SP
        fc2_out = fc2_out.view(-1, *inp.shape[1:-1], fc2_out.shape[-1])

        if return_layernorm_output:
            return fc2_out, ln_out_return.view_as(inp)
        return fc2_out

2073

Przemek Tredak's avatar
Przemek Tredak committed
2074
2075
2076
2077
    @staticmethod
    def backward(
        ctx, *grad_outputs: Tuple[torch.Tensor, ...]
    ) -> Tuple[Union[torch.Tensor, None], ...]:
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
        with _prepare_backward(ctx.fp8, ctx.fp8_meta, ctx.sequence_parallel, ctx.tp_group,
                               name="_LayerNormMLP"):
            (
                inputmat,
                ln_weight,
                mu,
                rsigma,
                ln_out,
                fc1_out,
                gelu_out,
                fc1_weight,
                fc1_weight_t_fp8,
                fc2_weight,
                fc2_weight_t_fp8,
                fc1_bias,
                fwd_scale_inverses,
            ) = ctx.saved_tensors
Przemek Tredak's avatar
Przemek Tredak committed
2095

2096
2097
2098
2099
2100
2101
2102
2103
            (
                grad_output,
                grad_output_c,
                grad_output_t,
                fc2_bias_grad,
            ) = TransformerEngineBaseModule.grad_output_preprocess(
                ctx, grad_outputs[0], True
            )
Przemek Tredak's avatar
Przemek Tredak committed
2104

2105
2106
2107
2108
2109
2110
2111
2112
            # Column Parallel Linear
            # Overlap input AG with dgrad
            if ctx.set_parallel_mode and ctx.sequence_parallel:
                ln_out_total, handle = gather_along_first_dim(
                    ln_out, ctx.tp_group, async_op=True
                )
            else:
                ln_out_total = ln_out
Przemek Tredak's avatar
Przemek Tredak committed
2113

2114
2115
2116
2117
2118
2119
            if ctx.is_first_microbatch is not None:
                accumulate_wgrad_into_param_main_grad = (
                    ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch
                )
            else:
                accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation
Przemek Tredak's avatar
Przemek Tredak committed
2120

2121
2122
2123
2124
2125
2126
2127
            if ctx.fp8:
                fp8_dtype_forward = get_fp8_te_dtype(
                    ctx.fp8_meta["recipe"], fprop_tensor=True
                )
                fp8_dtype_backward = get_fp8_te_dtype(
                    ctx.fp8_meta["recipe"], fprop_tensor=False
                )
Przemek Tredak's avatar
Przemek Tredak committed
2128

2129
2130
2131
                # FC2 DGRAD; Unconditional
                fc2_dgrad = fp8_gemm(
                    fc2_weight_t_fp8,
2132
2133
                    fwd_scale_inverses,
                    tex.FP8FwdTensors.GEMM2_WEIGHT,
2134
2135
                    fp8_dtype_forward,
                    grad_output_c,
2136
2137
                    ctx.fp8_meta["scaling_bwd"].scale_inv,
                    tex.FP8BwdTensors.GRAD_OUTPUT1,
2138
2139
2140
2141
2142
                    fp8_dtype_backward,
                    ctx.activation_dtype,
                    get_workspace(),
                    use_split_accumulator=_2X_ACC_DGRAD,
                )
Przemek Tredak's avatar
Przemek Tredak committed
2143

2144
2145
2146
2147
2148
2149
                # FC2 WGRAD
                if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
                    if fc2_weight.requires_grad:
                        gelu_out_t = tex.fp8_transpose(gelu_out, fp8_dtype_forward)
                        fc2_wgrad = fp8_gemm(
                            gelu_out_t,
2150
2151
                            fwd_scale_inverses,
                            tex.FP8FwdTensors.GEMM2_INPUT,
2152
2153
                            fp8_dtype_forward,
                            grad_output_t,
2154
2155
                            ctx.fp8_meta["scaling_bwd"].scale_inv,
                            tex.FP8BwdTensors.GRAD_OUTPUT1,
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
                            fp8_dtype_backward,
                            ctx.activation_dtype,
                            get_workspace(),
                            accumulate=accumulate_wgrad_into_param_main_grad,
                            fp32_output=ctx.fuse_wgrad_accumulation,
                            out=fc2_weight.main_grad
                            if ctx.fuse_wgrad_accumulation
                            else None,
                            use_split_accumulator=_2X_ACC_WGRAD,
                        )

                    fc1_bias_grad, dgelu, dgelu_t = fp8_cast_transpose_bgrad_dgelu_fused(
                        fc2_dgrad,
                        fc1_out,
                        ctx.fp8_meta["scaling_bwd"],
                        tex.FP8BwdTensors.GRAD_OUTPUT2,
                        fp8_dtype_backward,
                    )
                else:
                    if fc2_weight.requires_grad:
                        gelu_out_c = cast_from_fp8(
                            gelu_out,
                            ctx.fp8_meta["scaling_fwd"],
                            tex.FP8FwdTensors.GEMM2_INPUT,
                            fp8_dtype_forward,
                            TE_DType[ctx.activation_dtype],
                        )
                        fc2_wgrad, _, _ = gemm(
                            gelu_out_c,
                            grad_output,
                            ctx.activation_dtype,
                            get_workspace(),
                            layout="NT",
                            grad=True,
                            use_bias=ctx.use_bias,
                            accumulate=accumulate_wgrad_into_param_main_grad,
                            fp32_output=ctx.fuse_wgrad_accumulation,
                            out=fc2_weight.main_grad
                            if ctx.fuse_wgrad_accumulation
                            else None,
                        )

                    fc1_bias_grad, dgelu_no_fp8 = bgrad_dgelu_fused(
                        fc2_dgrad, fc1_out, fc1_bias
                    )
Przemek Tredak's avatar
Przemek Tredak committed
2201

2202
2203
2204
2205
                    dgelu = cast_to_fp8(
                        dgelu_no_fp8,
                        ctx.fp8_meta["scaling_bwd"],
                        tex.FP8BwdTensors.GRAD_OUTPUT2,
schetlur-nv's avatar
schetlur-nv committed
2206
2207
                        fp8_dtype_backward,
                    )
2208
                    dgelu_t = None
Przemek Tredak's avatar
Przemek Tredak committed
2209

2210
2211
2212
                # FC1 DGRAD: Unconditional
                fc1_dgrad = fp8_gemm(
                    fc1_weight_t_fp8,
2213
2214
                    fwd_scale_inverses,
                    tex.FP8FwdTensors.GEMM1_WEIGHT,
2215
2216
                    fp8_dtype_forward,
                    dgelu,
2217
2218
                    ctx.fp8_meta["scaling_bwd"].scale_inv,
                    tex.FP8BwdTensors.GRAD_OUTPUT2,
Przemek Tredak's avatar
Przemek Tredak committed
2219
                    fp8_dtype_backward,
2220
2221
2222
                    ctx.activation_dtype,
                    get_workspace(),
                    use_split_accumulator=_2X_ACC_DGRAD,
Przemek Tredak's avatar
Przemek Tredak committed
2223
2224
                )
            else:
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
                # FC2 DGRAD; Unconditional
                fc2_dgrad, _, _ = gemm(
                    fc2_weight,
                    grad_output,
                    ctx.activation_dtype,
                    get_workspace(),
                    layout="NN",
                    gelu=not ctx.bias_gelu_nvfusion,
                    grad=True,
                    gelu_input=fc1_out,
                )

                # FC2 WGRAD
schetlur-nv's avatar
schetlur-nv committed
2238
                if fc2_weight.requires_grad:
2239
                    fc2_wgrad, fc2_bias_grad, _ = gemm(
schetlur-nv's avatar
schetlur-nv committed
2240
2241
2242
2243
2244
2245
2246
2247
2248
                        gelu_out,
                        grad_output,
                        ctx.activation_dtype,
                        get_workspace(),
                        layout="NT",
                        grad=True,
                        use_bias=ctx.use_bias,
                        accumulate=accumulate_wgrad_into_param_main_grad,
                        fp32_output=ctx.fuse_wgrad_accumulation,
2249
                        out=fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
schetlur-nv's avatar
schetlur-nv committed
2250
                    )
Przemek Tredak's avatar
Przemek Tredak committed
2251

2252
2253
2254
2255
                if ctx.bias_gelu_nvfusion:
                    fc1_bias_grad, dgelu = bgrad_dgelu_fused(fc2_dgrad, fc1_out, fc1_bias)
                else:
                    dgelu = fc2_dgrad
Przemek Tredak's avatar
Przemek Tredak committed
2256

2257
2258
2259
2260
                # FC1 DGRAD: Unconditional
                fc1_dgrad, _, _ = gemm(
                    fc1_weight,
                    dgelu,
schetlur-nv's avatar
schetlur-nv committed
2261
2262
                    ctx.activation_dtype,
                    get_workspace(),
2263
                    layout="NN",
schetlur-nv's avatar
schetlur-nv committed
2264
2265
                    grad=True,
                )
Przemek Tredak's avatar
Przemek Tredak committed
2266

2267
2268
2269
2270
2271
2272
2273
2274
2275
2276
2277
2278
2279
2280
2281
2282
            # Overlap dgrad-RS/AR with wgrad
            if ctx.set_parallel_mode and ctx.sequence_parallel:
                handle.wait()
                fc1_dgrad, handle = reduce_scatter_along_first_dim(
                    fc1_dgrad, ctx.tp_group, async_op=True
                )
            elif ctx.set_parallel_mode and ctx.tensor_parallel:
                fc1_dgrad, handle = allreduce(fc1_dgrad, ctx.tp_group, async_op=True)

            if fc1_weight.requires_grad:
                if ctx.fp8:
                    # FC1 WGRAD
                    if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
                        ln_out_total_t = tex.fp8_transpose(ln_out_total, fp8_dtype_forward)
                        fc1_wgrad = fp8_gemm(
                            ln_out_total_t,
2283
2284
                            fwd_scale_inverses,
                            tex.FP8FwdTensors.GEMM1_INPUT,
2285
2286
                            fp8_dtype_forward,
                            dgelu_t,
2287
2288
                            ctx.fp8_meta["scaling_bwd"].scale_inv,
                            tex.FP8BwdTensors.GRAD_OUTPUT2,
2289
2290
2291
2292
2293
2294
2295
2296
2297
2298
2299
2300
2301
2302
2303
2304
2305
2306
2307
2308
2309
2310
2311
2312
2313
2314
2315
2316
2317
2318
2319
                            fp8_dtype_backward,
                            ctx.activation_dtype,
                            get_workspace(),
                            accumulate=accumulate_wgrad_into_param_main_grad,
                            fp32_output=ctx.fuse_wgrad_accumulation,
                            out=fc1_weight.main_grad
                            if ctx.fuse_wgrad_accumulation
                            else None,
                            use_split_accumulator=_2X_ACC_WGRAD,
                        )
                    else:
                        ln_out_total_c = cast_from_fp8(
                            ln_out_total,
                            ctx.fp8_meta["scaling_fwd"],
                            tex.FP8FwdTensors.GEMM1_INPUT,
                            fp8_dtype_forward,
                            TE_DType[ctx.activation_dtype],
                        )
                        fc1_wgrad, _, _ = gemm(
                            ln_out_total_c,
                            dgelu_no_fp8,
                            ctx.activation_dtype,
                            get_workspace(),
                            layout="NT",
                            grad=True,
                            accumulate=accumulate_wgrad_into_param_main_grad,
                            fp32_output=ctx.fuse_wgrad_accumulation,
                            out=fc1_weight.main_grad
                            if ctx.fuse_wgrad_accumulation
                            else None,
                        )
schetlur-nv's avatar
schetlur-nv committed
2320
                else:
2321
2322
                    # FC1 WGRAD
                    fc1_wgrad_outputs = gemm(
schetlur-nv's avatar
schetlur-nv committed
2323
                        ln_out_total,
2324
                        dgelu,
schetlur-nv's avatar
schetlur-nv committed
2325
2326
2327
2328
                        ctx.activation_dtype,
                        get_workspace(),
                        layout="NT",
                        grad=True,
2329
                        use_bias=not ctx.bias_gelu_nvfusion,
schetlur-nv's avatar
schetlur-nv committed
2330
2331
                        accumulate=accumulate_wgrad_into_param_main_grad,
                        fp32_output=ctx.fuse_wgrad_accumulation,
2332
                        out=fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
schetlur-nv's avatar
schetlur-nv committed
2333
                    )
Przemek Tredak's avatar
Przemek Tredak committed
2334

2335
2336
2337
2338
                    if ctx.bias_gelu_nvfusion:
                        fc1_wgrad, _, _ = fc1_wgrad_outputs
                    else:
                        fc1_wgrad, fc1_bias_grad, _ = fc1_wgrad_outputs
Przemek Tredak's avatar
Przemek Tredak committed
2339

2340
2341
2342
            # Column Parallel Linear
            if ctx.set_parallel_mode and ctx.tensor_parallel and handle is not None:
                handle.wait()
Przemek Tredak's avatar
Przemek Tredak committed
2343

2344
2345
            # LayerNorm gradient
            d_ln_out = fc1_dgrad.view(inputmat.shape)
Przemek Tredak's avatar
Przemek Tredak committed
2346

2347
2348
2349
            # Residual gradient
            if ctx.return_layernorm_output:
                d_ln_out = d_ln_out + grad_outputs[1].view_as(d_ln_out)
Przemek Tredak's avatar
Przemek Tredak committed
2350

2351
            dxmat, dgamma, dbeta = tex.layernorm_bwd(
2352
                d_ln_out, inputmat, mu, rsigma, ln_weight, ctx.bwd_ln_sm_margin
2353
            )
Przemek Tredak's avatar
Przemek Tredak committed
2354

2355
2356
            if not ctx.use_bias:
                fc2_bias_grad = None
Przemek Tredak's avatar
Przemek Tredak committed
2357
2358
2359
2360
2361

        return (
            dxmat.view(ctx.inp_shape),
            dgamma,
            dbeta,
schetlur-nv's avatar
schetlur-nv committed
2362
            fc1_wgrad if fc1_weight.requires_grad else None,
Przemek Tredak's avatar
Przemek Tredak committed
2363
2364
2365
            None,
            None,
            fc1_bias_grad,
schetlur-nv's avatar
schetlur-nv committed
2366
            fc2_wgrad if fc2_weight.requires_grad else None,
Przemek Tredak's avatar
Przemek Tredak committed
2367
2368
2369
2370
2371
2372
2373
2374
2375
2376
2377
2378
2379
2380
2381
            None,
            None,
            fc2_bias_grad,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
2382
            None,
2383
            None,
2384
2385
            None,
            None,
2386
            None,
Przemek Tredak's avatar
Przemek Tredak committed
2387
2388
2389
2390
2391
2392
2393
2394
2395
2396
2397
2398
2399
2400
2401
2402
2403
2404
2405
2406
2407
2408
2409
2410
2411
2412
2413
2414
2415
2416
2417
2418
2419
2420
2421
2422
2423
2424
2425
2426
2427
2428
2429
2430
2431
2432
2433
2434
2435
2436
2437
2438
2439
2440
2441
2442
2443
2444
2445
2446
2447
2448
2449
2450
2451
2452
2453
2454
2455
2456
2457
2458
2459
2460
2461
2462
2463
2464
2465
2466
2467
2468
2469
2470
2471
2472
2473
2474
2475
2476
2477
2478
2479
2480
2481
2482
2483
2484
2485
2486
2487
2488
2489
2490
2491
2492
2493
2494
2495
2496
2497
2498
2499
2500
2501
2502
2503
2504
2505
2506
2507
2508
2509
2510
2511
2512
2513
2514
2515
2516
2517
2518
2519
2520
2521
2522
2523
2524
2525
2526
2527
2528
2529
2530
2531
2532
2533
2534
2535
2536
2537
2538
2539
2540
2541
2542
2543
2544
2545
2546
2547
2548
2549
2550
2551
2552
2553
2554
2555
2556
2557
2558
2559
2560
2561
2562
2563
2564
2565
2566
2567
2568
2569
2570
2571
2572
2573
2574
2575
2576
2577
2578
2579
2580
        )


class LayerNormMLP(TransformerEngineBaseModule):
    """
    Applies layer normalization on the input followed by the MLP module, consisting of
    2 successive linear transformations, separated by the GeLU activation.

    Parameters
    ----------
    hidden_size : int
                 size of each input sample.
    ffn_hidden_size : int
                     intermediate size to which input samples are projected.
    eps : float, default = 1e-5
         a value added to the denominator of layer normalization for numerical stability.
    bias : bool, default = `True`
          if set to `False`, the FC2 layer will not learn an additive bias.
    init_method : Callable, default = `None`
                 used for initializing FC1 weights in the following way: `init_method(weight)`.
                 When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`.
    output_layer_init_method : Callable, default = `None`
                              used for initializing FC2 weights in the following way:
                              `output_layer_init_method(weight)`. When set to `None`, defaults to
                              `torch.nn.init.normal_(mean=0.0, std=0.023)`.
    return_layernorm_output : bool, default = `False`
                             if set to `True`, output of layernorm is returned from the forward
                             together with the output of the linear transformation.
                             Example use case: residual connection for transformer module
                             is taken post layernorm.

    Parallelism parameters
    ----------------------
    set_parallel_mode : bool, default = `False`
                      if set to `True`, FC1 is used as Column Parallel and FC2 is used as Row
                      Parallel as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
    sequence_parallel : bool, default = `False`
                       if set to `True`, uses sequence parallelism.
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
    tp_size : int, default = 1
             used as TP (tensor parallel) world size when TP groups are not formed during
             initialization. In this case, users must call the
             `set_tensor_parallel_group(tp_group)` method on the initialized module before the
             forward pass to supply the tensor parallel group needed for tensor and sequence
             parallel collectives.

    Optimization parameters
    -----------------------
    fuse_wgrad_accumulation : bool, default = 'False'
                             if set to `True`, enables fusing of creation and accumulation of
                             the weight gradient.
    return_bias : bool, default = `False`
                 when set to `True`, this module will not apply the additive bias itself, but
                 instead return the bias value during the forward pass together with the
                 output of the linear transformation :math:`y = xA^T`. This is useful when
                 the bias addition can be fused to subsequent operations.
    params_dtype : torch.dtype, default = `torch.float32`
                  it controls the type used to allocate the initial parameters. Useful when
                  the model is trained with lower precision and the original FP32 parameters
                  would not fit in GPU memory.
    seq_length: int
               sequence length of input samples. Needed for JIT Warmup, a technique where jit fused
               functions are warmed up before training to ensure same kernels are used for forward
               propogation and activation recompute phase.
    micro_batch_size: int
                     batch size per training step. Needed for JIT Warmup, a technique where jit
                     fused functions are warmed up before training to ensure same kernels are
                     used for forward propogation and activation recompute phase.
    """

    def __init__(
        self,
        hidden_size: int,
        ffn_hidden_size: int,
        eps: float = 1e-5,
        sequence_parallel: bool = False,
        return_bias: bool = False,
        get_rng_state_tracker: Optional[Callable] = None,
        tp_group: Optional[dist_group_type] = None,
        tp_size: int = 1,
        init_method: Optional[Callable] = None,
        bias: bool = True,
        output_layer_init_method: Optional[Callable] = None,
        fuse_wgrad_accumulation: bool = False,
        params_dtype: torch.dtype = torch.float32,
        return_layernorm_output: bool = False,
        seq_length: Optional[int] = None,
        micro_batch_size: Optional[int] = None,
        set_parallel_mode: bool = False,
    ) -> None:
        super().__init__()

        self.fuse_wgrad_accumulation = fuse_wgrad_accumulation
        self.use_bias = bias
        self.return_bias = return_bias
        self.return_layernorm_output = return_layernorm_output
        self.bias_gelu_nvfusion = bool(int(os.getenv("NVTE_BIAS_GELU_NVFUSION", "1")))
        self.set_parallel_mode = set_parallel_mode

        if tp_group is None:
            self.tp_size = tp_size
            if tp_size == 1:
                self.set_tensor_parallel_group(tp_group)
        else:
            self.tp_size = get_distributed_world_size(tp_group)
            self.set_tensor_parallel_group(tp_group)
        self.set_nccl_overlap_warning_if_tp()

        if init_method is None:
            init_method = get_default_init_method()
        if output_layer_init_method is None:
            output_layer_init_method = get_default_init_method()

        self.sequence_parallel = (self.tp_size > 1) and sequence_parallel
        self.size_per_partition = divide(ffn_hidden_size, self.tp_size)

        # LN init
        self.eps = eps
        self.layer_norm_weight = Parameter(
            torch.empty(
                hidden_size,
                device=torch.cuda.current_device(),
                dtype=params_dtype,
            )
        )
        self.layer_norm_bias = Parameter(
            torch.empty(
                hidden_size,
                device=torch.cuda.current_device(),
                dtype=params_dtype,
            )
        )
        setattr(self.layer_norm_weight, "sequence_parallel", self.sequence_parallel)
        setattr(self.layer_norm_bias, "sequence_parallel", self.sequence_parallel)
        self.reset_layer_norm_parameters()

        # FC1 init
        self.fc1_weight = Parameter(
            torch.empty(
                self.size_per_partition,
                hidden_size,
                device=torch.cuda.current_device(),
                dtype=params_dtype,
            )
        )
        self.fp8_weight_shapes.append(self.fc1_weight.shape)

        initialize_affine_weight_gpu(
            self.fc1_weight,
            init_method,
            get_rng_state_tracker,
            partition_dim=0,
            stride=1,
        )

        self.fc1_bias = Parameter(
            torch.empty(
                self.size_per_partition,
                device=torch.cuda.current_device(),
                dtype=params_dtype,
            )
        )
        set_tensor_model_parallel_attributes(self.fc1_bias, True, 0, 1)

        with torch.no_grad():
            self.fc1_bias.zero_()

        # FC2 init
        self.fc2_weight = Parameter(
            torch.empty(
                hidden_size,
                self.size_per_partition,
                device=torch.cuda.current_device(),
                dtype=params_dtype,
            )
        )
        self.fp8_weight_shapes.append(self.fc2_weight.shape)

        initialize_affine_weight_gpu(
            self.fc2_weight,
            output_layer_init_method,
            get_rng_state_tracker,
            partition_dim=1,
            stride=1,
        )

        if self.use_bias or self.return_bias:
            self.fc2_bias = Parameter(
                torch.empty(
                    hidden_size, device=torch.cuda.current_device(), dtype=params_dtype
                )
            )
        else:
2581
            self.register_buffer("fc2_bias", torch.Tensor().type(params_dtype), persistent=False)
Przemek Tredak's avatar
Przemek Tredak committed
2582
2583
2584
2585
2586
2587
2588
2589
2590
2591
2592
2593
2594
2595
2596
2597
2598
2599
2600

        # For RPL, bias has to be added after TP collectives
        # So it cannot be fused with the GEMM
        if self.set_parallel_mode and self.use_bias:
            self.gemm_bias_unfused_add = True
            self.use_bias = False
        else:
            self.gemm_bias_unfused_add = False

        with torch.no_grad():
            self.fc2_bias.zero_()

        if self.bias_gelu_nvfusion:
            set_jit_fusion_options()
            if seq_length and micro_batch_size:
                warmup_jit_bias_gelu_all_dtypes(
                    self.size_per_partition, seq_length, micro_batch_size
                )

2601
2602
2603
2604
2605
2606
2607
        # These many SMs are subtracted from the total SM count when calling forward
        # and backward LayerNorm C APIs. These envvars can be used to prevent the LN
        # kernels from using all SMs in the device. This is useful for cases such as
        # communication overlap with LN.
        self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))
        self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))

Przemek Tredak's avatar
Przemek Tredak committed
2608
2609
2610
2611
2612
2613
2614
2615
2616
2617
2618
2619
2620
2621
2622
2623
2624
2625
2626
2627
2628
2629
2630
2631
2632
2633
2634
2635
2636
2637
    def reset_layer_norm_parameters(self) -> None:
        """Init LN params"""
        init.ones_(self.layer_norm_weight)
        init.zeros_(self.layer_norm_bias)

    def forward(
        self, inp: torch.Tensor, is_first_microbatch: Optional[bool] = None
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
        """
        Apply layer normalization to the input followed by a feedforward network (MLP Block).

        Parameters
        ----------
        inp : torch.Tensor
             Input tensor.
        is_first_microbatch : {True, False, None}, default = None
                             During training using either gradient accumulation or
                             pipeline parallelism a minibatch of data is further split
                             into microbatches. Between the microbatches of the same minibatch
                             the model weights are not updated. Setting this parameter indicates
                             whether the current microbatch is the first in a minibatch or not.
                             When set, this parameter enables additional optimizations:

                             * during FP8 training, it allows caching of the FP8 versions of
                               the weights
                             * it also allows skipping gradient accumulation during the
                               first microbatch (since it is the first gradient being
                               produced)
        """

2638
        with self.prepare_forward(inp, is_first_microbatch, num_gemms=2) as inp:
2639
2640
2641
2642
2643
2644
2645
            if self.training:
                fwd_fn = _LayerNormMLP.apply
                args = []
            else:
                fwd_fn = _LayerNormMLP.forward
                args = [None]
            args += (
2646
2647
2648
2649
2650
2651
2652
2653
2654
2655
2656
2657
2658
2659
2660
                inp,
                self.layer_norm_weight,
                self.layer_norm_bias,
                self.fc1_weight,
                self.weight1_fp8 if self.fp8 else None,
                self.weight1_t_fp8 if self.fp8 else None,
                self.fc1_bias,
                self.fc2_weight,
                self.weight2_fp8 if self.fp8 else None,
                self.weight2_t_fp8 if self.fp8 else None,
                self.fc2_bias,
                self.use_bias,
                self.eps,
                is_first_microbatch,
                self.fp8,
schetlur-nv's avatar
schetlur-nv committed
2661
                self.fp8_calibration,
2662
2663
2664
2665
2666
2667
2668
2669
2670
                self.fp8_meta,
                self.fuse_wgrad_accumulation,
                self.tp_group,
                self.sequence_parallel,
                self.tp_size > 1,
                self.activation_dtype,
                self.return_layernorm_output,
                self.bias_gelu_nvfusion,
                self.set_parallel_mode,
2671
                self.training,
2672
2673
                self.fwd_ln_sm_margin,
                self.bwd_ln_sm_margin,
2674
            )
2675
            out = fwd_fn(*args)
Przemek Tredak's avatar
Przemek Tredak committed
2676
2677
2678
2679
2680
2681
2682
2683
2684
2685
2686
2687
2688
2689
2690
2691
2692
2693
2694
2695
2696
2697
2698
2699
2700
2701

        if self.return_layernorm_output:
            out, ln_out = out

        if self.gemm_bias_unfused_add:
            out = out + cast_if_needed(self.fc2_bias, self.activation_dtype)

        if self.return_bias:
            if self.return_layernorm_output:
                return out, cast_if_needed(self.fc2_bias, self.activation_dtype), ln_out
            return out, cast_if_needed(self.fc2_bias, self.activation_dtype)
        if self.return_layernorm_output:
            return out, ln_out
        return out


class _LayerNorm(torch.autograd.Function):
    """functional LayerNorm"""

    @staticmethod
    def forward(
        ctx,
        inp: torch.Tensor,
        ln_weight: torch.Tensor,
        ln_bias: torch.Tensor,
        eps: float,
2702
2703
        fwd_ln_sm_margin: int,
        bwd_ln_sm_margin: int,
Przemek Tredak's avatar
Przemek Tredak committed
2704
2705
2706
2707
2708
2709
2710
    ) -> torch.Tensor:
        # Make sure input dimensions are compatible
        in_features = ln_weight.numel()
        assert inp.is_cuda, "TransformerEngine needs CUDA."
        assert inp.shape[-1] == in_features, "LayerNorm not possible"
        inputmat = inp.view((-1, in_features))

2711
        ln_out, mu, rsigma = tex.layernorm_fwd(inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin)
Przemek Tredak's avatar
Przemek Tredak committed
2712
2713
        ctx.save_for_backward(inputmat, ln_weight, mu, rsigma)
        ctx.inp_shape = inp.shape
2714
        ctx.bwd_ln_sm_margin = bwd_ln_sm_margin
Przemek Tredak's avatar
Przemek Tredak committed
2715
2716
2717
2718
2719
2720
2721
2722
2723
2724
        return ln_out.view_as(inp)

    @staticmethod
    def backward(
        ctx, grad_output: torch.Tensor
    ) -> Tuple[Union[torch.Tensor, None], ...]:
        inputmat, ln_weight, mu, rsigma = ctx.saved_tensors
        grad_output = grad_output.contiguous()
        d_ln_out = grad_output.view(inputmat.shape)
        dxmat, dgamma, dbeta = tex.layernorm_bwd(
2725
            d_ln_out, inputmat, mu, rsigma, ln_weight, ctx.bwd_ln_sm_margin
Przemek Tredak's avatar
Przemek Tredak committed
2726
        )
2727
        return dxmat.view(ctx.inp_shape), dgamma, dbeta, None, None, None
Przemek Tredak's avatar
Przemek Tredak committed
2728
2729
2730
2731
2732
2733
2734
2735
2736
2737
2738
2739
2740
2741
2742
2743
2744
2745
2746
2747
2748
2749
2750
2751
2752
2753
2754
2755
2756
2757
2758
2759
2760
2761
2762
2763


class LayerNorm(torch.nn.Module):
    r"""
    Applies Layer Normalization over a mini-batch of inputs as described in
    the paper `Layer Normalization <https://arxiv.org/abs/1607.06450>`__

    .. math::
        y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta

    :math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
    size :attr:`hidden_size`

    Parameters
    ----------
    hidden_size : int
                size of each input sample.
    eps : float, default = 1e-5
        a value added to the denominator of layer normalization for numerical stability.
    sequence_parallel : bool, default = `False`
                        if set to `True`, uses sequence parallelism.
    params_dtype : torch.dtype, default = `torch.float32`
                    it controls the type used to allocate the initial parameters. Useful when
                    the model is trained with lower precision and the original FP32 parameters
                    would not fit in GPU memory.
    """

    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-5,
        sequence_parallel: bool = False,
        params_dtype: torch.dtype = torch.float32,
    ) -> None:
        super().__init__()
        self.eps = eps
2764
        self.weight = Parameter(
Przemek Tredak's avatar
Przemek Tredak committed
2765
2766
2767
2768
2769
2770
            torch.empty(
                hidden_size,
                device=torch.cuda.current_device(),
                dtype=params_dtype,
            )
        )
2771
        self.bias = Parameter(
Przemek Tredak's avatar
Przemek Tredak committed
2772
2773
2774
2775
2776
2777
            torch.empty(
                hidden_size,
                device=torch.cuda.current_device(),
                dtype=params_dtype,
            )
        )
2778
2779
        setattr(self.weight, "sequence_parallel", sequence_parallel)
        setattr(self.bias, "sequence_parallel", sequence_parallel)
Przemek Tredak's avatar
Przemek Tredak committed
2780
2781
        self.reset_layer_norm_parameters()

2782
2783
2784
2785
2786
2787
2788
        # These many SMs are subtracted from the total SM count when calling forward
        # and backward LayerNorm C APIs. These envvars can be used to prevent the LN
        # kernels from using all SMs in the device. This is useful for cases such as
        # communication overlap with LN.
        self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))
        self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))

2789
2790
2791
2792
2793
2794
2795
2796
2797
2798
2799
2800
2801
2802
2803
2804
2805
    def load_state_dict(
        self,
        state_dict: Mapping[str, Any],
        strict: bool = True,
    ) -> None:
        """Override PyTorch loader to maintain backward compatibility
        with previous version of LayerNorm parameter names.
        """
        if "layer_norm_weight" in state_dict:
            state_dict["weight"] = state_dict["layer_norm_weight"]
            del state_dict["layer_norm_weight"]
        if "layer_norm_bias" in state_dict:
            state_dict["bias"] = state_dict["layer_norm_bias"]
            del state_dict["layer_norm_bias"]

        super().load_state_dict(state_dict, strict)

Przemek Tredak's avatar
Przemek Tredak committed
2806
2807
    def reset_layer_norm_parameters(self) -> None:
        """Init LN params"""
2808
2809
        init.ones_(self.weight)
        init.zeros_(self.bias)
Przemek Tredak's avatar
Przemek Tredak committed
2810

schetlur-nv's avatar
schetlur-nv committed
2811

Przemek Tredak's avatar
Przemek Tredak committed
2812
2813
    def forward(self, inp: torch.Tensor) -> torch.Tensor:
        """LayerNorm FWD"""
2814
2815
2816
2817
2818
2819
        # Maintain backward compatibility.
        if hasattr(self, "layer_norm_weight"):
            setattr(self, "weight", self.layer_norm_weight)
        if hasattr(self, "layer_norm_bias"):
            setattr(self, "bias", self.layer_norm_bias)

2820
2821
2822
2823
2824
2825
2826
2827
        return _LayerNorm.apply(
            inp,
            self.weight,
            self.bias,
            self.eps,
            self.fwd_ln_sm_margin,
            self.bwd_ln_sm_margin,
        )