module.py 143 KB
Newer Older
1
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
5
6
#
# See LICENSE for license information.

"""Top level Transformer Engine PyTorch modules"""
import os
7
import pickle
Przemek Tredak's avatar
Przemek Tredak committed
8
9
import warnings
from abc import ABC, abstractmethod
10
from typing import Union, Optional, Callable, Tuple, Dict, Any, Mapping, List
Przemek Tredak's avatar
Przemek Tredak committed
11
from functools import partial
12
from contextlib import contextmanager
Przemek Tredak's avatar
Przemek Tredak committed
13

14
import numpy as np
Przemek Tredak's avatar
Przemek Tredak committed
15
import torch
16
import torch.nn.functional as F
Przemek Tredak's avatar
Przemek Tredak committed
17
18
19
20
21
22
from torch.nn.parameter import Parameter
from torch.nn import init

import transformer_engine_extensions as tex
from .fp8 import (
    is_fp8_enabled,
schetlur-nv's avatar
schetlur-nv committed
23
    is_fp8_calibration,
Przemek Tredak's avatar
Przemek Tredak committed
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
    get_fp8_recipe,
    get_fp8_group,
    get_default_fp8_recipe,
    get_fp8_te_dtype,
    is_first_fp8_module,
    new_fp8_context_id,
    get_fp8_context_id,
    set_fp8_context_id,
    add_amax_to_global_buffer,
    copy_amax_from_global_buffer,
    global_amax_reduction,
    setup_amax_forward_global_reduce_func,
    amax_and_scale_update,
    get_global_fp8_buffer,
    set_global_fp8_buffer,
    set_amax_buffer_key_deletion,
    delete_key_from_amax_buffer,
41
42
43
    copy_forward_fp8_meta_tensors_for_recompute,
    get_old_fp8_meta_tensors_for_recompute,
    restore_fp8_meta_tensors,
Sangkug Lym's avatar
Sangkug Lym committed
44
    get_amax_reduce_handle_fwd,
Przemek Tredak's avatar
Przemek Tredak committed
45
46
47
48
49
50
51
52
53
54
55
)
from .jit import (
    bias_gelu_fused,
    bgrad_dgelu_fused,
    set_jit_fusion_options,
    warmup_jit_bias_gelu_all_dtypes,
)
from .utils import (
    divide,
    get_default_init_method,
    cast_if_needed,
56
    check_dim_for_fp8_forward_exec,
Przemek Tredak's avatar
Przemek Tredak committed
57
58
59
60
61
62
63
64
65
)
from .distributed import (
    set_tensor_model_parallel_attributes,
    get_distributed_world_size,
    allreduce,
    initialize_affine_weight_gpu,
    reduce_scatter_along_first_dim,
    gather_along_first_dim,
    gather_along_last_dim,
66
67
    is_fp8_activation_recompute_enabled,
    in_fp8_activation_recompute_phase,
Przemek Tredak's avatar
Przemek Tredak committed
68
69
70
71
72
73
74
75
76
)
from .cpp_extensions import (
    fp8_gemm,
    gemm,
    fp8_cast_transpose_fused,
    fp8_cast_transpose_bgrad_fused,
    fp8_gelu,
    fp8_cast_transpose_bgrad_dgelu_fused,
    layernorm_fwd_fp8,
77
78
    layernorm_fwd_fp8_inf,
    layernorm_fwd_inf,
Przemek Tredak's avatar
Przemek Tredak committed
79
80
81
82
83
84
85
86
87
    cast_to_fp8,
    cast_from_fp8,
)
from .constants import GemmParallelModes, dist_group_type, TE_DType

_2X_ACC_FPROP = False
_2X_ACC_DGRAD = True
_2X_ACC_WGRAD = True
_cublas_workspace = None
88
89
_ub_communicators = None
_NUM_MAX_UB_STREAMS = 3
Sangkug Lym's avatar
Sangkug Lym committed
90
_amax_reduce_handle_bwd = None
Przemek Tredak's avatar
Przemek Tredak committed
91
92
93
94
95
96
97
98
99
100
101
102
103
104


def get_cublas_workspace_size_bytes() -> None:
    """Return 32 MiB if using hopper, 4 MiB for all other architectures."""
    if torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 9:
        return 33_554_432
    return 4_194_304


def get_workspace() -> torch.Tensor:
    """Returns workspace for cublas."""
    global _cublas_workspace
    if _cublas_workspace is None:
        _cublas_workspace = torch.empty(
cyanguwa's avatar
cyanguwa committed
105
            get_cublas_workspace_size_bytes(), dtype=torch.uint8, device="cuda"
Przemek Tredak's avatar
Przemek Tredak committed
106
107
108
        )
    return _cublas_workspace

109
@contextmanager
Sangkug Lym's avatar
Sangkug Lym committed
110
111
112
113
114
115
116
def _prepare_backward(
    fp8: bool,
    fp8_meta: Dict[str, Any],
    tp_group: dist_group_type,
    tp_size: int,
    name: str = ""
) -> None:
117
118
    """Checks and prep for BWD."""
    if fp8:
Sangkug Lym's avatar
Sangkug Lym committed
119
120
121
122
123
        global _amax_reduce_handle_bwd
        if _amax_reduce_handle_bwd is not None:
            _amax_reduce_handle_bwd.wait()
            _amax_reduce_handle_bwd = None

124
        # Update amax and scale; Skip all setup for global amax reduction
125
        if not fp8_meta["recipe"].reduce_amax:
126
127
128
129
130
131
132
133
            amax_and_scale_update(fp8_meta, False)
        else:
            # From previous iteration
            copy_amax_from_global_buffer(fp8_meta, forward=False)
            amax_and_scale_update(fp8_meta, False)
            set_amax_buffer_key_deletion(fp8_meta, forward=False)

            # Get new backward key.
134
            fp8_meta["autocast_id_bwd"] = fp8_meta["autocast_id_fwd_stack"].pop(0)
135
136
137
138
139
140

            add_amax_to_global_buffer(fp8_meta, forward=False)

    with torch.cuda.nvtx.range(name + " backward"):
        yield

141
    if fp8 and fp8_meta["recipe"].reduce_amax:
142
        if fp8_meta["first_module"]:
Sangkug Lym's avatar
Sangkug Lym committed
143
144
145
146
147
148
            _amax_reduce_handle_bwd = global_amax_reduction(
                fp8_meta,
                tp_group,
                tp_size,
                forward=False
            )
149
            delete_key_from_amax_buffer(forward=False)
150

Przemek Tredak's avatar
Przemek Tredak committed
151

152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
def initialize_ub(
    shape: list,
    tp_size: int,
    use_fp8: bool = False,
    ub_cfgs: Optional[dict] = None
) -> None:
    """Initialize communicators for TP comm overlap using userbuffers."""
    global _ub_communicators
    assert _ub_communicators is None, "UB communicators are already initialized."
    _ub_communicators = {}
    rank_id = torch.distributed.get_rank()

    # Increase the workspace by the number of maximum concurrent streams
    global _cublas_workspace
    _cublas_workspace = get_workspace().repeat(_NUM_MAX_UB_STREAMS)

    # Default buffer precision: AllGather buffers use fp8 when using fp8 recipe
    fp8_buf = [
        "qkv_fprop", "qkv_dgrad", "proj_dgrad", "fc1_fprop", "fc1_dgrad", "fc2_dgrad"
    ]
    # Default overlap methods for layers
    methods = {
        "ring_exchange":["qkv_fprop", "fc1_fprop", "proj_dgrad", "fc2_dgrad"],
        "pipeline":["proj_fprop", "fc2_fprop"],
        "bulk":["qkv_dgrad", "qkv_wgrad", "fc1_dgrad", "fc1_wgrad"],
    }

    def get_method(name):
        for method, names in methods.items():
            if name in names:
                return method
        raise KeyError(f"Given layer name {name} does not exist.")

    def add_ub(
        name: str,
        method: str,
        num_sm: int = 16,
        cga_size: int = 2,
        set_sm_margin: int = 0,
        num_splits: int = 4,
        aggregate: int = 0,
    ) -> None:
        dtype = torch.uint8 if (use_fp8 and name in fp8_buf) else torch.bfloat16
        sample_buffer = torch.empty(shape, dtype=dtype, device='cuda')
        if method == 'ring_exchange':
            ub_obj = tex.UbufP2PCommOverlap(
                    sample_buffer,          # Sample userbuffer
                    rank_id,                # Rank id
                    tp_size,                # TP size
                    aggregate,              # Aggregate 2X GEMM chunks
                    _NUM_MAX_UB_STREAMS,    # Max concurrent GEMM streams
                )
        else:
            ub_obj = tex.UbufCommOverlap(
                    sample_buffer,          # Sample userbuffer
                    rank_id,                # Rank id
                    tp_size,                # TP size
                    num_sm,                 # Number of communication SMs
                    cga_size,               # CGA cluster size
                    num_splits,             # Number of communication splits
                    set_sm_margin,          # Set SM margin
                    _NUM_MAX_UB_STREAMS,    # Max concurrent GEMM streams
                )
        _ub_communicators[name] = ub_obj

    for name in (methods["ring_exchange"]+methods["pipeline"]+methods["bulk"]):
        if ub_cfgs is not None and name in ub_cfgs:
            ub_cfg = ub_cfgs[name]
            method = ub_cfg["method"] if "method" in ub_cfg else get_method(name)
            num_sm = ub_cfg["num_sm"] if "num_sm" in ub_cfg else 16
            cga_size = ub_cfg["cga_size"] if "cga_size" in ub_cfg else 2
            num_splits = ub_cfg["num_splits"] if "num_splits" in ub_cfg else 0
            set_sm_margin = ub_cfg["set_sm_margin"] if "set_sm_margin" in ub_cfg else 0
            aggregate = ub_cfg["aggregate"] if "aggregate" in ub_cfg else 0
            add_ub(
                name,
                method,
                num_sm,
                cga_size,
                set_sm_margin,
                num_splits,
                aggregate
            )
        else:
            method = get_method(name)
            if method == "pipeline":
                add_ub(name, method)
            else:
                add_ub(name, method, num_splits=0)


def get_ub(name: str):
    """Get userbuffer communicator corresponding to give key."""
    global _ub_communicators
    assert _ub_communicators is not None, "UB manager is not initialized."
    assert name in _ub_communicators, f"UB for {name} is not registered."
    return _ub_communicators[name]


251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
class _NoopCat(torch.autograd.Function):
    """This class is a no-op replacement for `torch.cat`."""

    @staticmethod
    def forward(ctx,
                full_param_buffer: torch.Tensor,
                *params_split: Tuple[torch.Tensor, ...],
    ) -> torch.Tensor:
        assert not full_param_buffer.requires_grad, "Buffers should not require gradient"
        assert (
            full_param_buffer.shape[0] % len(params_split) == 0
        ), "Dimensions not compatible for concatenation"

        param_temp = full_param_buffer.new()
        param_temp.set_(full_param_buffer.storage(),
                        full_param_buffer.storage_offset(),
                        full_param_buffer.size(),
                        full_param_buffer.stride())
        param_temp.requires_grad = True

        ctx.save_for_backward(full_param_buffer, *params_split)
        return param_temp

    @staticmethod
    def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
        full_param_buffer, *params_split = ctx.saved_tensors

        split_size = full_param_buffer.shape[0] // len(params_split)
        grads = []

        for i, _ in enumerate(params_split):
            grads.append(grad_output[i * split_size : (i+1) * split_size])

        return None, *grads


Przemek Tredak's avatar
Przemek Tredak committed
287
288
289
290
291
292
class TransformerEngineBaseModule(torch.nn.Module, ABC):
    """Base TE module."""

    def __init__(self) -> None:
        super().__init__()
        assert torch.cuda.is_available(), "TransformerEngine needs CUDA."
schetlur-nv's avatar
schetlur-nv committed
293
        self.fp8_initialized = False
Przemek Tredak's avatar
Przemek Tredak committed
294
        self.fp8 = False
schetlur-nv's avatar
schetlur-nv committed
295
        self.fp8_calibration = False
Przemek Tredak's avatar
Przemek Tredak committed
296
297
298
299
300
301
302
303
        self.fp8_meta = {}
        self.fp8_meta["fp8_group"] = None
        self.fp8_meta["recipe"] = get_default_fp8_recipe()
        self.fp8_meta_tensors_initialized = False
        self.tp_group = None
        self.tp_size = 1
        self.sequence_parallel = False
        self.fp8_weight_shapes = []
304
        self.fp8_meta["autocast_id_fwd_stack"] = []
Sangkug Lym's avatar
Sangkug Lym committed
305
        self.fp8_meta["async_amax_reduction"] = bool(
306
            int(os.getenv("NVTE_ASYNC_AMAX_REDUCTION", "0"))
Sangkug Lym's avatar
Sangkug Lym committed
307
        )
Przemek Tredak's avatar
Przemek Tredak committed
308
309
310
311

    def set_meta_tensor(self, fwd: bool) -> None:
        """Init scales and amaxes for fwd | bwd."""
        fp8_meta_tensor_key = "scaling_fwd" if fwd else "scaling_bwd"
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328

        if self.fp8_meta_tensors_initialized:
            # Handle changed amax history size.
            curr_len = self.fp8_meta[fp8_meta_tensor_key].amax_history.shape[0]
            need_len = self.fp8_meta["recipe"].amax_history_len
            if need_len < curr_len:
                self.fp8_meta[fp8_meta_tensor_key].amax_history = (
                    self.fp8_meta[fp8_meta_tensor_key]
                    .amax_history[: self.fp8_meta["recipe"].amax_history_len].clone()
                )
            elif need_len > curr_len:
                extra_rows = need_len - curr_len
                self.fp8_meta[fp8_meta_tensor_key].amax_history = F.pad(
                    self.fp8_meta[fp8_meta_tensor_key].amax_history, pad=(0, 0, 0, extra_rows)
                )
            return

329
330
        # Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and
        # 2 (grad_output and grad_input) for bwd
Przemek Tredak's avatar
Przemek Tredak committed
331
        num_fp8_tensors = (
332
            self.fp8_meta["num_gemms"] * 3 if fwd else self.fp8_meta["num_gemms"] * 2
Przemek Tredak's avatar
Przemek Tredak committed
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
        )

        self.fp8_meta[fp8_meta_tensor_key] = tex.FP8TensorMeta()
        self.fp8_meta[fp8_meta_tensor_key].scale = torch.ones(
            num_fp8_tensors, dtype=torch.float32, device="cuda"
        )
        self.fp8_meta[fp8_meta_tensor_key].scale_inv = torch.ones(
            num_fp8_tensors, dtype=torch.float32, device="cuda"
        )
        self.fp8_meta[fp8_meta_tensor_key].amax_history = torch.zeros(
            self.fp8_meta["recipe"].amax_history_len,
            num_fp8_tensors,
            dtype=torch.float32,
            device="cuda",
        )

349
350
351
        # Needed for calculation of scale inverses to
        # preserve scale_inv when caching FP8 weights
        if fwd:
352
            # [True, False, True]: -> [input, weight, output]
353
            self.fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"] = torch.BoolTensor(
354
                [True, False, True] * self.fp8_meta["num_gemms"]
355
356
            ).cuda()
        else:
357
            # [True, True]: -> [grad_output, grad_input]
358
            self.fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"] = torch.BoolTensor(
359
                [True, True] * self.fp8_meta["num_gemms"]
360
361
            ).cuda()

Przemek Tredak's avatar
Przemek Tredak committed
362
363
364
365
    def init_fp8_meta_tensors(self) -> None:
        """Init scales and amaxes."""
        self.set_meta_tensor(True)
        self.set_meta_tensor(False)
366
        self.fp8_meta_tensors_initialized = True
Przemek Tredak's avatar
Przemek Tredak committed
367

368
    def get_extra_state(self) -> torch.Tensor:
Przemek Tredak's avatar
Przemek Tredak committed
369
        """Save before checkpointing."""
370
        state = None
schetlur-nv's avatar
schetlur-nv committed
371
        if self.fp8 or self.fp8_calibration:
372
373
            state = {}
            state["scale_fwd"] = self.fp8_meta["scaling_fwd"].scale
374
            state["scale_inv_fwd"] = self.fp8_meta["scaling_fwd"].scale_inv
375
376
            state["amax_history_fwd"] = self.fp8_meta["scaling_fwd"].amax_history
            state["scale_bwd"] = self.fp8_meta["scaling_bwd"].scale
377
            state["scale_inv_bwd"] = self.fp8_meta["scaling_bwd"].scale_inv
378
379
380
381
382
383
384
385
386
387
            state["amax_history_bwd"] = self.fp8_meta["scaling_bwd"].amax_history
            state["global_fp8_buffer"] = get_global_fp8_buffer()

            # Store other pickelable values.
            extra = {}
            for k, v in self.fp8_meta.items():
                if isinstance(v, (bool, int, float, str)):
                    extra[k] = v
            state["extra_fp8_variables"] = extra

388
389
        state_serialized = pickle.dumps(state)
        state_tensor = torch.tensor(np.frombuffer(state_serialized, dtype=np.uint8))
Przemek Tredak's avatar
Przemek Tredak committed
390

391
392
393
        return state_tensor

    def set_extra_state(self, state: torch.Tensor) -> None:
Przemek Tredak's avatar
Przemek Tredak committed
394
395
396
397
        """Load previous state."""
        if state is None:
            return

398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
        # Maintain backward compatibility with v0.2.0 and older.
        if isinstance(state, list):
            warnings.warn(
                "This checkpoint format is deprecated and will be"
                "removed in a future release of Transformer Engine"
            )

            # Retrieve checkpointed items.
            scale_fwd = state[0]
            amax_history_fwd = state[1]
            scale_bwd = state[2]
            amax_history_bwd = state[3]
            self.fp8_meta["recipe"].amax_history_len = amax_history_fwd.shape[0]
            self.fp8_meta["num_gemms"] = (
                amax_history_fwd.shape[1] // 2
            )  # Two FWD tensors per GEMM

            # Initialize before loading
            self.init_fp8_meta_tensors()
            self.fp8_meta["scaling_fwd"].scale.copy_(scale_fwd)
            self.fp8_meta["scaling_fwd"].amax_history.copy_(amax_history_fwd)
            self.fp8_meta["scaling_bwd"].scale.copy_(scale_bwd)
            self.fp8_meta["scaling_bwd"].amax_history.copy_(amax_history_bwd)

            # Restore global FP8 buffer state.
            set_global_fp8_buffer(state[4])
            self.fp8_meta["update_amax_and_scale_fwd"] = state[5]
            self.fp8_meta["global_fp8_buffer_pos_fwd"] = state[6]
            self.fp8_meta["global_fp8_buffer_pos_bwd"] = state[7]
            self.fp8_meta["autocast_id_fwd"] = state[8]
            self.fp8_meta["autocast_id_bwd"] = state[9]
            return

431
        if isinstance(state, torch.Tensor):
432
            state = pickle.loads(state.detach().cpu().numpy().tobytes())
433
434
435
            if state is None:
                return

436
437
438
439
440
        # Restore global FP8 buffer states.
        set_global_fp8_buffer(state["global_fp8_buffer"])
        # Load extra items.
        self.fp8_meta.update(state["extra_fp8_variables"])
        self.fp8_meta["recipe"].amax_history_len = state["amax_history_fwd"].shape[0]
441
442
        if "global_fp8_buffer_pos_fwd_recompute" in self.fp8_meta:
            del self.fp8_meta["global_fp8_buffer_pos_fwd_recompute"]
443
444

        # Initialize before loading.
Przemek Tredak's avatar
Przemek Tredak committed
445
        self.init_fp8_meta_tensors()
446
447
448
449
        self.fp8_meta["scaling_fwd"].scale.copy_(state["scale_fwd"])
        self.fp8_meta["scaling_fwd"].amax_history.copy_(state["amax_history_fwd"])
        self.fp8_meta["scaling_bwd"].scale.copy_(state["scale_bwd"])
        self.fp8_meta["scaling_bwd"].amax_history.copy_(state["amax_history_bwd"])
Przemek Tredak's avatar
Przemek Tredak committed
450

451
452
453
454
455
456
457
458
459
460
461
462
        # Backwards compatibility: compute scale inv if it wasn't saved in the extra state.
        if "scale_inv_fwd" not in state or "scale_inv_bwd" not in state:
            assert (
                "scale_inv_fwd" not in state and "scale_inv_bwd" not in state
            ), "Invalid state, began saving scale_inv_fwd and scale_inv_bwd at the same time"
            self.fp8_meta["scaling_fwd"].scale_inv.copy_(1.0/state["scale_fwd"])
            self.fp8_meta["scaling_bwd"].scale_inv.copy_(1.0/state["scale_bwd"])
        else:
            self.fp8_meta["scaling_fwd"].scale_inv.copy_(state["scale_inv_fwd"])
            self.fp8_meta["scaling_bwd"].scale_inv.copy_(state["scale_inv_bwd"])


Przemek Tredak's avatar
Przemek Tredak committed
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
    def set_activation_dtype(self, inp: torch.Tensor) -> None:
        """Get activation data type for AMP."""
        # Native AMP (`torch.autocast`) gets highest priority
        if torch.is_autocast_enabled():
            self.activation_dtype = torch.get_autocast_gpu_dtype()
            return

        # All checks after this have already been performed once, thus skip
        # We assume that user doesn't change input types across iterations
        if hasattr(self, "activation_dtype"):
            return

        assert all(
            (
                (inp.dtype == param.dtype) if param is not None else True
                for param in self.parameters()
            )
        ), (
            "Data type for activations and weights must "
            "match when outside of autocasted region"
        )
        assert all(
            (
                (inp.dtype == buf.dtype) if buf is not None else True
                for buf in self.buffers()
            )
        ), (
            "Data type for activations and buffers must "
            "match when outside of autocasted region"
        )
        self.activation_dtype = inp.dtype

    def set_fp8_weights(self) -> None:
        """Initializes FP8 weights for the module as class attributes. These
        are not parameters or buffers since we do not want functions such as
        `.to(dtype)` or `.to(device)` to effect them. These also do not need
        to be checkpointed. During `init` phase of the module, the attribute
        `fp8_weight_shapes` must be populated with the tensor shapes for FP8
        weights. This function will iterate over those shapes and initialize
        respective attributed named `weight1_fp8`, `weight2_fp8`, ...
        """
504
505
506
        if not self.fp8:
            return

Przemek Tredak's avatar
Przemek Tredak committed
507
508
509
        for i, shape in enumerate(self.fp8_weight_shapes, start=1):
            weight_cast_attr = f"weight{i}_fp8"
            weight_transpose_attr = f"weight{i}_t_fp8"
510
511
512
513
514
515
516
517
518
519
520
521
522

            if (
                hasattr(self, weight_cast_attr)
                and getattr(self, weight_cast_attr).shape == shape
            ):
                return

            setattr(
                self,
                weight_cast_attr,
                torch.empty(
                    shape,
                    device=torch.cuda.current_device(),
cyanguwa's avatar
cyanguwa committed
523
                    dtype=torch.uint8,
524
525
526
527
528
529
530
531
532
                ),
            )
            setattr(
                self,
                weight_transpose_attr,
                torch.empty(
                    shape[1],
                    shape[0],
                    device=torch.cuda.current_device(),
cyanguwa's avatar
cyanguwa committed
533
                    dtype=torch.uint8,
534
535
                ),
            )
Przemek Tredak's avatar
Przemek Tredak committed
536
537
538
539
540
541

    def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None:
        """Set TP group."""
        self.tp_group = tp_group
        self.tp_group_initialized = True

schetlur-nv's avatar
schetlur-nv committed
542
543
    # This routine is shared across FP8 and FP8_calibration paths so should not actually
    # assume FP8 execution.
Przemek Tredak's avatar
Przemek Tredak committed
544
545
    def fp8_init(self, num_gemms: int = 1) -> None:
        """Initialize fp8 related metadata and tensors during fprop."""
546
547
548
549
        self.fp8 = is_fp8_enabled()
        self.fp8_calibration = is_fp8_calibration()

        if self.fp8 or self.fp8_calibration:
schetlur-nv's avatar
schetlur-nv committed
550
551
552
            # FP8 init has already been run and recipe is the same, don't do anything.
            if self.fp8_initialized and get_fp8_recipe() == self.fp8_meta["recipe"]:
                return
Przemek Tredak's avatar
Przemek Tredak committed
553

schetlur-nv's avatar
schetlur-nv committed
554
555
556
557
            # Set FP8, recipe, and other FP8 metadata
            self.fp8_meta["recipe"] = get_fp8_recipe()
            self.fp8_meta["num_gemms"] = num_gemms
            self.fp8_meta["fp8_group"] = get_fp8_group()
Przemek Tredak's avatar
Przemek Tredak committed
558

schetlur-nv's avatar
schetlur-nv committed
559
560
561
            # Set FP8_MAX per tensor according to recipe
            self.fp8_meta["fp8_max_fwd"] = self.fp8_meta["recipe"].fp8_format.value.max_fwd
            self.fp8_meta["fp8_max_bwd"] = self.fp8_meta["recipe"].fp8_format.value.max_bwd
Przemek Tredak's avatar
Przemek Tredak committed
562

schetlur-nv's avatar
schetlur-nv committed
563
564
565
566
567
568
569
            # Allocate scales and amaxes
            self.init_fp8_meta_tensors()
            self.fp8_initialized = True
        else:
            # If fp8 isn't enabled, turn off and return.
            self.fp8_initialized = False
            return
Przemek Tredak's avatar
Przemek Tredak committed
570

571
    @contextmanager
572
573
574
575
576
577
    def prepare_forward(
        self,
        inp: torch.Tensor,
        is_first_microbatch: Union[bool, None],
        num_gemms: int = 1,
    ) -> None:
578
579
580
581
582
583
584
        """Checks and prep for FWD.
        The context manager is needed because there isn't a way for a module to know
        if it's the last FP8 module in the forward autocast. It is useful
        to setup the forward aggregated amax reduction for every module
        just in case. The autocast exit will pick up the most recent one.
        """

585
586
587
        # Activation recomputation is used and this is the second forward phase.
        if self.fp8 and in_fp8_activation_recompute_phase():
            get_old_fp8_meta_tensors_for_recompute(self.fp8_meta)
588
589
        else:
            assert inp.is_cuda, "TransformerEngine needs CUDA."
Przemek Tredak's avatar
Przemek Tredak committed
590

591
592
            if self.tp_size > 1:
                assert self.tp_group_initialized, "TP group not initialized."
Przemek Tredak's avatar
Przemek Tredak committed
593

594
595
596
            self.set_activation_dtype(inp)
            self.fp8_init(num_gemms=num_gemms)
            self.set_fp8_weights()
Przemek Tredak's avatar
Przemek Tredak committed
597

598
            update_weight_scale_inv = is_first_microbatch is None or is_first_microbatch
599
600
601
602
            if self.fp8 and self.sequence_parallel:
                assert self.fp8_meta["recipe"].reduce_amax, \
                "Amax reduction across tensor parallel group is " \
                "necessary when using sequence parallelism with FP8."
603
604
605

            # Previous iteration was grad_enabled
            if self.fp8_meta.get("update_amax_and_scale_fwd", False):
606
                if self.fp8_meta["recipe"].reduce_amax:
607
608
609
610
611
612
613
614
615
616
617
                    copy_amax_from_global_buffer(self.fp8_meta, forward=True)
                    amax_and_scale_update(
                        self.fp8_meta, True, update_weight_scale_inv=update_weight_scale_inv
                    )
                    set_amax_buffer_key_deletion(self.fp8_meta, forward=True)
                else:
                    amax_and_scale_update(
                        self.fp8_meta, True, update_weight_scale_inv=update_weight_scale_inv
                    )

            if self.fp8 and self.training:
618
                # Setup for amax reduction
619
                if self.fp8_meta["recipe"].reduce_amax:
620
621
                    self.fp8_meta["first_module"] = is_first_fp8_module()
                    if self.fp8_meta["first_module"]:
Sangkug Lym's avatar
Sangkug Lym committed
622
623
624
625
                        # Wait for the prior AMAX reduction to finish
                        amax_reduce_handle_fwd = get_amax_reduce_handle_fwd()
                        if amax_reduce_handle_fwd is not None:
                            amax_reduce_handle_fwd.wait()
626
627
628
629
                        self.fp8_meta["autocast_id_fwd"] = new_fp8_context_id()
                        set_fp8_context_id(self.fp8_meta["autocast_id_fwd"])
                    else:
                        self.fp8_meta["autocast_id_fwd"] = get_fp8_context_id()
630
631
632
                    self.fp8_meta["autocast_id_fwd_stack"].append(
                        self.fp8_meta["autocast_id_fwd"]
                    )
633
                    add_amax_to_global_buffer(self.fp8_meta, forward=True)
634
635
636
                self.fp8_meta["update_amax_and_scale_fwd"] = True
            else:
                self.fp8_meta["update_amax_and_scale_fwd"] = False
637

638
639
640
            # Activation recomputation is used and this is the first forward phase.
            if (
                self.fp8
641
                and self.training
642
643
644
645
                and is_fp8_activation_recompute_enabled()
                and not in_fp8_activation_recompute_phase()
            ):
                copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta)
646

647
648
        with torch.cuda.nvtx.range(self.__class__.__name__ + " forward"):
            yield inp.contiguous()
Przemek Tredak's avatar
Przemek Tredak committed
649

650
651
652
653
        if self.fp8 and in_fp8_activation_recompute_phase():
            restore_fp8_meta_tensors(self.fp8_meta)
            return

654
        if self.fp8 and self.training and self.fp8_meta["recipe"].reduce_amax:
Przemek Tredak's avatar
Przemek Tredak committed
655
            set_fp8_context_id(self.fp8_meta["autocast_id_fwd"])
Sangkug Lym's avatar
Sangkug Lym committed
656
657
658
659
660
661
662
            reduce_func = partial(
                global_amax_reduction,
                self.fp8_meta,
                self.tp_group,
                self.tp_size,
                forward=True
            )
Przemek Tredak's avatar
Przemek Tredak committed
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
            setup_amax_forward_global_reduce_func(reduce_func)

    def set_nccl_overlap_warning_if_tp(self) -> None:
        """When using TP, the NCCL communication needs to be scheduled
        before the GEMM for there to be a guaranteed overlap. From the
        host side in TE, the comm calls are always launched first, but
        to ensure that the GEMM isn't scheduled first, the environment
        variable `CUDA_DEVICE_MAX_CONNECTIONS` needs to be set to 1 to
        force a single channel.
        """
        if self.tp_size == 1:
            return
        num_cuda_work_queues = int(os.getenv("CUDA_DEVICE_MAX_CONNECTIONS", "0"))
        if num_cuda_work_queues != 1:
            warnings.warn(
                "To guarantee overlapping TP and SP collectives with the backward"
                "GEMMs, set environment variable CUDA_DEVICE_MAX_CONNECTIONS = 1"
            )

    @staticmethod
    def grad_output_preprocess(
        ctx, grad_output: torch.Tensor, row_parallel_mode: bool
    ) -> Tuple[Union[torch.Tensor, None], ...]:
        """Utility function for backward.
        Returns tuple in order (all optional/None based on training precion/recipe):
            R1: gathered `grad_output` in higher precision.
            R2: gathered `grad_output` in FP8.
            R3: R2 transposed.
            R4: bias gradient on R1.

        """
        grad_output = grad_output.contiguous()
        grad_output_mat = grad_output.view((-1, grad_output.shape[-1]))
        gather_grad_output = row_parallel_mode and ctx.sequence_parallel

        # No-FP8 case: bgrad is fused with wgrad for this case.
        if not ctx.fp8:
            if gather_grad_output:
701
702
703
704
705
706
707
                if not ctx.ub_split_ag:
                    grad_output_mat, _ = gather_along_first_dim(
                        grad_output_mat, ctx.tp_group
                    )
                else:
                    ctx.ub_obj_gradout.copy_input_to_ubuf(grad_output, True)
                    grad_output_mat = ctx.ub_obj_gradout.get_ubuf_output(1)
Przemek Tredak's avatar
Przemek Tredak committed
708
709
710
711
712
713
714
715
716
717
718
            return grad_output_mat, None, None, None

        fp8_dtype_backward = get_fp8_te_dtype(
            ctx.fp8_meta["recipe"], fprop_tensor=False
        )

        # FP8 case with non-FP8 wgrad
        if (
            gather_grad_output
            and ctx.fp8_meta["recipe"].override_linear_precision.wgrad
        ):
719
720
721
            assert (
                not ctx.ub_split_ag
            ), "override_linear_precision.wgrad not supported with ub_split_ag"
Przemek Tredak's avatar
Przemek Tredak committed
722
723
724
725
726
727
728
            grad_output_mat, _ = gather_along_first_dim(grad_output_mat, ctx.tp_group)
        # FP8 case with gather: unfused bgrad, cast, transpose for efficient gather
        elif gather_grad_output:
            if ctx.use_bias:
                grad_bias = grad_output_mat.sum(dim=0)
            else:
                grad_bias = None
729
730
731
732
733
            if ctx.ub_split_ag:
                grad_output_c = ctx.ub_obj_gradout.get_ubuf_output(0)
            else:
                grad_output_c = torch.empty_like(grad_output_mat, dtype=torch.uint8)
            cast_to_fp8(
Przemek Tredak's avatar
Przemek Tredak committed
734
735
736
737
                grad_output_mat,
                ctx.fp8_meta["scaling_bwd"],
                tex.FP8BwdTensors.GRAD_OUTPUT1,
                fp8_dtype_backward,
738
                out=grad_output_c,
Przemek Tredak's avatar
Przemek Tredak committed
739
            )
740
741
742
743
744
745
            if not ctx.ub_split_ag:
                grad_output_c, _ = gather_along_first_dim(grad_output_c, ctx.tp_group)
                grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward)
            else:
                grad_output_c = ctx.ub_obj_gradout.get_ubuf_output(1)
                grad_output_t = None
Przemek Tredak's avatar
Przemek Tredak committed
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765

            return grad_output_mat, grad_output_c, grad_output_t, grad_bias

        # FP8 case without gather: cast, transpose, bgrad fused
        if ctx.use_bias:
            grad_bias, grad_output_c, grad_output_t = fp8_cast_transpose_bgrad_fused(
                grad_output_mat,
                ctx.fp8_meta["scaling_bwd"],
                tex.FP8BwdTensors.GRAD_OUTPUT1,
                fp8_dtype_backward,
            )
        else:
            if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
                grad_output_c, grad_output_t = fp8_cast_transpose_fused(
                    grad_output_mat,
                    ctx.fp8_meta["scaling_bwd"],
                    tex.FP8BwdTensors.GRAD_OUTPUT1,
                    fp8_dtype_backward,
                )
            else:
766
                grad_output_t = None
Przemek Tredak's avatar
Przemek Tredak committed
767
768
769
770
771
772
773
774
775
776
                grad_output_c = cast_to_fp8(
                    grad_output_mat,
                    ctx.fp8_meta["scaling_bwd"],
                    tex.FP8BwdTensors.GRAD_OUTPUT1,
                    fp8_dtype_backward,
                )
            grad_bias = None

        return grad_output_mat, grad_output_c, grad_output_t, grad_bias

777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
    def noop_cat(self, buffer_name: str, pnames: List[str]) -> torch.Tensor:
        """No-op replacement of `torch.cat`. The buffer and split parameters must occupy
           the same memory region. If this is not the case, then the split parameters
           are concatenated and the buffer is overwritten. The parameters' memory is then
           re-assigned to point to the buffer to avoid subsequent concatenations.
        """

        assert hasattr(self, buffer_name), f"No buffer named {buffer_name}"
        full_param_buffer = getattr(self, buffer_name)
        split_size = full_param_buffer.shape[0] // len(pnames)
        params = [getattr(self, name) for name in pnames]
        for i, p in enumerate(params):
            if p.data.data_ptr() != full_param_buffer[i*split_size : (i+1)*split_size].data_ptr():
                with torch.no_grad():
                    setattr(self, buffer_name, torch.cat(params))
                    for j, pname in enumerate(pnames):
                        full_param_buffer = getattr(self, buffer_name)
                        setattr(self, pname,
                                Parameter(full_param_buffer[j*split_size : (j+1)*split_size]))
                break

        return _NoopCat.apply(getattr(self, buffer_name), *[getattr(self, name) for name in pnames])

Przemek Tredak's avatar
Przemek Tredak committed
800
801
802
803
804
    @abstractmethod
    def forward(self):
        """Needs override."""


805

Przemek Tredak's avatar
Przemek Tredak committed
806
807
808
809
810
811
812
813
814
815
816
817
class _LayerNormLinear(torch.autograd.Function):
    """LayerNormLinear semi-top level module
    Calls custom cuda extensions.
    """

    @staticmethod
    def forward(
        ctx,
        inp: torch.Tensor,
        ln_weight: torch.Tensor,
        ln_bias: torch.Tensor,
        weight: torch.Tensor,
818
819
        weight_fp8: Union[torch.Tensor, None],
        weight_t_fp8: Union[torch.Tensor, None],
Przemek Tredak's avatar
Przemek Tredak committed
820
821
822
823
824
        bias: torch.Tensor,
        use_bias: bool,
        eps: float,
        is_first_microbatch: Union[bool, None],
        fp8: bool,
schetlur-nv's avatar
schetlur-nv committed
825
        fp8_calibration: bool,
Przemek Tredak's avatar
Przemek Tredak committed
826
827
828
        fp8_meta: Dict[str, Any],
        fuse_wgrad_accumulation: bool,
        tp_group: Union[dist_group_type, None],
Sangkug Lym's avatar
Sangkug Lym committed
829
        tp_size: int,
Przemek Tredak's avatar
Przemek Tredak committed
830
        sequence_parallel: bool,
831
        tensor_parallel: bool,
Przemek Tredak's avatar
Przemek Tredak committed
832
833
834
        activation_dtype: torch.dtype,
        parallel_mode: Union[str, None],
        return_layernorm_output: bool,
835
        is_grad_enabled: bool,
836
837
        fwd_ln_sm_margin: int,
        bwd_ln_sm_margin: int,
838
        zero_centered_gamma: bool,
839
840
841
        ub_bulk_wgrad: bool,
        ub_bulk_dgrad: bool,
        ub_split_ag: bool,
Przemek Tredak's avatar
Przemek Tredak committed
842
843
844
845
846
    ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
        # Make sure input dimensions are compatible
        in_features = ln_weight.numel()
        assert inp.shape[-1] == in_features, "GEMM not possible"
        inputmat = inp.view((-1, in_features))
847
        assert (
848
849
            not fp8 or check_dim_for_fp8_forward_exec(inputmat, weight)
        ), "Input and weight dimensions are not compatible for FP8 execution."
Przemek Tredak's avatar
Przemek Tredak committed
850
851
852
853
854
855
856
857
858
859

        update_fp8_weights = is_first_microbatch is None or is_first_microbatch

        # Cast for native AMP
        inputmat = cast_if_needed(inputmat, activation_dtype)
        ln_weight = cast_if_needed(ln_weight, activation_dtype)
        ln_bias = cast_if_needed(ln_bias, activation_dtype)
        # If residual connection is after LN, we need `ln_out`
        # tensor in higher precision, this comes at the cost
        # of an extra fp8 cast.
860
861
862
863
864
865
866
867
868
        if ub_split_ag:
            tp_world_size = get_distributed_world_size(tp_group)
            if tp_world_size == 1 or (not is_grad_enabled) or return_layernorm_output:
                ub_split_ag = False
        if ub_split_ag:
            dim_size = list(inputmat.size())
            dim_size[0] = dim_size[0] * tp_world_size
            ub_obj_lnout = get_ub("qkv_fprop")
            ln_out = ub_obj_lnout.get_ubuf_output(0)
Przemek Tredak's avatar
Przemek Tredak committed
869
870
871
872
        if fp8:
            fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)

            if not return_layernorm_output:
873
                if is_grad_enabled:
874
875
876
                    if not ub_split_ag:
                        ln_out = torch.empty_like(inputmat, dtype=torch.uint8)
                    _, mu, rsigma = layernorm_fwd_fp8(
877
878
879
880
881
882
883
                        inputmat,
                        ln_weight,
                        ln_bias,
                        eps,
                        fp8_meta["scaling_fwd"],
                        tex.FP8FwdTensors.GEMM1_INPUT,
                        fp8_dtype_forward,
884
                        fwd_ln_sm_margin,
885
                        zero_centered_gamma,
886
                        ln_out = ln_out
887
888
889
890
891
892
893
894
895
896
897
                    )
                else:
                    mu = rsigma = None
                    ln_out = layernorm_fwd_fp8_inf(
                        inputmat,
                        ln_weight,
                        ln_bias,
                        eps,
                        fp8_meta["scaling_fwd"],
                        tex.FP8FwdTensors.GEMM1_INPUT,
                        fp8_dtype_forward,
898
                        zero_centered_gamma,
899
                    )
Przemek Tredak's avatar
Przemek Tredak committed
900
            else:
901
                if is_grad_enabled:
902
                    ln_out_return, mu, rsigma = tex.layernorm_fwd(
903
                        inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin, zero_centered_gamma
904
905
906
                    )
                else:
                    ln_out_return, mu, rsigma = layernorm_fwd_inf(
907
                        inputmat, ln_weight, ln_bias, eps, zero_centered_gamma
908
909
                    ), None, None

Przemek Tredak's avatar
Przemek Tredak committed
910
911
912
913
914
915
916
                ln_out = cast_to_fp8(
                    ln_out_return,
                    fp8_meta["scaling_fwd"],
                    tex.FP8FwdTensors.GEMM1_INPUT,
                    fp8_dtype_forward,
                )
        else:
917
            if is_grad_enabled:
918
919
920
921
922
923
924
925
926
                if ub_split_ag:
                    _, mu, rsigma = tex.layernorm_fwd_noalloc(
                        inputmat, ln_weight, ln_bias, ln_out, eps,
                        fwd_ln_sm_margin, zero_centered_gamma
                    )
                else:
                    ln_out, mu, rsigma = tex.layernorm_fwd(
                        inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin, zero_centered_gamma
                    )
927
928
            else:
                ln_out, mu, rsigma = layernorm_fwd_inf(
929
                        inputmat, ln_weight, ln_bias, eps, zero_centered_gamma
930
                ), None, None
Przemek Tredak's avatar
Przemek Tredak committed
931
932
            ln_out_return = ln_out
        # Column Parallel Linear
933
934
935
936
        if ub_split_ag:
            ln_out_total = ub_obj_lnout.get_ubuf_output(1)
            ln_out = torch.empty_like(ln_out)
        elif parallel_mode == "column" and sequence_parallel:
Przemek Tredak's avatar
Przemek Tredak committed
937
938
939
940
941
942
943
944
945
946
947
948
949
            ln_out_total, _ = gather_along_first_dim(ln_out, tp_group)
        else:
            ln_out_total = ln_out

        if fp8:
            bias_dtype = (
                torch.bfloat16
                if activation_dtype == torch.float32
                else activation_dtype
            )
            bias = cast_if_needed(bias, bias_dtype) if use_bias else bias

            if update_fp8_weights:
950
                if is_grad_enabled:
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
                    fp8_cast_transpose_fused(
                        weight,
                        fp8_meta["scaling_fwd"],
                        tex.FP8FwdTensors.GEMM1_WEIGHT,
                        fp8_dtype_forward,
                        cast_out=weight_fp8,
                        transpose_out=weight_t_fp8,
                    )
                else:
                    weight_t_fp8 = None
                    weight_fp8 = cast_to_fp8(
                        weight,
                        fp8_meta["scaling_fwd"],
                        tex.FP8FwdTensors.GEMM1_WEIGHT,
                        fp8_dtype_forward)
Przemek Tredak's avatar
Przemek Tredak committed
966
967
968

            out = fp8_gemm(
                weight_fp8,
969
970
                fp8_meta["scaling_fwd"].scale_inv,
                tex.FP8FwdTensors.GEMM1_WEIGHT,
Przemek Tredak's avatar
Przemek Tredak committed
971
972
                fp8_dtype_forward,
                ln_out_total,
973
974
                fp8_meta["scaling_fwd"].scale_inv,
                tex.FP8FwdTensors.GEMM1_INPUT,
Przemek Tredak's avatar
Przemek Tredak committed
975
976
977
978
979
980
                fp8_dtype_forward,
                activation_dtype,
                get_workspace(),
                bias=bias,
                use_bias=use_bias,
                use_split_accumulator=_2X_ACC_FPROP,
981
982
983
                ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ub_split_ag else None,
                ub=ub_obj_lnout if ub_split_ag else None,
                extra_output_tensor=ln_out if ub_split_ag else None,
Przemek Tredak's avatar
Przemek Tredak committed
984
985
986
987
988
989
            )
        else:
            # Cast for native AMP
            weight = cast_if_needed(weight, activation_dtype)
            bias = cast_if_needed(bias, activation_dtype) if use_bias else bias

schetlur-nv's avatar
schetlur-nv committed
990
991
992
993
994
995
996
997
            if fp8_calibration:
                # amax of input
                fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_INPUT] = \
                    torch.amax(ln_out_total).float()
                # amax of weight
                fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_WEIGHT] = \
                    torch.amax(weight).float()

Przemek Tredak's avatar
Przemek Tredak committed
998
999
1000
1001
1002
1003
1004
            out, _, _ = gemm(
                weight,
                ln_out_total,
                activation_dtype,
                get_workspace(),
                bias=bias,
                use_bias=use_bias,
1005
1006
1007
                ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ub_split_ag else None,
                ub=ub_obj_lnout if ub_split_ag else None,
                extra_output_tensor=ln_out if ub_split_ag else None,
Przemek Tredak's avatar
Przemek Tredak committed
1008
1009
            )

1010
        if is_grad_enabled:
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
            ctx.save_for_backward(
                inputmat,
                ln_weight,
                mu,
                rsigma,
                weight,
                weight_t_fp8,
                ln_out,
                fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None,
            )
Przemek Tredak's avatar
Przemek Tredak committed
1021

1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
            ctx.activation_dtype = activation_dtype
            ctx.fp8 = fp8
            ctx.fp8_meta = fp8_meta
            ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
            ctx.is_first_microbatch = is_first_microbatch
            ctx.use_bias = use_bias
            ctx.sequence_parallel = sequence_parallel
            ctx.tensor_parallel = tensor_parallel
            ctx.inp_shape = inp.shape
            ctx.parallel_mode = parallel_mode
            ctx.tp_group = tp_group
Sangkug Lym's avatar
Sangkug Lym committed
1033
            ctx.tp_size = tp_size
1034
            ctx.return_layernorm_output = return_layernorm_output
1035
            ctx.bwd_ln_sm_margin = bwd_ln_sm_margin
1036
            ctx.zero_centered_gamma = zero_centered_gamma
1037
1038
            ctx.ub_bulk_wgrad = ub_bulk_wgrad
            ctx.ub_bulk_dgrad = ub_bulk_dgrad
1039
            ctx.requires_dgrad = inp.requires_grad
Przemek Tredak's avatar
Przemek Tredak committed
1040
1041
1042
1043

        # Row Parallel Linear
        if parallel_mode == "row" and sequence_parallel:
            out, _ = reduce_scatter_along_first_dim(out, tp_group)
1044
        elif parallel_mode == "row" and tensor_parallel:
Przemek Tredak's avatar
Przemek Tredak committed
1045
1046
1047
1048
1049
1050
1051
1052
1053
            out, _ = allreduce(out, tp_group)

        # [*, in_features] -> [*, out_features] except first dimension changes for SP
        out = out.view(-1, *inp.shape[1:-1], out.shape[-1])

        if return_layernorm_output:
            return out, ln_out_return.view_as(inp)
        return out

1054

Przemek Tredak's avatar
Przemek Tredak committed
1055
1056
1057
1058
    @staticmethod
    def backward(
        ctx, *grad_outputs: Tuple[torch.Tensor, ...]
    ) -> Tuple[Union[torch.Tensor, None], ...]:
Sangkug Lym's avatar
Sangkug Lym committed
1059
1060
1061
        with _prepare_backward(
            ctx.fp8, ctx.fp8_meta, ctx.tp_group, ctx.tp_size, name="_LayerNormLinear"
        ):
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
            (
                inputmat,
                ln_weight,
                mu,
                rsigma,
                weight,
                weight_t_fp8,
                ln_out,
                fwd_scale_inverses,
            ) = ctx.saved_tensors
Przemek Tredak's avatar
Przemek Tredak committed
1072

1073
1074
1075
1076
1077
1078
1079
1080
1081
            if ctx.ub_bulk_dgrad:
                tp_world_size = get_distributed_world_size(ctx.tp_group)
                if tp_world_size == 1:
                    ctx.ub_bulk_dgrad = False
            if ctx.ub_bulk_dgrad:
                dim_size = list(ln_out.size())
                dim_size[0] = dim_size[0] * tp_world_size
                ub_obj_lnout = get_ub("qkv_dgrad")
                ub_obj_lnout.copy_input_to_ubuf(ln_out, 1)
1082
1083
1084
1085
1086
1087
1088
            (
                grad_output,
                grad_output_c,
                grad_output_t,
                grad_bias,
            ) = TransformerEngineBaseModule.grad_output_preprocess(
                ctx, grad_outputs[0], ctx.parallel_mode == "row"
Przemek Tredak's avatar
Przemek Tredak committed
1089
1090
            )

1091
1092
1093
1094
1095
            if ctx.ub_bulk_wgrad:
                tp_world_size = get_distributed_world_size(ctx.tp_group)
                if tp_world_size == 1:
                    ctx.ub_bulk_wgrad = False

1096
1097
            # Column Parallel Linear
            # Overlap input AG with dgrad
1098
            if (not ctx.ub_bulk_dgrad) and ctx.parallel_mode == "column" and ctx.sequence_parallel:
1099
1100
1101
1102
1103
                ln_out_total, handle = gather_along_first_dim(
                    ln_out, ctx.tp_group, async_op=True
                )
            else:
                ln_out_total = ln_out
Przemek Tredak's avatar
Przemek Tredak committed
1104

1105
1106
1107
1108
1109
1110
            if ctx.is_first_microbatch is not None:
                accumulate_wgrad_into_param_main_grad = (
                    ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch
                )
            else:
                accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation
Przemek Tredak's avatar
Przemek Tredak committed
1111

1112
1113
1114
1115
1116
1117
1118
1119
1120

            dgrad_size = list(grad_output.size())
            dgrad_size[1] = weight.size(1)
            if ctx.ub_bulk_wgrad: # allocate dgrad output
                ub_obj_dgrad = get_ub("qkv_wgrad")
                dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output
            else:
                dgrad = torch.empty (dgrad_size, dtype=ctx.activation_dtype, device=weight.device)

1121
1122
1123
1124
1125
1126
1127
            if ctx.fp8:
                fp8_dtype_forward = get_fp8_te_dtype(
                    ctx.fp8_meta["recipe"], fprop_tensor=True
                )
                fp8_dtype_backward = get_fp8_te_dtype(
                    ctx.fp8_meta["recipe"], fprop_tensor=False
                )
Przemek Tredak's avatar
Przemek Tredak committed
1128

1129
                # DGRAD: Evaluated unconditionally to feed into Linear backward
1130
                _ = fp8_gemm(
1131
                    weight_t_fp8,
1132
1133
                    fwd_scale_inverses,
                    tex.FP8FwdTensors.GEMM1_WEIGHT,
1134
1135
                    fp8_dtype_forward,
                    grad_output_c,
1136
1137
                    ctx.fp8_meta["scaling_bwd"].scale_inv,
                    tex.FP8BwdTensors.GRAD_OUTPUT1,
1138
1139
1140
                    fp8_dtype_backward,
                    ctx.activation_dtype,
                    get_workspace(),
1141
                    out=dgrad,
1142
                    use_split_accumulator=_2X_ACC_DGRAD,
1143
1144
                    ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_AG if ctx.ub_bulk_dgrad else None,
                    ub=ub_obj_lnout if ctx.ub_bulk_dgrad else None
1145
1146
1147
                )
            else:
                # DGRAD: Evaluated unconditionally to feed into Linear backward
1148
                _, _, _ = gemm(
1149
1150
1151
1152
                    weight,
                    grad_output,
                    ctx.activation_dtype,
                    get_workspace(),
1153
                    out=dgrad,
1154
1155
                    layout="NN",
                    grad=True,
1156
1157
                    ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_AG if ctx.ub_bulk_dgrad else None,
                    ub=ub_obj_lnout if ctx.ub_bulk_dgrad else None
1158
                )
1159
1160
            if ctx.ub_bulk_dgrad:
                ln_out_total = ub_obj_lnout.get_ubuf_output(1)
Przemek Tredak's avatar
Przemek Tredak committed
1161

1162
1163
            # Overlap dgrad-RS/AR with wgrad
            if ctx.parallel_mode == "column" and ctx.sequence_parallel:
1164
1165
1166
1167
1168
1169
                if not ctx.ub_bulk_dgrad:
                    handle.wait()
                if not ctx.ub_bulk_wgrad:
                    dgrad, handle = reduce_scatter_along_first_dim(
                        dgrad, ctx.tp_group, async_op=True
                    )
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
            elif ctx.parallel_mode == "column" and ctx.tensor_parallel:
                dgrad, handle = allreduce(dgrad, ctx.tp_group, async_op=True)

            if weight.requires_grad:
                if ctx.fp8:
                    # WGRAD
                    if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
                        ln_out_total_t = tex.fp8_transpose(ln_out_total, fp8_dtype_forward)
                        wgrad = fp8_gemm(
                            ln_out_total_t,
1180
1181
                            fwd_scale_inverses,
                            tex.FP8FwdTensors.GEMM1_INPUT,
1182
1183
                            fp8_dtype_forward,
                            grad_output_t,
1184
1185
                            ctx.fp8_meta["scaling_bwd"].scale_inv,
                            tex.FP8BwdTensors.GRAD_OUTPUT1,
1186
1187
1188
1189
1190
1191
                            fp8_dtype_backward,
                            ctx.activation_dtype,
                            get_workspace(),
                            accumulate=accumulate_wgrad_into_param_main_grad,
                            out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
                            use_split_accumulator=_2X_ACC_WGRAD,
1192
1193
1194
                            ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS
                            if ctx.ub_bulk_wgrad else None,
                            ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
                        )
                    else:
                        ln_out_total_c = cast_from_fp8(
                            ln_out_total,
                            ctx.fp8_meta["scaling_fwd"],
                            tex.FP8FwdTensors.GEMM1_INPUT,
                            fp8_dtype_forward,
                            TE_DType[ctx.activation_dtype],
                        )
                        wgrad, _, _ = gemm(
                            ln_out_total_c,
                            grad_output,
                            ctx.activation_dtype,
                            get_workspace(),
                            layout="NT",
                            grad=True,
                            accumulate=accumulate_wgrad_into_param_main_grad,
                            out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
1213
1214
1215
                            ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS
                            if ctx.ub_bulk_wgrad else None,
                            ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None
1216
                        )
schetlur-nv's avatar
schetlur-nv committed
1217
                else:
1218
1219
                    # WGRAD
                    wgrad, grad_bias, _ = gemm(
schetlur-nv's avatar
schetlur-nv committed
1220
1221
1222
1223
1224
1225
                        ln_out_total,
                        grad_output,
                        ctx.activation_dtype,
                        get_workspace(),
                        layout="NT",
                        grad=True,
1226
                        use_bias=ctx.use_bias,
schetlur-nv's avatar
schetlur-nv committed
1227
1228
                        accumulate=accumulate_wgrad_into_param_main_grad,
                        out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
1229
1230
                        ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None,
                        ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None
schetlur-nv's avatar
schetlur-nv committed
1231
                    )
Przemek Tredak's avatar
Przemek Tredak committed
1232

1233
1234
1235

            if ctx.ub_bulk_wgrad:
                dgrad = ub_obj_dgrad.get_ubuf_output(0) # Reduce-scatter output
1236
            # Column Parallel Linear
1237
            elif ctx.parallel_mode == "column" and ctx.tensor_parallel and handle is not None:
1238
                handle.wait()
Przemek Tredak's avatar
Przemek Tredak committed
1239

1240
1241
            # LayerNorm gradient
            d_ln_out = dgrad.view(inputmat.shape)
Przemek Tredak's avatar
Przemek Tredak committed
1242

1243
1244
1245
            # Residual gradient
            if ctx.return_layernorm_output:
                d_ln_out = d_ln_out + grad_outputs[1].view_as(d_ln_out)
Przemek Tredak's avatar
Przemek Tredak committed
1246

1247
            dxmat, dgamma, dbeta = tex.layernorm_bwd(
1248
1249
                d_ln_out, inputmat, mu, rsigma, ln_weight,
                ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma
1250
            )
Przemek Tredak's avatar
Przemek Tredak committed
1251

1252
1253
            if not ctx.use_bias:
                grad_bias = None
Przemek Tredak's avatar
Przemek Tredak committed
1254
1255

        return (
1256
            dxmat.view(ctx.inp_shape) if ctx.requires_dgrad else None,
Przemek Tredak's avatar
Przemek Tredak committed
1257
1258
            dgamma,
            dbeta,
schetlur-nv's avatar
schetlur-nv committed
1259
            wgrad if weight.requires_grad else None,
Przemek Tredak's avatar
Przemek Tredak committed
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
            None,
            None,
            grad_bias,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
1274
            None,
1275
            None,
1276
1277
            None,
            None,
1278
            None,
1279
            None,
Sangkug Lym's avatar
Sangkug Lym committed
1280
            None,
1281
1282
1283
            None,
            None,
            None,
Przemek Tredak's avatar
Przemek Tredak committed
1284
1285
1286
1287
        )


class LayerNormLinear(TransformerEngineBaseModule):
1288
    r"""
Przemek Tredak's avatar
Przemek Tredak committed
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
    Applies layer normalization followed by linear transformation to the incoming data.

    Parameters
    ----------
    in_features : int
                 size of each input sample.
    out_features : int
                  size of each output sample.
    eps : float, default = 1e-5
         a value added to the denominator of layer normalization for numerical stability.
    bias : bool, default = `True`
          if set to `False`, the layer will not learn an additive bias.
    init_method : Callable, default = `None`
                 used for initializing weights in the following way: `init_method(weight)`.
                 When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`.
    return_layernorm_output : bool, default = `False`
                             if set to `True`, output of layernorm is returned from the forward
                             together with the output of the linear transformation.
                             Example use case: residual connection for transformer module is
                             taken post layernorm.
1309
1310
1311
1312
1313
    parameters_split : Tuple[str, ...], default = None
                      if a tuple of strings is provided, the weight and bias parameters of the
                      module are exposed as `N` separate `torch.nn.parameter.Parameter`s each,
                      split along the first dimension, where `N` is the length of the argument
                      and the strings contained are the names of the split parameters.
1314
1315
1316
1317
1318
1319
1320
    zero_centered_gamma : bool, default = 'False'
                         if set to 'True', gamma parameter in LayerNorm is initialized to 0 and
                         the LayerNorm formula changes to

                         .. math::
                            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} *
                            (1 + \gamma) + \beta
Przemek Tredak's avatar
Przemek Tredak committed
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374

    Parallelism parameters
    ----------------------
    sequence_parallel : bool, default = `False`
                       if set to `True`, uses sequence parallelism.
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
    tp_size : int, default = 1
             used as TP (tensor parallel) world size when TP groups are not formed during
             initialization. In this case, users must call the
             `set_tensor_parallel_group(tp_group)` method on the initialized module before the
             forward pass to supply the tensor parallel group needed for tensor and sequence
             parallel collectives.
    parallel_mode : {None, 'Column', 'Row'}, default = `None`
                   used to decide whether this Linear layer is Column Parallel Linear or Row
                   Parallel Linear as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
                   When set to `None`, no communication is performed.
    skip_weight_param_allocation: bool, default = `False`
                                 if set to `True`, weight parameter is not allocated and must be
                                 passed as a keyword argument `weight` during the forward pass.

    Optimization parameters
    -----------------------
    fuse_wgrad_accumulation : bool, default = 'False'
                             if set to `True`, enables fusing of creation and accumulation of
                             the weight gradient.
    return_bias : bool, default = `False`
                 when set to `True`, this module will not apply the additive bias itself, but
                 instead return the bias value during the forward pass together with the
                 output of the linear transformation :math:`y = xA^T`. This is useful when
                 the bias addition can be fused to subsequent operations.
    params_dtype : torch.dtype, default = `torch.float32`
                  it controls the type used to allocate the initial parameters. Useful when
                  the model is trained with lower precision and the original FP32 parameters
                  would not fit in GPU memory.
    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        eps: float = 1e-5,
        sequence_parallel: bool = False,
        fuse_wgrad_accumulation: bool = False,
        tp_group: Optional[dist_group_type] = None,
        tp_size: int = 1,
        get_rng_state_tracker: Optional[Callable] = None,
        init_method: Optional[Callable] = None,
        bias: bool = True,
        return_bias: bool = False,
        params_dtype: torch.dtype = torch.float32,
        parallel_mode: Optional[str] = None,
        return_layernorm_output: bool = False,
        skip_weight_param_allocation: bool = False,
1375
        parameters_split: Optional[Tuple[str, ...]] = None,
1376
        zero_centered_gamma: bool = False,
1377
1378
1379
        ub_bulk_wgrad: bool = False,
        ub_bulk_dgrad: bool = False,
        ub_split_ag: bool = False,
Przemek Tredak's avatar
Przemek Tredak committed
1380
1381
1382
1383
1384
1385
1386
    ) -> None:
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.fuse_wgrad_accumulation = fuse_wgrad_accumulation
        self.use_bias = bias
        self.return_bias = return_bias
1387
        self.apply_bias = bias and not return_bias
Przemek Tredak's avatar
Przemek Tredak committed
1388
        self.return_layernorm_output = return_layernorm_output
1389
        self.parameters_split = parameters_split
1390
        self.zero_centered_gamma = zero_centered_gamma
1391
1392
1393
1394
1395
1396
1397
1398
        self.ub_bulk_wgrad = ub_bulk_wgrad
        self.ub_bulk_dgrad = ub_bulk_dgrad
        self.ub_split_ag = ub_split_ag

        if ub_bulk_wgrad or ub_bulk_dgrad or ub_split_ag:
            assert (
                tex.userbuf_comm_available()
            ), "Userbuffer communication backend not available."
Przemek Tredak's avatar
Przemek Tredak committed
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443

        if tp_group is None:
            self.tp_size = tp_size
            if tp_size == 1:
                self.set_tensor_parallel_group(tp_group)
        else:
            self.tp_size = get_distributed_world_size(tp_group)
            self.set_tensor_parallel_group(tp_group)
        self.set_nccl_overlap_warning_if_tp()

        self.parallel_mode = parallel_mode
        assert (
            self.parallel_mode in GemmParallelModes
        ), f"parallel_mode {parallel_mode} not supported"

        if self.parallel_mode == "column":
            self.out_features = divide(self.out_features, self.tp_size)
        elif self.parallel_mode == "row":
            self.in_features = divide(self.in_features, self.tp_size)

        if init_method is None:
            init_method = get_default_init_method()

        self.sequence_parallel = (self.tp_size > 1) and sequence_parallel

        self.eps = eps
        self.layer_norm_weight = Parameter(
            torch.empty(
                in_features,
                device=torch.cuda.current_device(),
                dtype=params_dtype,
            )
        )
        self.layer_norm_bias = Parameter(
            torch.empty(
                in_features,
                device=torch.cuda.current_device(),
                dtype=params_dtype,
            )
        )
        setattr(self.layer_norm_weight, "sequence_parallel", self.sequence_parallel)
        setattr(self.layer_norm_bias, "sequence_parallel", self.sequence_parallel)
        self.reset_layer_norm_parameters()

        if not skip_weight_param_allocation:
1444
1445
1446
1447
1448
1449
1450
            self.register_buffer("weight_tensor",
                                 torch.empty(
                                    self.out_features,
                                    self.in_features,
                                    device=torch.cuda.current_device(),
                                    dtype=params_dtype),
                                 persistent=False)
Przemek Tredak's avatar
Przemek Tredak committed
1451
1452

            initialize_affine_weight_gpu(
1453
                self.weight_tensor,
Przemek Tredak's avatar
Przemek Tredak committed
1454
1455
1456
1457
1458
1459
                init_method,
                get_rng_state_tracker,
                partition_dim=1 if self.parallel_mode == "row" else 0,
                stride=1,
            )

1460
            if self.use_bias:
1461
1462
1463
1464
1465
1466
                self.register_buffer("bias_tensor",
                                     torch.empty(
                                         self.out_features,
                                         device=torch.cuda.current_device(),
                                         dtype=params_dtype),
                                     persistent=False)
Przemek Tredak's avatar
Przemek Tredak committed
1467
            else:
1468
1469
1470
                self.register_buffer(
                    "bias_tensor", torch.Tensor().type(params_dtype), persistent=False
                )
Przemek Tredak's avatar
Przemek Tredak committed
1471
1472

            with torch.no_grad():
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
                self.bias_tensor.zero_()

            if parameters_split is None:
                parameters_split = ("",)

            assert (
                self.out_features % len(parameters_split) == 0
            ), f"Weight and bias params cannot be split into {len(parameters_split)} parts"

            split_size = self.out_features // len(parameters_split)

            self.weight_names = []
            self.bias_names = []

            for i, pname in enumerate(parameters_split):
                wname = pname + "weight"
                bname = pname + "bias"

                self.register_parameter(
                    wname, Parameter(self.weight_tensor[i * split_size : (i+1) * split_size])
                )

                set_tensor_model_parallel_attributes(
                    tensor=getattr(self, wname),
                    is_parallel=True,
                    dim=1 if parallel_mode == "row" else 0,
                    stride=1,
                )

1502
                if self.use_bias:
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
                    self.register_parameter(
                        bname, Parameter(self.bias_tensor[i * split_size : (i+1) * split_size])
                    )
                else:
                    self.register_buffer(bname, torch.Tensor().type(params_dtype), persistent=False)

                if parallel_mode == "column":
                    set_tensor_model_parallel_attributes(getattr(self, bname), True, 0, 1)

                self.weight_names.append(wname)
                self.bias_names.append(bname)
Przemek Tredak's avatar
Przemek Tredak committed
1514
1515
1516

        self.fp8_weight_shapes.append(torch.Size((self.out_features, self.in_features)))

1517

Przemek Tredak's avatar
Przemek Tredak committed
1518
1519
        # For RPL, bias has to be added after TP collectives
        # So it cannot be fused with the GEMM
1520
        if self.parallel_mode == "row" and self.apply_bias:
Przemek Tredak's avatar
Przemek Tredak committed
1521
1522
1523
1524
            self.gemm_bias_unfused_add = True
        else:
            self.gemm_bias_unfused_add = False

1525
1526
1527
1528
1529
1530
1531
        # These many SMs are subtracted from the total SM count when calling forward
        # and backward LayerNorm C APIs. These envvars can be used to prevent the LN
        # kernels from using all SMs in the device. This is useful for cases such as
        # communication overlap with LN.
        self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))
        self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))

Przemek Tredak's avatar
Przemek Tredak committed
1532
1533
    def reset_layer_norm_parameters(self) -> None:
        """Init LN params"""
1534
1535
1536
1537
        if not self.zero_centered_gamma:
            init.ones_(self.layer_norm_weight)
        else:
            init.zeros_(self.layer_norm_weight)
Przemek Tredak's avatar
Przemek Tredak committed
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
        init.zeros_(self.layer_norm_bias)

    def forward(
        self,
        inp: torch.Tensor,
        weight: Optional[torch.Tensor] = None,
        bias: Optional[torch.Tensor] = None,
        is_first_microbatch: Optional[bool] = None,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
        """
        Apply layer normalization to the input followed by a linear transformation.

        Parameters
        ----------
        inp : torch.Tensor
             Input tensor.
        weight : torch.Tensor, default = None
                An optional weight tensor for the module. This argument is compulsory if module
                is initialized with `skip_weight_param_allocation=True`
        bias : torch.Tensor, default = None
              An optional bias tensor for the module. This argument is compulsory if module
              is initialized with `skip_weight_param_allocation=True` and one of `use_bias`
              or `return_bias`
        is_first_microbatch : {True, False, None}, default = None
                             During training using either gradient accumulation or
                             pipeline parallelism a minibatch of data is further split
                             into microbatches. Between the microbatches of the same minibatch
                             the model weights are not updated. Setting this parameter indicates
                             whether the current microbatch is the first in a minibatch or not.
                             When set, this parameter enables additional optimizations:

                             * during FP8 training, it allows caching of the FP8 versions of
                               the weights
                             * it also allows skipping gradient accumulation during the
                               first microbatch (since it is the first gradient being
                               produced)
        """

1576
        with self.prepare_forward(inp, is_first_microbatch) as inp:
1577
1578
1579
            bias_tensor = (
                bias if bias is not None
                else self.bias if self.parameters_split is None
1580
                else self.bias_tensor if not torch.is_grad_enabled()
1581
1582
1583
1584
1585
                else self.noop_cat("bias_tensor", self.bias_names)
            )
            weight_tensor = (
                weight if weight is not None
                else self.weight if self.parameters_split is None
1586
                else self.weight_tensor if not torch.is_grad_enabled()
1587
1588
                else self.noop_cat("weight_tensor", self.weight_names)
            )
1589

1590
            if torch.is_grad_enabled():
1591
1592
1593
1594
1595
1596
                fwd_fn = _LayerNormLinear.apply
                args = []
            else:
                fwd_fn = _LayerNormLinear.forward
                args = [None]
            args += (
1597
1598
1599
                inp,
                self.layer_norm_weight,
                self.layer_norm_bias,
1600
                weight_tensor,
1601
1602
1603
                self.weight1_fp8 if self.fp8 else None,
                self.weight1_t_fp8 if self.fp8 else None,
                bias_tensor,
1604
                self.apply_bias and not self.gemm_bias_unfused_add,
1605
1606
1607
                self.eps,
                is_first_microbatch,
                self.fp8,
schetlur-nv's avatar
schetlur-nv committed
1608
                self.fp8_calibration,
1609
1610
1611
                self.fp8_meta,
                self.fuse_wgrad_accumulation,
                self.tp_group,
Sangkug Lym's avatar
Sangkug Lym committed
1612
                self.tp_size,
1613
1614
1615
1616
1617
                self.sequence_parallel,
                self.tp_size > 1,
                self.activation_dtype,
                self.parallel_mode,
                self.return_layernorm_output,
1618
                torch.is_grad_enabled(),
1619
1620
                self.fwd_ln_sm_margin,
                self.bwd_ln_sm_margin,
1621
                self.zero_centered_gamma,
1622
1623
1624
                self.ub_bulk_wgrad,
                self.ub_bulk_dgrad,
                self.ub_split_ag,
1625
            )
1626
            out = fwd_fn(*args)
Przemek Tredak's avatar
Przemek Tredak committed
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650

        if self.return_layernorm_output:
            out, ln_out = out

        if self.gemm_bias_unfused_add:
            out = out + cast_if_needed(bias_tensor, self.activation_dtype)

        if self.return_bias:
            if self.return_layernorm_output:
                return out, cast_if_needed(bias_tensor, self.activation_dtype), ln_out
            return out, cast_if_needed(bias_tensor, self.activation_dtype)
        if self.return_layernorm_output:
            return out, ln_out
        return out

class _Linear(torch.autograd.Function):
    """Linear semi-top level module
    Calls custom cuda extensions.
    """

    @staticmethod
    def forward(
        ctx,
        weight: torch.Tensor,
1651
1652
        weight_fp8: Union[torch.Tensor, None],
        weight_t_fp8: Union[torch.Tensor, None],
Przemek Tredak's avatar
Przemek Tredak committed
1653
1654
1655
1656
1657
        inp: torch.Tensor,
        bias: torch.Tensor,
        use_bias: bool,
        is_first_microbatch: Union[bool, None],
        fp8: bool,
schetlur-nv's avatar
schetlur-nv committed
1658
        fp8_calibration: bool,
Przemek Tredak's avatar
Przemek Tredak committed
1659
1660
1661
        fp8_meta: Dict[str, Any],
        fuse_wgrad_accumulation: bool,
        tp_group: Union[dist_group_type, None],
Sangkug Lym's avatar
Sangkug Lym committed
1662
        tp_size: int,
Przemek Tredak's avatar
Przemek Tredak committed
1663
        sequence_parallel: bool,
1664
        tensor_parallel: bool,
Przemek Tredak's avatar
Przemek Tredak committed
1665
1666
        activation_dtype: torch.dtype,
        parallel_mode: Union[str, None],
1667
        is_grad_enabled: bool,
1668
1669
        ub_split_rs: bool,
        ub_split_ag: bool,
Przemek Tredak's avatar
Przemek Tredak committed
1670
1671
1672
1673
1674
    ) -> torch.Tensor:
        # Make sure input dimensions are compatible
        in_features = weight.shape[-1]
        assert inp.shape[-1] == in_features, "GEMM not possible"
        inputmat = inp.view((-1, in_features))
1675
        assert (
1676
1677
            not fp8 or check_dim_for_fp8_forward_exec(inputmat, weight)
        ), "Input and weight dimensions are not compatible for FP8 execution."
Przemek Tredak's avatar
Przemek Tredak committed
1678
1679
1680

        update_fp8_weights = is_first_microbatch is None or is_first_microbatch

1681
1682
1683
1684
        if ub_split_rs:
            tp_world_size = get_distributed_world_size(tp_group)
            if tp_world_size == 1:
                ub_split_rs = False
Przemek Tredak's avatar
Przemek Tredak committed
1685
1686
1687
1688
1689
1690
1691
1692
        # Cast for native AMP
        inputmat = cast_if_needed(inputmat, activation_dtype)
        inputmat_no_fp8 = inputmat

        if fp8:
            fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)

            if not fp8_meta["recipe"].override_linear_precision.wgrad:
1693
                if is_grad_enabled:
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
                    inputmat, inputmat_t = fp8_cast_transpose_fused(
                        inputmat,
                        fp8_meta["scaling_fwd"],
                        tex.FP8FwdTensors.GEMM1_INPUT,
                        fp8_dtype_forward,
                    )
                else:
                    inputmat = cast_to_fp8(
                        inputmat,
                        fp8_meta["scaling_fwd"],
                        tex.FP8FwdTensors.GEMM1_INPUT,
                        fp8_dtype_forward,
                    )
Przemek Tredak's avatar
Przemek Tredak committed
1707
            else:
1708
                inputmat, inputmat_t = cast_to_fp8(
Przemek Tredak's avatar
Przemek Tredak committed
1709
1710
1711
1712
                    inputmat,
                    fp8_meta["scaling_fwd"],
                    tex.FP8FwdTensors.GEMM1_INPUT,
                    fp8_dtype_forward,
1713
                ), None
Przemek Tredak's avatar
Przemek Tredak committed
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729

        # Column Parallel Linear
        if parallel_mode == "column" and sequence_parallel:
            inputmat_total, _ = gather_along_first_dim(inputmat, tp_group)
        else:
            inputmat_total = inputmat

        if fp8:
            bias_dtype = (
                torch.bfloat16
                if activation_dtype == torch.float32
                else activation_dtype
            )
            bias = cast_if_needed(bias, bias_dtype) if use_bias else bias

            if update_fp8_weights:
1730
                if is_grad_enabled:
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
                    fp8_cast_transpose_fused(
                        weight,
                        fp8_meta["scaling_fwd"],
                        tex.FP8FwdTensors.GEMM1_WEIGHT,
                        fp8_dtype_forward,
                        cast_out=weight_fp8,
                        transpose_out=weight_t_fp8,
                    )
                else:
                    weight_t_fp8 = None
                    weight_fp8 = cast_to_fp8(
                        weight,
                        fp8_meta["scaling_fwd"],
                        tex.FP8FwdTensors.GEMM1_WEIGHT,
                        fp8_dtype_forward,
                    )
Przemek Tredak's avatar
Przemek Tredak committed
1747

1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
            if ub_split_rs:
                ub_obj_projout = get_ub("proj_fprop")
                out = ub_obj_projout.get_ubuf_output(1)
                dim_size = list(inputmat_total.size())
                dim_size[0] = dim_size[0] // tp_world_size
                dim_size[1] = weight.size(0)
                rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device)
            else:
                dim_size = list(inputmat_total.size())
                dim_size[1] = weight.size(0)
                out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device)

            _ = fp8_gemm(
Przemek Tredak's avatar
Przemek Tredak committed
1761
                weight_fp8,
1762
1763
                fp8_meta["scaling_fwd"].scale_inv,
                tex.FP8FwdTensors.GEMM1_WEIGHT,
Przemek Tredak's avatar
Przemek Tredak committed
1764
                fp8_dtype_forward,
1765
                inputmat_total,
1766
1767
                fp8_meta["scaling_fwd"].scale_inv,
                tex.FP8FwdTensors.GEMM1_INPUT,
Przemek Tredak's avatar
Przemek Tredak committed
1768
1769
1770
1771
1772
1773
                fp8_dtype_forward,
                activation_dtype,
                get_workspace(),
                bias=bias,
                use_bias=use_bias,
                use_split_accumulator=_2X_ACC_FPROP,
1774
1775
1776
1777
                out=out,
                ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS if ub_split_rs else None,
                ub=ub_obj_projout if ub_split_rs else None,
                extra_output_tensor=rs_out if ub_split_rs else None,
Przemek Tredak's avatar
Przemek Tredak committed
1778
1779
1780
1781
1782
1783
            )
        else:
            # Cast for native AMP
            weight = cast_if_needed(weight, activation_dtype)
            bias = cast_if_needed(bias, activation_dtype) if use_bias else bias

schetlur-nv's avatar
schetlur-nv committed
1784
1785
1786
1787
1788
1789
1790
1791
            if fp8_calibration:
                # amax of input
                fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_INPUT] = \
                    torch.amax(inputmat_total).float()
                # amax of weight
                fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_WEIGHT] = \
                    torch.amax(weight).float()

1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
            if ub_split_rs:
                ub_obj_projout = get_ub("proj_fprop")
                out = ub_obj_projout.get_ubuf_output(1)
                dim_size = list(inputmat_total.size())
                dim_size[0] = dim_size[0] // tp_world_size
                dim_size[1] = weight.size(0)
                rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device)
            else:
                dim_size = list(inputmat_total.size())
                dim_size[1] = weight.size(0)
                out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device)

            _, _, _ = gemm(
Przemek Tredak's avatar
Przemek Tredak committed
1805
1806
1807
1808
1809
1810
                weight,
                inputmat_total,
                activation_dtype,
                get_workspace(),
                bias=bias,
                use_bias=use_bias,
1811
1812
1813
1814
                out=out,
                ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS if ub_split_rs else None,
                ub=ub_obj_projout if ub_split_rs else None,
                extra_output_tensor=rs_out if ub_split_rs else None,
Przemek Tredak's avatar
Przemek Tredak committed
1815
1816
            )

1817
        if is_grad_enabled:
schetlur-nv's avatar
schetlur-nv committed
1818
            fp8_wgrad = fp8 and not fp8_meta["recipe"].override_linear_precision.wgrad
1819
            ctx.save_for_backward(
1820
1821
1822
1823
                inputmat_no_fp8 if weight.requires_grad and not fp8_wgrad else None,
                inputmat_t if weight.requires_grad and fp8_wgrad else None,
                weight,
                weight_t_fp8 if fp8 else None,
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
                fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None,
            )
            ctx.activation_dtype = activation_dtype
            ctx.fp8 = fp8
            ctx.fp8_meta = fp8_meta
            ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
            ctx.is_first_microbatch = is_first_microbatch
            ctx.use_bias = use_bias
            ctx.sequence_parallel = sequence_parallel
            ctx.tensor_parallel = tensor_parallel
            ctx.inp_shape = inp.shape
            ctx.parallel_mode = parallel_mode
            ctx.tp_group = tp_group
1837
            ctx.ub_split_ag = ub_split_ag
Sangkug Lym's avatar
Sangkug Lym committed
1838
            ctx.tp_size = tp_size
1839
            ctx.requires_dgrad = inp.requires_grad
Przemek Tredak's avatar
Przemek Tredak committed
1840
1841

        # Row Parallel Linear
1842
1843
1844
        if ub_split_rs:
            out = rs_out
        elif parallel_mode == "row" and sequence_parallel:
Przemek Tredak's avatar
Przemek Tredak committed
1845
            out, _ = reduce_scatter_along_first_dim(out, tp_group)
1846
        elif parallel_mode == "row" and tensor_parallel:
Przemek Tredak's avatar
Przemek Tredak committed
1847
1848
1849
1850
1851
            out, _ = allreduce(out, tp_group)

        # [*, in_features] -> [*, out_features] except first dimension changes for SP
        return out.view(-1, *inp.shape[1:-1], out.shape[-1])

1852

Przemek Tredak's avatar
Przemek Tredak committed
1853
1854
1855
1856
    @staticmethod
    def backward(
        ctx, grad_output: torch.Tensor
    ) -> Tuple[Union[torch.Tensor, None], ...]:
Sangkug Lym's avatar
Sangkug Lym committed
1857
1858
1859
        with _prepare_backward(
            ctx.fp8, ctx.fp8_meta, ctx.tp_group, ctx.tp_size, name="_Linear"
        ):
1860
1861
1862
1863
1864
1865
1866
            (
                inputmat,
                inputmat_t,
                weight,
                weight_t_fp8,
                fwd_scale_inverses,
            ) = ctx.saved_tensors
Przemek Tredak's avatar
Przemek Tredak committed
1867

1868
1869
1870
1871
1872
1873
1874
1875
            if ctx.ub_split_ag:
                tp_world_size = get_distributed_world_size(ctx.tp_group)
                if tp_world_size == 1:
                    ctx.ub_split_ag = False
            if ctx.ub_split_ag:
                dim_size = list(grad_output.size())
                dim_size[0] = dim_size[0] * tp_world_size
                ctx.ub_obj_gradout = get_ub("proj_dgrad")
1876
1877
1878
1879
1880
1881
1882
1883
            (
                grad_output,
                grad_output_c,
                grad_output_t,
                grad_bias,
            ) = TransformerEngineBaseModule.grad_output_preprocess(
                ctx, grad_output, ctx.parallel_mode == "row"
            )
Przemek Tredak's avatar
Przemek Tredak committed
1884

1885
1886
1887
1888
1889
            # Column Parallel Linear
            # Overlap input AG with dgrad
            if ctx.parallel_mode == "column" and ctx.sequence_parallel:
                if ctx.fp8 and not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
                    inputmat_t_total, handle = gather_along_last_dim(
1890
                        inputmat_t, ctx.tp_group, async_op=ctx.requires_dgrad
1891
1892
1893
                    )
                else:
                    inputmat_total, handle = gather_along_first_dim(
1894
                        inputmat, ctx.tp_group, async_op=ctx.requires_dgrad
1895
1896
1897
1898
                    )
            else:
                inputmat_t_total = inputmat_t
                inputmat_total = inputmat
Przemek Tredak's avatar
Przemek Tredak committed
1899

1900
1901
1902
            if ctx.is_first_microbatch is not None:
                accumulate_wgrad_into_param_main_grad = (
                    ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch
Przemek Tredak's avatar
Przemek Tredak committed
1903
1904
                )
            else:
1905
                accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation
Przemek Tredak's avatar
Przemek Tredak committed
1906

1907
1908
1909
1910
1911
1912
1913
            if ctx.fp8:
                fp8_dtype_forward = get_fp8_te_dtype(
                    ctx.fp8_meta["recipe"], fprop_tensor=True
                )
                fp8_dtype_backward = get_fp8_te_dtype(
                    ctx.fp8_meta["recipe"], fprop_tensor=False
                )
Przemek Tredak's avatar
Przemek Tredak committed
1914

1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
            if ctx.requires_dgrad:
                if ctx.fp8:
                    dgrad = fp8_gemm(
                        weight_t_fp8,
                        fwd_scale_inverses,
                        tex.FP8FwdTensors.GEMM1_WEIGHT,
                        fp8_dtype_forward,
                        grad_output_c,
                        ctx.fp8_meta["scaling_bwd"].scale_inv,
                        tex.FP8BwdTensors.GRAD_OUTPUT1,
                        fp8_dtype_backward,
                        ctx.activation_dtype,
                        get_workspace(),
                        use_split_accumulator=_2X_ACC_DGRAD,
1929
1930
                        ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ctx.ub_split_ag else None,
                        ub=ctx.ub_obj_gradout if ctx.ub_split_ag else None,
1931
1932
1933
1934
1935
1936
1937
1938
1939
                    )
                else:
                    dgrad, _, _ = gemm(
                        weight,
                        grad_output,
                        ctx.activation_dtype,
                        get_workspace(),
                        layout="NN",
                        grad=True,
1940
1941
                        ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ctx.ub_split_ag else None,
                        ub=ctx.ub_obj_gradout if ctx.ub_split_ag else None,
1942
                    )
Przemek Tredak's avatar
Przemek Tredak committed
1943

1944
1945
1946
1947
1948
1949
1950
1951
                # Overlap dgrad-RS/AR with wgrad
                if ctx.parallel_mode == "column" and ctx.sequence_parallel:
                    handle.wait()
                    dgrad, handle = reduce_scatter_along_first_dim(
                        dgrad, ctx.tp_group, async_op=True
                    )
                elif ctx.parallel_mode == "column" and ctx.tensor_parallel:
                    dgrad, handle = allreduce(dgrad, ctx.tp_group, async_op=True)
1952

1953
            if weight.requires_grad:
1954
1955
1956
                if ctx.fp8:
                    # WGRAD
                    if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
1957
1958
                        if ctx.ub_split_ag:
                            grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward)
1959
1960
                        wgrad = fp8_gemm(
                            inputmat_t_total,
1961
1962
                            fwd_scale_inverses,
                            tex.FP8FwdTensors.GEMM1_INPUT,
1963
1964
                            fp8_dtype_forward,
                            grad_output_t,
1965
1966
                            ctx.fp8_meta["scaling_bwd"].scale_inv,
                            tex.FP8BwdTensors.GRAD_OUTPUT1,
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
                            fp8_dtype_backward,
                            ctx.activation_dtype,
                            get_workspace(),
                            accumulate=accumulate_wgrad_into_param_main_grad,
                            out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
                            use_split_accumulator=_2X_ACC_WGRAD,
                        )
                    else:
                        wgrad, _, _ = gemm(
                            inputmat_total,
                            grad_output,
                            ctx.activation_dtype,
                            get_workspace(),
                            layout="NT",
                            grad=True,
                            accumulate=accumulate_wgrad_into_param_main_grad,
                            out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
                        )
schetlur-nv's avatar
schetlur-nv committed
1985
                else:
1986
1987
                    # WGRAD
                    wgrad, grad_bias, _ = gemm(
schetlur-nv's avatar
schetlur-nv committed
1988
1989
1990
1991
1992
1993
                        inputmat_total,
                        grad_output,
                        ctx.activation_dtype,
                        get_workspace(),
                        layout="NT",
                        grad=True,
1994
                        use_bias=ctx.use_bias,
schetlur-nv's avatar
schetlur-nv committed
1995
1996
1997
                        accumulate=accumulate_wgrad_into_param_main_grad,
                        out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
                    )
Przemek Tredak's avatar
Przemek Tredak committed
1998

1999
2000
2001
            # Column Parallel Linear
            if ctx.parallel_mode == "column" and ctx.tensor_parallel and handle is not None:
                handle.wait()
Przemek Tredak's avatar
Przemek Tredak committed
2002

2003
2004
            if not ctx.use_bias:
                grad_bias = None
Przemek Tredak's avatar
Przemek Tredak committed
2005
2006

        return (
2007
            wgrad if weight.requires_grad else None,
Przemek Tredak's avatar
Przemek Tredak committed
2008
2009
            None,
            None,
2010
            dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None,
Przemek Tredak's avatar
Przemek Tredak committed
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
            grad_bias,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
2021
            None,
2022
            None,
schetlur-nv's avatar
schetlur-nv committed
2023
            None,
Sangkug Lym's avatar
Sangkug Lym committed
2024
            None,
2025
2026
            None,
            None,
Przemek Tredak's avatar
Przemek Tredak committed
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
        )


class Linear(TransformerEngineBaseModule):
    """
    Applies a linear transformation to the incoming data :math:`y = xA^T + b`

    On NVIDIA GPUs it is a drop-in replacement for `torch.nn.Linear`.

    Parameters
    ----------
    in_features : int
                 size of each input sample.
    out_features : int
                  size of each output sample.
    bias : bool, default = `True`
          if set to `False`, the layer will not learn an additive bias.
    init_method : Callable, default = `None`
                 used for initializing weights in the following way: `init_method(weight)`.
                 When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`.
2047
2048
2049
2050
2051
    parameters_split : Tuple[str, ...], default = None
                      if a tuple of strings is provided, the weight and bias parameters of the
                      module are exposed as `N` separate `torch.nn.parameter.Parameter`s each,
                      split along the first dimension, where `N` is the length of the argument
                      and the strings contained are the names of the split parameters.
Przemek Tredak's avatar
Przemek Tredak committed
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106

    Parallelism parameters
    ----------------------
    sequence_parallel : bool, default = `False`
                       if set to `True`, uses sequence parallelism.
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
    tp_size : int, default = 1
             used as TP (tensor parallel) world size when TP groups are not formed during
             initialization. In this case, users must call the
             `set_tensor_parallel_group(tp_group)` method on the initialized module before the
             forward pass to supply the tensor parallel group needed for tensor and sequence
             parallel collectives.
    parallel_mode : {None, 'Column', 'Row'}, default = `None`
                   used to decide whether this Linear layer is Column Parallel Linear or Row
                   Parallel Linear as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
                   When set to `None`, no communication is performed.
    skip_weight_param_allocation: bool, default = `False`
                                 if set to `True`, weight parameter is not allocated and must be
                                 passed as a keyword argument `weight` during the forward pass.

    Optimization parameters
    -----------------------
    fuse_wgrad_accumulation : bool, default = 'False'
                             if set to `True`, enables fusing of creation and accumulation of
                             the weight gradient. When enabled, it is assumed that the weights
                             have an additional `main_grad` attribute (used instead of the
                             regular `grad`) which is a pre-allocated buffer of the correct
                             size to accumulate gradients in.
    return_bias : bool, default = `False`
                 when set to `True`, this module will not apply the additive bias itself, but
                 instead return the bias value during the forward pass together with the
                 output of the linear transformation :math:`y = xA^T`. This is useful when
                 the bias addition can be fused to subsequent operations.
    params_dtype : torch.dtype, default = `torch.float32`
                  it controls the type used to allocate the initial parameters. Useful when
                  the model is trained with lower precision and the original FP32 parameters
                  would not fit in GPU memory.
    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        sequence_parallel: bool = False,
        fuse_wgrad_accumulation: bool = False,
        tp_group: Optional[dist_group_type] = None,
        tp_size: int = 1,
        get_rng_state_tracker: Optional[Callable] = None,
        init_method: Optional[Callable] = None,
        bias: bool = True,
        return_bias: bool = False,
        params_dtype: torch.dtype = torch.float32,
        parallel_mode: Optional[str] = None,
        skip_weight_param_allocation: bool = False,
2107
        parameters_split: Optional[Tuple[str, ...]] = None,
2108
2109
        ub_split_rs: bool = False,
        ub_split_ag: bool = False,
Przemek Tredak's avatar
Przemek Tredak committed
2110
2111
2112
2113
2114
2115
2116
    ) -> None:
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.fuse_wgrad_accumulation = fuse_wgrad_accumulation
        self.use_bias = bias
        self.return_bias = return_bias
2117
        self.apply_bias = bias and not return_bias
2118
        self.parameters_split = parameters_split
2119
2120
2121
2122
2123
2124
2125
        self.ub_split_rs = ub_split_rs
        self.ub_split_ag = ub_split_ag

        if ub_split_rs or ub_split_ag:
            assert (
                tex.userbuf_comm_available()
            ), "Userbuffer communication backend not available."
Przemek Tredak's avatar
Przemek Tredak committed
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151

        if tp_group is None:
            self.tp_size = tp_size
            if tp_size == 1:
                self.set_tensor_parallel_group(tp_group)
        else:
            self.tp_size = get_distributed_world_size(tp_group)
            self.set_tensor_parallel_group(tp_group)
        self.set_nccl_overlap_warning_if_tp()

        self.parallel_mode = parallel_mode
        assert (
            self.parallel_mode in GemmParallelModes
        ), f"parallel_mode {parallel_mode} not supported"

        if self.parallel_mode == "column":
            self.out_features = divide(self.out_features, self.tp_size)
        elif self.parallel_mode == "row":
            self.in_features = divide(self.in_features, self.tp_size)

        if init_method is None:
            init_method = get_default_init_method()

        self.sequence_parallel = (self.tp_size > 1) and sequence_parallel

        if not skip_weight_param_allocation:
2152
2153
2154
2155
2156
2157
2158
            self.register_buffer("weight_tensor",
                                 torch.empty(
                                    self.out_features,
                                    self.in_features,
                                    device=torch.cuda.current_device(),
                                    dtype=params_dtype),
                                 persistent=False)
Przemek Tredak's avatar
Przemek Tredak committed
2159
2160

            initialize_affine_weight_gpu(
2161
                self.weight_tensor,
Przemek Tredak's avatar
Przemek Tredak committed
2162
2163
2164
2165
2166
2167
                init_method,
                get_rng_state_tracker,
                partition_dim=1 if self.parallel_mode == "row" else 0,
                stride=1,
            )

2168
            if self.use_bias:
2169
2170
2171
2172
2173
2174
                self.register_buffer("bias_tensor",
                                     torch.empty(
                                         self.out_features,
                                         device=torch.cuda.current_device(),
                                         dtype=params_dtype),
                                     persistent=False)
Przemek Tredak's avatar
Przemek Tredak committed
2175
            else:
2176
2177
2178
                self.register_buffer(
                    "bias_tensor", torch.Tensor().type(params_dtype), persistent=False
                )
Przemek Tredak's avatar
Przemek Tredak committed
2179
2180

            with torch.no_grad():
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
                self.bias_tensor.zero_()

            if parameters_split is None:
                parameters_split = ("",)

            assert (
                self.out_features % len(parameters_split) == 0
            ), f"Weight and bias params cannot be split into {len(parameters_split)} parts"

            split_size = self.out_features // len(parameters_split)

            self.weight_names = []
            self.bias_names = []

            for i, pname in enumerate(parameters_split):
                wname = pname + "weight"
                bname = pname + "bias"

                self.register_parameter(
                    wname, Parameter(self.weight_tensor[i * split_size : (i+1) * split_size])
                )

                set_tensor_model_parallel_attributes(
                    tensor=getattr(self, wname),
                    is_parallel=True,
                    dim=1 if parallel_mode == "row" else 0,
                    stride=1,
                )

2210
                if self.use_bias:
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
2221
                    self.register_parameter(
                        bname, Parameter(self.bias_tensor[i * split_size : (i+1) * split_size])
                    )
                else:
                    self.register_buffer(bname, torch.Tensor().type(params_dtype), persistent=False)

                if parallel_mode == "column":
                    set_tensor_model_parallel_attributes(getattr(self, bname), True, 0, 1)

                self.weight_names.append(wname)
                self.bias_names.append(bname)
Przemek Tredak's avatar
Przemek Tredak committed
2222
2223
2224
2225
2226

        self.fp8_weight_shapes.append(torch.Size((self.out_features, self.in_features)))

        # For RPL, bias has to be added after TP collectives
        # So it cannot be fused with the GEMM
2227
        if self.parallel_mode == "row" and self.apply_bias:
Przemek Tredak's avatar
Przemek Tredak committed
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
2249
2250
2251
2252
2253
2254
2255
2256
2257
2258
2259
2260
2261
2262
2263
2264
2265
2266
2267
            self.gemm_bias_unfused_add = True
        else:
            self.gemm_bias_unfused_add = False

    def forward(
        self,
        inp: torch.Tensor,
        weight: Optional[torch.Tensor] = None,
        bias: Optional[torch.Tensor] = None,
        is_first_microbatch: Optional[bool] = None,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
        """
        Apply the linear transformation to the input.

        Parameters
        ----------
        inp : torch.Tensor
             Input tensor.
        weight : torch.Tensor, default = None
                An optional weight tensor for the module. This argument is compulsory if module
                is initialized with `skip_weight_param_allocation=True`
        bias : torch.Tensor, default = None
              An optional bias tensor for the module. This argument is compulsory if module
              is initialized with `skip_weight_param_allocation=True` and one of `use_bias`
              or `return_bias`
        is_first_microbatch : {True, False, None}, default = None
                             During training using either gradient accumulation or
                             pipeline parallelism a minibatch of data is further split
                             into microbatches. Between the microbatches of the same minibatch
                             the model weights are not updated. Setting this parameter indicates
                             whether the current microbatch is the first in a minibatch or not.
                             When set, this parameter enables additional optimizations:

                             * during FP8 training, it allows caching of the FP8 versions of
                               the weights
                             * it also allows skipping gradient accumulation during the
                               first microbatch (since it is the first gradient being
                               produced)
        """

2268
        with self.prepare_forward(inp, is_first_microbatch) as inp:
2269
2270
2271
            bias_tensor = (
                bias if bias is not None
                else self.bias if self.parameters_split is None
2272
                else self.bias_tensor if not torch.is_grad_enabled()
2273
2274
2275
2276
2277
                else self.noop_cat("bias_tensor", self.bias_names)
            )
            weight_tensor = (
                weight if weight is not None
                else self.weight if self.parameters_split is None
2278
                else self.weight_tensor if not torch.is_grad_enabled()
2279
2280
                else self.noop_cat("weight_tensor", self.weight_names)
            )
2281

2282
            if torch.is_grad_enabled():
2283
2284
2285
2286
2287
2288
                linear_fn = _Linear.apply
                args = []
            else:
                linear_fn = _Linear.forward
                args = [None]
            args += (
2289
                weight_tensor,
2290
2291
2292
2293
                self.weight1_fp8 if self.fp8 else None,
                self.weight1_t_fp8 if self.fp8 else None,
                inp,
                bias_tensor,
2294
                self.apply_bias and not self.gemm_bias_unfused_add,
2295
2296
                is_first_microbatch,
                self.fp8,
schetlur-nv's avatar
schetlur-nv committed
2297
                self.fp8_calibration,
2298
2299
2300
                self.fp8_meta,
                self.fuse_wgrad_accumulation,
                self.tp_group,
Sangkug Lym's avatar
Sangkug Lym committed
2301
                self.tp_size,
2302
2303
2304
2305
                self.sequence_parallel,
                self.tp_size > 1,
                self.activation_dtype,
                self.parallel_mode,
2306
                torch.is_grad_enabled(),
2307
2308
                self.ub_split_rs,
                self.ub_split_ag,
2309
            )
2310
            out = linear_fn(*args)
Przemek Tredak's avatar
Przemek Tredak committed
2311
2312
2313
2314
2315
2316
2317
2318
2319
2320
2321
2322
2323
2324
2325
2326
2327
2328
2329
2330
2331

        if self.gemm_bias_unfused_add:
            out = out + cast_if_needed(bias_tensor, self.activation_dtype)

        if self.return_bias:
            return out, cast_if_needed(bias_tensor, self.activation_dtype)
        return out


class _LayerNormMLP(torch.autograd.Function):
    """LayerNormMLP semi-top level module
    Calls custom cuda extensions.
    """

    @staticmethod
    def forward(
        ctx,
        inp: torch.Tensor,
        ln_weight: torch.Tensor,
        ln_bias: torch.Tensor,
        fc1_weight: torch.Tensor,
2332
2333
        fc1_weight_fp8: Union[torch.Tensor, None],
        fc1_weight_t_fp8: Union[torch.Tensor, None],
Przemek Tredak's avatar
Przemek Tredak committed
2334
        fc1_bias: torch.Tensor,
ngoyal2707's avatar
ngoyal2707 committed
2335
        use_fc1_bias: bool,
Przemek Tredak's avatar
Przemek Tredak committed
2336
        fc2_weight: torch.Tensor,
2337
2338
        fc2_weight_fp8: Union[torch.Tensor, None],
        fc2_weight_t_fp8: Union[torch.Tensor, None],
Przemek Tredak's avatar
Przemek Tredak committed
2339
        fc2_bias: torch.Tensor,
ngoyal2707's avatar
ngoyal2707 committed
2340
        use_fc2_bias: bool,
Przemek Tredak's avatar
Przemek Tredak committed
2341
2342
2343
        eps: float,
        is_first_microbatch: Union[bool, None],
        fp8: bool,
schetlur-nv's avatar
schetlur-nv committed
2344
        fp8_calibration: bool,
Przemek Tredak's avatar
Przemek Tredak committed
2345
2346
2347
        fp8_meta: Dict[str, Any],
        fuse_wgrad_accumulation: bool,
        tp_group: Union[dist_group_type, None],
Sangkug Lym's avatar
Sangkug Lym committed
2348
        tp_size: int,
Przemek Tredak's avatar
Przemek Tredak committed
2349
        sequence_parallel: bool,
2350
        tensor_parallel: bool,
Przemek Tredak's avatar
Przemek Tredak committed
2351
2352
2353
2354
        activation_dtype: torch.dtype,
        return_layernorm_output: bool,
        bias_gelu_nvfusion: bool,
        set_parallel_mode: bool,
2355
        is_grad_enabled: bool,
2356
2357
        fwd_ln_sm_margin: int,
        bwd_ln_sm_margin: int,
2358
        zero_centered_gamma: bool,
2359
2360
2361
2362
        ub_bulk_wgrad: bool,
        ub_bulk_dgrad: bool,
        ub_split_rs: bool,
        ub_split_ag: bool,
Przemek Tredak's avatar
Przemek Tredak committed
2363
2364
2365
2366
2367
    ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
        # Make sure input dimensions are compatible
        in_features = ln_weight.numel()
        assert inp.shape[-1] == in_features, "GEMM not possible"
        inputmat = inp.view((-1, in_features))
2368
        assert (
2369
2370
            not fp8 or check_dim_for_fp8_forward_exec(inputmat, fc1_weight, fc2_weight)
        ), "Input and weight dimensions are not compatible for FP8 execution."
Przemek Tredak's avatar
Przemek Tredak committed
2371
2372
2373
2374
2375
2376
2377
2378

        update_fp8_weights = is_first_microbatch is None or is_first_microbatch

        # Cast for native AMP
        inputmat = cast_if_needed(inputmat, activation_dtype)
        ln_weight = cast_if_needed(ln_weight, activation_dtype)
        ln_bias = cast_if_needed(ln_bias, activation_dtype)

2379
2380
2381
2382
2383
2384
2385
2386
2387
2388
2389
2390
        if ub_split_ag:
            tp_world_size = get_distributed_world_size(tp_group)
            if tp_world_size == 1 or (not is_grad_enabled) or return_layernorm_output:
                ub_split_ag = False
        if ub_split_ag:
            ub_obj_lnout = get_ub("fc1_fprop")
            ln_out = ub_obj_lnout.get_ubuf_output(0)
        if ub_split_rs:
            tp_world_size = get_distributed_world_size(tp_group)
            if tp_world_size == 1:
                ub_split_rs = False

Przemek Tredak's avatar
Przemek Tredak committed
2391
2392
2393
2394
2395
2396
        # If residual connection is after LN, we need `ln_out`
        # tensor in higher precision, this comes at the cost
        # of an extra fp8 cast.
        if fp8:
            fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
            if not return_layernorm_output:
2397
                if is_grad_enabled:
2398
2399
2400
                    if not ub_split_ag:
                        ln_out = torch.empty_like(inputmat, dtype=torch.uint8)
                    _, mu, rsigma = layernorm_fwd_fp8(
2401
2402
2403
2404
2405
2406
2407
                        inputmat,
                        ln_weight,
                        ln_bias,
                        eps,
                        fp8_meta["scaling_fwd"],
                        tex.FP8FwdTensors.GEMM1_INPUT,
                        fp8_dtype_forward,
2408
                        fwd_ln_sm_margin,
2409
                        zero_centered_gamma,
2410
                        ln_out = ln_out,
2411
2412
2413
2414
2415
2416
2417
2418
2419
2420
                    )
                else:
                    ln_out = layernorm_fwd_fp8_inf(
                        inputmat,
                        ln_weight,
                        ln_bias,
                        eps,
                        fp8_meta["scaling_fwd"],
                        tex.FP8FwdTensors.GEMM1_INPUT,
                        fp8_dtype_forward,
2421
                        zero_centered_gamma,
2422
                    )
Przemek Tredak's avatar
Przemek Tredak committed
2423
2424
            else:
                ln_out_return, mu, rsigma = tex.layernorm_fwd(
2425
                    inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin, zero_centered_gamma
Przemek Tredak's avatar
Przemek Tredak committed
2426
2427
2428
2429
2430
2431
2432
2433
                )
                ln_out = cast_to_fp8(
                    ln_out_return,
                    fp8_meta["scaling_fwd"],
                    tex.FP8FwdTensors.GEMM1_INPUT,
                    fp8_dtype_forward,
                )
        else:
2434
            if is_grad_enabled:
2435
2436
2437
2438
2439
2440
2441
2442
2443
                if ub_split_ag:
                    _, mu, rsigma = tex.layernorm_fwd_noalloc(
                        inputmat, ln_weight, ln_bias, ln_out, eps,
                        fwd_ln_sm_margin, zero_centered_gamma
                    )
                else:
                    ln_out, mu, rsigma = tex.layernorm_fwd(
                        inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin, zero_centered_gamma
                    )
2444
2445
            else:
                ln_out, mu, rsigma = layernorm_fwd_inf(
2446
                        inputmat, ln_weight, ln_bias, eps, zero_centered_gamma
2447
                        ), None, None
Przemek Tredak's avatar
Przemek Tredak committed
2448

2449
            ln_out_return = ln_out
Przemek Tredak's avatar
Przemek Tredak committed
2450
        # Column Parallel Linear
2451
2452
2453
2454
        if ub_split_ag:
            ln_out_total = ub_obj_lnout.get_ubuf_output(1)
            ln_out = torch.empty_like(ln_out)
        elif set_parallel_mode and sequence_parallel:
Przemek Tredak's avatar
Przemek Tredak committed
2455
2456
2457
2458
2459
2460
2461
2462
2463
2464
            ln_out_total, _ = gather_along_first_dim(ln_out, tp_group)
        else:
            ln_out_total = ln_out

        if fp8:
            bias_dtype = (
                torch.bfloat16
                if activation_dtype == torch.float32
                else activation_dtype
            )
ngoyal2707's avatar
ngoyal2707 committed
2465
2466
            fc1_bias = cast_if_needed(fc1_bias, bias_dtype) if use_fc1_bias else fc1_bias
            fc2_bias = cast_if_needed(fc2_bias, bias_dtype) if use_fc2_bias else fc2_bias
Przemek Tredak's avatar
Przemek Tredak committed
2467
2468

            if update_fp8_weights:
2469
                if is_grad_enabled:
2470
2471
2472
2473
2474
2475
2476
2477
                    fp8_cast_transpose_fused(
                        fc1_weight,
                        fp8_meta["scaling_fwd"],
                        tex.FP8FwdTensors.GEMM1_WEIGHT,
                        fp8_dtype_forward,
                        cast_out=fc1_weight_fp8,
                        transpose_out=fc1_weight_t_fp8,
                    )
Przemek Tredak's avatar
Przemek Tredak committed
2478

2479
2480
2481
2482
2483
2484
2485
2486
2487
2488
2489
2490
2491
2492
2493
2494
2495
2496
2497
2498
2499
2500
2501
                    fp8_cast_transpose_fused(
                        fc2_weight,
                        fp8_meta["scaling_fwd"],
                        tex.FP8FwdTensors.GEMM2_WEIGHT,
                        fp8_dtype_forward,
                        cast_out=fc2_weight_fp8,
                        transpose_out=fc2_weight_t_fp8,
                    )
                else:
                    fc1_weight_t_fp8 = None
                    fc1_weight_fp8 = cast_to_fp8(
                        fc1_weight,
                        fp8_meta["scaling_fwd"],
                        tex.FP8FwdTensors.GEMM1_WEIGHT,
                        fp8_dtype_forward,
                    )
                    fc2_weight_t_fp8 = None
                    fc2_weight_fp8 = cast_to_fp8(
                        fc2_weight,
                        fp8_meta["scaling_fwd"],
                        tex.FP8FwdTensors.GEMM2_WEIGHT,
                        fp8_dtype_forward,
                    )
Przemek Tredak's avatar
Przemek Tredak committed
2502
2503
2504

            fc1_out = fp8_gemm(
                fc1_weight_fp8,
2505
2506
                fp8_meta["scaling_fwd"].scale_inv,
                tex.FP8FwdTensors.GEMM1_WEIGHT,
Przemek Tredak's avatar
Przemek Tredak committed
2507
2508
                fp8_dtype_forward,
                ln_out_total,
2509
2510
                fp8_meta["scaling_fwd"].scale_inv,
                tex.FP8FwdTensors.GEMM1_INPUT,
Przemek Tredak's avatar
Przemek Tredak committed
2511
2512
2513
2514
                fp8_dtype_forward,
                activation_dtype,
                get_workspace(),
                bias=fc1_bias,
ngoyal2707's avatar
ngoyal2707 committed
2515
                use_bias=use_fc1_bias,
Przemek Tredak's avatar
Przemek Tredak committed
2516
                use_split_accumulator=_2X_ACC_FPROP,
2517
2518
2519
                ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ub_split_ag else None,
                ub=ub_obj_lnout if ub_split_ag else None,
                extra_output_tensor=ln_out if ub_split_ag else None,
Przemek Tredak's avatar
Przemek Tredak committed
2520
2521
2522
2523
2524
2525
2526
2527
2528
            )

            gelu_out = fp8_gelu(
                fc1_out,
                fp8_meta["scaling_fwd"],
                tex.FP8FwdTensors.GEMM2_INPUT,
                fp8_dtype_forward,
            )

2529
2530
2531
2532
2533
2534
2535
2536
2537
2538
2539
2540
2541
            if ub_split_rs:
                ub_obj_fc2out = get_ub("fc2_fprop")
                fc2_out = ub_obj_fc2out.get_ubuf_output(1)
                dim_size = list(gelu_out.size())
                dim_size[0] = dim_size[0] // tp_world_size
                dim_size[1] = fc2_weight.size(0)
                rs_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device)
            else:
                dim_size = list(gelu_out.size())
                dim_size[1] = fc2_weight.size(0)
                fc2_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device)

            _ = fp8_gemm(
Przemek Tredak's avatar
Przemek Tredak committed
2542
                fc2_weight_fp8,
2543
2544
                fp8_meta["scaling_fwd"].scale_inv,
                tex.FP8FwdTensors.GEMM2_WEIGHT,
Przemek Tredak's avatar
Przemek Tredak committed
2545
2546
                fp8_dtype_forward,
                gelu_out,
2547
2548
                fp8_meta["scaling_fwd"].scale_inv,
                tex.FP8FwdTensors.GEMM2_INPUT,
Przemek Tredak's avatar
Przemek Tredak committed
2549
2550
2551
2552
                fp8_dtype_forward,
                activation_dtype,
                get_workspace(),
                bias=fc2_bias,
ngoyal2707's avatar
ngoyal2707 committed
2553
                use_bias=use_fc2_bias,
Przemek Tredak's avatar
Przemek Tredak committed
2554
                use_split_accumulator=_2X_ACC_FPROP,
2555
2556
2557
2558
                out=fc2_out,
                ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS if ub_split_rs else None,
                ub=ub_obj_fc2out if ub_split_rs else None,
                extra_output_tensor=rs_out if ub_split_rs else None,
Przemek Tredak's avatar
Przemek Tredak committed
2559
2560
2561
2562
2563
            )
        else:
            # Cast for native AMP
            fc1_weight = cast_if_needed(fc1_weight, activation_dtype)
            fc2_weight = cast_if_needed(fc2_weight, activation_dtype)
ngoyal2707's avatar
ngoyal2707 committed
2564
2565
2566
            fc1_bias = (
                cast_if_needed(fc1_bias, activation_dtype) if use_fc1_bias else fc1_bias
            )
Przemek Tredak's avatar
Przemek Tredak committed
2567
            fc2_bias = (
ngoyal2707's avatar
ngoyal2707 committed
2568
                cast_if_needed(fc2_bias, activation_dtype) if use_fc2_bias else fc2_bias
Przemek Tredak's avatar
Przemek Tredak committed
2569
2570
            )

schetlur-nv's avatar
schetlur-nv committed
2571
2572
2573
2574
2575
2576
2577
2578
            if fp8_calibration:
                # amax of fc1 input
                fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_INPUT] = \
                    torch.amax(ln_out_total).float()
                # amax of fc1 weight
                fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_WEIGHT] = \
                    torch.amax(fc1_weight).float()

Przemek Tredak's avatar
Przemek Tredak committed
2579
2580
2581
2582
2583
2584
            fc1_outputs = gemm(
                fc1_weight,
                ln_out_total,
                activation_dtype,
                get_workspace(),
                bias=fc1_bias,
ngoyal2707's avatar
ngoyal2707 committed
2585
                use_bias=(not bias_gelu_nvfusion) and use_fc1_bias,
Przemek Tredak's avatar
Przemek Tredak committed
2586
                gelu=not bias_gelu_nvfusion,
2587
2588
2589
                ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ub_split_ag else None,
                ub=ub_obj_lnout if ub_split_ag else None,
                extra_output_tensor=ln_out if ub_split_ag else None,
Przemek Tredak's avatar
Przemek Tredak committed
2590
2591
            )

2592
            if bias_gelu_nvfusion:
Przemek Tredak's avatar
Przemek Tredak committed
2593
                fc1_out, _, _ = fc1_outputs
ngoyal2707's avatar
ngoyal2707 committed
2594

Przemek Tredak's avatar
Przemek Tredak committed
2595
2596
2597
2598
                gelu_out = bias_gelu_fused(fc1_out, fc1_bias)
            else:
                gelu_out, _, fc1_out = fc1_outputs

schetlur-nv's avatar
schetlur-nv committed
2599
2600
2601
2602
2603
2604
2605
2606
            if fp8_calibration:
                # amax of fc2 input
                fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM2_INPUT] = \
                    torch.amax(gelu_out).float()
                # amax of fc2 weight
                fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM2_WEIGHT] = \
                    torch.amax(fc2_weight).float()

2607
2608
2609
2610
2611
2612
2613
2614
2615
2616
2617
2618
            if ub_split_rs:
                ub_obj_fc2out = get_ub("fc2_fprop")
                fc2_out = ub_obj_fc2out.get_ubuf_output(1)
                dim_size = list(gelu_out.size())
                dim_size[0] = dim_size[0] // tp_world_size
                dim_size[1] = fc2_weight.size(0)
                rs_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device)
            else:
                dim_size = list(gelu_out.size())
                dim_size[1] = fc2_weight.size(0)
                fc2_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device)
            _, _, _ = gemm(
Przemek Tredak's avatar
Przemek Tredak committed
2619
2620
2621
2622
2623
                fc2_weight,
                gelu_out,
                activation_dtype,
                get_workspace(),
                bias=fc2_bias,
ngoyal2707's avatar
ngoyal2707 committed
2624
                use_bias=use_fc2_bias,
2625
2626
2627
2628
                out=fc2_out,
                ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS if ub_split_rs else None,
                ub=ub_obj_fc2out if ub_split_rs else None,
                extra_output_tensor=rs_out if ub_split_rs else None,
Przemek Tredak's avatar
Przemek Tredak committed
2629
            )
2630

2631
        if is_grad_enabled:
2632
2633
2634
2635
2636
2637
2638
2639
2640
2641
2642
2643
2644
2645
2646
2647
2648
2649
2650
2651
            ctx.save_for_backward(
                inputmat,
                ln_weight,
                mu,
                rsigma,
                ln_out,
                fc1_out,
                gelu_out,
                fc1_weight,
                fc1_weight_t_fp8,
                fc2_weight,
                fc2_weight_t_fp8,
                fc1_bias,
                fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None,
            )
            ctx.activation_dtype = activation_dtype
            ctx.fp8 = fp8
            ctx.fp8_meta = fp8_meta
            ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
            ctx.is_first_microbatch = is_first_microbatch
ngoyal2707's avatar
ngoyal2707 committed
2652
2653
            ctx.use_fc1_bias = use_fc1_bias
            ctx.use_fc2_bias = use_fc2_bias
2654
2655
2656
2657
            ctx.sequence_parallel = sequence_parallel
            ctx.tensor_parallel = tensor_parallel
            ctx.inp_shape = inp.shape
            ctx.tp_group = tp_group
Sangkug Lym's avatar
Sangkug Lym committed
2658
            ctx.tp_size = tp_size
2659
2660
2661
            ctx.bias_gelu_nvfusion = bias_gelu_nvfusion
            ctx.return_layernorm_output = return_layernorm_output
            ctx.set_parallel_mode = set_parallel_mode
2662
            ctx.bwd_ln_sm_margin = bwd_ln_sm_margin
2663
            ctx.zero_centered_gamma = zero_centered_gamma
2664
2665
2666
            ctx.ub_bulk_wgrad = ub_bulk_wgrad
            ctx.ub_bulk_dgrad = ub_bulk_dgrad
            ctx.ub_split_ag = ub_split_ag
2667
            ctx.requires_dgrad = inp.requires_grad
Przemek Tredak's avatar
Przemek Tredak committed
2668
2669

        # Row Parallel Linear
2670
2671
2672
        if ub_split_rs:
            fc2_out = rs_out
        elif set_parallel_mode and sequence_parallel:
Przemek Tredak's avatar
Przemek Tredak committed
2673
            fc2_out, _ = reduce_scatter_along_first_dim(fc2_out, tp_group)
2674
        elif set_parallel_mode and tensor_parallel:
Przemek Tredak's avatar
Przemek Tredak committed
2675
2676
2677
2678
2679
2680
2681
2682
2683
            fc2_out, _ = allreduce(fc2_out, tp_group)

        # [*, in_features] -> [*, out_features] except first dimension changes for SP
        fc2_out = fc2_out.view(-1, *inp.shape[1:-1], fc2_out.shape[-1])

        if return_layernorm_output:
            return fc2_out, ln_out_return.view_as(inp)
        return fc2_out

2684

Przemek Tredak's avatar
Przemek Tredak committed
2685
2686
2687
2688
    @staticmethod
    def backward(
        ctx, *grad_outputs: Tuple[torch.Tensor, ...]
    ) -> Tuple[Union[torch.Tensor, None], ...]:
Sangkug Lym's avatar
Sangkug Lym committed
2689
2690
2691
        with _prepare_backward(
            ctx.fp8, ctx.fp8_meta, ctx.tp_group, ctx.tp_size, name="_LayerNormMLP"
        ):
2692
2693
2694
2695
2696
2697
2698
2699
2700
2701
2702
2703
2704
2705
2706
            (
                inputmat,
                ln_weight,
                mu,
                rsigma,
                ln_out,
                fc1_out,
                gelu_out,
                fc1_weight,
                fc1_weight_t_fp8,
                fc2_weight,
                fc2_weight_t_fp8,
                fc1_bias,
                fwd_scale_inverses,
            ) = ctx.saved_tensors
Przemek Tredak's avatar
Przemek Tredak committed
2707

2708
2709
2710
2711
2712
2713
2714
2715
2716
2717
2718
2719
2720
2721
2722
2723
2724
2725
            if ctx.ub_bulk_dgrad:
                tp_world_size = get_distributed_world_size(ctx.tp_group)
                if tp_world_size == 1:
                    ctx.ub_bulk_dgrad = False
            if ctx.ub_bulk_dgrad:
                dim_size = list(ln_out.size())
                dim_size[0] = dim_size[0] * tp_world_size
                ub_obj_lnout = get_ub("fc1_dgrad")
                ub_obj_lnout.copy_input_to_ubuf(ln_out, 1)
            if ctx.ub_split_ag:
                tp_world_size = get_distributed_world_size(ctx.tp_group)
                if tp_world_size == 1:
                    ctx.ub_split_ag = False
            if ctx.ub_split_ag:
                dim_size = list(grad_outputs[0].size())
                dim_size[0] = dim_size[0] * tp_world_size
                ctx.ub_obj_gradout = get_ub("fc2_dgrad")

ngoyal2707's avatar
ngoyal2707 committed
2726
            ctx.use_bias = ctx.use_fc2_bias # For grad_output_preprocess
2727
2728
2729
2730
2731
2732
2733
2734
            (
                grad_output,
                grad_output_c,
                grad_output_t,
                fc2_bias_grad,
            ) = TransformerEngineBaseModule.grad_output_preprocess(
                ctx, grad_outputs[0], True
            )
2735
2736
2737
2738
            if ctx.ub_bulk_wgrad:
                tp_world_size = get_distributed_world_size(ctx.tp_group)
                if tp_world_size == 1:
                    ctx.ub_bulk_wgrad = False
2739
2740
            # Column Parallel Linear
            # Overlap input AG with dgrad
2741
            if (not ctx.ub_bulk_dgrad) and ctx.set_parallel_mode and ctx.sequence_parallel:
2742
2743
2744
2745
2746
                ln_out_total, handle = gather_along_first_dim(
                    ln_out, ctx.tp_group, async_op=True
                )
            else:
                ln_out_total = ln_out
Przemek Tredak's avatar
Przemek Tredak committed
2747

2748
2749
2750
2751
2752
2753
            if ctx.is_first_microbatch is not None:
                accumulate_wgrad_into_param_main_grad = (
                    ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch
                )
            else:
                accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation
Przemek Tredak's avatar
Przemek Tredak committed
2754

2755
2756
2757
2758
2759
2760
2761
            if ctx.fp8:
                fp8_dtype_forward = get_fp8_te_dtype(
                    ctx.fp8_meta["recipe"], fprop_tensor=True
                )
                fp8_dtype_backward = get_fp8_te_dtype(
                    ctx.fp8_meta["recipe"], fprop_tensor=False
                )
Przemek Tredak's avatar
Przemek Tredak committed
2762

2763
2764
2765
                # FC2 DGRAD; Unconditional
                fc2_dgrad = fp8_gemm(
                    fc2_weight_t_fp8,
2766
2767
                    fwd_scale_inverses,
                    tex.FP8FwdTensors.GEMM2_WEIGHT,
2768
2769
                    fp8_dtype_forward,
                    grad_output_c,
2770
2771
                    ctx.fp8_meta["scaling_bwd"].scale_inv,
                    tex.FP8BwdTensors.GRAD_OUTPUT1,
2772
2773
2774
2775
                    fp8_dtype_backward,
                    ctx.activation_dtype,
                    get_workspace(),
                    use_split_accumulator=_2X_ACC_DGRAD,
2776
2777
                    ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ctx.ub_split_ag else None,
                    ub=ctx.ub_obj_gradout if ctx.ub_split_ag else None,
2778
                )
2779
2780
                if ctx.ub_split_ag:
                    grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward)
2781
2782
2783
2784
2785
2786
                # FC2 WGRAD
                if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
                    if fc2_weight.requires_grad:
                        gelu_out_t = tex.fp8_transpose(gelu_out, fp8_dtype_forward)
                        fc2_wgrad = fp8_gemm(
                            gelu_out_t,
2787
2788
                            fwd_scale_inverses,
                            tex.FP8FwdTensors.GEMM2_INPUT,
2789
2790
                            fp8_dtype_forward,
                            grad_output_t,
2791
2792
                            ctx.fp8_meta["scaling_bwd"].scale_inv,
                            tex.FP8BwdTensors.GRAD_OUTPUT1,
2793
2794
2795
2796
2797
2798
2799
2800
2801
2802
2803
2804
2805
2806
2807
2808
2809
2810
2811
2812
2813
2814
2815
2816
2817
2818
2819
2820
2821
2822
2823
2824
2825
                            fp8_dtype_backward,
                            ctx.activation_dtype,
                            get_workspace(),
                            accumulate=accumulate_wgrad_into_param_main_grad,
                            out=fc2_weight.main_grad
                            if ctx.fuse_wgrad_accumulation
                            else None,
                            use_split_accumulator=_2X_ACC_WGRAD,
                        )

                    fc1_bias_grad, dgelu, dgelu_t = fp8_cast_transpose_bgrad_dgelu_fused(
                        fc2_dgrad,
                        fc1_out,
                        ctx.fp8_meta["scaling_bwd"],
                        tex.FP8BwdTensors.GRAD_OUTPUT2,
                        fp8_dtype_backward,
                    )
                else:
                    if fc2_weight.requires_grad:
                        gelu_out_c = cast_from_fp8(
                            gelu_out,
                            ctx.fp8_meta["scaling_fwd"],
                            tex.FP8FwdTensors.GEMM2_INPUT,
                            fp8_dtype_forward,
                            TE_DType[ctx.activation_dtype],
                        )
                        fc2_wgrad, _, _ = gemm(
                            gelu_out_c,
                            grad_output,
                            ctx.activation_dtype,
                            get_workspace(),
                            layout="NT",
                            grad=True,
ngoyal2707's avatar
ngoyal2707 committed
2826
                            use_bias=False,
2827
2828
2829
2830
2831
2832
2833
2834
2835
                            accumulate=accumulate_wgrad_into_param_main_grad,
                            out=fc2_weight.main_grad
                            if ctx.fuse_wgrad_accumulation
                            else None,
                        )

                    fc1_bias_grad, dgelu_no_fp8 = bgrad_dgelu_fused(
                        fc2_dgrad, fc1_out, fc1_bias
                    )
Przemek Tredak's avatar
Przemek Tredak committed
2836

2837
2838
2839
2840
                    dgelu = cast_to_fp8(
                        dgelu_no_fp8,
                        ctx.fp8_meta["scaling_bwd"],
                        tex.FP8BwdTensors.GRAD_OUTPUT2,
schetlur-nv's avatar
schetlur-nv committed
2841
2842
                        fp8_dtype_backward,
                    )
2843
                    dgelu_t = None
Przemek Tredak's avatar
Przemek Tredak committed
2844

2845
2846
2847
2848
2849
2850
2851
2852
2853
                fc1_dgrad_size = list(dgelu.size())
                fc1_dgrad_size[1] = fc1_weight.size(1)
                if ctx.ub_bulk_wgrad: # allocate dgrad output
                    ub_obj_dgrad = get_ub("fc1_wgrad")
                    fc1_dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output
                else:
                    fc1_dgrad = torch.empty(
                        fc1_dgrad_size, dtype=ctx.activation_dtype, device=fc1_weight.device
                    )
2854
                # FC1 DGRAD: Unconditional
2855
                _ = fp8_gemm(
2856
                    fc1_weight_t_fp8,
2857
2858
                    fwd_scale_inverses,
                    tex.FP8FwdTensors.GEMM1_WEIGHT,
2859
2860
                    fp8_dtype_forward,
                    dgelu,
2861
2862
                    ctx.fp8_meta["scaling_bwd"].scale_inv,
                    tex.FP8BwdTensors.GRAD_OUTPUT2,
Przemek Tredak's avatar
Przemek Tredak committed
2863
                    fp8_dtype_backward,
2864
2865
                    ctx.activation_dtype,
                    get_workspace(),
2866
                    out=fc1_dgrad,
2867
                    use_split_accumulator=_2X_ACC_DGRAD,
2868
2869
                    ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_AG if ctx.ub_bulk_dgrad else None,
                    ub=ub_obj_lnout if ctx.ub_bulk_dgrad else None
Przemek Tredak's avatar
Przemek Tredak committed
2870
2871
                )
            else:
2872
2873
2874
2875
2876
2877
2878
2879
2880
2881
                # FC2 DGRAD; Unconditional
                fc2_dgrad, _, _ = gemm(
                    fc2_weight,
                    grad_output,
                    ctx.activation_dtype,
                    get_workspace(),
                    layout="NN",
                    gelu=not ctx.bias_gelu_nvfusion,
                    grad=True,
                    gelu_input=fc1_out,
2882
2883
                    ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ctx.ub_split_ag else None,
                    ub=ctx.ub_obj_gradout if ctx.ub_split_ag else None,
2884
2885
2886
                )

                # FC2 WGRAD
schetlur-nv's avatar
schetlur-nv committed
2887
                if fc2_weight.requires_grad:
2888
                    fc2_wgrad, fc2_bias_grad, _ = gemm(
schetlur-nv's avatar
schetlur-nv committed
2889
2890
2891
2892
2893
2894
                        gelu_out,
                        grad_output,
                        ctx.activation_dtype,
                        get_workspace(),
                        layout="NT",
                        grad=True,
ngoyal2707's avatar
ngoyal2707 committed
2895
                        use_bias=ctx.use_fc2_bias,
schetlur-nv's avatar
schetlur-nv committed
2896
                        accumulate=accumulate_wgrad_into_param_main_grad,
2897
                        out=fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
schetlur-nv's avatar
schetlur-nv committed
2898
                    )
Przemek Tredak's avatar
Przemek Tredak committed
2899

2900
2901
2902
2903
                if ctx.bias_gelu_nvfusion:
                    fc1_bias_grad, dgelu = bgrad_dgelu_fused(fc2_dgrad, fc1_out, fc1_bias)
                else:
                    dgelu = fc2_dgrad
Przemek Tredak's avatar
Przemek Tredak committed
2904

2905
2906
2907
2908
2909
2910
2911
2912
2913
                fc1_dgrad_size = list(dgelu.size())
                fc1_dgrad_size[1] = fc1_weight.size(1)
                if ctx.ub_bulk_wgrad: # allocate dgrad output
                    ub_obj_dgrad = get_ub("fc1_wgrad")
                    fc1_dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output
                else:
                    fc1_dgrad = torch.empty(
                        fc1_dgrad_size, dtype=ctx.activation_dtype, device=fc1_weight.device
                    )
2914
                # FC1 DGRAD: Unconditional
2915
                _, _, _ = gemm(
2916
2917
                    fc1_weight,
                    dgelu,
schetlur-nv's avatar
schetlur-nv committed
2918
2919
                    ctx.activation_dtype,
                    get_workspace(),
2920
                    out=fc1_dgrad,
2921
                    layout="NN",
schetlur-nv's avatar
schetlur-nv committed
2922
                    grad=True,
2923
2924
                    ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_AG if ctx.ub_bulk_dgrad else None,
                    ub=ub_obj_lnout if ctx.ub_bulk_dgrad else None
schetlur-nv's avatar
schetlur-nv committed
2925
                )
Przemek Tredak's avatar
Przemek Tredak committed
2926

2927
2928
            if ctx.ub_bulk_dgrad:
                ln_out_total = ub_obj_lnout.get_ubuf_output(1)
2929
2930
            # Overlap dgrad-RS/AR with wgrad
            if ctx.set_parallel_mode and ctx.sequence_parallel:
2931
2932
2933
2934
2935
2936
                if not ctx.ub_bulk_dgrad:
                    handle.wait()
                if not ctx.ub_bulk_wgrad:
                    fc1_dgrad, handle = reduce_scatter_along_first_dim(
                        fc1_dgrad, ctx.tp_group, async_op=True
                    )
2937
2938
2939
2940
2941
2942
2943
2944
2945
2946
            elif ctx.set_parallel_mode and ctx.tensor_parallel:
                fc1_dgrad, handle = allreduce(fc1_dgrad, ctx.tp_group, async_op=True)

            if fc1_weight.requires_grad:
                if ctx.fp8:
                    # FC1 WGRAD
                    if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
                        ln_out_total_t = tex.fp8_transpose(ln_out_total, fp8_dtype_forward)
                        fc1_wgrad = fp8_gemm(
                            ln_out_total_t,
2947
2948
                            fwd_scale_inverses,
                            tex.FP8FwdTensors.GEMM1_INPUT,
2949
2950
                            fp8_dtype_forward,
                            dgelu_t,
2951
2952
                            ctx.fp8_meta["scaling_bwd"].scale_inv,
                            tex.FP8BwdTensors.GRAD_OUTPUT2,
2953
2954
2955
2956
2957
2958
2959
2960
                            fp8_dtype_backward,
                            ctx.activation_dtype,
                            get_workspace(),
                            accumulate=accumulate_wgrad_into_param_main_grad,
                            out=fc1_weight.main_grad
                            if ctx.fuse_wgrad_accumulation
                            else None,
                            use_split_accumulator=_2X_ACC_WGRAD,
2961
2962
2963
                            ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS
                            if ctx.ub_bulk_wgrad else None,
                            ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None,
2964
2965
2966
2967
2968
2969
2970
2971
2972
2973
2974
2975
2976
2977
2978
2979
2980
2981
2982
2983
                        )
                    else:
                        ln_out_total_c = cast_from_fp8(
                            ln_out_total,
                            ctx.fp8_meta["scaling_fwd"],
                            tex.FP8FwdTensors.GEMM1_INPUT,
                            fp8_dtype_forward,
                            TE_DType[ctx.activation_dtype],
                        )
                        fc1_wgrad, _, _ = gemm(
                            ln_out_total_c,
                            dgelu_no_fp8,
                            ctx.activation_dtype,
                            get_workspace(),
                            layout="NT",
                            grad=True,
                            accumulate=accumulate_wgrad_into_param_main_grad,
                            out=fc1_weight.main_grad
                            if ctx.fuse_wgrad_accumulation
                            else None,
2984
2985
2986
                            ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS
                            if ctx.ub_bulk_wgrad else None,
                            ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None,
2987
                        )
schetlur-nv's avatar
schetlur-nv committed
2988
                else:
2989
2990
                    # FC1 WGRAD
                    fc1_wgrad_outputs = gemm(
schetlur-nv's avatar
schetlur-nv committed
2991
                        ln_out_total,
2992
                        dgelu,
schetlur-nv's avatar
schetlur-nv committed
2993
2994
2995
2996
                        ctx.activation_dtype,
                        get_workspace(),
                        layout="NT",
                        grad=True,
2997
                        use_bias=not ctx.bias_gelu_nvfusion,
schetlur-nv's avatar
schetlur-nv committed
2998
                        accumulate=accumulate_wgrad_into_param_main_grad,
2999
                        out=fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
3000
3001
                        ub_algo=tex.UbufOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None,
                        ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None
schetlur-nv's avatar
schetlur-nv committed
3002
                    )
Przemek Tredak's avatar
Przemek Tredak committed
3003

3004
3005
3006
3007
                    if ctx.bias_gelu_nvfusion:
                        fc1_wgrad, _, _ = fc1_wgrad_outputs
                    else:
                        fc1_wgrad, fc1_bias_grad, _ = fc1_wgrad_outputs
Przemek Tredak's avatar
Przemek Tredak committed
3008

3009
            # Column Parallel Linear
3010
3011
3012
            if ctx.ub_bulk_wgrad:
                fc1_dgrad = ub_obj_dgrad.get_ubuf_output(0) # Reduce-scatter output
            elif ctx.set_parallel_mode and ctx.tensor_parallel and handle is not None:
3013
                handle.wait()
Przemek Tredak's avatar
Przemek Tredak committed
3014

3015
3016
            # LayerNorm gradient
            d_ln_out = fc1_dgrad.view(inputmat.shape)
Przemek Tredak's avatar
Przemek Tredak committed
3017

3018
3019
3020
            # Residual gradient
            if ctx.return_layernorm_output:
                d_ln_out = d_ln_out + grad_outputs[1].view_as(d_ln_out)
Przemek Tredak's avatar
Przemek Tredak committed
3021

3022
            dxmat, dgamma, dbeta = tex.layernorm_bwd(
3023
3024
                d_ln_out, inputmat, mu, rsigma, ln_weight,
                ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma
3025
            )
Przemek Tredak's avatar
Przemek Tredak committed
3026
3027

        return (
3028
            dxmat.view(ctx.inp_shape) if ctx.requires_dgrad else None,
Przemek Tredak's avatar
Przemek Tredak committed
3029
3030
            dgamma,
            dbeta,
schetlur-nv's avatar
schetlur-nv committed
3031
            fc1_wgrad if fc1_weight.requires_grad else None,
Przemek Tredak's avatar
Przemek Tredak committed
3032
3033
            None,
            None,
ngoyal2707's avatar
ngoyal2707 committed
3034
3035
            fc1_bias_grad if ctx.use_fc1_bias else None,
            None,
schetlur-nv's avatar
schetlur-nv committed
3036
            fc2_wgrad if fc2_weight.requires_grad else None,
Przemek Tredak's avatar
Przemek Tredak committed
3037
3038
            None,
            None,
ngoyal2707's avatar
ngoyal2707 committed
3039
            fc2_bias_grad if ctx.use_fc2_bias else None,
Przemek Tredak's avatar
Przemek Tredak committed
3040
3041
3042
3043
3044
3045
3046
3047
3048
3049
3050
3051
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
3052
            None,
3053
            None,
3054
3055
            None,
            None,
3056
            None,
3057
            None,
Sangkug Lym's avatar
Sangkug Lym committed
3058
            None,
3059
3060
3061
3062
            None,
            None,
            None,
            None,
Przemek Tredak's avatar
Przemek Tredak committed
3063
3064
3065
3066
        )


class LayerNormMLP(TransformerEngineBaseModule):
3067
    r"""
Przemek Tredak's avatar
Przemek Tredak committed
3068
3069
3070
3071
3072
3073
3074
3075
3076
3077
3078
3079
    Applies layer normalization on the input followed by the MLP module, consisting of
    2 successive linear transformations, separated by the GeLU activation.

    Parameters
    ----------
    hidden_size : int
                 size of each input sample.
    ffn_hidden_size : int
                     intermediate size to which input samples are projected.
    eps : float, default = 1e-5
         a value added to the denominator of layer normalization for numerical stability.
    bias : bool, default = `True`
ngoyal2707's avatar
ngoyal2707 committed
3080
          if set to `False`, the FC1 and FC2 layers will not learn an additive bias.
Przemek Tredak's avatar
Przemek Tredak committed
3081
3082
3083
3084
3085
3086
3087
3088
3089
3090
3091
3092
    init_method : Callable, default = `None`
                 used for initializing FC1 weights in the following way: `init_method(weight)`.
                 When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`.
    output_layer_init_method : Callable, default = `None`
                              used for initializing FC2 weights in the following way:
                              `output_layer_init_method(weight)`. When set to `None`, defaults to
                              `torch.nn.init.normal_(mean=0.0, std=0.023)`.
    return_layernorm_output : bool, default = `False`
                             if set to `True`, output of layernorm is returned from the forward
                             together with the output of the linear transformation.
                             Example use case: residual connection for transformer module
                             is taken post layernorm.
3093
3094
3095
3096
3097
3098
3099
    zero_centered_gamma : bool, default = 'False'
                         if set to 'True', gamma parameter in LayerNorm is initialized to 0 and
                         the LayerNorm formula changes to

                         .. math::
                            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} *
                            (1 + \gamma) + \beta
Przemek Tredak's avatar
Przemek Tredak committed
3100
3101
3102
3103
3104
3105
3106
3107
3108
3109
3110
3111
3112
3113
3114
3115
3116
3117
3118
3119
3120
3121
3122

    Parallelism parameters
    ----------------------
    set_parallel_mode : bool, default = `False`
                      if set to `True`, FC1 is used as Column Parallel and FC2 is used as Row
                      Parallel as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
    sequence_parallel : bool, default = `False`
                       if set to `True`, uses sequence parallelism.
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
    tp_size : int, default = 1
             used as TP (tensor parallel) world size when TP groups are not formed during
             initialization. In this case, users must call the
             `set_tensor_parallel_group(tp_group)` method on the initialized module before the
             forward pass to supply the tensor parallel group needed for tensor and sequence
             parallel collectives.

    Optimization parameters
    -----------------------
    fuse_wgrad_accumulation : bool, default = 'False'
                             if set to `True`, enables fusing of creation and accumulation of
                             the weight gradient.
    return_bias : bool, default = `False`
ngoyal2707's avatar
ngoyal2707 committed
3123
                 when set to `True`, this module will not apply the additive bias for FC2, but
Przemek Tredak's avatar
Przemek Tredak committed
3124
3125
3126
3127
3128
3129
3130
3131
3132
3133
3134
3135
3136
3137
3138
3139
3140
3141
3142
3143
3144
3145
3146
3147
3148
3149
3150
3151
3152
3153
3154
3155
3156
3157
3158
3159
                 instead return the bias value during the forward pass together with the
                 output of the linear transformation :math:`y = xA^T`. This is useful when
                 the bias addition can be fused to subsequent operations.
    params_dtype : torch.dtype, default = `torch.float32`
                  it controls the type used to allocate the initial parameters. Useful when
                  the model is trained with lower precision and the original FP32 parameters
                  would not fit in GPU memory.
    seq_length: int
               sequence length of input samples. Needed for JIT Warmup, a technique where jit fused
               functions are warmed up before training to ensure same kernels are used for forward
               propogation and activation recompute phase.
    micro_batch_size: int
                     batch size per training step. Needed for JIT Warmup, a technique where jit
                     fused functions are warmed up before training to ensure same kernels are
                     used for forward propogation and activation recompute phase.
    """

    def __init__(
        self,
        hidden_size: int,
        ffn_hidden_size: int,
        eps: float = 1e-5,
        sequence_parallel: bool = False,
        return_bias: bool = False,
        get_rng_state_tracker: Optional[Callable] = None,
        tp_group: Optional[dist_group_type] = None,
        tp_size: int = 1,
        init_method: Optional[Callable] = None,
        bias: bool = True,
        output_layer_init_method: Optional[Callable] = None,
        fuse_wgrad_accumulation: bool = False,
        params_dtype: torch.dtype = torch.float32,
        return_layernorm_output: bool = False,
        seq_length: Optional[int] = None,
        micro_batch_size: Optional[int] = None,
        set_parallel_mode: bool = False,
3160
        zero_centered_gamma: bool = False,
3161
3162
3163
3164
        ub_bulk_wgrad: bool = False,
        ub_bulk_dgrad: bool = False,
        ub_split_rs: bool = False,
        ub_split_ag: bool = False,
Przemek Tredak's avatar
Przemek Tredak committed
3165
3166
3167
3168
3169
3170
    ) -> None:
        super().__init__()

        self.fuse_wgrad_accumulation = fuse_wgrad_accumulation
        self.use_bias = bias
        self.return_bias = return_bias
3171
        self.apply_bias = bias and not return_bias
Przemek Tredak's avatar
Przemek Tredak committed
3172
3173
3174
        self.return_layernorm_output = return_layernorm_output
        self.bias_gelu_nvfusion = bool(int(os.getenv("NVTE_BIAS_GELU_NVFUSION", "1")))
        self.set_parallel_mode = set_parallel_mode
3175
        self.zero_centered_gamma = zero_centered_gamma
3176
3177
3178
3179
3180
3181
3182
3183
3184
        self.ub_bulk_wgrad = ub_bulk_wgrad
        self.ub_bulk_dgrad = ub_bulk_dgrad
        self.ub_split_rs = ub_split_rs
        self.ub_split_ag = ub_split_ag

        if ub_bulk_wgrad or ub_bulk_dgrad or ub_split_rs or ub_split_ag:
            assert (
                tex.userbuf_comm_available()
            ), "Userbuffer communication backend not available."
Przemek Tredak's avatar
Przemek Tredak committed
3185
3186
3187
3188
3189
3190
3191
3192
3193
3194
3195
3196
3197
3198
3199
3200
3201
3202
3203
3204
3205
3206
3207
3208
3209
3210
3211
3212
3213
3214
3215
3216
3217
3218
3219
3220
3221
3222
3223
3224
3225
3226
3227
3228
3229
3230
3231
3232
3233
3234
3235
3236
3237
3238
3239
3240
3241

        if tp_group is None:
            self.tp_size = tp_size
            if tp_size == 1:
                self.set_tensor_parallel_group(tp_group)
        else:
            self.tp_size = get_distributed_world_size(tp_group)
            self.set_tensor_parallel_group(tp_group)
        self.set_nccl_overlap_warning_if_tp()

        if init_method is None:
            init_method = get_default_init_method()
        if output_layer_init_method is None:
            output_layer_init_method = get_default_init_method()

        self.sequence_parallel = (self.tp_size > 1) and sequence_parallel
        self.size_per_partition = divide(ffn_hidden_size, self.tp_size)

        # LN init
        self.eps = eps
        self.layer_norm_weight = Parameter(
            torch.empty(
                hidden_size,
                device=torch.cuda.current_device(),
                dtype=params_dtype,
            )
        )
        self.layer_norm_bias = Parameter(
            torch.empty(
                hidden_size,
                device=torch.cuda.current_device(),
                dtype=params_dtype,
            )
        )
        setattr(self.layer_norm_weight, "sequence_parallel", self.sequence_parallel)
        setattr(self.layer_norm_bias, "sequence_parallel", self.sequence_parallel)
        self.reset_layer_norm_parameters()

        # FC1 init
        self.fc1_weight = Parameter(
            torch.empty(
                self.size_per_partition,
                hidden_size,
                device=torch.cuda.current_device(),
                dtype=params_dtype,
            )
        )
        self.fp8_weight_shapes.append(self.fc1_weight.shape)

        initialize_affine_weight_gpu(
            self.fc1_weight,
            init_method,
            get_rng_state_tracker,
            partition_dim=0,
            stride=1,
        )

ngoyal2707's avatar
ngoyal2707 committed
3242
3243
3244
3245
3246
3247
3248
        if self.use_bias:
            self.fc1_bias = Parameter(
                torch.empty(
                    self.size_per_partition,
                    device=torch.cuda.current_device(),
                    dtype=params_dtype,
                )
Przemek Tredak's avatar
Przemek Tredak committed
3249
            )
ngoyal2707's avatar
ngoyal2707 committed
3250
3251
3252
            set_tensor_model_parallel_attributes(self.fc1_bias, True, 0, 1)
        else:
            self.register_buffer("fc1_bias", torch.Tensor().type(params_dtype), persistent=False)
Przemek Tredak's avatar
Przemek Tredak committed
3253
3254
3255
3256
3257
3258
3259
3260
3261
3262
3263
3264
3265
3266
3267
3268
3269
3270
3271
3272
3273
3274
3275

        with torch.no_grad():
            self.fc1_bias.zero_()

        # FC2 init
        self.fc2_weight = Parameter(
            torch.empty(
                hidden_size,
                self.size_per_partition,
                device=torch.cuda.current_device(),
                dtype=params_dtype,
            )
        )
        self.fp8_weight_shapes.append(self.fc2_weight.shape)

        initialize_affine_weight_gpu(
            self.fc2_weight,
            output_layer_init_method,
            get_rng_state_tracker,
            partition_dim=1,
            stride=1,
        )

3276
        if self.use_bias:
Przemek Tredak's avatar
Przemek Tredak committed
3277
3278
3279
3280
3281
3282
            self.fc2_bias = Parameter(
                torch.empty(
                    hidden_size, device=torch.cuda.current_device(), dtype=params_dtype
                )
            )
        else:
3283
            self.register_buffer("fc2_bias", torch.Tensor().type(params_dtype), persistent=False)
Przemek Tredak's avatar
Przemek Tredak committed
3284
3285
3286

        # For RPL, bias has to be added after TP collectives
        # So it cannot be fused with the GEMM
3287
        if self.set_parallel_mode and self.apply_bias:
Przemek Tredak's avatar
Przemek Tredak committed
3288
3289
3290
3291
3292
3293
3294
3295
3296
3297
3298
3299
3300
3301
            self.gemm_bias_unfused_add = True
        else:
            self.gemm_bias_unfused_add = False

        with torch.no_grad():
            self.fc2_bias.zero_()

        if self.bias_gelu_nvfusion:
            set_jit_fusion_options()
            if seq_length and micro_batch_size:
                warmup_jit_bias_gelu_all_dtypes(
                    self.size_per_partition, seq_length, micro_batch_size
                )

3302
3303
3304
3305
3306
3307
3308
        # These many SMs are subtracted from the total SM count when calling forward
        # and backward LayerNorm C APIs. These envvars can be used to prevent the LN
        # kernels from using all SMs in the device. This is useful for cases such as
        # communication overlap with LN.
        self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))
        self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))

Przemek Tredak's avatar
Przemek Tredak committed
3309
3310
    def reset_layer_norm_parameters(self) -> None:
        """Init LN params"""
3311
3312
3313
3314
        if not self.zero_centered_gamma:
            init.ones_(self.layer_norm_weight)
        else:
            init.zeros_(self.layer_norm_weight)
Przemek Tredak's avatar
Przemek Tredak committed
3315
3316
3317
3318
3319
3320
3321
3322
3323
3324
3325
3326
3327
3328
3329
3330
3331
3332
3333
3334
3335
3336
3337
3338
3339
3340
3341
        init.zeros_(self.layer_norm_bias)

    def forward(
        self, inp: torch.Tensor, is_first_microbatch: Optional[bool] = None
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
        """
        Apply layer normalization to the input followed by a feedforward network (MLP Block).

        Parameters
        ----------
        inp : torch.Tensor
             Input tensor.
        is_first_microbatch : {True, False, None}, default = None
                             During training using either gradient accumulation or
                             pipeline parallelism a minibatch of data is further split
                             into microbatches. Between the microbatches of the same minibatch
                             the model weights are not updated. Setting this parameter indicates
                             whether the current microbatch is the first in a minibatch or not.
                             When set, this parameter enables additional optimizations:

                             * during FP8 training, it allows caching of the FP8 versions of
                               the weights
                             * it also allows skipping gradient accumulation during the
                               first microbatch (since it is the first gradient being
                               produced)
        """

3342
        with self.prepare_forward(inp, is_first_microbatch, num_gemms=2) as inp:
3343
            if torch.is_grad_enabled():
3344
3345
3346
3347
3348
3349
                fwd_fn = _LayerNormMLP.apply
                args = []
            else:
                fwd_fn = _LayerNormMLP.forward
                args = [None]
            args += (
3350
3351
3352
3353
3354
3355
3356
                inp,
                self.layer_norm_weight,
                self.layer_norm_bias,
                self.fc1_weight,
                self.weight1_fp8 if self.fp8 else None,
                self.weight1_t_fp8 if self.fp8 else None,
                self.fc1_bias,
ngoyal2707's avatar
ngoyal2707 committed
3357
                self.use_bias,
3358
3359
3360
3361
                self.fc2_weight,
                self.weight2_fp8 if self.fp8 else None,
                self.weight2_t_fp8 if self.fp8 else None,
                self.fc2_bias,
3362
                self.apply_bias and not self.gemm_bias_unfused_add,
3363
3364
3365
                self.eps,
                is_first_microbatch,
                self.fp8,
schetlur-nv's avatar
schetlur-nv committed
3366
                self.fp8_calibration,
3367
3368
3369
                self.fp8_meta,
                self.fuse_wgrad_accumulation,
                self.tp_group,
Sangkug Lym's avatar
Sangkug Lym committed
3370
                self.tp_size,
3371
3372
3373
3374
3375
3376
                self.sequence_parallel,
                self.tp_size > 1,
                self.activation_dtype,
                self.return_layernorm_output,
                self.bias_gelu_nvfusion,
                self.set_parallel_mode,
3377
                torch.is_grad_enabled(),
3378
3379
                self.fwd_ln_sm_margin,
                self.bwd_ln_sm_margin,
3380
                self.zero_centered_gamma,
3381
3382
3383
3384
                self.ub_bulk_wgrad,
                self.ub_bulk_dgrad,
                self.ub_split_rs,
                self.ub_split_ag,
3385
            )
3386
            out = fwd_fn(*args)
Przemek Tredak's avatar
Przemek Tredak committed
3387
3388
3389
3390
3391
3392
3393
3394
3395
3396
3397
3398
3399
3400
3401
3402
3403
3404
3405
3406
3407
3408
3409
3410
3411
3412

        if self.return_layernorm_output:
            out, ln_out = out

        if self.gemm_bias_unfused_add:
            out = out + cast_if_needed(self.fc2_bias, self.activation_dtype)

        if self.return_bias:
            if self.return_layernorm_output:
                return out, cast_if_needed(self.fc2_bias, self.activation_dtype), ln_out
            return out, cast_if_needed(self.fc2_bias, self.activation_dtype)
        if self.return_layernorm_output:
            return out, ln_out
        return out


class _LayerNorm(torch.autograd.Function):
    """functional LayerNorm"""

    @staticmethod
    def forward(
        ctx,
        inp: torch.Tensor,
        ln_weight: torch.Tensor,
        ln_bias: torch.Tensor,
        eps: float,
3413
3414
        fwd_ln_sm_margin: int,
        bwd_ln_sm_margin: int,
3415
        zero_centered_gamma: bool,
Przemek Tredak's avatar
Przemek Tredak committed
3416
3417
3418
3419
3420
3421
3422
    ) -> torch.Tensor:
        # Make sure input dimensions are compatible
        in_features = ln_weight.numel()
        assert inp.is_cuda, "TransformerEngine needs CUDA."
        assert inp.shape[-1] == in_features, "LayerNorm not possible"
        inputmat = inp.view((-1, in_features))

3423
3424
3425
        ln_out, mu, rsigma = tex.layernorm_fwd(inputmat, ln_weight,
                                               ln_bias, eps, fwd_ln_sm_margin,
                                               zero_centered_gamma)
Przemek Tredak's avatar
Przemek Tredak committed
3426
3427
        ctx.save_for_backward(inputmat, ln_weight, mu, rsigma)
        ctx.inp_shape = inp.shape
3428
        ctx.bwd_ln_sm_margin = bwd_ln_sm_margin
3429
        ctx.zero_centered_gamma = zero_centered_gamma
Przemek Tredak's avatar
Przemek Tredak committed
3430
3431
3432
3433
3434
3435
3436
3437
3438
3439
        return ln_out.view_as(inp)

    @staticmethod
    def backward(
        ctx, grad_output: torch.Tensor
    ) -> Tuple[Union[torch.Tensor, None], ...]:
        inputmat, ln_weight, mu, rsigma = ctx.saved_tensors
        grad_output = grad_output.contiguous()
        d_ln_out = grad_output.view(inputmat.shape)
        dxmat, dgamma, dbeta = tex.layernorm_bwd(
3440
3441
            d_ln_out, inputmat, mu, rsigma, ln_weight,
            ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma
Przemek Tredak's avatar
Przemek Tredak committed
3442
        )
3443
        return dxmat.view(ctx.inp_shape), dgamma, dbeta, None, None, None, None
Przemek Tredak's avatar
Przemek Tredak committed
3444
3445
3446
3447
3448
3449
3450
3451


class LayerNorm(torch.nn.Module):
    r"""
    Applies Layer Normalization over a mini-batch of inputs as described in
    the paper `Layer Normalization <https://arxiv.org/abs/1607.06450>`__

    .. math::
3452
        y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * \gamma + \beta
Przemek Tredak's avatar
Przemek Tredak committed
3453
3454
3455
3456
3457
3458
3459
3460
3461
3462
3463
3464
3465
3466
3467
3468

    :math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
    size :attr:`hidden_size`

    Parameters
    ----------
    hidden_size : int
                size of each input sample.
    eps : float, default = 1e-5
        a value added to the denominator of layer normalization for numerical stability.
    sequence_parallel : bool, default = `False`
                        if set to `True`, uses sequence parallelism.
    params_dtype : torch.dtype, default = `torch.float32`
                    it controls the type used to allocate the initial parameters. Useful when
                    the model is trained with lower precision and the original FP32 parameters
                    would not fit in GPU memory.
3469
3470
3471
3472
3473
3474
3475
    zero_centered_gamma : bool, default = 'False'
                         if set to 'True', gamma parameter in LayerNorm is initialized to 0 and
                         the LayerNorm formula changes to

                         .. math::
                            y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} *
                            (1 + \gamma) + \beta
Przemek Tredak's avatar
Przemek Tredak committed
3476
3477
3478
3479
3480
3481
3482
3483
    """

    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-5,
        sequence_parallel: bool = False,
        params_dtype: torch.dtype = torch.float32,
3484
        zero_centered_gamma: bool = False,
Przemek Tredak's avatar
Przemek Tredak committed
3485
3486
3487
    ) -> None:
        super().__init__()
        self.eps = eps
3488
        self.zero_centered_gamma = zero_centered_gamma
3489
        self.weight = Parameter(
Przemek Tredak's avatar
Przemek Tredak committed
3490
3491
3492
3493
3494
3495
            torch.empty(
                hidden_size,
                device=torch.cuda.current_device(),
                dtype=params_dtype,
            )
        )
3496
        self.bias = Parameter(
Przemek Tredak's avatar
Przemek Tredak committed
3497
3498
3499
3500
3501
3502
            torch.empty(
                hidden_size,
                device=torch.cuda.current_device(),
                dtype=params_dtype,
            )
        )
3503
3504
        setattr(self.weight, "sequence_parallel", sequence_parallel)
        setattr(self.bias, "sequence_parallel", sequence_parallel)
Przemek Tredak's avatar
Przemek Tredak committed
3505
3506
        self.reset_layer_norm_parameters()

3507
3508
3509
3510
3511
3512
3513
        # These many SMs are subtracted from the total SM count when calling forward
        # and backward LayerNorm C APIs. These envvars can be used to prevent the LN
        # kernels from using all SMs in the device. This is useful for cases such as
        # communication overlap with LN.
        self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0"))
        self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0"))

3514
3515
3516
3517
3518
3519
3520
3521
3522
3523
3524
3525
3526
3527
3528
3529
3530
    def load_state_dict(
        self,
        state_dict: Mapping[str, Any],
        strict: bool = True,
    ) -> None:
        """Override PyTorch loader to maintain backward compatibility
        with previous version of LayerNorm parameter names.
        """
        if "layer_norm_weight" in state_dict:
            state_dict["weight"] = state_dict["layer_norm_weight"]
            del state_dict["layer_norm_weight"]
        if "layer_norm_bias" in state_dict:
            state_dict["bias"] = state_dict["layer_norm_bias"]
            del state_dict["layer_norm_bias"]

        super().load_state_dict(state_dict, strict)

Przemek Tredak's avatar
Przemek Tredak committed
3531
3532
    def reset_layer_norm_parameters(self) -> None:
        """Init LN params"""
3533
3534
3535
3536
        if not self.zero_centered_gamma:
            init.ones_(self.weight)
        else:
            init.zeros_(self.weight)
3537
        init.zeros_(self.bias)
Przemek Tredak's avatar
Przemek Tredak committed
3538

schetlur-nv's avatar
schetlur-nv committed
3539

Przemek Tredak's avatar
Przemek Tredak committed
3540
3541
    def forward(self, inp: torch.Tensor) -> torch.Tensor:
        """LayerNorm FWD"""
3542
3543
3544
3545
3546
3547
        # Maintain backward compatibility.
        if hasattr(self, "layer_norm_weight"):
            setattr(self, "weight", self.layer_norm_weight)
        if hasattr(self, "layer_norm_bias"):
            setattr(self, "bias", self.layer_norm_bias)

3548
3549
3550
3551
3552
3553
3554
        return _LayerNorm.apply(
            inp,
            self.weight,
            self.bias,
            self.eps,
            self.fwd_ln_sm_margin,
            self.bwd_ln_sm_margin,
3555
            self.zero_centered_gamma
3556
        )