linear.py 50.7 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.

"""Linear API"""
6
7
from functools import reduce
from operator import mul as multiply_op
8
from typing import Any, Callable, Dict, Optional, Tuple, Union
9
10
11

import torch

12
import transformer_engine_torch as tex
13
14
15
16
17
18
19
20
21

from .base import (
    get_workspace,
    get_ub,
    TransformerEngineBaseModule,
    _2X_ACC_FPROP,
    _2X_ACC_DGRAD,
    _2X_ACC_WGRAD,
)
22
from ._common import _noop_cat
23
from ..fp8 import get_fp8_te_dtype, FP8GlobalStateManager
24
25
26
from ..utils import (
    divide,
    cast_if_needed,
27
    assert_dim_for_fp8_exec,
28
    clear_tensor_data,
29
    init_method_constant,
30
    requires_grad,
31
32
33
34
35
36
37
)
from ..distributed import (
    set_tensor_model_parallel_attributes,
    get_distributed_world_size,
    allreduce,
    reduce_scatter_along_first_dim,
    gather_along_first_dim,
38
    in_fp8_activation_recompute_phase,
39
40
    _fsdp_scatter_tensors,
    _fsdp_gather_tensors,
41
42
43
44
45
46
47
)
from ..cpp_extensions import (
    fp8_gemm,
    gemm,
    fp8_cast_transpose_fused,
    cast_to_fp8,
)
48
from ..constants import GemmParallelModes, dist_group_type, TE_DType
49
from ..jit import no_torch_dynamo
50
from ..graph import is_graph_capturing
51
from ..float8_tensor import Float8Tensor
52
from ..export import is_in_onnx_export_mode
53
from ..tensor import QuantizedTensor
54
from ..cpu_offload import is_cpu_offload_enabled
55

56
57
58
59
60
61
62
63
64
65
66
__all__ = ["Linear"]


class _Linear(torch.autograd.Function):
    """Linear semi-top level module
    Calls custom cuda extensions.
    """

    @staticmethod
    def forward(
        ctx,
67
        weight: Union[Float8Tensor, torch.Tensor],
68
        weight_fp8: Optional[Float8Tensor],
69
70
71
72
73
74
75
76
        inp: torch.Tensor,
        bias: torch.Tensor,
        use_bias: bool,
        is_first_microbatch: Union[bool, None],
        fp8: bool,
        fp8_calibration: bool,
        fp8_meta: Dict[str, Any],
        fuse_wgrad_accumulation: bool,
77
        cpu_offloading: bool,
78
79
80
81
82
83
84
        tp_group: Union[dist_group_type, None],
        tp_size: int,
        sequence_parallel: bool,
        tensor_parallel: bool,
        activation_dtype: torch.dtype,
        parallel_mode: Union[str, None],
        is_grad_enabled: bool,
85
86
87
88
89
90
        ub_overlap_rs_fprop: bool,
        ub_overlap_ag_dgrad: bool,
        ub_overlap_ag_fprop: bool,
        ub_overlap_rs_dgrad: bool,
        ub_bulk_dgrad: bool,
        ub_bulk_wgrad: bool,
91
        ub_name: str,
92
        fp8_output: bool,
93
        fsdp_group: Union[dist_group_type, None],
94
    ) -> torch.Tensor:
95
        # pylint: disable=missing-function-docstring
96
97
        is_input_fp8 = isinstance(inp, Float8Tensor)

98
        # Make sure input dimensions are compatible
99
100
101
        out_features, in_features = weight.shape
        inp_shape = inp.shape
        assert inp_shape[-1] == in_features, "GEMM not possible"
102
        inputmat = inp.view(-1, in_features)
103
        if fp8:
104
105
            assert_dim_for_fp8_exec(inputmat)
            assert_dim_for_fp8_exec(weight)
106

107
        tp_world_size = get_distributed_world_size(tp_group)
108
109
        ub_overlap_ag_fprop = False if tp_world_size == 1 else ub_overlap_ag_fprop
        ub_overlap_rs_fprop = False if tp_world_size == 1 else ub_overlap_rs_fprop
110
111

        # Cast input to expected dtype
112
        inputmat = cast_if_needed(inputmat, activation_dtype)
113
        inputmat_t = None
114
        inputmat_no_fp8 = inputmat
115
        inputmat_scale_inv = None
116

117
118
        if fp8:
            fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
119
            if isinstance(inputmat, Float8Tensor):
120
                inputmat_scale_inv = inputmat._scale_inv
121
            else:
122
                inputmat_scale_inv = torch.empty([1], dtype=torch.float32, device=inputmat.device)
123
124
125
126
127
128
129
130
131
132
133
134
                if (
                    not fp8_meta["recipe"].override_linear_precision.wgrad
                    and is_grad_enabled
                    and weight.requires_grad
                    and not sequence_parallel
                ):
                    # FP8 input for forward, FP8 input transpose for backward wgrad
                    inputmat, inputmat_t = fp8_cast_transpose_fused(
                        inputmat,
                        fp8_meta["scaling_fwd"],
                        tex.FP8FwdTensors.GEMM1_INPUT,
                        fp8_dtype_forward,
135
                        scale_inv=inputmat_scale_inv,
136
137
138
139
140
141
142
143
                    )
                else:
                    # FP8 input for forward
                    inputmat = cast_to_fp8(
                        inputmat,
                        fp8_meta["scaling_fwd"],
                        tex.FP8FwdTensors.GEMM1_INPUT,
                        fp8_dtype_forward,
144
                        scale_inv=inputmat_scale_inv,
145
                    )
146

147
148
149
150
151
152
153
154
155
156
157
158
            # Hack for ONNX export
            # Note: ONNX models are represented as a graph of tensor
            # operations, so the in-place scale-inv update doesn't fit
            # very well. We work around this by making it look like
            # the scale-inv tensor is initialized with a copy.
            # Note: ONNX export expects FP8 scales can be represented
            # with constant ops. However, copying into a buffer
            # involves an expand op for array broadcasting. We work
            # around this by filling the buffer instead.
            if is_in_onnx_export_mode():
                inputmat_scale_inv.fill_(inputmat_scale_inv.item())

159
        # Column Parallel Linear
160
        if parallel_mode == "column" and sequence_parallel and not ub_overlap_ag_fprop:
161
162
163
            inputmat_total, _ = gather_along_first_dim(inputmat, tp_group)
        else:
            inputmat_total = inputmat
164

165
        if fp8:
166
            bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype
167
168
            bias = cast_if_needed(bias, bias_dtype) if use_bias else bias

169
170
            # Use FP8 weights
            if weight_fp8 is None:
171
                weight_fp8 = weight
172

173
            assert isinstance(weight_fp8, Float8Tensor)
174

175
            if fp8_output:
176
                out_index, meta_tensor, out_tedtype, out_pttype = (
177
178
179
                    tex.FP8FwdTensors.GEMM1_OUTPUT,
                    fp8_meta["scaling_fwd"],
                    fp8_dtype_forward,
180
181
                    torch.uint8,
                )
182
            else:
183
                out_index, meta_tensor, out_tedtype, out_pttype = (
184
185
186
187
188
                    None,
                    None,
                    None,
                    activation_dtype,
                )
189

190
            ub_obj = None
191
192
            ub_algo = None
            rs_out = None
193
194
195
196
197
198
            inputmat_data = (
                inputmat_total._data if isinstance(inputmat_total, Float8Tensor) else inputmat_total
            )
            if ub_overlap_rs_fprop:
                ub_obj = get_ub(ub_name + "_fprop")
                out = ub_obj.get_ubuf_output(1)
199
200
                dim_size = list(inputmat_total.size())
                dim_size[0] = dim_size[0] // tp_world_size
201
                dim_size[1] = out_features
202
                rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device)
203
204
                if ub_obj.is_p2p_overlap():
                    if ub_obj.is_atomic_gemm():
205
                        ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_RS_P2P
206
                    else:
207
                        ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P
208
                else:
209
                    if ub_obj.is_atomic_gemm():
210
                        ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_RS
211
                    else:
212
                        ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS
213
214
                if ub_obj.is_fp8_ubuf():
                    out_index = tex.FP8FwdTensors.GEMM1_OUTPUT
215
                    meta_tensor = fp8_meta["scaling_fwd"]
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
                    out_tedtype = fp8_dtype_forward
                    out_pttype = torch.uint8
                    ub_obj.set_ubuf_scale_inv(meta_tensor.scale_inv[out_index])

            elif ub_overlap_ag_fprop:
                ub_obj = get_ub(ub_name + "_fprop")
                assert ub_obj.is_fp8_ubuf(), "AG overlap with FP8 GEMM requires FP8 buffer."
                ub_obj.copy_input_to_ubuf(inputmat_data, True)
                ub_obj.set_ubuf_scale_inv(inputmat_scale_inv)
                if ub_obj.is_atomic_gemm():
                    ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_AG_P2P
                else:
                    ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P
                out_tedtype = TE_DType[activation_dtype]
                out_pttype = activation_dtype
                dim_size = list(inputmat_total.size())
                dim_size[0] *= tp_size
                dim_size[1] = out_features
                out = torch.empty(dim_size, dtype=out_pttype, device=inputmat_total.device)

236
237
            else:
                dim_size = list(inputmat_total.size())
238
                dim_size[1] = out_features
239
                out = torch.empty(dim_size, dtype=out_pttype, device=inputmat_total.device)
240
241

            _ = fp8_gemm(
242
                weight_fp8._data,
243
244
245
                weight_fp8._scale_inv,
                0,
                weight_fp8._fp8_dtype,
246
                inputmat_data,
247
248
                inputmat_scale_inv,
                0,
249
                fp8_dtype_forward,
250
                out_pttype,
251
252
253
254
255
                get_workspace(),
                bias=bias,
                use_bias=use_bias,
                use_split_accumulator=_2X_ACC_FPROP,
                out=out,
256
257
258
259
                ub_algo=ub_algo,
                ub=ub_obj,
                extra_output_tensor=rs_out,
                out_index=out_index,
260
                fp8_meta_tensor=meta_tensor,
261
                D_dtype=out_tedtype,
262
            )
263
            if fp8_output:
264
265
                out = Float8Tensor(
                    data=out,
266
267
268
269
270
271
                    fp8_meta=fp8_meta,
                    fp8_meta_forward=True,
                    fp8_meta_index=tex.FP8FwdTensors.GEMM1_OUTPUT,
                    fp8_dtype=fp8_dtype_forward,
                    dtype=activation_dtype,
                )
272
273
274
275
276
277
278
        else:
            # Cast for native AMP
            weight = cast_if_needed(weight, activation_dtype)
            bias = cast_if_needed(bias, activation_dtype) if use_bias else bias

            if fp8_calibration:
                # amax of input
279
                amin, amax = inputmat_total.aminmax()
280
281
282
                fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_INPUT] = torch.max(
                    -amin, amax
                ).float()
283
                # amax of weight
284
                amin, amax = weight.aminmax()
285
286
287
                fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_WEIGHT] = torch.max(
                    -amin, amax
                ).float()
288

289
290
291
292
293
294
            ub_obj = None
            ub_algo = None
            rs_out = None
            if ub_overlap_rs_fprop:
                ub_obj = get_ub(ub_name + "_fprop")
                out = ub_obj.get_ubuf_output(1)
295
                dim_size = list(inputmat_total.size())
296
                dim_size[0] = dim_size[0] // tp_world_size
297
                dim_size[1] = out_features
298
                rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device)
299
                if ub_obj.is_p2p_overlap():
300
                    ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P
301
                else:
302
                    ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS
303
304
305
306
307
308
309
310
311
312

            elif ub_overlap_ag_fprop:
                ub_obj = get_ub(ub_name + "_fprop")
                ub_obj.copy_input_to_ubuf(inputmat_total, True)
                dim_size = list(inputmat_total.size())
                dim_size[0] *= tp_size  # all-gathered sequence length
                dim_size[1] = out_features
                out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device)
                ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P

313
314
            else:
                dim_size = list(inputmat_total.size())
315
                dim_size[1] = out_features
316
317
                out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device)

318
            _ = gemm(
319
320
321
322
323
324
325
                weight,
                inputmat_total,
                activation_dtype,
                get_workspace(),
                bias=bias,
                use_bias=use_bias,
                out=out,
326
327
328
                ub_algo=ub_algo,
                ub=ub_obj,
                extra_output_tensor=rs_out,
329
330
331
            )

        if is_grad_enabled:
332
333
334
335
336
337
338
339
            saved_inputmat = None
            saved_inputmat_t = None
            if weight.requires_grad:
                if fp8 and not fp8_meta["recipe"].override_linear_precision.wgrad:
                    if inputmat_t is None:
                        saved_inputmat = inputmat
                    else:
                        saved_inputmat_t = inputmat_t
340
341
                        if cpu_offloading:
                            saved_inputmat_t.activation_offloading = True
342
343
                else:
                    saved_inputmat = inputmat_no_fp8
344
345

                if cpu_offloading:
346
347
                    if fp8 and weight_fp8 is not None:
                        weight_fp8.weight_offloading = True
348
349
350
351
352
                    weight.weight_offloading = True

                    if saved_inputmat is not None:
                        saved_inputmat.activation_offloading = True

353
354
355
356
357
            # Scatter intermediate/activation tensors saved for the backward pass
            # NOTE: FSDP sharding is not valid for models initialized with primary Fp8 weights
            ctx.fsdp_group = fsdp_group
            ctx.fsdp_shapes = _fsdp_scatter_tensors(
                fsdp_group,
358
359
                saved_inputmat,  # None if fp8 == False
                saved_inputmat_t,  # None if fp8 == False AND not is_grad_enabled
360
361
362
                weight_fp8 if fp8 and not isinstance(weight, Float8Tensor) else None,
            )

363
            ctx.save_for_backward(
364
365
                saved_inputmat,
                saved_inputmat_t,
366
                inputmat_scale_inv,
367
                weight,
368
                weight_fp8,
369
                weight.main_grad if cpu_offloading and fuse_wgrad_accumulation else None,
370
            )
371

372
373
374
375
            ctx.activation_dtype = activation_dtype
            ctx.fp8 = fp8
            ctx.fp8_meta = fp8_meta
            ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
376
            ctx.cpu_offloading = cpu_offloading
377
378
379
380
            ctx.is_first_microbatch = is_first_microbatch
            ctx.use_bias = use_bias
            ctx.sequence_parallel = sequence_parallel
            ctx.tensor_parallel = tensor_parallel
381
            ctx.inp_shape = inp_shape
382
383
            ctx.parallel_mode = parallel_mode
            ctx.tp_group = tp_group
384
385
386
387
            ctx.ub_overlap_ag = ub_overlap_ag_dgrad
            ctx.ub_overlap_rs_dgrad = ub_overlap_rs_dgrad
            ctx.ub_bulk_dgrad = ub_bulk_dgrad
            ctx.ub_bulk_wgrad = ub_bulk_wgrad
388
            ctx.ub_name = ub_name
389
390
            ctx.tp_size = tp_size
            ctx.requires_dgrad = inp.requires_grad
391
            ctx.is_input_fp8 = is_input_fp8
392
393
            ctx.reduce_and_update_bwd_fp8_tensors = False
            if ctx.fp8 and requires_grad(inp, weight, bias):
394
395
396
397
                _first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE
                ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module()
                if in_fp8_activation_recompute_phase():
                    FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module
398
399

        # Row Parallel Linear
400
401
402
403
404
405
406
        if parallel_mode == "row":
            if ub_overlap_rs_fprop:
                out = rs_out
            elif sequence_parallel:
                out, _ = reduce_scatter_along_first_dim(out, tp_group)
            elif tensor_parallel:
                out, _ = allreduce(out, tp_group)
407
408

        # [*, in_features] -> [*, out_features] except first dimension changes for SP
409
        return out.view(-1, *inp_shape[1:-1], out_features)
410
411

    @staticmethod
412
    def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
413
        # pylint: disable=missing-function-docstring
414
        if isinstance(grad_output, Float8Tensor):
415
            ctx.fp8_meta["scaling_bwd"].scale_inv[
416
417
                tex.FP8BwdTensors.GRAD_OUTPUT1
            ] = grad_output._scale_inv
418

419
        with torch.cuda.nvtx.range("_Linear_backward"):
420
421
422
            (
                inputmat,
                inputmat_t,
423
                inputmat_scale_inv,
424
                weight,
425
                weight_fp8,
426
                main_grad,
427
            ) = ctx.saved_tensors
428

429
430
431
432
433
434
435
436
            # Gather intermediate/activation tensors if needed
            # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already
            #       shards/unshards the base weights so we don't do it ourselves
            _fsdp_gather_tensors(
                ctx.fsdp_group,
                ctx.fsdp_shapes,
                inputmat,
                inputmat_t,
437
438
                weight_fp8 if ctx.fp8 and not isinstance(weight, Float8Tensor) else None,
            )
439

440
            if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation:
441
                weight = torch.nn.Parameter(weight, weight.requires_grad)
442
443
                weight.main_grad = main_grad

444
445
            tp_world_size = get_distributed_world_size(ctx.tp_group)
            ctx.ub_overlap_ag = False if tp_world_size == 1 else ctx.ub_overlap_ag
446
447
448
449
450
451
452
453
454
455
456
            ctx.ub_overlap_rs_dgrad = False if tp_world_size == 1 else ctx.ub_overlap_rs_dgrad
            ctx.ub_bulk_dgrad = False if tp_world_size == 1 else ctx.ub_bulk_dgrad
            ctx.ub_bulk_wgrad = False if tp_world_size == 1 else ctx.ub_bulk_wgrad

            ctx.ub_obj_gradout = None
            ub_obj_wgrad = None
            ub_algo_wgrad = None
            ub_algo_dgrad = None
            rs_out = None
            dgrad = None
            dgrad_shape = [reduce(multiply_op, ctx.inp_shape[:-1]), ctx.inp_shape[-1]]
457
            if ctx.ub_overlap_ag:
458
                # Overlap grad_output all-gather with dgrad compute
459
                ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad")
460
                if ctx.ub_obj_gradout.is_atomic_gemm():
461
                    ub_algo_dgrad = tex.CommOverlapAlgo.ATOMIC_GEMM_AG_P2P
462
                else:
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
                    ub_algo_dgrad = tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P
                dgrad = torch.empty(
                    dgrad_shape, dtype=ctx.activation_dtype, device=grad_output.device
                )

            elif ctx.ub_overlap_rs_dgrad:
                # Overlap dgrad reduce-scatter with dgrad compute
                ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad")
                dgrad = ctx.ub_obj_gradout.get_ubuf_output(1)
                if ctx.ub_obj_gradout.is_p2p_overlap():
                    if ctx.ub_obj_gradout.is_atomic_gemm():
                        ub_algo_dgrad = tex.CommOverlapAlgo.ATOMIC_GEMM_RS_P2P
                    else:
                        ub_algo_dgrad = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P
                else:
                    if ctx.ub_obj_gradout.is_atomic_gemm():
                        ub_algo_dgrad = tex.CommOverlapAlgo.ATOMIC_GEMM_RS
                    else:
                        ub_algo_dgrad = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS
                rs_out = torch.empty(
                    dgrad_shape, dtype=ctx.activation_dtype, device=grad_output.device
                )
                ctx.ub_bulk_dgrad = False
                ctx.ub_bulk_wgrad = False

            else:
                if ctx.ub_bulk_dgrad:
                    # Overlap inputmat all-gather with dgrad compute
                    ub_algo_dgrad = tex.CommOverlapAlgo.BULK_OVERLAP_AG
                    ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad")
                    inputmat_data = (
                        inputmat._data if isinstance(inputmat, Float8Tensor) else inputmat
                    )
                    ctx.ub_obj_gradout.copy_input_to_ubuf(inputmat_data, True)
                    inputmat_ubuf = ctx.ub_obj_gradout.get_ubuf_output(1)
                    if isinstance(inputmat, Float8Tensor):
                        inputmat._data = inputmat_ubuf
                    else:
                        inputmat = inputmat_ubuf

                if ctx.ub_bulk_wgrad:
                    # Overlap dgrad reduce-scatter with wgrad compute
                    ub_algo_wgrad = tex.CommOverlapAlgo.BULK_OVERLAP_RS
                    ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad")
                    dgrad = ub_obj_wgrad.get_ubuf_output(1)

509
510
511
512
513
514
515
516
517
            (
                grad_output,
                grad_output_c,
                grad_output_t,
                grad_bias,
            ) = TransformerEngineBaseModule.grad_output_preprocess(
                ctx, grad_output, ctx.parallel_mode == "row"
            )

518
            # Overlap inputmat AG with dgrad via NCCL async comms (no TP overlap via Userbuffers)
519
520
            inputmat_total = None
            inputmat_t_total = None
521
522
523
524
525
526
527
528
            inputmat_gather_handle = None
            if (
                weight.requires_grad
                and ctx.parallel_mode == "column"
                and ctx.sequence_parallel
                and not ctx.ub_bulk_dgrad
            ):
                inputmat_total, inputmat_gather_handle = gather_along_first_dim(
529
530
                    inputmat, ctx.tp_group, async_op=ctx.requires_dgrad
                )
531
532
            else:
                inputmat_total = inputmat
533
                inputmat_t_total = inputmat_t
534
535
536
537
538
539
540
541
542

            if ctx.is_first_microbatch is not None:
                accumulate_wgrad_into_param_main_grad = (
                    ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch
                )
            else:
                accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation

            if ctx.fp8:
543
544
                fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
                fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False)
545

546
            output_dtype = ctx.activation_dtype
547
548
            if ctx.requires_dgrad:
                if ctx.fp8:
549
550
551
                    if ctx.is_input_fp8 or (
                        ctx.ub_overlap_rs_dgrad and ctx.ub_obj_gradout.is_fp8_ubuf()
                    ):
552
553
554
555
                        out_index, meta_tensor, output_te_dtype, output_dtype = (
                            tex.FP8BwdTensors.GRAD_INPUT1,
                            ctx.fp8_meta["scaling_bwd"],
                            fp8_dtype_backward,
556
557
                            torch.uint8,
                        )
558
559
                        if ctx.ub_overlap_rs_dgrad and ctx.ub_obj_gradout.is_fp8_ubuf():
                            ctx.ub_obj_gradout.set_ubuf_scale_inv(meta_tensor.scale_inv[out_index])
560
561
                    else:
                        out_index, meta_tensor, output_te_dtype, output_dtype = (
562
563
564
565
566
                            None,
                            None,
                            None,
                            ctx.activation_dtype,
                        )
567
568
569
570
571
572
573
574

                if dgrad is None:
                    if ctx.parallel_mode == "column" and ctx.sequence_parallel:
                        dgrad_shape[0] = dgrad_shape[0] * tp_world_size
                    dgrad = torch.empty(dgrad_shape, dtype=output_dtype, device=grad_output.device)

            if ctx.requires_dgrad:
                if ctx.fp8:
575
                    _ = fp8_gemm(
576
577
578
579
                        weight_fp8.transpose_2d(),
                        weight_fp8._scale_inv,
                        0,
                        weight_fp8._fp8_dtype,
580
581
582
583
                        grad_output_c,
                        ctx.fp8_meta["scaling_bwd"].scale_inv,
                        tex.FP8BwdTensors.GRAD_OUTPUT1,
                        fp8_dtype_backward,
584
                        output_dtype,
585
586
                        get_workspace(),
                        use_split_accumulator=_2X_ACC_DGRAD,
587
588
589
                        ub_algo=ub_algo_dgrad,
                        ub=ctx.ub_obj_gradout,
                        out=dgrad,
590
591
592
                        out_index=out_index,
                        fp8_meta_tensor=meta_tensor,
                        D_dtype=output_te_dtype,
593
                        extra_output_tensor=rs_out,
594
                    )
595
596
597

                    if ctx.ub_overlap_rs_dgrad:
                        dgrad = rs_out
598
                    elif output_dtype == torch.uint8:
599
600
                        dgrad = Float8Tensor(
                            data=dgrad,
601
602
603
604
605
                            fp8_meta=ctx.fp8_meta,
                            fp8_meta_forward=False,
                            fp8_meta_index=tex.FP8BwdTensors.GRAD_INPUT1,
                            fp8_dtype=fp8_dtype_backward,
                            dtype=ctx.activation_dtype,
606
                        )
607
                else:
608
                    _ = gemm(
609
610
611
612
613
614
                        weight,
                        grad_output,
                        ctx.activation_dtype,
                        get_workspace(),
                        layout="NN",
                        grad=True,
615
616
617
618
                        ub_algo=ub_algo_dgrad,
                        ub=ctx.ub_obj_gradout,
                        out=dgrad,
                        extra_output_tensor=rs_out,
619
620
                    )

621
622
623
624
625
626
627
628
629
630
631
                    if ctx.ub_overlap_rs_dgrad:
                        dgrad = rs_out

            if inputmat_gather_handle is not None:
                inputmat_gather_handle.wait()

            # Overlap dgrad RS/AR with wgrad via NCCL async comms (no TP overlap via Userbuffers)
            dgrad_reduce_handle = None
            if ctx.requires_dgrad and ctx.parallel_mode == "column":
                if ctx.sequence_parallel and not (ctx.ub_overlap_rs_dgrad or ctx.ub_bulk_wgrad):
                    dgrad, dgrad_reduce_handle = reduce_scatter_along_first_dim(
632
633
                        dgrad, ctx.tp_group, async_op=True
                    )
634
635
                elif ctx.tensor_parallel and not ctx.sequence_parallel:
                    dgrad, dgrad_reduce_handle = allreduce(dgrad, ctx.tp_group, async_op=True)
636

637
            wgrad = None
638
639
640
641
            if weight.requires_grad:
                if ctx.fp8:
                    # WGRAD
                    if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
642
                        if ctx.ub_overlap_ag:
643
644
645
646
                            if isinstance(grad_output_c, Float8Tensor):
                                grad_output_t = grad_output_c.transpose_2d()
                            else:
                                grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward)
647
                        if inputmat_t_total is None:
648
649
650
651
                            if isinstance(inputmat_total, Float8Tensor):
                                inputmat_t_total = inputmat_total.transpose_2d()
                            else:
                                inputmat_t_total = tex.fp8_transpose(
652
653
                                    inputmat_total, fp8_dtype_backward
                                )
654
                        wgrad, _ = fp8_gemm(
655
656
657
658
659
                            (
                                inputmat_t_total._data
                                if isinstance(inputmat_t_total, Float8Tensor)
                                else inputmat_t_total
                            ),
660
661
                            inputmat_scale_inv,
                            0,
662
663
664
665
666
667
668
669
670
671
                            fp8_dtype_forward,
                            grad_output_t,
                            ctx.fp8_meta["scaling_bwd"].scale_inv,
                            tex.FP8BwdTensors.GRAD_OUTPUT1,
                            fp8_dtype_backward,
                            ctx.activation_dtype,
                            get_workspace(),
                            accumulate=accumulate_wgrad_into_param_main_grad,
                            out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
                            use_split_accumulator=_2X_ACC_WGRAD,
672
673
                            ub=ub_obj_wgrad,
                            ub_algo=ub_algo_wgrad,
674
675
676
677
678
679
680
681
682
683
684
                        )
                    else:
                        wgrad, _, _ = gemm(
                            inputmat_total,
                            grad_output,
                            ctx.activation_dtype,
                            get_workspace(),
                            layout="NT",
                            grad=True,
                            accumulate=accumulate_wgrad_into_param_main_grad,
                            out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
685
686
                            ub=ub_obj_wgrad,
                            ub_algo=ub_algo_wgrad,
687
688
689
690
691
692
693
694
695
696
697
698
699
                        )
                else:
                    # WGRAD
                    wgrad, grad_bias, _ = gemm(
                        inputmat_total,
                        grad_output,
                        ctx.activation_dtype,
                        get_workspace(),
                        layout="NT",
                        grad=True,
                        use_bias=ctx.use_bias,
                        accumulate=accumulate_wgrad_into_param_main_grad,
                        out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
700
701
                        ub=ub_obj_wgrad,
                        ub_algo=ub_algo_wgrad,
702
                    )
703

704
705
706
                if ctx.ub_bulk_wgrad:
                    dgrad = ub_obj_wgrad.get_ubuf_output(0)

707
708
709
                # Deallocate input tensor
                clear_tensor_data(inputmat_total)
                clear_tensor_data(inputmat_t_total)
710

711
712
713
            # Wait for dgrad reduce-scatter or all-reduce
            if dgrad_reduce_handle is not None:
                dgrad_reduce_handle.wait()
714
715
716
717

            if not ctx.use_bias:
                grad_bias = None

718
719
        if weight.requires_grad:
            # Handle custom DDP from mcore.
720
            if ctx.fuse_wgrad_accumulation and hasattr(weight, "grad_added_to_main_grad"):
721
                weight.grad_added_to_main_grad = True
722
723
724
725
726
727
728
                if getattr(weight, "zero_out_wgrad", False):
                    wgrad = torch.zeros(
                        weight.main_grad.shape,
                        dtype=weight.dtype,
                        device=torch.cuda.current_device(),
                        requires_grad=False,
                    )
729
                else:
730
731
732
733
734
735
                    wgrad = torch.empty(
                        weight.main_grad.shape,
                        dtype=weight.dtype,
                        device=torch.cuda.current_device(),
                        requires_grad=False,
                    )
736
737
738
739
            elif ctx.fuse_wgrad_accumulation:
                wgrad = None
        else:
            wgrad = None
740

741
        if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing():
742
743
            FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)

744
745
746
747
        # Scatter fp8 weight buffers
        if ctx.fp8 and not isinstance(weight, Float8Tensor):
            _fsdp_scatter_tensors(ctx.fsdp_group, weight_fp8)

748
        return (
749
            wgrad,
750
            None,  # weight_fp8
751
752
            dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None,
            grad_bias,
753
754
755
756
757
758
759
760
761
762
763
764
765
766
            None,  # use_bias
            None,  # is_first_microbatch
            None,  # fp8
            None,  # fp8_calibration
            None,  # fp8_meta
            None,  # fuse_wgrad_accumulation
            None,  # cpu_offloading
            None,  # tp_group
            None,  # tp_size
            None,  # sequence_parallel
            None,  # tensor_parallel
            None,  # activation_dtype
            None,  # parallel_mode
            None,  # is_grad_enabled
767
768
769
770
771
772
            None,  # ub_overlap_rs_fprop
            None,  # ub_overlap_ag_dgrad
            None,  # ub_overlap_ag_fprop
            None,  # ub_overlap_rs_dgrad
            None,  # ub_bulk_dgrad
            None,  # ub_bulk_wgrad
773
            None,  # ub_name
774
            None,  # fp8_output
775
            None,  # fsdp_group
776
777
778
779
        )


class Linear(TransformerEngineBaseModule):
780
    """Applies a linear transformation to the incoming data :math:`y = xA^T + b`
781
782
783
784
785
786
787
788
789
790
791
792
793
794

    On NVIDIA GPUs it is a drop-in replacement for `torch.nn.Linear`.

    Parameters
    ----------
    in_features : int
                 size of each input sample.
    out_features : int
                  size of each output sample.
    bias : bool, default = `True`
          if set to `False`, the layer will not learn an additive bias.
    init_method : Callable, default = `None`
                 used for initializing weights in the following way: `init_method(weight)`.
                 When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`.
795
    get_rng_state_tracker : Callable, default = `None`
796
                 used to get the random number generator state tracker for initializing weights.
797
798
    rng_tracker_name : str, default = `None`
                 the param passed to get_rng_state_tracker to get the specific rng tracker.
cyanguwa's avatar
cyanguwa committed
799
    parameters_split : Optional[Union[Tuple[str, ...], Dict[str, int]]], default = None
800
801
802
803
804
805
806
                      Configuration for splitting the weight and bias tensors along dim 0 into
                      multiple PyTorch parameters. If a list or tuple of strings is provided,
                      they are used to make the names of equally-sized parameters. If a dict
                      (preferably an OrderedDict) is provided, the keys are used as names and
                      values as split sizes along dim 0. The resulting parameters will have
                      names that end in `_weight` or `_bias`, so trailing underscores are
                      stripped from any provided names.
807
    device : Union[torch.device, str], default = "cuda"
808
          The device on which the parameters of the model will be allocated. It is the user's
809
810
          responsibility to ensure all parameters are moved to the GPU before running the
          forward pass.
811
812
813
814
815
816
817
818
819
820
821
822
823

    Parallelism parameters
    ----------------------
    sequence_parallel : bool, default = `False`
                       if set to `True`, uses sequence parallelism.
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
    tp_size : int, default = 1
             used as TP (tensor parallel) world size when TP groups are not formed during
             initialization. In this case, users must call the
             `set_tensor_parallel_group(tp_group)` method on the initialized module before the
             forward pass to supply the tensor parallel group needed for tensor and sequence
             parallel collectives.
824
    parallel_mode : {None, 'column', 'row'}, default = `None`
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
                   used to decide whether this Linear layer is Column Parallel Linear or Row
                   Parallel Linear as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
                   When set to `None`, no communication is performed.

    Optimization parameters
    -----------------------
    fuse_wgrad_accumulation : bool, default = 'False'
                             if set to `True`, enables fusing of creation and accumulation of
                             the weight gradient. When enabled, it is assumed that the weights
                             have an additional `main_grad` attribute (used instead of the
                             regular `grad`) which is a pre-allocated buffer of the correct
                             size to accumulate gradients in.
    return_bias : bool, default = `False`
                 when set to `True`, this module will not apply the additive bias itself, but
                 instead return the bias value during the forward pass together with the
                 output of the linear transformation :math:`y = xA^T`. This is useful when
                 the bias addition can be fused to subsequent operations.
842
    params_dtype : torch.dtype, default = `torch.get_default_dtype()`
843
844
845
                  it controls the type used to allocate the initial parameters. Useful when
                  the model is trained with lower precision and the original FP32 parameters
                  would not fit in GPU memory.
846

847
848
849
850
851
852
853
854
855
856
857
    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        sequence_parallel: bool = False,
        fuse_wgrad_accumulation: bool = False,
        tp_group: Optional[dist_group_type] = None,
        tp_size: int = 1,
        get_rng_state_tracker: Optional[Callable] = None,
858
        rng_tracker_name: Optional[str] = None,
859
860
861
        init_method: Optional[Callable] = None,
        bias: bool = True,
        return_bias: bool = False,
862
        params_dtype: Optional[torch.dtype] = None,
863
        parallel_mode: Optional[str] = None,
cyanguwa's avatar
cyanguwa committed
864
        parameters_split: Optional[Union[Tuple[str, ...], Dict[str, int]]] = None,
865
        device: Union[torch.device, str] = "cuda",
866
        ub_overlap_ag: bool = False,
867
868
869
        ub_overlap_rs: bool = False,
        ub_bulk_dgrad: bool = False,
        ub_bulk_wgrad: bool = False,
870
        ub_name: Optional[str] = None,
871
872
    ) -> None:
        super().__init__()
873
874

        params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
875
876
877
878
879
880
        self.in_features = in_features
        self.out_features = out_features
        self.fuse_wgrad_accumulation = fuse_wgrad_accumulation
        self.use_bias = bias
        self.return_bias = return_bias
        self.apply_bias = bias and not return_bias
881

882
883
        if device == "meta":
            assert parameters_split is None, "Cannot split module parameters on 'meta' device."
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
        if tp_group is None:
            self.tp_size = tp_size
            if tp_size == 1:
                self.set_tensor_parallel_group(tp_group)
        else:
            self.tp_size = get_distributed_world_size(tp_group)
            self.set_tensor_parallel_group(tp_group)
        self.set_nccl_overlap_warning_if_tp()

        self.parallel_mode = parallel_mode
        assert (
            self.parallel_mode in GemmParallelModes
        ), f"parallel_mode {parallel_mode} not supported"

        if self.parallel_mode == "column":
            self.out_features = divide(self.out_features, self.tp_size)
        elif self.parallel_mode == "row":
            self.in_features = divide(self.in_features, self.tp_size)

        self.sequence_parallel = (self.tp_size > 1) and sequence_parallel

905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
        # Column parallel TP overlap options
        self.ub_overlap_ag_fprop = parallel_mode == "column" and sequence_parallel and ub_overlap_ag
        self.ub_overlap_rs_dgrad = parallel_mode == "column" and sequence_parallel and ub_overlap_rs
        self.ub_bulk_dgrad = parallel_mode == "column" and sequence_parallel and ub_bulk_dgrad
        self.ub_bulk_wgrad = parallel_mode == "column" and sequence_parallel and ub_bulk_wgrad
        if self.ub_overlap_rs_dgrad:
            self.ub_bulk_dgrad = False
            self.ub_bulk_wgrad = False

        # Row parallel TP overlap options
        self.ub_overlap_rs_fprop = parallel_mode == "row" and sequence_parallel and ub_overlap_rs
        self.ub_overlap_ag_dgrad = parallel_mode == "row" and sequence_parallel and ub_overlap_ag

        if any(
            [
                self.ub_overlap_rs_fprop,
                self.ub_overlap_ag_dgrad,
                self.ub_overlap_ag_fprop,
                self.ub_overlap_rs_dgrad,
                self.ub_bulk_dgrad,
                self.ub_bulk_wgrad,
            ]
        ):
            assert ub_name is not None, f"Comm+GEMM overlap layer '{ub_name}' is not initialized."
        self.ub_name = ub_name

        assert not (
            self.ub_overlap_rs_fprop and self.ub_overlap_ag_fprop
        ), "Cannot enable AG+GEMM and GEMM+RS overlaps at the same time."
        assert not (
            self.ub_overlap_rs_dgrad and self.ub_bulk_dgrad
        ), "Cannot enable DGRAD+RS and bulk DGRAD overlaps at the same time."
        assert not (
            self.ub_overlap_ag_dgrad and (self.ub_overlap_rs_dgrad or self.ub_bulk_dgrad)
        ), "Cannot enable AG+DGRAD and DGRAD+RS or bulk DGRAD overlaps at the same time."

        self.get_rng_state_tracker = get_rng_state_tracker
        self.rng_tracker_name = rng_tracker_name

944
945
946
        # Initialize params in FP8
        with_fp8_params = FP8GlobalStateManager.with_fp8_parameters()

947
948
949
950
951
952
953
954
        # Contiguous buffers for params
        weight_tensor = torch.empty(
            self.out_features,
            self.in_features,
            device=device,
            dtype=params_dtype,
        )
        bias_tensor = None
955
        if self.use_bias:
956
957
958
959
960
            bias_tensor = torch.empty(
                self.out_features,
                device=device,
                dtype=params_dtype,
            )
961

962
963
964
965
        # Configure parameter splits
        self.weight_names = []
        self.bias_names = []
        self.parameter_split_sizes = []
966
        if parameters_split is None:
967
968
969
970
971
972
            # Split into a single parameter by default
            self.weight_names = ["weight"]
            self.bias_names = ["bias"]
            self.parameter_split_sizes = [out_features]
        elif not parameters_split:
            raise ValueError("Cannot split weight buffer into 0 parameters")
cyanguwa's avatar
cyanguwa committed
973
        elif isinstance(parameters_split, dict):
974
975
976
977
978
979
980
981
982
983
984
985
            # Split parameters with provided sizes
            for name, split_size in parameters_split.items():
                self.weight_names.append(f"{name.rstrip('_')}_weight")
                self.bias_names.append(f"{name.rstrip('_')}_bias")
                self.parameter_split_sizes.append(split_size)
        elif all(isinstance(name, str) for name in parameters_split):
            # Split parameters evenly
            split_size = out_features // len(parameters_split)
            for name in parameters_split:
                self.weight_names.append(f"{name.rstrip('_')}_weight")
                self.bias_names.append(f"{name.rstrip('_')}_bias")
                self.parameter_split_sizes.append(split_size)
cyanguwa's avatar
cyanguwa committed
986
        else:
987
            raise TypeError("Invalid configuration for parameters split")
988

989
990
991
992
993
994
        # Make sure parameter splits are valid
        if sum(self.parameter_split_sizes) != out_features:
            raise ValueError(
                f"Trying to split weight buffer ({out_features=}) "
                f"with split sizes {self.parameter_split_sizes}"
            )
995

996
997
998
999
1000
1001
1002
1003
1004
1005
        # Adjust parameter splits for tensor-parallel distribution
        if self.parallel_mode == "column":
            for i, size in enumerate(self.parameter_split_sizes):
                if size % self.tp_size != 0:
                    raise RuntimeError(
                        f"Attempting to distribute a parameter with out_features={size} "
                        f"between {self.tp_size} tensor-parallel processes"
                    )
                self.parameter_split_sizes[i] = size // self.tp_size

1006
1007
1008
1009
1010
        # Construct weight parameters
        # Note: Register weights together so that they are adjacent to
        # each other in Linear.parameters(). This makes it more likely
        # that they will stay contiguous if the weights are
        # manipulated externally, e.g. by FSDP.
1011
1012
1013
1014
1015
1016
1017
1018
        offset = 0
        for i, split_size in enumerate(self.parameter_split_sizes):
            split_start = offset
            offset += split_size
            split_end = offset

            # Check if parameters are subviews of buffers
            is_subview = (split_start, split_end) != (0, self.out_features)
1019
            if is_subview and with_fp8_params:
1020
                raise RuntimeError("Splitting Float8Tensor into multiple params is not supported")
1021

1022
            # Construct weight parameter
1023
1024
1025
1026
1027
1028
1029
            self.register_parameter(
                self.weight_names[i],
                torch.nn.Parameter(weight_tensor[split_start:split_end]),
                init_fn=init_method,
                get_rng_state_tracker=get_rng_state_tracker,
                fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT,
            )
1030

1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
        # Construct bias parameters if needed
        if self.use_bias:
            offset = 0
            for i, split_size in enumerate(self.parameter_split_sizes):
                split_start = offset
                offset += split_size
                split_end = offset
                self.register_parameter(
                    self.bias_names[i],
                    torch.nn.Parameter(bias_tensor[split_start:split_end]),
                    init_fn=init_method_constant(0.0),
                )
        else:
            for name in self.bias_names:
                bias = torch.Tensor().to(dtype=params_dtype, device=device)
                setattr(self, name, bias)
cyanguwa's avatar
cyanguwa committed
1047

1048
        if with_fp8_params:
1049
1050
            self.init_fp8_metadata()

1051
        self.reset_parameters(defer_init=device == "meta")
1052

1053
1054
1055
1056
1057
1058
1059
        # For RPL, bias has to be added after TP collectives
        # So it cannot be fused with the GEMM
        if self.parallel_mode == "row" and self.apply_bias:
            self.gemm_bias_unfused_add = True
        else:
            self.gemm_bias_unfused_add = False

1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
    def reset_parameters(self, defer_init=False):
        super().reset_parameters(defer_init=defer_init)

        if not defer_init:
            # Set parallelism attributes for linear weights
            for weight in self.weight_names:
                set_tensor_model_parallel_attributes(
                    tensor=getattr(self, weight),
                    is_parallel=True,
                    dim=1 if self.parallel_mode == "row" else 0,
                    stride=1,
                )

            # Set parallelism attributes for linear biases
            if self.use_bias:
                for bias in self.bias_names:
                    if self.parallel_mode == "row":
                        setattr(getattr(self, bias), "sequence_parallel", self.sequence_parallel)
                    elif self.parallel_mode == "column":
                        set_tensor_model_parallel_attributes(getattr(self, bias), True, 0, 1)

1081
    @no_torch_dynamo()
1082
1083
1084
1085
    def forward(
        self,
        inp: torch.Tensor,
        is_first_microbatch: Optional[bool] = None,
1086
        fp8_output: Optional[bool] = False,
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
        """
        Apply the linear transformation to the input.

        Parameters
        ----------
        inp : torch.Tensor
             Input tensor.
        is_first_microbatch : {True, False, None}, default = None
                             During training using either gradient accumulation or
                             pipeline parallelism a minibatch of data is further split
                             into microbatches. Between the microbatches of the same minibatch
                             the model weights are not updated. Setting this parameter indicates
                             whether the current microbatch is the first in a minibatch or not.
                             When set, this parameter enables additional optimizations:

                             * during FP8 training, it allows caching of the FP8 versions of
                               the weights
                             * it also allows skipping gradient accumulation during the
                               first microbatch (since it is the first gradient being
                               produced)
        """
1109
1110
1111
1112
        if FP8GlobalStateManager.fp8_graph_capturing():
            skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
        else:
            skip_fp8_weight_update = None
1113
1114
1115
        if skip_fp8_weight_update is not None:
            is_first_microbatch = False

1116
1117
        with self.prepare_forward(
            inp,
1118
            is_first_microbatch,
1119
            allow_non_contiguous=isinstance(inp, QuantizedTensor),
1120
        ) as inp:
1121
1122

            # Get concatenated weight and bias tensors
1123
            unfused_weights = [getattr(self, name) for name in self.weight_names]
1124
            if any(isinstance(w, QuantizedTensor) for w in unfused_weights):
1125
1126
1127
                if self.fp8:
                    if len(unfused_weights) != 1:
                        raise RuntimeError(
1128
                            "Splitting QuantizedTensor into multiple params is not supported"
1129
1130
                        )
                else:
1131
                    unfused_weights = [w.dequantize() for w in unfused_weights]
1132
            weight_tensor = _noop_cat(unfused_weights)
1133
            if self.use_bias:
1134
                bias_tensor = _noop_cat([getattr(self, name) for name in self.bias_names])
1135
            else:
1136
                bias_tensor = getattr(self, self.bias_names[0])  # Unused
1137

1138
1139
1140
1141
            # Initialize FP8 weights if needed
            weight_fp8 = None
            if self.fp8:
                if isinstance(weight_tensor, Float8Tensor):
1142
1143
1144
1145
                    # Make sure transpose cache is valid, if present
                    # Note: Transpose cache may have been invalidated
                    # externally, e.g. by optimizer.
                    if weight_tensor._transpose is not None:
1146
1147
1148
1149
1150
1151
                        weight_tensor.transpose_2d(
                            fill_cache=True,
                            noop_flag=skip_fp8_weight_update,
                        )
                else:
                    # FP8 cast to workspace buffer
1152
                    update_workspace = is_first_microbatch is None or is_first_microbatch
1153
1154
1155
1156
1157
1158
1159
                    weight_fp8 = self.get_fp8_workspace(
                        tensor=weight_tensor,
                        fp8_meta_forward=True,
                        fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT,
                        cache_name=(None if is_first_microbatch is None else "weight"),
                        update_workspace=update_workspace,
                        skip_update_flag=skip_fp8_weight_update,
1160
                        fsdp_group=self.fsdp_group,
1161
                    )
1162

1163
1164
1165
1166
1167
1168
1169
1170
            if torch.is_grad_enabled():
                linear_fn = _Linear.apply
                args = []
            else:
                linear_fn = _Linear.forward
                args = [None]
            args += (
                weight_tensor,
1171
                weight_fp8,
1172
1173
1174
1175
1176
1177
1178
1179
                inp,
                bias_tensor,
                self.apply_bias and not self.gemm_bias_unfused_add,
                is_first_microbatch,
                self.fp8,
                self.fp8_calibration,
                self.fp8_meta,
                self.fuse_wgrad_accumulation,
1180
                is_cpu_offload_enabled(),
1181
1182
1183
1184
1185
1186
1187
                self.tp_group,
                self.tp_size,
                self.sequence_parallel,
                self.tp_size > 1,
                self.activation_dtype,
                self.parallel_mode,
                torch.is_grad_enabled(),
1188
1189
1190
1191
1192
1193
                self.ub_overlap_rs_fprop,
                self.ub_overlap_ag_dgrad,
                self.ub_overlap_ag_fprop,
                self.ub_overlap_rs_dgrad,
                self.ub_bulk_dgrad,
                self.ub_bulk_wgrad,
1194
                self.ub_name,
1195
                fp8_output,
1196
                self.fsdp_group,
1197
1198
1199
1200
1201
1202
1203
1204
1205
            )
            out = linear_fn(*args)

        if self.gemm_bias_unfused_add:
            out = out + cast_if_needed(bias_tensor, self.activation_dtype)

        if self.return_bias:
            return out, cast_if_needed(bias_tensor, self.activation_dtype)
        return out