module.py 122 KB
Newer Older
1
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
5
6
#
# See LICENSE for license information.

"""Top level Transformer Engine PyTorch modules"""
import os
7
import pickle
Przemek Tredak's avatar
Przemek Tredak committed
8
9
import warnings
from abc import ABC, abstractmethod
10
from typing import Union, Optional, Callable, Tuple, Dict, Any, Mapping, List
Przemek Tredak's avatar
Przemek Tredak committed
11
from functools import partial
12
from contextlib import contextmanager
Przemek Tredak's avatar
Przemek Tredak committed
13

14
import numpy as np
Przemek Tredak's avatar
Przemek Tredak committed
15
import torch
16
import torch.nn.functional as F
Przemek Tredak's avatar
Przemek Tredak committed
17
18
19
20
21
22
from torch.nn.parameter import Parameter
from torch.nn import init

import transformer_engine_extensions as tex
from .fp8 import (
    is_fp8_enabled,
schetlur-nv's avatar
schetlur-nv committed
23
    is_fp8_calibration,
Przemek Tredak's avatar
Przemek Tredak committed
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
    get_fp8_recipe,
    get_fp8_group,
    get_default_fp8_recipe,
    get_fp8_te_dtype,
    is_first_fp8_module,
    new_fp8_context_id,
    get_fp8_context_id,
    set_fp8_context_id,
    add_amax_to_global_buffer,
    copy_amax_from_global_buffer,
    global_amax_reduction,
    setup_amax_forward_global_reduce_func,
    amax_and_scale_update,
    get_global_fp8_buffer,
    set_global_fp8_buffer,
    set_amax_buffer_key_deletion,
    delete_key_from_amax_buffer,
41
42
43
    copy_forward_fp8_meta_tensors_for_recompute,
    get_old_fp8_meta_tensors_for_recompute,
    restore_fp8_meta_tensors,
Sangkug Lym's avatar
Sangkug Lym committed
44
    get_amax_reduce_handle_fwd,
Przemek Tredak's avatar
Przemek Tredak committed
45
46
47
48
49
50
51
52
53
54
55
)
from .jit import (
    bias_gelu_fused,
    bgrad_dgelu_fused,
    set_jit_fusion_options,
    warmup_jit_bias_gelu_all_dtypes,
)
from .utils import (
    divide,
    get_default_init_method,
    cast_if_needed,
56
    check_dim_for_fp8_forward_exec,
Przemek Tredak's avatar
Przemek Tredak committed
57
58
59
60
61
62
63
64
65
)
from .distributed import (
    set_tensor_model_parallel_attributes,
    get_distributed_world_size,
    allreduce,
    initialize_affine_weight_gpu,
    reduce_scatter_along_first_dim,
    gather_along_first_dim,
    gather_along_last_dim,
66
67
    is_fp8_activation_recompute_enabled,
    in_fp8_activation_recompute_phase,
Przemek Tredak's avatar
Przemek Tredak committed
68
69
70
71
72
73
74
75
76
)
from .cpp_extensions import (
    fp8_gemm,
    gemm,
    fp8_cast_transpose_fused,
    fp8_cast_transpose_bgrad_fused,
    fp8_gelu,
    fp8_cast_transpose_bgrad_dgelu_fused,
    layernorm_fwd_fp8,
77
78
    layernorm_fwd_fp8_inf,
    layernorm_fwd_inf,
Przemek Tredak's avatar
Przemek Tredak committed
79
80
81
82
83
84
85
86
87
    cast_to_fp8,
    cast_from_fp8,
)
from .constants import GemmParallelModes, dist_group_type, TE_DType

_2X_ACC_FPROP = False
_2X_ACC_DGRAD = True
_2X_ACC_WGRAD = True
_cublas_workspace = None
Sangkug Lym's avatar
Sangkug Lym committed
88
_amax_reduce_handle_bwd = None
Przemek Tredak's avatar
Przemek Tredak committed
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106


def get_cublas_workspace_size_bytes() -> None:
    """Return 32 MiB if using hopper, 4 MiB for all other architectures."""
    if torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 9:
        return 33_554_432
    return 4_194_304


def get_workspace() -> torch.Tensor:
    """Returns workspace for cublas."""
    global _cublas_workspace
    if _cublas_workspace is None:
        _cublas_workspace = torch.empty(
            get_cublas_workspace_size_bytes(), dtype=torch.int8, device="cuda"
        )
    return _cublas_workspace

107
@contextmanager
Sangkug Lym's avatar
Sangkug Lym committed
108
109
110
111
112
113
114
def _prepare_backward(
    fp8: bool,
    fp8_meta: Dict[str, Any],
    tp_group: dist_group_type,
    tp_size: int,
    name: str = ""
) -> None:
115
116
    """Checks and prep for BWD."""
    if fp8:
Sangkug Lym's avatar
Sangkug Lym committed
117
118
119
120
121
        global _amax_reduce_handle_bwd
        if _amax_reduce_handle_bwd is not None:
            _amax_reduce_handle_bwd.wait()
            _amax_reduce_handle_bwd = None

122
        # Update amax and scale; Skip all setup for global amax reduction
123
        if not fp8_meta["recipe"].reduce_amax:
124
125
126
127
128
129
130
131
            amax_and_scale_update(fp8_meta, False)
        else:
            # From previous iteration
            copy_amax_from_global_buffer(fp8_meta, forward=False)
            amax_and_scale_update(fp8_meta, False)
            set_amax_buffer_key_deletion(fp8_meta, forward=False)

            # Get new backward key.
132
            fp8_meta["autocast_id_bwd"] = fp8_meta["autocast_id_fwd_stack"].pop(0)
133
134
135
136
137
138

            add_amax_to_global_buffer(fp8_meta, forward=False)

    with torch.cuda.nvtx.range(name + " backward"):
        yield

139
    if fp8 and fp8_meta["recipe"].reduce_amax:
140
        if fp8_meta["first_module"]:
Sangkug Lym's avatar
Sangkug Lym committed
141
142
143
144
145
146
            _amax_reduce_handle_bwd = global_amax_reduction(
                fp8_meta,
                tp_group,
                tp_size,
                forward=False
            )
147
            delete_key_from_amax_buffer(forward=False)
148

Przemek Tredak's avatar
Przemek Tredak committed
149

150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
class _NoopCat(torch.autograd.Function):
    """This class is a no-op replacement for `torch.cat`."""

    @staticmethod
    def forward(ctx,
                full_param_buffer: torch.Tensor,
                *params_split: Tuple[torch.Tensor, ...],
    ) -> torch.Tensor:
        assert not full_param_buffer.requires_grad, "Buffers should not require gradient"
        assert (
            full_param_buffer.shape[0] % len(params_split) == 0
        ), "Dimensions not compatible for concatenation"

        param_temp = full_param_buffer.new()
        param_temp.set_(full_param_buffer.storage(),
                        full_param_buffer.storage_offset(),
                        full_param_buffer.size(),
                        full_param_buffer.stride())
        param_temp.requires_grad = True

        ctx.save_for_backward(full_param_buffer, *params_split)
        return param_temp

    @staticmethod
    def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
        full_param_buffer, *params_split = ctx.saved_tensors

        split_size = full_param_buffer.shape[0] // len(params_split)
        grads = []

        for i, _ in enumerate(params_split):
            grads.append(grad_output[i * split_size : (i+1) * split_size])

        return None, *grads


Przemek Tredak's avatar
Przemek Tredak committed
186
187
188
189
190
191
class TransformerEngineBaseModule(torch.nn.Module, ABC):
    """Base TE module."""

    def __init__(self) -> None:
        super().__init__()
        assert torch.cuda.is_available(), "TransformerEngine needs CUDA."
schetlur-nv's avatar
schetlur-nv committed
192
        self.fp8_initialized = False
Przemek Tredak's avatar
Przemek Tredak committed
193
        self.fp8 = False
schetlur-nv's avatar
schetlur-nv committed
194
        self.fp8_calibration = False
Przemek Tredak's avatar
Przemek Tredak committed
195
196
197
198
199
200
201
202
        self.fp8_meta = {}
        self.fp8_meta["fp8_group"] = None
        self.fp8_meta["recipe"] = get_default_fp8_recipe()
        self.fp8_meta_tensors_initialized = False
        self.tp_group = None
        self.tp_size = 1
        self.sequence_parallel = False
        self.fp8_weight_shapes = []
203
        self.fp8_meta["autocast_id_fwd_stack"] = []
Sangkug Lym's avatar
Sangkug Lym committed
204
        self.fp8_meta["async_amax_reduction"] = bool(
205
            int(os.getenv("NVTE_ASYNC_AMAX_REDUCTION", "0"))
Sangkug Lym's avatar
Sangkug Lym committed
206
        )
Przemek Tredak's avatar
Przemek Tredak committed
207
208
209
210

    def set_meta_tensor(self, fwd: bool) -> None:
        """Init scales and amaxes for fwd | bwd."""
        fp8_meta_tensor_key = "scaling_fwd" if fwd else "scaling_bwd"
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227

        if self.fp8_meta_tensors_initialized:
            # Handle changed amax history size.
            curr_len = self.fp8_meta[fp8_meta_tensor_key].amax_history.shape[0]
            need_len = self.fp8_meta["recipe"].amax_history_len
            if need_len < curr_len:
                self.fp8_meta[fp8_meta_tensor_key].amax_history = (
                    self.fp8_meta[fp8_meta_tensor_key]
                    .amax_history[: self.fp8_meta["recipe"].amax_history_len].clone()
                )
            elif need_len > curr_len:
                extra_rows = need_len - curr_len
                self.fp8_meta[fp8_meta_tensor_key].amax_history = F.pad(
                    self.fp8_meta[fp8_meta_tensor_key].amax_history, pad=(0, 0, 0, extra_rows)
                )
            return

228
229
        # Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and
        # 2 (grad_output and grad_input) for bwd
Przemek Tredak's avatar
Przemek Tredak committed
230
        num_fp8_tensors = (
231
            self.fp8_meta["num_gemms"] * 3 if fwd else self.fp8_meta["num_gemms"] * 2
Przemek Tredak's avatar
Przemek Tredak committed
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
        )

        self.fp8_meta[fp8_meta_tensor_key] = tex.FP8TensorMeta()
        self.fp8_meta[fp8_meta_tensor_key].scale = torch.ones(
            num_fp8_tensors, dtype=torch.float32, device="cuda"
        )
        self.fp8_meta[fp8_meta_tensor_key].scale_inv = torch.ones(
            num_fp8_tensors, dtype=torch.float32, device="cuda"
        )
        self.fp8_meta[fp8_meta_tensor_key].amax_history = torch.zeros(
            self.fp8_meta["recipe"].amax_history_len,
            num_fp8_tensors,
            dtype=torch.float32,
            device="cuda",
        )

248
249
250
        # Needed for calculation of scale inverses to
        # preserve scale_inv when caching FP8 weights
        if fwd:
251
            # [True, False, True]: -> [input, weight, output]
252
            self.fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"] = torch.BoolTensor(
253
                [True, False, True] * self.fp8_meta["num_gemms"]
254
255
            ).cuda()
        else:
256
            # [True, True]: -> [grad_output, grad_input]
257
            self.fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"] = torch.BoolTensor(
258
                [True, True] * self.fp8_meta["num_gemms"]
259
260
            ).cuda()

Przemek Tredak's avatar
Przemek Tredak committed
261
262
263
264
    def init_fp8_meta_tensors(self) -> None:
        """Init scales and amaxes."""
        self.set_meta_tensor(True)
        self.set_meta_tensor(False)
265
        self.fp8_meta_tensors_initialized = True
Przemek Tredak's avatar
Przemek Tredak committed
266

267
    def get_extra_state(self) -> torch.Tensor:
Przemek Tredak's avatar
Przemek Tredak committed
268
        """Save before checkpointing."""
269
        state = None
schetlur-nv's avatar
schetlur-nv committed
270
        if self.fp8 or self.fp8_calibration:
271
272
            state = {}
            state["scale_fwd"] = self.fp8_meta["scaling_fwd"].scale
273
            state["scale_inv_fwd"] = self.fp8_meta["scaling_fwd"].scale_inv
274
275
            state["amax_history_fwd"] = self.fp8_meta["scaling_fwd"].amax_history
            state["scale_bwd"] = self.fp8_meta["scaling_bwd"].scale
276
            state["scale_inv_bwd"] = self.fp8_meta["scaling_bwd"].scale_inv
277
278
279
280
281
282
283
284
285
286
            state["amax_history_bwd"] = self.fp8_meta["scaling_bwd"].amax_history
            state["global_fp8_buffer"] = get_global_fp8_buffer()

            # Store other pickelable values.
            extra = {}
            for k, v in self.fp8_meta.items():
                if isinstance(v, (bool, int, float, str)):
                    extra[k] = v
            state["extra_fp8_variables"] = extra

287
288
        state_serialized = pickle.dumps(state)
        state_tensor = torch.tensor(np.frombuffer(state_serialized, dtype=np.uint8))
Przemek Tredak's avatar
Przemek Tredak committed
289

290
291
292
        return state_tensor

    def set_extra_state(self, state: torch.Tensor) -> None:
Przemek Tredak's avatar
Przemek Tredak committed
293
294
295
296
        """Load previous state."""
        if state is None:
            return

297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
        # Maintain backward compatibility with v0.2.0 and older.
        if isinstance(state, list):
            warnings.warn(
                "This checkpoint format is deprecated and will be"
                "removed in a future release of Transformer Engine"
            )

            # Retrieve checkpointed items.
            scale_fwd = state[0]
            amax_history_fwd = state[1]
            scale_bwd = state[2]
            amax_history_bwd = state[3]
            self.fp8_meta["recipe"].amax_history_len = amax_history_fwd.shape[0]
            self.fp8_meta["num_gemms"] = (
                amax_history_fwd.shape[1] // 2
            )  # Two FWD tensors per GEMM

            # Initialize before loading
            self.init_fp8_meta_tensors()
            self.fp8_meta["scaling_fwd"].scale.copy_(scale_fwd)
            self.fp8_meta["scaling_fwd"].amax_history.copy_(amax_history_fwd)
            self.fp8_meta["scaling_bwd"].scale.copy_(scale_bwd)
            self.fp8_meta["scaling_bwd"].amax_history.copy_(amax_history_bwd)

            # Restore global FP8 buffer state.
            set_global_fp8_buffer(state[4])
            self.fp8_meta["update_amax_and_scale_fwd"] = state[5]
            self.fp8_meta["global_fp8_buffer_pos_fwd"] = state[6]
            self.fp8_meta["global_fp8_buffer_pos_bwd"] = state[7]
            self.fp8_meta["autocast_id_fwd"] = state[8]
            self.fp8_meta["autocast_id_bwd"] = state[9]
            return

330
        if isinstance(state, torch.Tensor):
331
            state = pickle.loads(state.detach().cpu().numpy().tobytes())
332
333
334
            if state is None:
                return

335
336
337
338
339
        # Restore global FP8 buffer states.
        set_global_fp8_buffer(state["global_fp8_buffer"])
        # Load extra items.
        self.fp8_meta.update(state["extra_fp8_variables"])
        self.fp8_meta["recipe"].amax_history_len = state["amax_history_fwd"].shape[0]
340
341
        if "global_fp8_buffer_pos_fwd_recompute" in self.fp8_meta:
            del self.fp8_meta["global_fp8_buffer_pos_fwd_recompute"]
342
343

        # Initialize before loading.
Przemek Tredak's avatar
Przemek Tredak committed
344
        self.init_fp8_meta_tensors()
345
346
347
348
        self.fp8_meta["scaling_fwd"].scale.copy_(state["scale_fwd"])
        self.fp8_meta["scaling_fwd"].amax_history.copy_(state["amax_history_fwd"])
        self.fp8_meta["scaling_bwd"].scale.copy_(state["scale_bwd"])
        self.fp8_meta["scaling_bwd"].amax_history.copy_(state["amax_history_bwd"])
Przemek Tredak's avatar
Przemek Tredak committed
349

350
351
352
353
354
355
356
357
358
359
360
361
        # Backwards compatibility: compute scale inv if it wasn't saved in the extra state.
        if "scale_inv_fwd" not in state or "scale_inv_bwd" not in state:
            assert (
                "scale_inv_fwd" not in state and "scale_inv_bwd" not in state
            ), "Invalid state, began saving scale_inv_fwd and scale_inv_bwd at the same time"
            self.fp8_meta["scaling_fwd"].scale_inv.copy_(1.0/state["scale_fwd"])
            self.fp8_meta["scaling_bwd"].scale_inv.copy_(1.0/state["scale_bwd"])
        else:
            self.fp8_meta["scaling_fwd"].scale_inv.copy_(state["scale_inv_fwd"])
            self.fp8_meta["scaling_bwd"].scale_inv.copy_(state["scale_inv_bwd"])


Przemek Tredak's avatar
Przemek Tredak committed
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
    def set_activation_dtype(self, inp: torch.Tensor) -> None:
        """Get activation data type for AMP."""
        # Native AMP (`torch.autocast`) gets highest priority
        if torch.is_autocast_enabled():
            self.activation_dtype = torch.get_autocast_gpu_dtype()
            return

        # All checks after this have already been performed once, thus skip
        # We assume that user doesn't change input types across iterations
        if hasattr(self, "activation_dtype"):
            return

        assert all(
            (
                (inp.dtype == param.dtype) if param is not None else True
                for param in self.parameters()
            )
        ), (
            "Data type for activations and weights must "
            "match when outside of autocasted region"
        )
        assert all(
            (
                (inp.dtype == buf.dtype) if buf is not None else True
                for buf in self.buffers()
            )
        ), (
            "Data type for activations and buffers must "
            "match when outside of autocasted region"
        )
        self.activation_dtype = inp.dtype

    def set_fp8_weights(self) -> None:
        """Initializes FP8 weights for the module as class attributes. These
        are not parameters or buffers since we do not want functions such as
        `.to(dtype)` or `.to(device)` to effect them. These also do not need
        to be checkpointed. During `init` phase of the module, the attribute
        `fp8_weight_shapes` must be populated with the tensor shapes for FP8
        weights. This function will iterate over those shapes and initialize
        respective attributed named `weight1_fp8`, `weight2_fp8`, ...
        """
403
404
405
        if not self.fp8:
            return

Przemek Tredak's avatar
Przemek Tredak committed
406
407
408
        for i, shape in enumerate(self.fp8_weight_shapes, start=1):
            weight_cast_attr = f"weight{i}_fp8"
            weight_transpose_attr = f"weight{i}_t_fp8"
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434

            if (
                hasattr(self, weight_cast_attr)
                and getattr(self, weight_cast_attr).shape == shape
            ):
                return

            setattr(
                self,
                weight_cast_attr,
                torch.empty(
                    shape,
                    device=torch.cuda.current_device(),
                    dtype=torch.int8,
                ),
            )
            setattr(
                self,
                weight_transpose_attr,
                torch.empty(
                    shape[1],
                    shape[0],
                    device=torch.cuda.current_device(),
                    dtype=torch.int8,
                ),
            )
Przemek Tredak's avatar
Przemek Tredak committed
435
436
437
438
439
440

    def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
        """Set TP group."""
        self.tp_group = tp_group
        self.tp_group_initialized = True

schetlur-nv's avatar
schetlur-nv committed
441
442
    # This routine is shared across FP8 and FP8_calibration paths so should not actually
    # assume FP8 execution.
Przemek Tredak's avatar
Przemek Tredak committed
443
444
    def fp8_init(self, num_gemms: int = 1) -> None:
        """Initialize fp8 related metadata and tensors during fprop."""
schetlur-nv's avatar
schetlur-nv committed
445
446
447
448
        if is_fp8_enabled() or is_fp8_calibration():
            # FP8 init has already been run and recipe is the same, don't do anything.
            if self.fp8_initialized and get_fp8_recipe() == self.fp8_meta["recipe"]:
                return
Przemek Tredak's avatar
Przemek Tredak committed
449

schetlur-nv's avatar
schetlur-nv committed
450
451
452
453
454
455
            # Set FP8, recipe, and other FP8 metadata
            self.fp8 = is_fp8_enabled()
            self.fp8_calibration = is_fp8_calibration()
            self.fp8_meta["recipe"] = get_fp8_recipe()
            self.fp8_meta["num_gemms"] = num_gemms
            self.fp8_meta["fp8_group"] = get_fp8_group()
Przemek Tredak's avatar
Przemek Tredak committed
456

schetlur-nv's avatar
schetlur-nv committed
457
458
459
            # Set FP8_MAX per tensor according to recipe
            self.fp8_meta["fp8_max_fwd"] = self.fp8_meta["recipe"].fp8_format.value.max_fwd
            self.fp8_meta["fp8_max_bwd"] = self.fp8_meta["recipe"].fp8_format.value.max_bwd
Przemek Tredak's avatar
Przemek Tredak committed
460

schetlur-nv's avatar
schetlur-nv committed
461
462
463
464
465
466
467
            # Allocate scales and amaxes
            self.init_fp8_meta_tensors()
            self.fp8_initialized = True
        else:
            # If fp8 isn't enabled, turn off and return.
            self.fp8_initialized = False
            return
Przemek Tredak's avatar
Przemek Tredak committed
468

469
    @contextmanager
470
471
472
473
474
475
    def prepare_forward(
        self,
        inp: torch.Tensor,
        is_first_microbatch: Union[bool, None],
        num_gemms: int = 1,
    ) -> None:
476
477
478
479
480
481
482
        """Checks and prep for FWD.
        The context manager is needed because there isn't a way for a module to know
        if it's the last FP8 module in the forward autocast. It is useful
        to setup the forward aggregated amax reduction for every module
        just in case. The autocast exit will pick up the most recent one.
        """

483
484
485
        # Activation recomputation is used and this is the second forward phase.
        if self.fp8 and in_fp8_activation_recompute_phase():
            get_old_fp8_meta_tensors_for_recompute(self.fp8_meta)
486
487
        else:
            assert inp.is_cuda, "TransformerEngine needs CUDA."
Przemek Tredak's avatar
Przemek Tredak committed
488

489
490
            if self.tp_size > 1:
                assert self.tp_group_initialized, "TP group not initialized."
Przemek Tredak's avatar
Przemek Tredak committed
491

492
493
494
            self.set_activation_dtype(inp)
            self.fp8_init(num_gemms=num_gemms)
            self.set_fp8_weights()
Przemek Tredak's avatar
Przemek Tredak committed
495

496
            update_weight_scale_inv = is_first_microbatch is None or is_first_microbatch
497
498
499
500
            if self.fp8 and self.sequence_parallel:
                assert self.fp8_meta["recipe"].reduce_amax, \
                "Amax reduction across tensor parallel group is " \
                "necessary when using sequence parallelism with FP8."
501
502
503

            # Previous iteration was grad_enabled
            if self.fp8_meta.get("update_amax_and_scale_fwd", False):
504
                if self.fp8_meta["recipe"].reduce_amax:
505
506
507
508
509
510
511
512
513
514
515
                    copy_amax_from_global_buffer(self.fp8_meta, forward=True)
                    amax_and_scale_update(
                        self.fp8_meta, True, update_weight_scale_inv=update_weight_scale_inv
                    )
                    set_amax_buffer_key_deletion(self.fp8_meta, forward=True)
                else:
                    amax_and_scale_update(
                        self.fp8_meta, True, update_weight_scale_inv=update_weight_scale_inv
                    )

            if self.fp8 and self.training:
516
                # Setup for amax reduction
517
                if self.fp8_meta["recipe"].reduce_amax:
518
519
                    self.fp8_meta["first_module"] = is_first_fp8_module()
                    if self.fp8_meta["first_module"]:
Sangkug Lym's avatar
Sangkug Lym committed
520
521
522
523
                        # Wait for the prior AMAX reduction to finish
                        amax_reduce_handle_fwd = get_amax_reduce_handle_fwd()
                        if amax_reduce_handle_fwd is not None:
                            amax_reduce_handle_fwd.wait()
524
525
526
527
                        self.fp8_meta["autocast_id_fwd"] = new_fp8_context_id()
                        set_fp8_context_id(self.fp8_meta["autocast_id_fwd"])
                    else:
                        self.fp8_meta["autocast_id_fwd"] = get_fp8_context_id()
528
529
530
                    self.fp8_meta["autocast_id_fwd_stack"].append(
                        self.fp8_meta["autocast_id_fwd"]
                    )
531
                    add_amax_to_global_buffer(self.fp8_meta, forward=True)
532
533
534
                self.fp8_meta["update_amax_and_scale_fwd"] = True
            else:
                self.fp8_meta["update_amax_and_scale_fwd"] = False
535

536
537
538
            # Activation recomputation is used and this is the first forward phase.
            if (
                self.fp8
539
                and self.training
540
541
542
543
                and is_fp8_activation_recompute_enabled()
                and not in_fp8_activation_recompute_phase()
            ):
                copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta)
544

545
546
        with torch.cuda.nvtx.range(self.__class__.__name__ + " forward"):
            yield inp.contiguous()
Przemek Tredak's avatar
Przemek Tredak committed
547

548
549
550
551
        if self.fp8 and in_fp8_activation_recompute_phase():
            restore_fp8_meta_tensors(self.fp8_meta)
            return

552
        if self.fp8 and self.training and self.fp8_meta["recipe"].reduce_amax:
Przemek Tredak's avatar
Przemek Tredak committed
553
            set_fp8_context_id(self.fp8_meta["autocast_id_fwd"])
Sangkug Lym's avatar
Sangkug Lym committed
554
555
556
557
558
559
560
            reduce_func = partial(
                global_amax_reduction,
                self.fp8_meta,
                self.tp_group,
                self.tp_size,
                forward=True
            )
Przemek Tredak's avatar
Przemek Tredak committed
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
            setup_amax_forward_global_reduce_func(reduce_func)

    def set_nccl_overlap_warning_if_tp(self) -> None:
        """When using TP, the NCCL communication needs to be scheduled
        before the GEMM for there to be a guaranteed overlap. From the
        host side in TE, the comm calls are always launched first, but
        to ensure that the GEMM isn't scheduled first, the environment
        variable `CUDA_DEVICE_MAX_CONNECTIONS` needs to be set to 1 to
        force a single channel.
        """
        if self.tp_size == 1:
            return
        num_cuda_work_queues = int(os.getenv("CUDA_DEVICE_MAX_CONNECTIONS", "0"))
        if num_cuda_work_queues != 1:
            warnings.warn(
                "To guarantee overlapping TP and SP collectives with the backward"
                "GEMMs, set environment variable CUDA_DEVICE_MAX_CONNECTIONS = 1"
            )

    @staticmethod
    def grad_output_preprocess(
        ctx, grad_output: torch.Tensor, row_parallel_mode: bool
    ) -> Tuple[Union[torch.Tensor, None], ...]:
        """Utility function for backward.
        Returns tuple in order (all optional/None based on training precion/recipe):
            R1: gathered `grad_output` in higher precision.
            R2: gathered `grad_output` in FP8.
            R3: R2 transposed.
            R4: bias gradient on R1.

        """
        grad_output = grad_output.contiguous()
        grad_output_mat = grad_output.view((-1, grad_output.shape[-1]))
        gather_grad_output = row_parallel_mode and ctx.sequence_parallel

        # No-FP8 case: bgrad is fused with wgrad for this case.
        if not ctx.fp8:
            if gather_grad_output:
                grad_output_mat, _ = gather_along_first_dim(
                    grad_output_mat, ctx.tp_group
                )
            return grad_output_mat, None, None, None

        fp8_dtype_backward = get_fp8_te_dtype(
            ctx.fp8_meta["recipe"], fprop_tensor=False
        )

        # FP8 case with non-FP8 wgrad
        if (
            gather_grad_output
            and ctx.fp8_meta["recipe"].override_linear_precision.wgrad
        ):
            grad_output_mat, _ = gather_along_first_dim(grad_output_mat, ctx.tp_group)
        # FP8 case with gather: unfused bgrad, cast, transpose for efficient gather
        elif gather_grad_output:
            if ctx.use_bias:
                grad_bias = grad_output_mat.sum(dim=0)
            else:
                grad_bias = None
            grad_output_c = cast_to_fp8(
                grad_output_mat,
                ctx.fp8_meta["scaling_bwd"],
                tex.FP8BwdTensors.GRAD_OUTPUT1,
                fp8_dtype_backward,
            )
            grad_output_c, _ = gather_along_first_dim(grad_output_c, ctx.tp_group)
            grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward)

            return grad_output_mat, grad_output_c, grad_output_t, grad_bias

        # FP8 case without gather: cast, transpose, bgrad fused
        if ctx.use_bias:
            grad_bias, grad_output_c, grad_output_t = fp8_cast_transpose_bgrad_fused(
                grad_output_mat,
                ctx.fp8_meta["scaling_bwd"],
                tex.FP8BwdTensors.GRAD_OUTPUT1,
                fp8_dtype_backward,
            )
        else:
            if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
                grad_output_c, grad_output_t = fp8_cast_transpose_fused(
                    grad_output_mat,
                    ctx.fp8_meta["scaling_bwd"],
                    tex.FP8BwdTensors.GRAD_OUTPUT1,
                    fp8_dtype_backward,
                )
            else:
648
                grad_output_t = None
Przemek Tredak's avatar
Przemek Tredak committed
649
650
651
652
653
654
655
656
657
658
                grad_output_c = cast_to_fp8(
                    grad_output_mat,
                    ctx.fp8_meta["scaling_bwd"],
                    tex.FP8BwdTensors.GRAD_OUTPUT1,
                    fp8_dtype_backward,
                )
            grad_bias = None

        return grad_output_mat, grad_output_c, grad_output_t, grad_bias

659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
    def noop_cat(self, buffer_name: str, pnames: List[str]) -> torch.Tensor:
        """No-op replacement of `torch.cat`. The buffer and split parameters must occupy
           the same memory region. If this is not the case, then the split parameters
           are concatenated and the buffer is overwritten. The parameters' memory is then
           re-assigned to point to the buffer to avoid subsequent concatenations.
        """

        assert hasattr(self, buffer_name), f"No buffer named {buffer_name}"
        full_param_buffer = getattr(self, buffer_name)
        split_size = full_param_buffer.shape[0] // len(pnames)
        params = [getattr(self, name) for name in pnames]
        for i, p in enumerate(params):
            if p.data.data_ptr() != full_param_buffer[i*split_size : (i+1)*split_size].data_ptr():
                with torch.no_grad():
                    setattr(self, buffer_name, torch.cat(params))
                    for j, pname in enumerate(pnames):
                        full_param_buffer = getattr(self, buffer_name)
                        setattr(self, pname,
                                Parameter(full_param_buffer[j*split_size : (j+1)*split_size]))
                break

        return _NoopCat.apply(getattr(self, buffer_name), *[getattr(self, name) for name in pnames])

Przemek Tredak's avatar
Przemek Tredak committed
682
683
684
685
686
    @abstractmethod
    def forward(self):
        """Needs override."""


687

Przemek Tredak's avatar
Przemek Tredak committed
688
689
690
691
692
693
694
695
696
697
698
699
class _LayerNormLinear(torch.autograd.Function):
    """LayerNormLinear semi-top level module
    Calls custom cuda extensions.
    """

    @staticmethod
    def forward(
        ctx,
        inp: torch.Tensor,
        ln_weight: torch.Tensor,
        ln_bias: torch.Tensor,
        weight: torch.Tensor,
700
701
        weight_fp8: Union[torch.Tensor, None],
        weight_t_fp8: Union[torch.Tensor, None],
Przemek Tredak's avatar
Przemek Tredak committed
702
703
704
705
706
        bias: torch.Tensor,
        use_bias: bool,
        eps: float,
        is_first_microbatch: Union[bool, None],
        fp8: bool,
schetlur-nv's avatar
schetlur-nv committed
707
        fp8_calibration: bool,
Przemek Tredak's avatar
Przemek Tredak committed
708
709
710
        fp8_meta: Dict[str, Any],
        fuse_wgrad_accumulation: bool,
        tp_group: Union[dist_group_type, None],
Sangkug Lym's avatar
Sangkug Lym committed
711
        tp_size: int,
Przemek Tredak's avatar
Przemek Tredak committed
712
        sequence_parallel: bool,
713
        tensor_parallel: bool,
Przemek Tredak's avatar
Przemek Tredak committed
714
715
716
        activation_dtype: torch.dtype,
        parallel_mode: Union[str, None],
        return_layernorm_output: bool,
717
        is_grad_enabled: bool,
718
719
        fwd_ln_sm_margin: int,
        bwd_ln_sm_margin: int,
720
        zero_centered_gamma: bool,
Przemek Tredak's avatar
Przemek Tredak committed
721
722
723
724
725
    ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
        # Make sure input dimensions are compatible
        in_features = ln_weight.numel()
        assert inp.shape[-1] == in_features, "GEMM not possible"
        inputmat = inp.view((-1, in_features))
726
        assert (
727
728
            not fp8 or check_dim_for_fp8_forward_exec(inputmat, weight)
        ), "Input and weight dimensions are not compatible for FP8 execution."
Przemek Tredak's avatar
Przemek Tredak committed
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743

        update_fp8_weights = is_first_microbatch is None or is_first_microbatch

        # Cast for native AMP
        inputmat = cast_if_needed(inputmat, activation_dtype)
        ln_weight = cast_if_needed(ln_weight, activation_dtype)
        ln_bias = cast_if_needed(ln_bias, activation_dtype)

        # If residual connection is after LN, we need `ln_out`
        # tensor in higher precision, this comes at the cost
        # of an extra fp8 cast.
        if fp8:
            fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)

            if not return_layernorm_output:
744
                if is_grad_enabled:
745
746
747
748
749
750
751
752
                    ln_out, mu, rsigma = layernorm_fwd_fp8(
                        inputmat,
                        ln_weight,
                        ln_bias,
                        eps,
                        fp8_meta["scaling_fwd"],
                        tex.FP8FwdTensors.GEMM1_INPUT,
                        fp8_dtype_forward,
753
                        fwd_ln_sm_margin,
754
                        zero_centered_gamma,
755
756
757
758
759
760
761
762
763
764
765
                    )
                else:
                    mu = rsigma = None
                    ln_out = layernorm_fwd_fp8_inf(
                        inputmat,
                        ln_weight,
                        ln_bias,
                        eps,
                        fp8_meta["scaling_fwd"],
                        tex.FP8FwdTensors.GEMM1_INPUT,
                        fp8_dtype_forward,
766
                        zero_centered_gamma,
767
                    )
Przemek Tredak's avatar
Przemek Tredak committed
768
            else:
769
                if is_grad_enabled:
770
                    ln_out_return, mu, rsigma = tex.layernorm_fwd(
771
                        inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin, zero_centered_gamma
772
773
774
                    )
                else:
                    ln_out_return, mu, rsigma = layernorm_fwd_inf(
775
                        inputmat, ln_weight, ln_bias, eps, zero_centered_gamma
776
777
                    ), None, None

Przemek Tredak's avatar
Przemek Tredak committed
778
779
780
781
782
783
784
                ln_out = cast_to_fp8(
                    ln_out_return,
                    fp8_meta["scaling_fwd"],
                    tex.FP8FwdTensors.GEMM1_INPUT,
                    fp8_dtype_forward,
                )
        else:
785
            if is_grad_enabled:
786
                ln_out, mu, rsigma = tex.layernorm_fwd(
787
                    inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin, zero_centered_gamma
788
                )
789
790
            else:
                ln_out, mu, rsigma = layernorm_fwd_inf(
791
                        inputmat, ln_weight, ln_bias, eps, zero_centered_gamma
792
                ), None, None
Przemek Tredak's avatar
Przemek Tredak committed
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
            ln_out_return = ln_out

        # Column Parallel Linear
        if parallel_mode == "column" and sequence_parallel:
            ln_out_total, _ = gather_along_first_dim(ln_out, tp_group)
        else:
            ln_out_total = ln_out

        if fp8:
            bias_dtype = (
                torch.bfloat16
                if activation_dtype == torch.float32
                else activation_dtype
            )
            bias = cast_if_needed(bias, bias_dtype) if use_bias else bias

            if update_fp8_weights:
810
                if is_grad_enabled:
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
                    fp8_cast_transpose_fused(
                        weight,
                        fp8_meta["scaling_fwd"],
                        tex.FP8FwdTensors.GEMM1_WEIGHT,
                        fp8_dtype_forward,
                        cast_out=weight_fp8,
                        transpose_out=weight_t_fp8,
                    )
                else:
                    weight_t_fp8 = None
                    weight_fp8 = cast_to_fp8(
                        weight,
                        fp8_meta["scaling_fwd"],
                        tex.FP8FwdTensors.GEMM1_WEIGHT,
                        fp8_dtype_forward)
Przemek Tredak's avatar
Przemek Tredak committed
826
827
828

            out = fp8_gemm(
                weight_fp8,
829
830
                fp8_meta["scaling_fwd"].scale_inv,
                tex.FP8FwdTensors.GEMM1_WEIGHT,
Przemek Tredak's avatar
Przemek Tredak committed
831
832
                fp8_dtype_forward,
                ln_out_total,
833
834
                fp8_meta["scaling_fwd"].scale_inv,
                tex.FP8FwdTensors.GEMM1_INPUT,
Przemek Tredak's avatar
Przemek Tredak committed
835
836
837
838
839
840
841
842
843
844
845
846
                fp8_dtype_forward,
                activation_dtype,
                get_workspace(),
                bias=bias,
                use_bias=use_bias,
                use_split_accumulator=_2X_ACC_FPROP,
            )
        else:
            # Cast for native AMP
            weight = cast_if_needed(weight, activation_dtype)
            bias = cast_if_needed(bias, activation_dtype) if use_bias else bias

schetlur-nv's avatar
schetlur-nv committed
847
848
849
850
851
852
853
854
            if fp8_calibration:
                # amax of input
                fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_INPUT] = \
                    torch.amax(ln_out_total).float()
                # amax of weight
                fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_WEIGHT] = \
                    torch.amax(weight).float()

Przemek Tredak's avatar
Przemek Tredak committed
855
856
857
858
859
860
861
862
863
            out, _, _ = gemm(
                weight,
                ln_out_total,
                activation_dtype,
                get_workspace(),
                bias=bias,
                use_bias=use_bias,
            )

864
        if is_grad_enabled:
865
866
867
868
869
870
871
872
873
874
            ctx.save_for_backward(
                inputmat,
                ln_weight,
                mu,
                rsigma,
                weight,
                weight_t_fp8,
                ln_out,
                fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None,
            )
Przemek Tredak's avatar
Przemek Tredak committed
875

876
877
878
879
880
881
882
883
884
885
886
            ctx.activation_dtype = activation_dtype
            ctx.fp8 = fp8
            ctx.fp8_meta = fp8_meta
            ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
            ctx.is_first_microbatch = is_first_microbatch
            ctx.use_bias = use_bias
            ctx.sequence_parallel = sequence_parallel
            ctx.tensor_parallel = tensor_parallel
            ctx.inp_shape = inp.shape
            ctx.parallel_mode = parallel_mode
            ctx.tp_group = tp_group
Sangkug Lym's avatar
Sangkug Lym committed
887
            ctx.tp_size = tp_size
888
            ctx.return_layernorm_output = return_layernorm_output
889
            ctx.bwd_ln_sm_margin = bwd_ln_sm_margin
890
            ctx.zero_centered_gamma = zero_centered_gamma
891
            ctx.requires_dgrad = inp.requires_grad
Przemek Tredak's avatar
Przemek Tredak committed
892
893
894
895

        # Row Parallel Linear
        if parallel_mode == "row" and sequence_parallel:
            out, _ = reduce_scatter_along_first_dim(out, tp_group)
896
        elif parallel_mode == "row" and tensor_parallel:
Przemek Tredak's avatar
Przemek Tredak committed
897
898
899
900
901
902
903
904
905
            out, _ = allreduce(out, tp_group)

        # [*, in_features] -> [*, out_features] except first dimension changes for SP
        out = out.view(-1, *inp.shape[1:-1], out.shape[-1])

        if return_layernorm_output:
            return out, ln_out_return.view_as(inp)
        return out

906

Przemek Tredak's avatar
Przemek Tredak committed
907
908
909
910
    @staticmethod
    def backward(
        ctx, *grad_outputs: Tuple[torch.Tensor, ...]
    ) -> Tuple[Union[torch.Tensor, None], ...]:
Sangkug Lym's avatar
Sangkug Lym committed
911
912
913
        with _prepare_backward(
            ctx.fp8, ctx.fp8_meta, ctx.tp_group, ctx.tp_size, name="_LayerNormLinear"
        ):
914
915
916
917
918
919
920
921
922
923
            (
                inputmat,
                ln_weight,
                mu,
                rsigma,
                weight,
                weight_t_fp8,
                ln_out,
                fwd_scale_inverses,
            ) = ctx.saved_tensors
Przemek Tredak's avatar
Przemek Tredak committed
924

925
926
927
928
929
930
931
            (
                grad_output,
                grad_output_c,
                grad_output_t,
                grad_bias,
            ) = TransformerEngineBaseModule.grad_output_preprocess(
                ctx, grad_outputs[0], ctx.parallel_mode == "row"
Przemek Tredak's avatar
Przemek Tredak committed
932
933
            )

934
935
936
937
938
939
940
941
            # Column Parallel Linear
            # Overlap input AG with dgrad
            if ctx.parallel_mode == "column" and ctx.sequence_parallel:
                ln_out_total, handle = gather_along_first_dim(
                    ln_out, ctx.tp_group, async_op=True
                )
            else:
                ln_out_total = ln_out
Przemek Tredak's avatar
Przemek Tredak committed
942

943
944
945
946
947
948
            if ctx.is_first_microbatch is not None:
                accumulate_wgrad_into_param_main_grad = (
                    ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch
                )
            else:
                accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation
Przemek Tredak's avatar
Przemek Tredak committed
949

950
951
952
953
954
955
956
            if ctx.fp8:
                fp8_dtype_forward = get_fp8_te_dtype(
                    ctx.fp8_meta["recipe"], fprop_tensor=True
                )
                fp8_dtype_backward = get_fp8_te_dtype(
                    ctx.fp8_meta["recipe"], fprop_tensor=False
                )
Przemek Tredak's avatar
Przemek Tredak committed
957

958
959
960
                # DGRAD: Evaluated unconditionally to feed into Linear backward
                dgrad = fp8_gemm(
                    weight_t_fp8,
961
962
                    fwd_scale_inverses,
                    tex.FP8FwdTensors.GEMM1_WEIGHT,
963
964
                    fp8_dtype_forward,
                    grad_output_c,
965
966
                    ctx.fp8_meta["scaling_bwd"].scale_inv,
                    tex.FP8BwdTensors.GRAD_OUTPUT1,
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
                    fp8_dtype_backward,
                    ctx.activation_dtype,
                    get_workspace(),
                    use_split_accumulator=_2X_ACC_DGRAD,
                )
            else:
                # DGRAD: Evaluated unconditionally to feed into Linear backward
                dgrad, _, _ = gemm(
                    weight,
                    grad_output,
                    ctx.activation_dtype,
                    get_workspace(),
                    layout="NN",
                    grad=True,
                )
Przemek Tredak's avatar
Przemek Tredak committed
982

983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
            # Overlap dgrad-RS/AR with wgrad
            if ctx.parallel_mode == "column" and ctx.sequence_parallel:
                handle.wait()
                dgrad, handle = reduce_scatter_along_first_dim(
                    dgrad, ctx.tp_group, async_op=True
                )
            elif ctx.parallel_mode == "column" and ctx.tensor_parallel:
                dgrad, handle = allreduce(dgrad, ctx.tp_group, async_op=True)

            if weight.requires_grad:
                if ctx.fp8:
                    # WGRAD
                    if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
                        ln_out_total_t = tex.fp8_transpose(ln_out_total, fp8_dtype_forward)
                        wgrad = fp8_gemm(
                            ln_out_total_t,
999
1000
                            fwd_scale_inverses,
                            tex.FP8FwdTensors.GEMM1_INPUT,
1001
1002
                            fp8_dtype_forward,
                            grad_output_t,
1003
1004
                            ctx.fp8_meta["scaling_bwd"].scale_inv,
                            tex.FP8BwdTensors.GRAD_OUTPUT1,
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
                            fp8_dtype_backward,
                            ctx.activation_dtype,
                            get_workspace(),
                            accumulate=accumulate_wgrad_into_param_main_grad,
                            out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
                            use_split_accumulator=_2X_ACC_WGRAD,
                        )
                    else:
                        ln_out_total_c = cast_from_fp8(
                            ln_out_total,
                            ctx.fp8_meta["scaling_fwd"],
                            tex.FP8FwdTensors.GEMM1_INPUT,
                            fp8_dtype_forward,
                            TE_DType[ctx.activation_dtype],
                        )
                        wgrad, _, _ = gemm(
                            ln_out_total_c,
                            grad_output,
                            ctx.activation_dtype,
                            get_workspace(),
                            layout="NT",
                            grad=True,
                            accumulate=accumulate_wgrad_into_param_main_grad,
                            out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
                        )
schetlur-nv's avatar
schetlur-nv committed
1030
                else:
1031
1032
                    # WGRAD
                    wgrad, grad_bias, _ = gemm(
schetlur-nv's avatar
schetlur-nv committed
1033
1034
1035
1036
1037
1038
                        ln_out_total,
                        grad_output,
                        ctx.activation_dtype,
                        get_workspace(),
                        layout="NT",
                        grad=True,
1039
                        use_bias=ctx.use_bias,
schetlur-nv's avatar
schetlur-nv committed
1040
1041
1042
                        accumulate=accumulate_wgrad_into_param_main_grad,
                        out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
                    )
Przemek Tredak's avatar
Przemek Tredak committed
1043

1044
1045
1046
            # Column Parallel Linear
            if ctx.parallel_mode == "column" and ctx.tensor_parallel and handle is not None:
                handle.wait()
Przemek Tredak's avatar
Przemek Tredak committed
1047

1048
1049
            # LayerNorm gradient
            d_ln_out = dgrad.view(inputmat.shape)
Przemek Tredak's avatar
Przemek Tredak committed
1050

1051
1052
1053
            # Residual gradient
            if ctx.return_layernorm_output:
                d_ln_out = d_ln_out + grad_outputs[1].view_as(d_ln_out)
Przemek Tredak's avatar
Przemek Tredak committed
1054

1055
            dxmat, dgamma, dbeta = tex.layernorm_bwd(
1056
1057
                d_ln_out, inputmat, mu, rsigma, ln_weight,
                ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma
1058
            )
Przemek Tredak's avatar
Przemek Tredak committed
1059

1060
1061
            if not ctx.use_bias:
                grad_bias = None
Przemek Tredak's avatar
Przemek Tredak committed
1062
1063

        return (
1064
            dxmat.view(ctx.inp_shape) if ctx.requires_dgrad else None,
Przemek Tredak's avatar
Przemek Tredak committed
1065
1066
            dgamma,
            dbeta,
schetlur-nv's avatar
schetlur-nv committed
1067
            wgrad if weight.requires_grad else None,
Przemek Tredak's avatar
Przemek Tredak committed
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
            None,
            None,
            grad_bias,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
1082
            None,
1083
            None,
1084
1085
            None,
            None,
1086
            None,
1087
            None,
Sangkug Lym's avatar
Sangkug Lym committed
1088
            None,
Przemek Tredak's avatar
Przemek Tredak committed
1089
1090
1091
1092
        )


class LayerNormLinear(TransformerEngineBaseModule):
1093
    r"""
Przemek Tredak's avatar
Przemek Tredak committed
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
    Applies layer normalization followed by linear transformation to the incoming data.

    Parameters
    ----------
    in_features : int
                 size of each input sample.
    out_features : int
                  size of each output sample.
    eps : float, default = 1e-5
         a value added to the denominator of layer normalization for numerical stability.
    bias : bool, default = `True`
          if set to `False`, the layer will not learn an additive bias.
    init_method : Callable, default = `None`
                 used for initializing weights in the following way: `init_method(weight)`.
                 When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`.
    return_layernorm_output : bool, default = `False`
                             if set to `True`, output of layernorm is returned from the forward
                             together with the output of the linear transformation.
                             Example use case: residual connection for transformer module is
                             taken post layernorm.
1114
1115
1116
1117
1118
    parameters_split : Tuple[str, ...], default = None
                      if a tuple of strings is provided, the weight and bias parameters of the
                      module are exposed as `N` separate `torch.nn.parameter.Parameter`s each,
                      split along the first dimension, where `N` is the length of the argument
                      and the strings contained are the names of the split parameters.
1119
1120
1121
1122
1123
1124
1125
    zero_centered_gamma : bool, default = 'False'
                         if set to 'True', gamma parameter in LayerNorm is initialized to 0 and
                         the LayerNorm formula changes to

                         .. math::
                            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} *
                            (1 + \gamma) + \beta
Przemek Tredak's avatar
Przemek Tredak committed
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179

    Parallelism parameters
    ----------------------
    sequence_parallel : bool, default = `False`
                       if set to `True`, uses sequence parallelism.
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
    tp_size : int, default = 1
             used as TP (tensor parallel) world size when TP groups are not formed during
             initialization. In this case, users must call the
             `set_tensor_parallel_group(tp_group)` method on the initialized module before the
             forward pass to supply the tensor parallel group needed for tensor and sequence
             parallel collectives.
    parallel_mode : {None, 'Column', 'Row'}, default = `None`
                   used to decide whether this Linear layer is Column Parallel Linear or Row
                   Parallel Linear as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
                   When set to `None`, no communication is performed.
    skip_weight_param_allocation: bool, default = `False`
                                 if set to `True`, weight parameter is not allocated and must be
                                 passed as a keyword argument `weight` during the forward pass.

    Optimization parameters
    -----------------------
    fuse_wgrad_accumulation : bool, default = 'False'
                             if set to `True`, enables fusing of creation and accumulation of
                             the weight gradient.
    return_bias : bool, default = `False`
                 when set to `True`, this module will not apply the additive bias itself, but
                 instead return the bias value during the forward pass together with the
                 output of the linear transformation :math:`y = xA^T`. This is useful when
                 the bias addition can be fused to subsequent operations.
    params_dtype : torch.dtype, default = `torch.float32`
                  it controls the type used to allocate the initial parameters. Useful when
                  the model is trained with lower precision and the original FP32 parameters
                  would not fit in GPU memory.
    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        eps: float = 1e-5,
        sequence_parallel: bool = False,
        fuse_wgrad_accumulation: bool = False,
        tp_group: Optional[dist_group_type] = None,
        tp_size: int = 1,
        get_rng_state_tracker: Optional[Callable] = None,
        init_method: Optional[Callable] = None,
        bias: bool = True,
        return_bias: bool = False,
        params_dtype: torch.dtype = torch.float32,
        parallel_mode: Optional[str] = None,
        return_layernorm_output: bool = False,
        skip_weight_param_allocation: bool = False,
1180
        parameters_split: Optional[Tuple[str, ...]] = None,
1181
        zero_centered_gamma: bool = False,
Przemek Tredak's avatar
Przemek Tredak committed
1182
1183
1184
1185
1186
1187
1188
    ) -> None:
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.fuse_wgrad_accumulation = fuse_wgrad_accumulation
        self.use_bias = bias
        self.return_bias = return_bias
1189
        self.apply_bias = bias and not return_bias
Przemek Tredak's avatar
Przemek Tredak committed
1190
        self.return_layernorm_output = return_layernorm_output
1191
        self.parameters_split = parameters_split
1192
        self.zero_centered_gamma = zero_centered_gamma
Przemek Tredak's avatar
Przemek Tredak committed
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237

        if tp_group is None:
            self.tp_size = tp_size
            if tp_size == 1:
                self.set_tensor_parallel_group(tp_group)
        else:
            self.tp_size = get_distributed_world_size(tp_group)
            self.set_tensor_parallel_group(tp_group)
        self.set_nccl_overlap_warning_if_tp()

        self.parallel_mode = parallel_mode
        assert (
            self.parallel_mode in GemmParallelModes
        ), f"parallel_mode {parallel_mode} not supported"

        if self.parallel_mode == "column":
            self.out_features = divide(self.out_features, self.tp_size)
        elif self.parallel_mode == "row":
            self.in_features = divide(self.in_features, self.tp_size)

        if init_method is None:
            init_method = get_default_init_method()

        self.sequence_parallel = (self.tp_size > 1) and sequence_parallel

        self.eps = eps
        self.layer_norm_weight = Parameter(
            torch.empty(
                in_features,
                device=torch.cuda.current_device(),
                dtype=params_dtype,
            )
        )
        self.layer_norm_bias = Parameter(
            torch.empty(
                in_features,
                device=torch.cuda.current_device(),
                dtype=params_dtype,
            )
        )
        setattr(self.layer_norm_weight, "sequence_parallel", self.sequence_parallel)
        setattr(self.layer_norm_bias, "sequence_parallel", self.sequence_parallel)
        self.reset_layer_norm_parameters()

        if not skip_weight_param_allocation:
1238
1239
1240
1241
1242
1243
1244
            self.register_buffer("weight_tensor",
                                 torch.empty(
                                    self.out_features,
                                    self.in_features,
                                    device=torch.cuda.current_device(),
                                    dtype=params_dtype),
                                 persistent=False)
Przemek Tredak's avatar
Przemek Tredak committed
1245
1246

            initialize_affine_weight_gpu(
1247
                self.weight_tensor,
Przemek Tredak's avatar
Przemek Tredak committed
1248
1249
1250
1251
1252
1253
                init_method,
                get_rng_state_tracker,
                partition_dim=1 if self.parallel_mode == "row" else 0,
                stride=1,
            )

1254
            if self.use_bias:
1255
1256
1257
1258
1259
1260
                self.register_buffer("bias_tensor",
                                     torch.empty(
                                         self.out_features,
                                         device=torch.cuda.current_device(),
                                         dtype=params_dtype),
                                     persistent=False)
Przemek Tredak's avatar
Przemek Tredak committed
1261
            else:
1262
1263
1264
                self.register_buffer(
                    "bias_tensor", torch.Tensor().type(params_dtype), persistent=False
                )
Przemek Tredak's avatar
Przemek Tredak committed
1265
1266

            with torch.no_grad():
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
                self.bias_tensor.zero_()

            if parameters_split is None:
                parameters_split = ("",)

            assert (
                self.out_features % len(parameters_split) == 0
            ), f"Weight and bias params cannot be split into {len(parameters_split)} parts"

            split_size = self.out_features // len(parameters_split)

            self.weight_names = []
            self.bias_names = []

            for i, pname in enumerate(parameters_split):
                wname = pname + "weight"
                bname = pname + "bias"

                self.register_parameter(
                    wname, Parameter(self.weight_tensor[i * split_size : (i+1) * split_size])
                )

                set_tensor_model_parallel_attributes(
                    tensor=getattr(self, wname),
                    is_parallel=True,
                    dim=1 if parallel_mode == "row" else 0,
                    stride=1,
                )

1296
                if self.use_bias:
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
                    self.register_parameter(
                        bname, Parameter(self.bias_tensor[i * split_size : (i+1) * split_size])
                    )
                else:
                    self.register_buffer(bname, torch.Tensor().type(params_dtype), persistent=False)

                if parallel_mode == "column":
                    set_tensor_model_parallel_attributes(getattr(self, bname), True, 0, 1)

                self.weight_names.append(wname)
                self.bias_names.append(bname)
Przemek Tredak's avatar
Przemek Tredak committed
1308
1309
1310
1311
1312

        self.fp8_weight_shapes.append(torch.Size((self.out_features, self.in_features)))

        # For RPL, bias has to be added after TP collectives
        # So it cannot be fused with the GEMM
1313
        if self.parallel_mode == "row" and self.apply_bias:
Przemek Tredak's avatar
Przemek Tredak committed
1314
1315
1316
1317
            self.gemm_bias_unfused_add = True
        else:
            self.gemm_bias_unfused_add = False

1318
1319
1320
1321
1322
1323
1324
        # These many SMs are subtracted from the total SM count when calling forward
        # and backward LayerNorm C APIs. These envvars can be used to prevent the LN
        # kernels from using all SMs in the device. This is useful for cases such as
        # communication overlap with LN.
        self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))
        self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))

Przemek Tredak's avatar
Przemek Tredak committed
1325
1326
    def reset_layer_norm_parameters(self) -> None:
        """Init LN params"""
1327
1328
1329
1330
        if not self.zero_centered_gamma:
            init.ones_(self.layer_norm_weight)
        else:
            init.zeros_(self.layer_norm_weight)
Przemek Tredak's avatar
Przemek Tredak committed
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
        init.zeros_(self.layer_norm_bias)

    def forward(
        self,
        inp: torch.Tensor,
        weight: Optional[torch.Tensor] = None,
        bias: Optional[torch.Tensor] = None,
        is_first_microbatch: Optional[bool] = None,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
        """
        Apply layer normalization to the input followed by a linear transformation.

        Parameters
        ----------
        inp : torch.Tensor
             Input tensor.
        weight : torch.Tensor, default = None
                An optional weight tensor for the module. This argument is compulsory if module
                is initialized with `skip_weight_param_allocation=True`
        bias : torch.Tensor, default = None
              An optional bias tensor for the module. This argument is compulsory if module
              is initialized with `skip_weight_param_allocation=True` and one of `use_bias`
              or `return_bias`
        is_first_microbatch : {True, False, None}, default = None
                             During training using either gradient accumulation or
                             pipeline parallelism a minibatch of data is further split
                             into microbatches. Between the microbatches of the same minibatch
                             the model weights are not updated. Setting this parameter indicates
                             whether the current microbatch is the first in a minibatch or not.
                             When set, this parameter enables additional optimizations:

                             * during FP8 training, it allows caching of the FP8 versions of
                               the weights
                             * it also allows skipping gradient accumulation during the
                               first microbatch (since it is the first gradient being
                               produced)
        """

1369
        with self.prepare_forward(inp, is_first_microbatch) as inp:
1370
1371
1372
            bias_tensor = (
                bias if bias is not None
                else self.bias if self.parameters_split is None
1373
                else self.bias_tensor if not torch.is_grad_enabled()
1374
1375
1376
1377
1378
                else self.noop_cat("bias_tensor", self.bias_names)
            )
            weight_tensor = (
                weight if weight is not None
                else self.weight if self.parameters_split is None
1379
                else self.weight_tensor if not torch.is_grad_enabled()
1380
1381
                else self.noop_cat("weight_tensor", self.weight_names)
            )
1382

1383
            if torch.is_grad_enabled():
1384
1385
1386
1387
1388
1389
                fwd_fn = _LayerNormLinear.apply
                args = []
            else:
                fwd_fn = _LayerNormLinear.forward
                args = [None]
            args += (
1390
1391
1392
                inp,
                self.layer_norm_weight,
                self.layer_norm_bias,
1393
                weight_tensor,
1394
1395
1396
                self.weight1_fp8 if self.fp8 else None,
                self.weight1_t_fp8 if self.fp8 else None,
                bias_tensor,
1397
                self.apply_bias and not self.gemm_bias_unfused_add,
1398
1399
1400
                self.eps,
                is_first_microbatch,
                self.fp8,
schetlur-nv's avatar
schetlur-nv committed
1401
                self.fp8_calibration,
1402
1403
1404
                self.fp8_meta,
                self.fuse_wgrad_accumulation,
                self.tp_group,
Sangkug Lym's avatar
Sangkug Lym committed
1405
                self.tp_size,
1406
1407
1408
1409
1410
                self.sequence_parallel,
                self.tp_size > 1,
                self.activation_dtype,
                self.parallel_mode,
                self.return_layernorm_output,
1411
                torch.is_grad_enabled(),
1412
1413
                self.fwd_ln_sm_margin,
                self.bwd_ln_sm_margin,
1414
                self.zero_centered_gamma,
1415
            )
1416
            out = fwd_fn(*args)
Przemek Tredak's avatar
Przemek Tredak committed
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440

        if self.return_layernorm_output:
            out, ln_out = out

        if self.gemm_bias_unfused_add:
            out = out + cast_if_needed(bias_tensor, self.activation_dtype)

        if self.return_bias:
            if self.return_layernorm_output:
                return out, cast_if_needed(bias_tensor, self.activation_dtype), ln_out
            return out, cast_if_needed(bias_tensor, self.activation_dtype)
        if self.return_layernorm_output:
            return out, ln_out
        return out

class _Linear(torch.autograd.Function):
    """Linear semi-top level module
    Calls custom cuda extensions.
    """

    @staticmethod
    def forward(
        ctx,
        weight: torch.Tensor,
1441
1442
        weight_fp8: Union[torch.Tensor, None],
        weight_t_fp8: Union[torch.Tensor, None],
Przemek Tredak's avatar
Przemek Tredak committed
1443
1444
1445
1446
1447
        inp: torch.Tensor,
        bias: torch.Tensor,
        use_bias: bool,
        is_first_microbatch: Union[bool, None],
        fp8: bool,
schetlur-nv's avatar
schetlur-nv committed
1448
        fp8_calibration: bool,
Przemek Tredak's avatar
Przemek Tredak committed
1449
1450
1451
        fp8_meta: Dict[str, Any],
        fuse_wgrad_accumulation: bool,
        tp_group: Union[dist_group_type, None],
Sangkug Lym's avatar
Sangkug Lym committed
1452
        tp_size: int,
Przemek Tredak's avatar
Przemek Tredak committed
1453
        sequence_parallel: bool,
1454
        tensor_parallel: bool,
Przemek Tredak's avatar
Przemek Tredak committed
1455
1456
        activation_dtype: torch.dtype,
        parallel_mode: Union[str, None],
1457
        is_grad_enabled: bool,
Przemek Tredak's avatar
Przemek Tredak committed
1458
1459
1460
1461
1462
    ) -> torch.Tensor:
        # Make sure input dimensions are compatible
        in_features = weight.shape[-1]
        assert inp.shape[-1] == in_features, "GEMM not possible"
        inputmat = inp.view((-1, in_features))
1463
        assert (
1464
1465
            not fp8 or check_dim_for_fp8_forward_exec(inputmat, weight)
        ), "Input and weight dimensions are not compatible for FP8 execution."
Przemek Tredak's avatar
Przemek Tredak committed
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476

        update_fp8_weights = is_first_microbatch is None or is_first_microbatch

        # Cast for native AMP
        inputmat = cast_if_needed(inputmat, activation_dtype)
        inputmat_no_fp8 = inputmat

        if fp8:
            fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)

            if not fp8_meta["recipe"].override_linear_precision.wgrad:
1477
                if is_grad_enabled:
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
                    inputmat, inputmat_t = fp8_cast_transpose_fused(
                        inputmat,
                        fp8_meta["scaling_fwd"],
                        tex.FP8FwdTensors.GEMM1_INPUT,
                        fp8_dtype_forward,
                    )
                else:
                    inputmat = cast_to_fp8(
                        inputmat,
                        fp8_meta["scaling_fwd"],
                        tex.FP8FwdTensors.GEMM1_INPUT,
                        fp8_dtype_forward,
                    )
Przemek Tredak's avatar
Przemek Tredak committed
1491
            else:
1492
                inputmat, inputmat_t = cast_to_fp8(
Przemek Tredak's avatar
Przemek Tredak committed
1493
1494
1495
1496
                    inputmat,
                    fp8_meta["scaling_fwd"],
                    tex.FP8FwdTensors.GEMM1_INPUT,
                    fp8_dtype_forward,
1497
                ), None
Przemek Tredak's avatar
Przemek Tredak committed
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513

        # Column Parallel Linear
        if parallel_mode == "column" and sequence_parallel:
            inputmat_total, _ = gather_along_first_dim(inputmat, tp_group)
        else:
            inputmat_total = inputmat

        if fp8:
            bias_dtype = (
                torch.bfloat16
                if activation_dtype == torch.float32
                else activation_dtype
            )
            bias = cast_if_needed(bias, bias_dtype) if use_bias else bias

            if update_fp8_weights:
1514
                if is_grad_enabled:
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
                    fp8_cast_transpose_fused(
                        weight,
                        fp8_meta["scaling_fwd"],
                        tex.FP8FwdTensors.GEMM1_WEIGHT,
                        fp8_dtype_forward,
                        cast_out=weight_fp8,
                        transpose_out=weight_t_fp8,
                    )
                else:
                    weight_t_fp8 = None
                    weight_fp8 = cast_to_fp8(
                        weight,
                        fp8_meta["scaling_fwd"],
                        tex.FP8FwdTensors.GEMM1_WEIGHT,
                        fp8_dtype_forward,
                    )
Przemek Tredak's avatar
Przemek Tredak committed
1531
1532
1533

            out = fp8_gemm(
                weight_fp8,
1534
1535
                fp8_meta["scaling_fwd"].scale_inv,
                tex.FP8FwdTensors.GEMM1_WEIGHT,
Przemek Tredak's avatar
Przemek Tredak committed
1536
1537
                fp8_dtype_forward,
                inputmat,
1538
1539
                fp8_meta["scaling_fwd"].scale_inv,
                tex.FP8FwdTensors.GEMM1_INPUT,
Przemek Tredak's avatar
Przemek Tredak committed
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
                fp8_dtype_forward,
                activation_dtype,
                get_workspace(),
                bias=bias,
                use_bias=use_bias,
                use_split_accumulator=_2X_ACC_FPROP,
            )
        else:
            # Cast for native AMP
            weight = cast_if_needed(weight, activation_dtype)
            bias = cast_if_needed(bias, activation_dtype) if use_bias else bias

schetlur-nv's avatar
schetlur-nv committed
1552
1553
1554
1555
1556
1557
1558
1559
            if fp8_calibration:
                # amax of input
                fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_INPUT] = \
                    torch.amax(inputmat_total).float()
                # amax of weight
                fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_WEIGHT] = \
                    torch.amax(weight).float()

Przemek Tredak's avatar
Przemek Tredak committed
1560
1561
1562
1563
1564
1565
1566
1567
1568
            out, _, _ = gemm(
                weight,
                inputmat_total,
                activation_dtype,
                get_workspace(),
                bias=bias,
                use_bias=use_bias,
            )

1569
        if is_grad_enabled:
schetlur-nv's avatar
schetlur-nv committed
1570
            fp8_wgrad = fp8 and not fp8_meta["recipe"].override_linear_precision.wgrad
1571
            ctx.save_for_backward(
1572
1573
1574
1575
                inputmat_no_fp8 if weight.requires_grad and not fp8_wgrad else None,
                inputmat_t if weight.requires_grad and fp8_wgrad else None,
                weight,
                weight_t_fp8 if fp8 else None,
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
                fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None,
            )
            ctx.activation_dtype = activation_dtype
            ctx.fp8 = fp8
            ctx.fp8_meta = fp8_meta
            ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
            ctx.is_first_microbatch = is_first_microbatch
            ctx.use_bias = use_bias
            ctx.sequence_parallel = sequence_parallel
            ctx.tensor_parallel = tensor_parallel
            ctx.inp_shape = inp.shape
            ctx.parallel_mode = parallel_mode
            ctx.tp_group = tp_group
Sangkug Lym's avatar
Sangkug Lym committed
1589
            ctx.tp_size = tp_size
1590
            ctx.requires_dgrad = inp.requires_grad
Przemek Tredak's avatar
Przemek Tredak committed
1591
1592
1593
1594

        # Row Parallel Linear
        if parallel_mode == "row" and sequence_parallel:
            out, _ = reduce_scatter_along_first_dim(out, tp_group)
1595
        elif parallel_mode == "row" and tensor_parallel:
Przemek Tredak's avatar
Przemek Tredak committed
1596
1597
1598
1599
1600
            out, _ = allreduce(out, tp_group)

        # [*, in_features] -> [*, out_features] except first dimension changes for SP
        return out.view(-1, *inp.shape[1:-1], out.shape[-1])

1601

Przemek Tredak's avatar
Przemek Tredak committed
1602
1603
1604
1605
    @staticmethod
    def backward(
        ctx, grad_output: torch.Tensor
    ) -> Tuple[Union[torch.Tensor, None], ...]:
Sangkug Lym's avatar
Sangkug Lym committed
1606
1607
1608
        with _prepare_backward(
            ctx.fp8, ctx.fp8_meta, ctx.tp_group, ctx.tp_size, name="_Linear"
        ):
1609
1610
1611
1612
1613
1614
1615
            (
                inputmat,
                inputmat_t,
                weight,
                weight_t_fp8,
                fwd_scale_inverses,
            ) = ctx.saved_tensors
Przemek Tredak's avatar
Przemek Tredak committed
1616

1617
1618
1619
1620
1621
1622
1623
1624
            (
                grad_output,
                grad_output_c,
                grad_output_t,
                grad_bias,
            ) = TransformerEngineBaseModule.grad_output_preprocess(
                ctx, grad_output, ctx.parallel_mode == "row"
            )
Przemek Tredak's avatar
Przemek Tredak committed
1625

1626
1627
1628
1629
1630
            # Column Parallel Linear
            # Overlap input AG with dgrad
            if ctx.parallel_mode == "column" and ctx.sequence_parallel:
                if ctx.fp8 and not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
                    inputmat_t_total, handle = gather_along_last_dim(
1631
                        inputmat_t, ctx.tp_group, async_op=ctx.requires_dgrad
1632
1633
1634
                    )
                else:
                    inputmat_total, handle = gather_along_first_dim(
1635
                        inputmat, ctx.tp_group, async_op=ctx.requires_dgrad
1636
1637
1638
1639
                    )
            else:
                inputmat_t_total = inputmat_t
                inputmat_total = inputmat
Przemek Tredak's avatar
Przemek Tredak committed
1640

1641
1642
1643
            if ctx.is_first_microbatch is not None:
                accumulate_wgrad_into_param_main_grad = (
                    ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch
Przemek Tredak's avatar
Przemek Tredak committed
1644
1645
                )
            else:
1646
                accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation
Przemek Tredak's avatar
Przemek Tredak committed
1647

1648
1649
1650
1651
1652
1653
1654
            if ctx.fp8:
                fp8_dtype_forward = get_fp8_te_dtype(
                    ctx.fp8_meta["recipe"], fprop_tensor=True
                )
                fp8_dtype_backward = get_fp8_te_dtype(
                    ctx.fp8_meta["recipe"], fprop_tensor=False
                )
Przemek Tredak's avatar
Przemek Tredak committed
1655

1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
            if ctx.requires_dgrad:
                if ctx.fp8:
                    dgrad = fp8_gemm(
                        weight_t_fp8,
                        fwd_scale_inverses,
                        tex.FP8FwdTensors.GEMM1_WEIGHT,
                        fp8_dtype_forward,
                        grad_output_c,
                        ctx.fp8_meta["scaling_bwd"].scale_inv,
                        tex.FP8BwdTensors.GRAD_OUTPUT1,
                        fp8_dtype_backward,
                        ctx.activation_dtype,
                        get_workspace(),
                        use_split_accumulator=_2X_ACC_DGRAD,
                    )
                else:
                    dgrad, _, _ = gemm(
                        weight,
                        grad_output,
                        ctx.activation_dtype,
                        get_workspace(),
                        layout="NN",
                        grad=True,
                    )
Przemek Tredak's avatar
Przemek Tredak committed
1680

1681
1682
1683
1684
1685
1686
1687
1688
                # Overlap dgrad-RS/AR with wgrad
                if ctx.parallel_mode == "column" and ctx.sequence_parallel:
                    handle.wait()
                    dgrad, handle = reduce_scatter_along_first_dim(
                        dgrad, ctx.tp_group, async_op=True
                    )
                elif ctx.parallel_mode == "column" and ctx.tensor_parallel:
                    dgrad, handle = allreduce(dgrad, ctx.tp_group, async_op=True)
1689

1690
            if weight.requires_grad:
1691
1692
1693
1694
1695
                if ctx.fp8:
                    # WGRAD
                    if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
                        wgrad = fp8_gemm(
                            inputmat_t_total,
1696
1697
                            fwd_scale_inverses,
                            tex.FP8FwdTensors.GEMM1_INPUT,
1698
1699
                            fp8_dtype_forward,
                            grad_output_t,
1700
1701
                            ctx.fp8_meta["scaling_bwd"].scale_inv,
                            tex.FP8BwdTensors.GRAD_OUTPUT1,
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
                            fp8_dtype_backward,
                            ctx.activation_dtype,
                            get_workspace(),
                            accumulate=accumulate_wgrad_into_param_main_grad,
                            out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
                            use_split_accumulator=_2X_ACC_WGRAD,
                        )
                    else:
                        wgrad, _, _ = gemm(
                            inputmat_total,
                            grad_output,
                            ctx.activation_dtype,
                            get_workspace(),
                            layout="NT",
                            grad=True,
                            accumulate=accumulate_wgrad_into_param_main_grad,
                            out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
                        )
schetlur-nv's avatar
schetlur-nv committed
1720
                else:
1721
1722
                    # WGRAD
                    wgrad, grad_bias, _ = gemm(
schetlur-nv's avatar
schetlur-nv committed
1723
1724
1725
1726
1727
1728
                        inputmat_total,
                        grad_output,
                        ctx.activation_dtype,
                        get_workspace(),
                        layout="NT",
                        grad=True,
1729
                        use_bias=ctx.use_bias,
schetlur-nv's avatar
schetlur-nv committed
1730
1731
1732
                        accumulate=accumulate_wgrad_into_param_main_grad,
                        out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
                    )
Przemek Tredak's avatar
Przemek Tredak committed
1733

1734
1735
1736
            # Column Parallel Linear
            if ctx.parallel_mode == "column" and ctx.tensor_parallel and handle is not None:
                handle.wait()
Przemek Tredak's avatar
Przemek Tredak committed
1737

1738
1739
            if not ctx.use_bias:
                grad_bias = None
Przemek Tredak's avatar
Przemek Tredak committed
1740
1741

        return (
1742
            wgrad if weight.requires_grad else None,
Przemek Tredak's avatar
Przemek Tredak committed
1743
1744
            None,
            None,
1745
            dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None,
Przemek Tredak's avatar
Przemek Tredak committed
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
            grad_bias,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
1756
            None,
1757
            None,
schetlur-nv's avatar
schetlur-nv committed
1758
            None,
Sangkug Lym's avatar
Sangkug Lym committed
1759
            None,
Przemek Tredak's avatar
Przemek Tredak committed
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
        )


class Linear(TransformerEngineBaseModule):
    """
    Applies a linear transformation to the incoming data :math:`y = xA^T + b`

    On NVIDIA GPUs it is a drop-in replacement for `torch.nn.Linear`.

    Parameters
    ----------
    in_features : int
                 size of each input sample.
    out_features : int
                  size of each output sample.
    bias : bool, default = `True`
          if set to `False`, the layer will not learn an additive bias.
    init_method : Callable, default = `None`
                 used for initializing weights in the following way: `init_method(weight)`.
                 When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`.
1780
1781
1782
1783
1784
    parameters_split : Tuple[str, ...], default = None
                      if a tuple of strings is provided, the weight and bias parameters of the
                      module are exposed as `N` separate `torch.nn.parameter.Parameter`s each,
                      split along the first dimension, where `N` is the length of the argument
                      and the strings contained are the names of the split parameters.
Przemek Tredak's avatar
Przemek Tredak committed
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839

    Parallelism parameters
    ----------------------
    sequence_parallel : bool, default = `False`
                       if set to `True`, uses sequence parallelism.
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
    tp_size : int, default = 1
             used as TP (tensor parallel) world size when TP groups are not formed during
             initialization. In this case, users must call the
             `set_tensor_parallel_group(tp_group)` method on the initialized module before the
             forward pass to supply the tensor parallel group needed for tensor and sequence
             parallel collectives.
    parallel_mode : {None, 'Column', 'Row'}, default = `None`
                   used to decide whether this Linear layer is Column Parallel Linear or Row
                   Parallel Linear as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
                   When set to `None`, no communication is performed.
    skip_weight_param_allocation: bool, default = `False`
                                 if set to `True`, weight parameter is not allocated and must be
                                 passed as a keyword argument `weight` during the forward pass.

    Optimization parameters
    -----------------------
    fuse_wgrad_accumulation : bool, default = 'False'
                             if set to `True`, enables fusing of creation and accumulation of
                             the weight gradient. When enabled, it is assumed that the weights
                             have an additional `main_grad` attribute (used instead of the
                             regular `grad`) which is a pre-allocated buffer of the correct
                             size to accumulate gradients in.
    return_bias : bool, default = `False`
                 when set to `True`, this module will not apply the additive bias itself, but
                 instead return the bias value during the forward pass together with the
                 output of the linear transformation :math:`y = xA^T`. This is useful when
                 the bias addition can be fused to subsequent operations.
    params_dtype : torch.dtype, default = `torch.float32`
                  it controls the type used to allocate the initial parameters. Useful when
                  the model is trained with lower precision and the original FP32 parameters
                  would not fit in GPU memory.
    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        sequence_parallel: bool = False,
        fuse_wgrad_accumulation: bool = False,
        tp_group: Optional[dist_group_type] = None,
        tp_size: int = 1,
        get_rng_state_tracker: Optional[Callable] = None,
        init_method: Optional[Callable] = None,
        bias: bool = True,
        return_bias: bool = False,
        params_dtype: torch.dtype = torch.float32,
        parallel_mode: Optional[str] = None,
        skip_weight_param_allocation: bool = False,
1840
        parameters_split: Optional[Tuple[str, ...]] = None,
Przemek Tredak's avatar
Przemek Tredak committed
1841
1842
1843
1844
1845
1846
1847
    ) -> None:
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.fuse_wgrad_accumulation = fuse_wgrad_accumulation
        self.use_bias = bias
        self.return_bias = return_bias
1848
        self.apply_bias = bias and not return_bias
1849
        self.parameters_split = parameters_split
Przemek Tredak's avatar
Przemek Tredak committed
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875

        if tp_group is None:
            self.tp_size = tp_size
            if tp_size == 1:
                self.set_tensor_parallel_group(tp_group)
        else:
            self.tp_size = get_distributed_world_size(tp_group)
            self.set_tensor_parallel_group(tp_group)
        self.set_nccl_overlap_warning_if_tp()

        self.parallel_mode = parallel_mode
        assert (
            self.parallel_mode in GemmParallelModes
        ), f"parallel_mode {parallel_mode} not supported"

        if self.parallel_mode == "column":
            self.out_features = divide(self.out_features, self.tp_size)
        elif self.parallel_mode == "row":
            self.in_features = divide(self.in_features, self.tp_size)

        if init_method is None:
            init_method = get_default_init_method()

        self.sequence_parallel = (self.tp_size > 1) and sequence_parallel

        if not skip_weight_param_allocation:
1876
1877
1878
1879
1880
1881
1882
            self.register_buffer("weight_tensor",
                                 torch.empty(
                                    self.out_features,
                                    self.in_features,
                                    device=torch.cuda.current_device(),
                                    dtype=params_dtype),
                                 persistent=False)
Przemek Tredak's avatar
Przemek Tredak committed
1883
1884

            initialize_affine_weight_gpu(
1885
                self.weight_tensor,
Przemek Tredak's avatar
Przemek Tredak committed
1886
1887
1888
1889
1890
1891
                init_method,
                get_rng_state_tracker,
                partition_dim=1 if self.parallel_mode == "row" else 0,
                stride=1,
            )

1892
            if self.use_bias:
1893
1894
1895
1896
1897
1898
                self.register_buffer("bias_tensor",
                                     torch.empty(
                                         self.out_features,
                                         device=torch.cuda.current_device(),
                                         dtype=params_dtype),
                                     persistent=False)
Przemek Tredak's avatar
Przemek Tredak committed
1899
            else:
1900
1901
1902
                self.register_buffer(
                    "bias_tensor", torch.Tensor().type(params_dtype), persistent=False
                )
Przemek Tredak's avatar
Przemek Tredak committed
1903
1904

            with torch.no_grad():
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
                self.bias_tensor.zero_()

            if parameters_split is None:
                parameters_split = ("",)

            assert (
                self.out_features % len(parameters_split) == 0
            ), f"Weight and bias params cannot be split into {len(parameters_split)} parts"

            split_size = self.out_features // len(parameters_split)

            self.weight_names = []
            self.bias_names = []

            for i, pname in enumerate(parameters_split):
                wname = pname + "weight"
                bname = pname + "bias"

                self.register_parameter(
                    wname, Parameter(self.weight_tensor[i * split_size : (i+1) * split_size])
                )

                set_tensor_model_parallel_attributes(
                    tensor=getattr(self, wname),
                    is_parallel=True,
                    dim=1 if parallel_mode == "row" else 0,
                    stride=1,
                )

1934
                if self.use_bias:
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
                    self.register_parameter(
                        bname, Parameter(self.bias_tensor[i * split_size : (i+1) * split_size])
                    )
                else:
                    self.register_buffer(bname, torch.Tensor().type(params_dtype), persistent=False)

                if parallel_mode == "column":
                    set_tensor_model_parallel_attributes(getattr(self, bname), True, 0, 1)

                self.weight_names.append(wname)
                self.bias_names.append(bname)
Przemek Tredak's avatar
Przemek Tredak committed
1946
1947
1948
1949
1950

        self.fp8_weight_shapes.append(torch.Size((self.out_features, self.in_features)))

        # For RPL, bias has to be added after TP collectives
        # So it cannot be fused with the GEMM
1951
        if self.parallel_mode == "row" and self.apply_bias:
Przemek Tredak's avatar
Przemek Tredak committed
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
            self.gemm_bias_unfused_add = True
        else:
            self.gemm_bias_unfused_add = False

    def forward(
        self,
        inp: torch.Tensor,
        weight: Optional[torch.Tensor] = None,
        bias: Optional[torch.Tensor] = None,
        is_first_microbatch: Optional[bool] = None,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
        """
        Apply the linear transformation to the input.

        Parameters
        ----------
        inp : torch.Tensor
             Input tensor.
        weight : torch.Tensor, default = None
                An optional weight tensor for the module. This argument is compulsory if module
                is initialized with `skip_weight_param_allocation=True`
        bias : torch.Tensor, default = None
              An optional bias tensor for the module. This argument is compulsory if module
              is initialized with `skip_weight_param_allocation=True` and one of `use_bias`
              or `return_bias`
        is_first_microbatch : {True, False, None}, default = None
                             During training using either gradient accumulation or
                             pipeline parallelism a minibatch of data is further split
                             into microbatches. Between the microbatches of the same minibatch
                             the model weights are not updated. Setting this parameter indicates
                             whether the current microbatch is the first in a minibatch or not.
                             When set, this parameter enables additional optimizations:

                             * during FP8 training, it allows caching of the FP8 versions of
                               the weights
                             * it also allows skipping gradient accumulation during the
                               first microbatch (since it is the first gradient being
                               produced)
        """

1992
        with self.prepare_forward(inp, is_first_microbatch) as inp:
1993
1994
1995
            bias_tensor = (
                bias if bias is not None
                else self.bias if self.parameters_split is None
1996
                else self.bias_tensor if not torch.is_grad_enabled()
1997
1998
1999
2000
2001
                else self.noop_cat("bias_tensor", self.bias_names)
            )
            weight_tensor = (
                weight if weight is not None
                else self.weight if self.parameters_split is None
2002
                else self.weight_tensor if not torch.is_grad_enabled()
2003
2004
                else self.noop_cat("weight_tensor", self.weight_names)
            )
2005

2006
            if torch.is_grad_enabled():
2007
2008
2009
2010
2011
2012
                linear_fn = _Linear.apply
                args = []
            else:
                linear_fn = _Linear.forward
                args = [None]
            args += (
2013
                weight_tensor,
2014
2015
2016
2017
                self.weight1_fp8 if self.fp8 else None,
                self.weight1_t_fp8 if self.fp8 else None,
                inp,
                bias_tensor,
2018
                self.apply_bias and not self.gemm_bias_unfused_add,
2019
2020
                is_first_microbatch,
                self.fp8,
schetlur-nv's avatar
schetlur-nv committed
2021
                self.fp8_calibration,
2022
2023
2024
                self.fp8_meta,
                self.fuse_wgrad_accumulation,
                self.tp_group,
Sangkug Lym's avatar
Sangkug Lym committed
2025
                self.tp_size,
2026
2027
2028
2029
                self.sequence_parallel,
                self.tp_size > 1,
                self.activation_dtype,
                self.parallel_mode,
2030
                torch.is_grad_enabled(),
2031
            )
2032
            out = linear_fn(*args)
Przemek Tredak's avatar
Przemek Tredak committed
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053

        if self.gemm_bias_unfused_add:
            out = out + cast_if_needed(bias_tensor, self.activation_dtype)

        if self.return_bias:
            return out, cast_if_needed(bias_tensor, self.activation_dtype)
        return out


class _LayerNormMLP(torch.autograd.Function):
    """LayerNormMLP semi-top level module
    Calls custom cuda extensions.
    """

    @staticmethod
    def forward(
        ctx,
        inp: torch.Tensor,
        ln_weight: torch.Tensor,
        ln_bias: torch.Tensor,
        fc1_weight: torch.Tensor,
2054
2055
        fc1_weight_fp8: Union[torch.Tensor, None],
        fc1_weight_t_fp8: Union[torch.Tensor, None],
Przemek Tredak's avatar
Przemek Tredak committed
2056
        fc1_bias: torch.Tensor,
ngoyal2707's avatar
ngoyal2707 committed
2057
        use_fc1_bias: bool,
Przemek Tredak's avatar
Przemek Tredak committed
2058
        fc2_weight: torch.Tensor,
2059
2060
        fc2_weight_fp8: Union[torch.Tensor, None],
        fc2_weight_t_fp8: Union[torch.Tensor, None],
Przemek Tredak's avatar
Przemek Tredak committed
2061
        fc2_bias: torch.Tensor,
ngoyal2707's avatar
ngoyal2707 committed
2062
        use_fc2_bias: bool,
Przemek Tredak's avatar
Przemek Tredak committed
2063
2064
2065
        eps: float,
        is_first_microbatch: Union[bool, None],
        fp8: bool,
schetlur-nv's avatar
schetlur-nv committed
2066
        fp8_calibration: bool,
Przemek Tredak's avatar
Przemek Tredak committed
2067
2068
2069
        fp8_meta: Dict[str, Any],
        fuse_wgrad_accumulation: bool,
        tp_group: Union[dist_group_type, None],
Sangkug Lym's avatar
Sangkug Lym committed
2070
        tp_size: int,
Przemek Tredak's avatar
Przemek Tredak committed
2071
        sequence_parallel: bool,
2072
        tensor_parallel: bool,
Przemek Tredak's avatar
Przemek Tredak committed
2073
2074
2075
2076
        activation_dtype: torch.dtype,
        return_layernorm_output: bool,
        bias_gelu_nvfusion: bool,
        set_parallel_mode: bool,
2077
        is_grad_enabled: bool,
2078
2079
        fwd_ln_sm_margin: int,
        bwd_ln_sm_margin: int,
2080
        zero_centered_gamma: bool,
Przemek Tredak's avatar
Przemek Tredak committed
2081
2082
2083
2084
2085
    ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
        # Make sure input dimensions are compatible
        in_features = ln_weight.numel()
        assert inp.shape[-1] == in_features, "GEMM not possible"
        inputmat = inp.view((-1, in_features))
2086
        assert (
2087
2088
            not fp8 or check_dim_for_fp8_forward_exec(inputmat, fc1_weight, fc2_weight)
        ), "Input and weight dimensions are not compatible for FP8 execution."
Przemek Tredak's avatar
Przemek Tredak committed
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102

        update_fp8_weights = is_first_microbatch is None or is_first_microbatch

        # Cast for native AMP
        inputmat = cast_if_needed(inputmat, activation_dtype)
        ln_weight = cast_if_needed(ln_weight, activation_dtype)
        ln_bias = cast_if_needed(ln_bias, activation_dtype)

        # If residual connection is after LN, we need `ln_out`
        # tensor in higher precision, this comes at the cost
        # of an extra fp8 cast.
        if fp8:
            fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
            if not return_layernorm_output:
2103
                if is_grad_enabled:
2104
2105
2106
2107
2108
2109
2110
2111
                    ln_out, mu, rsigma = layernorm_fwd_fp8(
                        inputmat,
                        ln_weight,
                        ln_bias,
                        eps,
                        fp8_meta["scaling_fwd"],
                        tex.FP8FwdTensors.GEMM1_INPUT,
                        fp8_dtype_forward,
2112
                        fwd_ln_sm_margin,
2113
                        zero_centered_gamma,
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
                    )
                else:
                    ln_out = layernorm_fwd_fp8_inf(
                        inputmat,
                        ln_weight,
                        ln_bias,
                        eps,
                        fp8_meta["scaling_fwd"],
                        tex.FP8FwdTensors.GEMM1_INPUT,
                        fp8_dtype_forward,
2124
                        zero_centered_gamma,
2125
                    )
Przemek Tredak's avatar
Przemek Tredak committed
2126
2127
            else:
                ln_out_return, mu, rsigma = tex.layernorm_fwd(
2128
                    inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin, zero_centered_gamma
Przemek Tredak's avatar
Przemek Tredak committed
2129
2130
2131
2132
2133
2134
2135
2136
                )
                ln_out = cast_to_fp8(
                    ln_out_return,
                    fp8_meta["scaling_fwd"],
                    tex.FP8FwdTensors.GEMM1_INPUT,
                    fp8_dtype_forward,
                )
        else:
2137
            if is_grad_enabled:
2138
                ln_out, mu, rsigma = tex.layernorm_fwd(
2139
                    inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin, zero_centered_gamma
2140
                )
2141
2142
            else:
                ln_out, mu, rsigma = layernorm_fwd_inf(
2143
                        inputmat, ln_weight, ln_bias, eps, zero_centered_gamma
2144
                        ), None, None
Przemek Tredak's avatar
Przemek Tredak committed
2145

2146
            ln_out_return = ln_out
Przemek Tredak's avatar
Przemek Tredak committed
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
        # Column Parallel Linear
        if set_parallel_mode and sequence_parallel:
            ln_out_total, _ = gather_along_first_dim(ln_out, tp_group)
        else:
            ln_out_total = ln_out

        if fp8:
            bias_dtype = (
                torch.bfloat16
                if activation_dtype == torch.float32
                else activation_dtype
            )
ngoyal2707's avatar
ngoyal2707 committed
2159
2160
            fc1_bias = cast_if_needed(fc1_bias, bias_dtype) if use_fc1_bias else fc1_bias
            fc2_bias = cast_if_needed(fc2_bias, bias_dtype) if use_fc2_bias else fc2_bias
Przemek Tredak's avatar
Przemek Tredak committed
2161
2162

            if update_fp8_weights:
2163
                if is_grad_enabled:
2164
2165
2166
2167
2168
2169
2170
2171
                    fp8_cast_transpose_fused(
                        fc1_weight,
                        fp8_meta["scaling_fwd"],
                        tex.FP8FwdTensors.GEMM1_WEIGHT,
                        fp8_dtype_forward,
                        cast_out=fc1_weight_fp8,
                        transpose_out=fc1_weight_t_fp8,
                    )
Przemek Tredak's avatar
Przemek Tredak committed
2172

2173
2174
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
                    fp8_cast_transpose_fused(
                        fc2_weight,
                        fp8_meta["scaling_fwd"],
                        tex.FP8FwdTensors.GEMM2_WEIGHT,
                        fp8_dtype_forward,
                        cast_out=fc2_weight_fp8,
                        transpose_out=fc2_weight_t_fp8,
                    )
                else:
                    fc1_weight_t_fp8 = None
                    fc1_weight_fp8 = cast_to_fp8(
                        fc1_weight,
                        fp8_meta["scaling_fwd"],
                        tex.FP8FwdTensors.GEMM1_WEIGHT,
                        fp8_dtype_forward,
                    )
                    fc2_weight_t_fp8 = None
                    fc2_weight_fp8 = cast_to_fp8(
                        fc2_weight,
                        fp8_meta["scaling_fwd"],
                        tex.FP8FwdTensors.GEMM2_WEIGHT,
                        fp8_dtype_forward,
                    )
Przemek Tredak's avatar
Przemek Tredak committed
2196
2197
2198

            fc1_out = fp8_gemm(
                fc1_weight_fp8,
2199
2200
                fp8_meta["scaling_fwd"].scale_inv,
                tex.FP8FwdTensors.GEMM1_WEIGHT,
Przemek Tredak's avatar
Przemek Tredak committed
2201
2202
                fp8_dtype_forward,
                ln_out_total,
2203
2204
                fp8_meta["scaling_fwd"].scale_inv,
                tex.FP8FwdTensors.GEMM1_INPUT,
Przemek Tredak's avatar
Przemek Tredak committed
2205
2206
2207
2208
                fp8_dtype_forward,
                activation_dtype,
                get_workspace(),
                bias=fc1_bias,
ngoyal2707's avatar
ngoyal2707 committed
2209
                use_bias=use_fc1_bias,
Przemek Tredak's avatar
Przemek Tredak committed
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
2221
                use_split_accumulator=_2X_ACC_FPROP,
            )

            gelu_out = fp8_gelu(
                fc1_out,
                fp8_meta["scaling_fwd"],
                tex.FP8FwdTensors.GEMM2_INPUT,
                fp8_dtype_forward,
            )

            fc2_out = fp8_gemm(
                fc2_weight_fp8,
2222
2223
                fp8_meta["scaling_fwd"].scale_inv,
                tex.FP8FwdTensors.GEMM2_WEIGHT,
Przemek Tredak's avatar
Przemek Tredak committed
2224
2225
                fp8_dtype_forward,
                gelu_out,
2226
2227
                fp8_meta["scaling_fwd"].scale_inv,
                tex.FP8FwdTensors.GEMM2_INPUT,
Przemek Tredak's avatar
Przemek Tredak committed
2228
2229
2230
2231
                fp8_dtype_forward,
                activation_dtype,
                get_workspace(),
                bias=fc2_bias,
ngoyal2707's avatar
ngoyal2707 committed
2232
                use_bias=use_fc2_bias,
Przemek Tredak's avatar
Przemek Tredak committed
2233
2234
2235
2236
2237
2238
                use_split_accumulator=_2X_ACC_FPROP,
            )
        else:
            # Cast for native AMP
            fc1_weight = cast_if_needed(fc1_weight, activation_dtype)
            fc2_weight = cast_if_needed(fc2_weight, activation_dtype)
ngoyal2707's avatar
ngoyal2707 committed
2239
2240
2241
            fc1_bias = (
                cast_if_needed(fc1_bias, activation_dtype) if use_fc1_bias else fc1_bias
            )
Przemek Tredak's avatar
Przemek Tredak committed
2242
            fc2_bias = (
ngoyal2707's avatar
ngoyal2707 committed
2243
                cast_if_needed(fc2_bias, activation_dtype) if use_fc2_bias else fc2_bias
Przemek Tredak's avatar
Przemek Tredak committed
2244
2245
            )

schetlur-nv's avatar
schetlur-nv committed
2246
2247
2248
2249
2250
2251
2252
2253
            if fp8_calibration:
                # amax of fc1 input
                fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_INPUT] = \
                    torch.amax(ln_out_total).float()
                # amax of fc1 weight
                fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_WEIGHT] = \
                    torch.amax(fc1_weight).float()

Przemek Tredak's avatar
Przemek Tredak committed
2254
2255
2256
2257
2258
2259
            fc1_outputs = gemm(
                fc1_weight,
                ln_out_total,
                activation_dtype,
                get_workspace(),
                bias=fc1_bias,
ngoyal2707's avatar
ngoyal2707 committed
2260
                use_bias=(not bias_gelu_nvfusion) and use_fc1_bias,
Przemek Tredak's avatar
Przemek Tredak committed
2261
2262
2263
                gelu=not bias_gelu_nvfusion,
            )

2264
            if bias_gelu_nvfusion:
Przemek Tredak's avatar
Przemek Tredak committed
2265
                fc1_out, _, _ = fc1_outputs
ngoyal2707's avatar
ngoyal2707 committed
2266

Przemek Tredak's avatar
Przemek Tredak committed
2267
2268
2269
2270
                gelu_out = bias_gelu_fused(fc1_out, fc1_bias)
            else:
                gelu_out, _, fc1_out = fc1_outputs

schetlur-nv's avatar
schetlur-nv committed
2271
2272
2273
2274
2275
2276
2277
2278
            if fp8_calibration:
                # amax of fc2 input
                fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM2_INPUT] = \
                    torch.amax(gelu_out).float()
                # amax of fc2 weight
                fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM2_WEIGHT] = \
                    torch.amax(fc2_weight).float()

Przemek Tredak's avatar
Przemek Tredak committed
2279
2280
2281
2282
2283
2284
            fc2_out, _, _ = gemm(
                fc2_weight,
                gelu_out,
                activation_dtype,
                get_workspace(),
                bias=fc2_bias,
ngoyal2707's avatar
ngoyal2707 committed
2285
                use_bias=use_fc2_bias,
Przemek Tredak's avatar
Przemek Tredak committed
2286
            )
2287
        if is_grad_enabled:
2288
2289
2290
2291
2292
2293
2294
2295
2296
2297
2298
2299
2300
2301
2302
2303
2304
2305
2306
2307
            ctx.save_for_backward(
                inputmat,
                ln_weight,
                mu,
                rsigma,
                ln_out,
                fc1_out,
                gelu_out,
                fc1_weight,
                fc1_weight_t_fp8,
                fc2_weight,
                fc2_weight_t_fp8,
                fc1_bias,
                fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None,
            )
            ctx.activation_dtype = activation_dtype
            ctx.fp8 = fp8
            ctx.fp8_meta = fp8_meta
            ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
            ctx.is_first_microbatch = is_first_microbatch
ngoyal2707's avatar
ngoyal2707 committed
2308
2309
            ctx.use_fc1_bias = use_fc1_bias
            ctx.use_fc2_bias = use_fc2_bias
2310
2311
2312
2313
            ctx.sequence_parallel = sequence_parallel
            ctx.tensor_parallel = tensor_parallel
            ctx.inp_shape = inp.shape
            ctx.tp_group = tp_group
Sangkug Lym's avatar
Sangkug Lym committed
2314
            ctx.tp_size = tp_size
2315
2316
2317
            ctx.bias_gelu_nvfusion = bias_gelu_nvfusion
            ctx.return_layernorm_output = return_layernorm_output
            ctx.set_parallel_mode = set_parallel_mode
2318
            ctx.bwd_ln_sm_margin = bwd_ln_sm_margin
2319
            ctx.zero_centered_gamma = zero_centered_gamma
2320
            ctx.requires_dgrad = inp.requires_grad
Przemek Tredak's avatar
Przemek Tredak committed
2321
2322
2323
2324

        # Row Parallel Linear
        if set_parallel_mode and sequence_parallel:
            fc2_out, _ = reduce_scatter_along_first_dim(fc2_out, tp_group)
2325
        elif set_parallel_mode and tensor_parallel:
Przemek Tredak's avatar
Przemek Tredak committed
2326
2327
2328
2329
2330
2331
2332
2333
2334
            fc2_out, _ = allreduce(fc2_out, tp_group)

        # [*, in_features] -> [*, out_features] except first dimension changes for SP
        fc2_out = fc2_out.view(-1, *inp.shape[1:-1], fc2_out.shape[-1])

        if return_layernorm_output:
            return fc2_out, ln_out_return.view_as(inp)
        return fc2_out

2335

Przemek Tredak's avatar
Przemek Tredak committed
2336
2337
2338
2339
    @staticmethod
    def backward(
        ctx, *grad_outputs: Tuple[torch.Tensor, ...]
    ) -> Tuple[Union[torch.Tensor, None], ...]:
Sangkug Lym's avatar
Sangkug Lym committed
2340
2341
2342
        with _prepare_backward(
            ctx.fp8, ctx.fp8_meta, ctx.tp_group, ctx.tp_size, name="_LayerNormMLP"
        ):
2343
2344
2345
2346
2347
2348
2349
2350
2351
2352
2353
2354
2355
2356
2357
            (
                inputmat,
                ln_weight,
                mu,
                rsigma,
                ln_out,
                fc1_out,
                gelu_out,
                fc1_weight,
                fc1_weight_t_fp8,
                fc2_weight,
                fc2_weight_t_fp8,
                fc1_bias,
                fwd_scale_inverses,
            ) = ctx.saved_tensors
Przemek Tredak's avatar
Przemek Tredak committed
2358

ngoyal2707's avatar
ngoyal2707 committed
2359
            ctx.use_bias = ctx.use_fc2_bias # For grad_output_preprocess
2360
2361
2362
2363
2364
2365
2366
2367
            (
                grad_output,
                grad_output_c,
                grad_output_t,
                fc2_bias_grad,
            ) = TransformerEngineBaseModule.grad_output_preprocess(
                ctx, grad_outputs[0], True
            )
Przemek Tredak's avatar
Przemek Tredak committed
2368

2369
2370
2371
2372
2373
2374
2375
2376
            # Column Parallel Linear
            # Overlap input AG with dgrad
            if ctx.set_parallel_mode and ctx.sequence_parallel:
                ln_out_total, handle = gather_along_first_dim(
                    ln_out, ctx.tp_group, async_op=True
                )
            else:
                ln_out_total = ln_out
Przemek Tredak's avatar
Przemek Tredak committed
2377

2378
2379
2380
2381
2382
2383
            if ctx.is_first_microbatch is not None:
                accumulate_wgrad_into_param_main_grad = (
                    ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch
                )
            else:
                accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation
Przemek Tredak's avatar
Przemek Tredak committed
2384

2385
2386
2387
2388
2389
2390
2391
            if ctx.fp8:
                fp8_dtype_forward = get_fp8_te_dtype(
                    ctx.fp8_meta["recipe"], fprop_tensor=True
                )
                fp8_dtype_backward = get_fp8_te_dtype(
                    ctx.fp8_meta["recipe"], fprop_tensor=False
                )
Przemek Tredak's avatar
Przemek Tredak committed
2392

2393
2394
2395
                # FC2 DGRAD; Unconditional
                fc2_dgrad = fp8_gemm(
                    fc2_weight_t_fp8,
2396
2397
                    fwd_scale_inverses,
                    tex.FP8FwdTensors.GEMM2_WEIGHT,
2398
2399
                    fp8_dtype_forward,
                    grad_output_c,
2400
2401
                    ctx.fp8_meta["scaling_bwd"].scale_inv,
                    tex.FP8BwdTensors.GRAD_OUTPUT1,
2402
2403
2404
2405
2406
                    fp8_dtype_backward,
                    ctx.activation_dtype,
                    get_workspace(),
                    use_split_accumulator=_2X_ACC_DGRAD,
                )
Przemek Tredak's avatar
Przemek Tredak committed
2407

2408
2409
2410
2411
2412
2413
                # FC2 WGRAD
                if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
                    if fc2_weight.requires_grad:
                        gelu_out_t = tex.fp8_transpose(gelu_out, fp8_dtype_forward)
                        fc2_wgrad = fp8_gemm(
                            gelu_out_t,
2414
2415
                            fwd_scale_inverses,
                            tex.FP8FwdTensors.GEMM2_INPUT,
2416
2417
                            fp8_dtype_forward,
                            grad_output_t,
2418
2419
                            ctx.fp8_meta["scaling_bwd"].scale_inv,
                            tex.FP8BwdTensors.GRAD_OUTPUT1,
2420
2421
2422
2423
2424
2425
2426
2427
2428
2429
2430
2431
2432
2433
2434
2435
2436
2437
2438
2439
2440
2441
2442
2443
2444
2445
2446
2447
2448
2449
2450
2451
2452
                            fp8_dtype_backward,
                            ctx.activation_dtype,
                            get_workspace(),
                            accumulate=accumulate_wgrad_into_param_main_grad,
                            out=fc2_weight.main_grad
                            if ctx.fuse_wgrad_accumulation
                            else None,
                            use_split_accumulator=_2X_ACC_WGRAD,
                        )

                    fc1_bias_grad, dgelu, dgelu_t = fp8_cast_transpose_bgrad_dgelu_fused(
                        fc2_dgrad,
                        fc1_out,
                        ctx.fp8_meta["scaling_bwd"],
                        tex.FP8BwdTensors.GRAD_OUTPUT2,
                        fp8_dtype_backward,
                    )
                else:
                    if fc2_weight.requires_grad:
                        gelu_out_c = cast_from_fp8(
                            gelu_out,
                            ctx.fp8_meta["scaling_fwd"],
                            tex.FP8FwdTensors.GEMM2_INPUT,
                            fp8_dtype_forward,
                            TE_DType[ctx.activation_dtype],
                        )
                        fc2_wgrad, _, _ = gemm(
                            gelu_out_c,
                            grad_output,
                            ctx.activation_dtype,
                            get_workspace(),
                            layout="NT",
                            grad=True,
ngoyal2707's avatar
ngoyal2707 committed
2453
                            use_bias=False,
2454
2455
2456
2457
2458
2459
2460
2461
2462
                            accumulate=accumulate_wgrad_into_param_main_grad,
                            out=fc2_weight.main_grad
                            if ctx.fuse_wgrad_accumulation
                            else None,
                        )

                    fc1_bias_grad, dgelu_no_fp8 = bgrad_dgelu_fused(
                        fc2_dgrad, fc1_out, fc1_bias
                    )
Przemek Tredak's avatar
Przemek Tredak committed
2463

2464
2465
2466
2467
                    dgelu = cast_to_fp8(
                        dgelu_no_fp8,
                        ctx.fp8_meta["scaling_bwd"],
                        tex.FP8BwdTensors.GRAD_OUTPUT2,
schetlur-nv's avatar
schetlur-nv committed
2468
2469
                        fp8_dtype_backward,
                    )
2470
                    dgelu_t = None
Przemek Tredak's avatar
Przemek Tredak committed
2471

2472
2473
2474
                # FC1 DGRAD: Unconditional
                fc1_dgrad = fp8_gemm(
                    fc1_weight_t_fp8,
2475
2476
                    fwd_scale_inverses,
                    tex.FP8FwdTensors.GEMM1_WEIGHT,
2477
2478
                    fp8_dtype_forward,
                    dgelu,
2479
2480
                    ctx.fp8_meta["scaling_bwd"].scale_inv,
                    tex.FP8BwdTensors.GRAD_OUTPUT2,
Przemek Tredak's avatar
Przemek Tredak committed
2481
                    fp8_dtype_backward,
2482
2483
2484
                    ctx.activation_dtype,
                    get_workspace(),
                    use_split_accumulator=_2X_ACC_DGRAD,
Przemek Tredak's avatar
Przemek Tredak committed
2485
2486
                )
            else:
2487
2488
2489
2490
2491
2492
2493
2494
2495
2496
2497
2498
2499
                # FC2 DGRAD; Unconditional
                fc2_dgrad, _, _ = gemm(
                    fc2_weight,
                    grad_output,
                    ctx.activation_dtype,
                    get_workspace(),
                    layout="NN",
                    gelu=not ctx.bias_gelu_nvfusion,
                    grad=True,
                    gelu_input=fc1_out,
                )

                # FC2 WGRAD
schetlur-nv's avatar
schetlur-nv committed
2500
                if fc2_weight.requires_grad:
2501
                    fc2_wgrad, fc2_bias_grad, _ = gemm(
schetlur-nv's avatar
schetlur-nv committed
2502
2503
2504
2505
2506
2507
                        gelu_out,
                        grad_output,
                        ctx.activation_dtype,
                        get_workspace(),
                        layout="NT",
                        grad=True,
ngoyal2707's avatar
ngoyal2707 committed
2508
                        use_bias=ctx.use_fc2_bias,
schetlur-nv's avatar
schetlur-nv committed
2509
                        accumulate=accumulate_wgrad_into_param_main_grad,
2510
                        out=fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
schetlur-nv's avatar
schetlur-nv committed
2511
                    )
Przemek Tredak's avatar
Przemek Tredak committed
2512

2513
2514
2515
2516
                if ctx.bias_gelu_nvfusion:
                    fc1_bias_grad, dgelu = bgrad_dgelu_fused(fc2_dgrad, fc1_out, fc1_bias)
                else:
                    dgelu = fc2_dgrad
Przemek Tredak's avatar
Przemek Tredak committed
2517

2518
2519
2520
2521
                # FC1 DGRAD: Unconditional
                fc1_dgrad, _, _ = gemm(
                    fc1_weight,
                    dgelu,
schetlur-nv's avatar
schetlur-nv committed
2522
2523
                    ctx.activation_dtype,
                    get_workspace(),
2524
                    layout="NN",
schetlur-nv's avatar
schetlur-nv committed
2525
2526
                    grad=True,
                )
Przemek Tredak's avatar
Przemek Tredak committed
2527

2528
2529
2530
2531
2532
2533
2534
2535
2536
2537
2538
2539
2540
2541
2542
2543
            # Overlap dgrad-RS/AR with wgrad
            if ctx.set_parallel_mode and ctx.sequence_parallel:
                handle.wait()
                fc1_dgrad, handle = reduce_scatter_along_first_dim(
                    fc1_dgrad, ctx.tp_group, async_op=True
                )
            elif ctx.set_parallel_mode and ctx.tensor_parallel:
                fc1_dgrad, handle = allreduce(fc1_dgrad, ctx.tp_group, async_op=True)

            if fc1_weight.requires_grad:
                if ctx.fp8:
                    # FC1 WGRAD
                    if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
                        ln_out_total_t = tex.fp8_transpose(ln_out_total, fp8_dtype_forward)
                        fc1_wgrad = fp8_gemm(
                            ln_out_total_t,
2544
2545
                            fwd_scale_inverses,
                            tex.FP8FwdTensors.GEMM1_INPUT,
2546
2547
                            fp8_dtype_forward,
                            dgelu_t,
2548
2549
                            ctx.fp8_meta["scaling_bwd"].scale_inv,
                            tex.FP8BwdTensors.GRAD_OUTPUT2,
2550
2551
2552
2553
2554
2555
2556
2557
2558
2559
2560
2561
2562
2563
2564
2565
2566
2567
2568
2569
2570
2571
2572
2573
2574
2575
2576
2577
2578
                            fp8_dtype_backward,
                            ctx.activation_dtype,
                            get_workspace(),
                            accumulate=accumulate_wgrad_into_param_main_grad,
                            out=fc1_weight.main_grad
                            if ctx.fuse_wgrad_accumulation
                            else None,
                            use_split_accumulator=_2X_ACC_WGRAD,
                        )
                    else:
                        ln_out_total_c = cast_from_fp8(
                            ln_out_total,
                            ctx.fp8_meta["scaling_fwd"],
                            tex.FP8FwdTensors.GEMM1_INPUT,
                            fp8_dtype_forward,
                            TE_DType[ctx.activation_dtype],
                        )
                        fc1_wgrad, _, _ = gemm(
                            ln_out_total_c,
                            dgelu_no_fp8,
                            ctx.activation_dtype,
                            get_workspace(),
                            layout="NT",
                            grad=True,
                            accumulate=accumulate_wgrad_into_param_main_grad,
                            out=fc1_weight.main_grad
                            if ctx.fuse_wgrad_accumulation
                            else None,
                        )
schetlur-nv's avatar
schetlur-nv committed
2579
                else:
2580
2581
                    # FC1 WGRAD
                    fc1_wgrad_outputs = gemm(
schetlur-nv's avatar
schetlur-nv committed
2582
                        ln_out_total,
2583
                        dgelu,
schetlur-nv's avatar
schetlur-nv committed
2584
2585
2586
2587
                        ctx.activation_dtype,
                        get_workspace(),
                        layout="NT",
                        grad=True,
2588
                        use_bias=not ctx.bias_gelu_nvfusion,
schetlur-nv's avatar
schetlur-nv committed
2589
                        accumulate=accumulate_wgrad_into_param_main_grad,
2590
                        out=fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
schetlur-nv's avatar
schetlur-nv committed
2591
                    )
Przemek Tredak's avatar
Przemek Tredak committed
2592

2593
2594
2595
2596
                    if ctx.bias_gelu_nvfusion:
                        fc1_wgrad, _, _ = fc1_wgrad_outputs
                    else:
                        fc1_wgrad, fc1_bias_grad, _ = fc1_wgrad_outputs
Przemek Tredak's avatar
Przemek Tredak committed
2597

2598
2599
2600
            # Column Parallel Linear
            if ctx.set_parallel_mode and ctx.tensor_parallel and handle is not None:
                handle.wait()
Przemek Tredak's avatar
Przemek Tredak committed
2601

2602
2603
            # LayerNorm gradient
            d_ln_out = fc1_dgrad.view(inputmat.shape)
Przemek Tredak's avatar
Przemek Tredak committed
2604

2605
2606
2607
            # Residual gradient
            if ctx.return_layernorm_output:
                d_ln_out = d_ln_out + grad_outputs[1].view_as(d_ln_out)
Przemek Tredak's avatar
Przemek Tredak committed
2608

2609
            dxmat, dgamma, dbeta = tex.layernorm_bwd(
2610
2611
                d_ln_out, inputmat, mu, rsigma, ln_weight,
                ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma
2612
            )
Przemek Tredak's avatar
Przemek Tredak committed
2613
2614

        return (
2615
            dxmat.view(ctx.inp_shape) if ctx.requires_dgrad else None,
Przemek Tredak's avatar
Przemek Tredak committed
2616
2617
            dgamma,
            dbeta,
schetlur-nv's avatar
schetlur-nv committed
2618
            fc1_wgrad if fc1_weight.requires_grad else None,
Przemek Tredak's avatar
Przemek Tredak committed
2619
2620
            None,
            None,
ngoyal2707's avatar
ngoyal2707 committed
2621
2622
            fc1_bias_grad if ctx.use_fc1_bias else None,
            None,
schetlur-nv's avatar
schetlur-nv committed
2623
            fc2_wgrad if fc2_weight.requires_grad else None,
Przemek Tredak's avatar
Przemek Tredak committed
2624
2625
            None,
            None,
ngoyal2707's avatar
ngoyal2707 committed
2626
            fc2_bias_grad if ctx.use_fc2_bias else None,
Przemek Tredak's avatar
Przemek Tredak committed
2627
2628
2629
2630
2631
2632
2633
2634
2635
2636
2637
2638
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
2639
            None,
2640
            None,
2641
2642
            None,
            None,
2643
            None,
2644
            None,
Sangkug Lym's avatar
Sangkug Lym committed
2645
            None,
Przemek Tredak's avatar
Przemek Tredak committed
2646
2647
2648
2649
        )


class LayerNormMLP(TransformerEngineBaseModule):
2650
    r"""
Przemek Tredak's avatar
Przemek Tredak committed
2651
2652
2653
2654
2655
2656
2657
2658
2659
2660
2661
2662
    Applies layer normalization on the input followed by the MLP module, consisting of
    2 successive linear transformations, separated by the GeLU activation.

    Parameters
    ----------
    hidden_size : int
                 size of each input sample.
    ffn_hidden_size : int
                     intermediate size to which input samples are projected.
    eps : float, default = 1e-5
         a value added to the denominator of layer normalization for numerical stability.
    bias : bool, default = `True`
ngoyal2707's avatar
ngoyal2707 committed
2663
          if set to `False`, the FC1 and FC2 layers will not learn an additive bias.
Przemek Tredak's avatar
Przemek Tredak committed
2664
2665
2666
2667
2668
2669
2670
2671
2672
2673
2674
2675
    init_method : Callable, default = `None`
                 used for initializing FC1 weights in the following way: `init_method(weight)`.
                 When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`.
    output_layer_init_method : Callable, default = `None`
                              used for initializing FC2 weights in the following way:
                              `output_layer_init_method(weight)`. When set to `None`, defaults to
                              `torch.nn.init.normal_(mean=0.0, std=0.023)`.
    return_layernorm_output : bool, default = `False`
                             if set to `True`, output of layernorm is returned from the forward
                             together with the output of the linear transformation.
                             Example use case: residual connection for transformer module
                             is taken post layernorm.
2676
2677
2678
2679
2680
2681
2682
    zero_centered_gamma : bool, default = 'False'
                         if set to 'True', gamma parameter in LayerNorm is initialized to 0 and
                         the LayerNorm formula changes to

                         .. math::
                            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} *
                            (1 + \gamma) + \beta
Przemek Tredak's avatar
Przemek Tredak committed
2683
2684
2685
2686
2687
2688
2689
2690
2691
2692
2693
2694
2695
2696
2697
2698
2699
2700
2701
2702
2703
2704
2705

    Parallelism parameters
    ----------------------
    set_parallel_mode : bool, default = `False`
                      if set to `True`, FC1 is used as Column Parallel and FC2 is used as Row
                      Parallel as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
    sequence_parallel : bool, default = `False`
                       if set to `True`, uses sequence parallelism.
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
    tp_size : int, default = 1
             used as TP (tensor parallel) world size when TP groups are not formed during
             initialization. In this case, users must call the
             `set_tensor_parallel_group(tp_group)` method on the initialized module before the
             forward pass to supply the tensor parallel group needed for tensor and sequence
             parallel collectives.

    Optimization parameters
    -----------------------
    fuse_wgrad_accumulation : bool, default = 'False'
                             if set to `True`, enables fusing of creation and accumulation of
                             the weight gradient.
    return_bias : bool, default = `False`
ngoyal2707's avatar
ngoyal2707 committed
2706
                 when set to `True`, this module will not apply the additive bias for FC2, but
Przemek Tredak's avatar
Przemek Tredak committed
2707
2708
2709
2710
2711
2712
2713
2714
2715
2716
2717
2718
2719
2720
2721
2722
2723
2724
2725
2726
2727
2728
2729
2730
2731
2732
2733
2734
2735
2736
2737
2738
2739
2740
2741
2742
                 instead return the bias value during the forward pass together with the
                 output of the linear transformation :math:`y = xA^T`. This is useful when
                 the bias addition can be fused to subsequent operations.
    params_dtype : torch.dtype, default = `torch.float32`
                  it controls the type used to allocate the initial parameters. Useful when
                  the model is trained with lower precision and the original FP32 parameters
                  would not fit in GPU memory.
    seq_length: int
               sequence length of input samples. Needed for JIT Warmup, a technique where jit fused
               functions are warmed up before training to ensure same kernels are used for forward
               propogation and activation recompute phase.
    micro_batch_size: int
                     batch size per training step. Needed for JIT Warmup, a technique where jit
                     fused functions are warmed up before training to ensure same kernels are
                     used for forward propogation and activation recompute phase.
    """

    def __init__(
        self,
        hidden_size: int,
        ffn_hidden_size: int,
        eps: float = 1e-5,
        sequence_parallel: bool = False,
        return_bias: bool = False,
        get_rng_state_tracker: Optional[Callable] = None,
        tp_group: Optional[dist_group_type] = None,
        tp_size: int = 1,
        init_method: Optional[Callable] = None,
        bias: bool = True,
        output_layer_init_method: Optional[Callable] = None,
        fuse_wgrad_accumulation: bool = False,
        params_dtype: torch.dtype = torch.float32,
        return_layernorm_output: bool = False,
        seq_length: Optional[int] = None,
        micro_batch_size: Optional[int] = None,
        set_parallel_mode: bool = False,
2743
        zero_centered_gamma: bool = False,
Przemek Tredak's avatar
Przemek Tredak committed
2744
2745
2746
2747
2748
2749
    ) -> None:
        super().__init__()

        self.fuse_wgrad_accumulation = fuse_wgrad_accumulation
        self.use_bias = bias
        self.return_bias = return_bias
2750
        self.apply_bias = bias and not return_bias
Przemek Tredak's avatar
Przemek Tredak committed
2751
2752
2753
        self.return_layernorm_output = return_layernorm_output
        self.bias_gelu_nvfusion = bool(int(os.getenv("NVTE_BIAS_GELU_NVFUSION", "1")))
        self.set_parallel_mode = set_parallel_mode
2754
        self.zero_centered_gamma = zero_centered_gamma
Przemek Tredak's avatar
Przemek Tredak committed
2755
2756
2757
2758
2759
2760
2761
2762
2763
2764
2765
2766
2767
2768
2769
2770
2771
2772
2773
2774
2775
2776
2777
2778
2779
2780
2781
2782
2783
2784
2785
2786
2787
2788
2789
2790
2791
2792
2793
2794
2795
2796
2797
2798
2799
2800
2801
2802
2803
2804
2805
2806
2807
2808
2809
2810
2811

        if tp_group is None:
            self.tp_size = tp_size
            if tp_size == 1:
                self.set_tensor_parallel_group(tp_group)
        else:
            self.tp_size = get_distributed_world_size(tp_group)
            self.set_tensor_parallel_group(tp_group)
        self.set_nccl_overlap_warning_if_tp()

        if init_method is None:
            init_method = get_default_init_method()
        if output_layer_init_method is None:
            output_layer_init_method = get_default_init_method()

        self.sequence_parallel = (self.tp_size > 1) and sequence_parallel
        self.size_per_partition = divide(ffn_hidden_size, self.tp_size)

        # LN init
        self.eps = eps
        self.layer_norm_weight = Parameter(
            torch.empty(
                hidden_size,
                device=torch.cuda.current_device(),
                dtype=params_dtype,
            )
        )
        self.layer_norm_bias = Parameter(
            torch.empty(
                hidden_size,
                device=torch.cuda.current_device(),
                dtype=params_dtype,
            )
        )
        setattr(self.layer_norm_weight, "sequence_parallel", self.sequence_parallel)
        setattr(self.layer_norm_bias, "sequence_parallel", self.sequence_parallel)
        self.reset_layer_norm_parameters()

        # FC1 init
        self.fc1_weight = Parameter(
            torch.empty(
                self.size_per_partition,
                hidden_size,
                device=torch.cuda.current_device(),
                dtype=params_dtype,
            )
        )
        self.fp8_weight_shapes.append(self.fc1_weight.shape)

        initialize_affine_weight_gpu(
            self.fc1_weight,
            init_method,
            get_rng_state_tracker,
            partition_dim=0,
            stride=1,
        )

ngoyal2707's avatar
ngoyal2707 committed
2812
2813
2814
2815
2816
2817
2818
        if self.use_bias:
            self.fc1_bias = Parameter(
                torch.empty(
                    self.size_per_partition,
                    device=torch.cuda.current_device(),
                    dtype=params_dtype,
                )
Przemek Tredak's avatar
Przemek Tredak committed
2819
            )
ngoyal2707's avatar
ngoyal2707 committed
2820
2821
2822
            set_tensor_model_parallel_attributes(self.fc1_bias, True, 0, 1)
        else:
            self.register_buffer("fc1_bias", torch.Tensor().type(params_dtype), persistent=False)
Przemek Tredak's avatar
Przemek Tredak committed
2823
2824
2825
2826
2827
2828
2829
2830
2831
2832
2833
2834
2835
2836
2837
2838
2839
2840
2841
2842
2843
2844
2845

        with torch.no_grad():
            self.fc1_bias.zero_()

        # FC2 init
        self.fc2_weight = Parameter(
            torch.empty(
                hidden_size,
                self.size_per_partition,
                device=torch.cuda.current_device(),
                dtype=params_dtype,
            )
        )
        self.fp8_weight_shapes.append(self.fc2_weight.shape)

        initialize_affine_weight_gpu(
            self.fc2_weight,
            output_layer_init_method,
            get_rng_state_tracker,
            partition_dim=1,
            stride=1,
        )

2846
        if self.use_bias:
Przemek Tredak's avatar
Przemek Tredak committed
2847
2848
2849
2850
2851
2852
            self.fc2_bias = Parameter(
                torch.empty(
                    hidden_size, device=torch.cuda.current_device(), dtype=params_dtype
                )
            )
        else:
2853
            self.register_buffer("fc2_bias", torch.Tensor().type(params_dtype), persistent=False)
Przemek Tredak's avatar
Przemek Tredak committed
2854
2855
2856

        # For RPL, bias has to be added after TP collectives
        # So it cannot be fused with the GEMM
2857
        if self.set_parallel_mode and self.apply_bias:
Przemek Tredak's avatar
Przemek Tredak committed
2858
2859
2860
2861
2862
2863
2864
2865
2866
2867
2868
2869
2870
2871
            self.gemm_bias_unfused_add = True
        else:
            self.gemm_bias_unfused_add = False

        with torch.no_grad():
            self.fc2_bias.zero_()

        if self.bias_gelu_nvfusion:
            set_jit_fusion_options()
            if seq_length and micro_batch_size:
                warmup_jit_bias_gelu_all_dtypes(
                    self.size_per_partition, seq_length, micro_batch_size
                )

2872
2873
2874
2875
2876
2877
2878
        # These many SMs are subtracted from the total SM count when calling forward
        # and backward LayerNorm C APIs. These envvars can be used to prevent the LN
        # kernels from using all SMs in the device. This is useful for cases such as
        # communication overlap with LN.
        self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))
        self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))

Przemek Tredak's avatar
Przemek Tredak committed
2879
2880
    def reset_layer_norm_parameters(self) -> None:
        """Init LN params"""
2881
2882
2883
2884
        if not self.zero_centered_gamma:
            init.ones_(self.layer_norm_weight)
        else:
            init.zeros_(self.layer_norm_weight)
Przemek Tredak's avatar
Przemek Tredak committed
2885
2886
2887
2888
2889
2890
2891
2892
2893
2894
2895
2896
2897
2898
2899
2900
2901
2902
2903
2904
2905
2906
2907
2908
2909
2910
2911
        init.zeros_(self.layer_norm_bias)

    def forward(
        self, inp: torch.Tensor, is_first_microbatch: Optional[bool] = None
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
        """
        Apply layer normalization to the input followed by a feedforward network (MLP Block).

        Parameters
        ----------
        inp : torch.Tensor
             Input tensor.
        is_first_microbatch : {True, False, None}, default = None
                             During training using either gradient accumulation or
                             pipeline parallelism a minibatch of data is further split
                             into microbatches. Between the microbatches of the same minibatch
                             the model weights are not updated. Setting this parameter indicates
                             whether the current microbatch is the first in a minibatch or not.
                             When set, this parameter enables additional optimizations:

                             * during FP8 training, it allows caching of the FP8 versions of
                               the weights
                             * it also allows skipping gradient accumulation during the
                               first microbatch (since it is the first gradient being
                               produced)
        """

2912
        with self.prepare_forward(inp, is_first_microbatch, num_gemms=2) as inp:
2913
            if torch.is_grad_enabled():
2914
2915
2916
2917
2918
2919
                fwd_fn = _LayerNormMLP.apply
                args = []
            else:
                fwd_fn = _LayerNormMLP.forward
                args = [None]
            args += (
2920
2921
2922
2923
2924
2925
2926
                inp,
                self.layer_norm_weight,
                self.layer_norm_bias,
                self.fc1_weight,
                self.weight1_fp8 if self.fp8 else None,
                self.weight1_t_fp8 if self.fp8 else None,
                self.fc1_bias,
ngoyal2707's avatar
ngoyal2707 committed
2927
                self.use_bias,
2928
2929
2930
2931
                self.fc2_weight,
                self.weight2_fp8 if self.fp8 else None,
                self.weight2_t_fp8 if self.fp8 else None,
                self.fc2_bias,
2932
                self.apply_bias and not self.gemm_bias_unfused_add,
2933
2934
2935
                self.eps,
                is_first_microbatch,
                self.fp8,
schetlur-nv's avatar
schetlur-nv committed
2936
                self.fp8_calibration,
2937
2938
2939
                self.fp8_meta,
                self.fuse_wgrad_accumulation,
                self.tp_group,
Sangkug Lym's avatar
Sangkug Lym committed
2940
                self.tp_size,
2941
2942
2943
2944
2945
2946
                self.sequence_parallel,
                self.tp_size > 1,
                self.activation_dtype,
                self.return_layernorm_output,
                self.bias_gelu_nvfusion,
                self.set_parallel_mode,
2947
                torch.is_grad_enabled(),
2948
2949
                self.fwd_ln_sm_margin,
                self.bwd_ln_sm_margin,
2950
                self.zero_centered_gamma,
2951
            )
2952
            out = fwd_fn(*args)
Przemek Tredak's avatar
Przemek Tredak committed
2953
2954
2955
2956
2957
2958
2959
2960
2961
2962
2963
2964
2965
2966
2967
2968
2969
2970
2971
2972
2973
2974
2975
2976
2977
2978

        if self.return_layernorm_output:
            out, ln_out = out

        if self.gemm_bias_unfused_add:
            out = out + cast_if_needed(self.fc2_bias, self.activation_dtype)

        if self.return_bias:
            if self.return_layernorm_output:
                return out, cast_if_needed(self.fc2_bias, self.activation_dtype), ln_out
            return out, cast_if_needed(self.fc2_bias, self.activation_dtype)
        if self.return_layernorm_output:
            return out, ln_out
        return out


class _LayerNorm(torch.autograd.Function):
    """functional LayerNorm"""

    @staticmethod
    def forward(
        ctx,
        inp: torch.Tensor,
        ln_weight: torch.Tensor,
        ln_bias: torch.Tensor,
        eps: float,
2979
2980
        fwd_ln_sm_margin: int,
        bwd_ln_sm_margin: int,
2981
        zero_centered_gamma: bool,
Przemek Tredak's avatar
Przemek Tredak committed
2982
2983
2984
2985
2986
2987
2988
    ) -> torch.Tensor:
        # Make sure input dimensions are compatible
        in_features = ln_weight.numel()
        assert inp.is_cuda, "TransformerEngine needs CUDA."
        assert inp.shape[-1] == in_features, "LayerNorm not possible"
        inputmat = inp.view((-1, in_features))

2989
2990
2991
        ln_out, mu, rsigma = tex.layernorm_fwd(inputmat, ln_weight,
                                               ln_bias, eps, fwd_ln_sm_margin,
                                               zero_centered_gamma)
Przemek Tredak's avatar
Przemek Tredak committed
2992
2993
        ctx.save_for_backward(inputmat, ln_weight, mu, rsigma)
        ctx.inp_shape = inp.shape
2994
        ctx.bwd_ln_sm_margin = bwd_ln_sm_margin
2995
        ctx.zero_centered_gamma = zero_centered_gamma
Przemek Tredak's avatar
Przemek Tredak committed
2996
2997
2998
2999
3000
3001
3002
3003
3004
3005
        return ln_out.view_as(inp)

    @staticmethod
    def backward(
        ctx, grad_output: torch.Tensor
    ) -> Tuple[Union[torch.Tensor, None], ...]:
        inputmat, ln_weight, mu, rsigma = ctx.saved_tensors
        grad_output = grad_output.contiguous()
        d_ln_out = grad_output.view(inputmat.shape)
        dxmat, dgamma, dbeta = tex.layernorm_bwd(
3006
3007
            d_ln_out, inputmat, mu, rsigma, ln_weight,
            ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma
Przemek Tredak's avatar
Przemek Tredak committed
3008
        )
3009
        return dxmat.view(ctx.inp_shape), dgamma, dbeta, None, None, None, None
Przemek Tredak's avatar
Przemek Tredak committed
3010
3011
3012
3013
3014
3015
3016
3017


class LayerNorm(torch.nn.Module):
    r"""
    Applies Layer Normalization over a mini-batch of inputs as described in
    the paper `Layer Normalization <https://arxiv.org/abs/1607.06450>`__

    .. math::
3018
        y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * \gamma + \beta
Przemek Tredak's avatar
Przemek Tredak committed
3019
3020
3021
3022
3023
3024
3025
3026
3027
3028
3029
3030
3031
3032
3033
3034

    :math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
    size :attr:`hidden_size`

    Parameters
    ----------
    hidden_size : int
                size of each input sample.
    eps : float, default = 1e-5
        a value added to the denominator of layer normalization for numerical stability.
    sequence_parallel : bool, default = `False`
                        if set to `True`, uses sequence parallelism.
    params_dtype : torch.dtype, default = `torch.float32`
                    it controls the type used to allocate the initial parameters. Useful when
                    the model is trained with lower precision and the original FP32 parameters
                    would not fit in GPU memory.
3035
3036
3037
3038
3039
3040
3041
    zero_centered_gamma : bool, default = 'False'
                         if set to 'True', gamma parameter in LayerNorm is initialized to 0 and
                         the LayerNorm formula changes to

                         .. math::
                            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} *
                            (1 + \gamma) + \beta
Przemek Tredak's avatar
Przemek Tredak committed
3042
3043
3044
3045
3046
3047
3048
3049
    """

    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-5,
        sequence_parallel: bool = False,
        params_dtype: torch.dtype = torch.float32,
3050
        zero_centered_gamma: bool = False,
Przemek Tredak's avatar
Przemek Tredak committed
3051
3052
3053
    ) -> None:
        super().__init__()
        self.eps = eps
3054
        self.zero_centered_gamma = zero_centered_gamma
3055
        self.weight = Parameter(
Przemek Tredak's avatar
Przemek Tredak committed
3056
3057
3058
3059
3060
3061
            torch.empty(
                hidden_size,
                device=torch.cuda.current_device(),
                dtype=params_dtype,
            )
        )
3062
        self.bias = Parameter(
Przemek Tredak's avatar
Przemek Tredak committed
3063
3064
3065
3066
3067
3068
            torch.empty(
                hidden_size,
                device=torch.cuda.current_device(),
                dtype=params_dtype,
            )
        )
3069
3070
        setattr(self.weight, "sequence_parallel", sequence_parallel)
        setattr(self.bias, "sequence_parallel", sequence_parallel)
Przemek Tredak's avatar
Przemek Tredak committed
3071
3072
        self.reset_layer_norm_parameters()

3073
3074
3075
3076
3077
3078
3079
        # These many SMs are subtracted from the total SM count when calling forward
        # and backward LayerNorm C APIs. These envvars can be used to prevent the LN
        # kernels from using all SMs in the device. This is useful for cases such as
        # communication overlap with LN.
        self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))
        self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))

3080
3081
3082
3083
3084
3085
3086
3087
3088
3089
3090
3091
3092
3093
3094
3095
3096
    def load_state_dict(
        self,
        state_dict: Mapping[str, Any],
        strict: bool = True,
    ) -> None:
        """Override PyTorch loader to maintain backward compatibility
        with previous version of LayerNorm parameter names.
        """
        if "layer_norm_weight" in state_dict:
            state_dict["weight"] = state_dict["layer_norm_weight"]
            del state_dict["layer_norm_weight"]
        if "layer_norm_bias" in state_dict:
            state_dict["bias"] = state_dict["layer_norm_bias"]
            del state_dict["layer_norm_bias"]

        super().load_state_dict(state_dict, strict)

Przemek Tredak's avatar
Przemek Tredak committed
3097
3098
    def reset_layer_norm_parameters(self) -> None:
        """Init LN params"""
3099
3100
3101
3102
        if not self.zero_centered_gamma:
            init.ones_(self.weight)
        else:
            init.zeros_(self.weight)
3103
        init.zeros_(self.bias)
Przemek Tredak's avatar
Przemek Tredak committed
3104

schetlur-nv's avatar
schetlur-nv committed
3105

Przemek Tredak's avatar
Przemek Tredak committed
3106
3107
    def forward(self, inp: torch.Tensor) -> torch.Tensor:
        """LayerNorm FWD"""
3108
3109
3110
3111
3112
3113
        # Maintain backward compatibility.
        if hasattr(self, "layer_norm_weight"):
            setattr(self, "weight", self.layer_norm_weight)
        if hasattr(self, "layer_norm_bias"):
            setattr(self, "bias", self.layer_norm_bias)

3114
3115
3116
3117
3118
3119
3120
        return _LayerNorm.apply(
            inp,
            self.weight,
            self.bias,
            self.eps,
            self.fwd_ln_sm_margin,
            self.bwd_ln_sm_margin,
3121
            self.zero_centered_gamma
3122
        )