linear.py 42.9 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.

"""Linear API"""
6
from typing import Any, Callable, Dict, Optional, Tuple, Union
7
8
9

import torch

10
import transformer_engine_torch as tex
11
12
13
14
15
16
17
18
19

from .base import (
    get_workspace,
    get_ub,
    TransformerEngineBaseModule,
    _2X_ACC_FPROP,
    _2X_ACC_DGRAD,
    _2X_ACC_WGRAD,
)
20
from ._common import _noop_cat
21
from ..fp8 import get_fp8_te_dtype, FP8GlobalStateManager
22
23
24
from ..utils import (
    divide,
    cast_if_needed,
25
    assert_dim_for_fp8_exec,
26
    clear_tensor_data,
27
    init_method_constant,
28
    requires_grad,
29
30
31
32
33
34
35
)
from ..distributed import (
    set_tensor_model_parallel_attributes,
    get_distributed_world_size,
    allreduce,
    reduce_scatter_along_first_dim,
    gather_along_first_dim,
36
37
    _fsdp_scatter_tensors,
    _fsdp_gather_tensors,
38
39
40
41
42
43
44
45
)
from ..cpp_extensions import (
    fp8_gemm,
    gemm,
    fp8_cast_transpose_fused,
    cast_to_fp8,
)
from ..constants import GemmParallelModes, dist_group_type
46
from ..jit import no_torch_dynamo
47
from ..graph import is_graph_capturing
48
from ..float8_tensor import Float8Tensor
49
from ..export import is_in_onnx_export_mode
50
from ..tensor import QuantizedTensor
51
from ..cpu_offload import is_cpu_offload_enabled
52

53
54
55
56
57
58
59
60
61
62
63
__all__ = ["Linear"]


class _Linear(torch.autograd.Function):
    """Linear semi-top level module
    Calls custom cuda extensions.
    """

    @staticmethod
    def forward(
        ctx,
64
        weight: Union[Float8Tensor, torch.Tensor],
65
        weight_fp8: Optional[Float8Tensor],
66
67
68
69
70
71
72
73
        inp: torch.Tensor,
        bias: torch.Tensor,
        use_bias: bool,
        is_first_microbatch: Union[bool, None],
        fp8: bool,
        fp8_calibration: bool,
        fp8_meta: Dict[str, Any],
        fuse_wgrad_accumulation: bool,
74
        cpu_offloading: bool,
75
76
77
78
79
80
81
        tp_group: Union[dist_group_type, None],
        tp_size: int,
        sequence_parallel: bool,
        tensor_parallel: bool,
        activation_dtype: torch.dtype,
        parallel_mode: Union[str, None],
        is_grad_enabled: bool,
82
83
        ub_overlap_rs: bool,
        ub_overlap_ag: bool,
84
        ub_name: str,
85
        fp8_output: bool,
86
        fsdp_group: Union[dist_group_type, None],
87
    ) -> torch.Tensor:
88
        # pylint: disable=missing-function-docstring
89
90
        is_input_fp8 = isinstance(inp, Float8Tensor)

91
        # Make sure input dimensions are compatible
92
93
94
        out_features, in_features = weight.shape
        inp_shape = inp.shape
        assert inp_shape[-1] == in_features, "GEMM not possible"
95
        inputmat = inp.view(-1, in_features)
96
        if fp8:
97
98
            assert_dim_for_fp8_exec(inputmat)
            assert_dim_for_fp8_exec(weight)
99

100
101
        tp_world_size = get_distributed_world_size(tp_group)
        ub_overlap_rs = False if tp_world_size == 1 else ub_overlap_rs
102
103

        # Cast input to expected dtype
104
        inputmat = cast_if_needed(inputmat, activation_dtype)
105
        inputmat_t = None
106
        inputmat_no_fp8 = inputmat
107
        inputmat_scale_inv = None
108

109
110
        if fp8:
            fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
111
            if isinstance(inputmat, Float8Tensor):
112
                inputmat_scale_inv = inputmat._scale_inv
113
            else:
114
                inputmat_scale_inv = torch.empty([1], dtype=torch.float32, device=inputmat.device)
115
116
117
118
119
120
121
122
123
124
125
126
                if (
                    not fp8_meta["recipe"].override_linear_precision.wgrad
                    and is_grad_enabled
                    and weight.requires_grad
                    and not sequence_parallel
                ):
                    # FP8 input for forward, FP8 input transpose for backward wgrad
                    inputmat, inputmat_t = fp8_cast_transpose_fused(
                        inputmat,
                        fp8_meta["scaling_fwd"],
                        tex.FP8FwdTensors.GEMM1_INPUT,
                        fp8_dtype_forward,
127
                        scale_inv=inputmat_scale_inv,
128
129
130
131
132
133
134
135
                    )
                else:
                    # FP8 input for forward
                    inputmat = cast_to_fp8(
                        inputmat,
                        fp8_meta["scaling_fwd"],
                        tex.FP8FwdTensors.GEMM1_INPUT,
                        fp8_dtype_forward,
136
                        scale_inv=inputmat_scale_inv,
137
                    )
138

139
140
141
142
143
144
145
146
147
148
149
150
            # Hack for ONNX export
            # Note: ONNX models are represented as a graph of tensor
            # operations, so the in-place scale-inv update doesn't fit
            # very well. We work around this by making it look like
            # the scale-inv tensor is initialized with a copy.
            # Note: ONNX export expects FP8 scales can be represented
            # with constant ops. However, copying into a buffer
            # involves an expand op for array broadcasting. We work
            # around this by filling the buffer instead.
            if is_in_onnx_export_mode():
                inputmat_scale_inv.fill_(inputmat_scale_inv.item())

151
152
153
154
155
156
        # Column Parallel Linear
        if parallel_mode == "column" and sequence_parallel:
            inputmat_total, _ = gather_along_first_dim(inputmat, tp_group)
        else:
            inputmat_total = inputmat
        if fp8:
157
            bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype
158
159
            bias = cast_if_needed(bias, bias_dtype) if use_bias else bias

160
161
            # Use FP8 weights
            if weight_fp8 is None:
162
                weight_fp8 = weight
163

164
            assert isinstance(weight_fp8, Float8Tensor)
165

166
            if fp8_output:
167
168
169
170
                proj_out_index, meta_tensor, proj_out_tetype, proj_out_pttype = (
                    tex.FP8FwdTensors.GEMM1_OUTPUT,
                    fp8_meta["scaling_fwd"],
                    fp8_dtype_forward,
171
172
                    torch.uint8,
                )
173
174
            else:
                proj_out_index, meta_tensor, proj_out_tetype, proj_out_pttype = (
175
176
177
178
179
                    None,
                    None,
                    None,
                    activation_dtype,
                )
180

181
182
            ub_algo = None
            rs_out = None
183
            if ub_overlap_rs:
184
                ub_obj_projout = get_ub(ub_name + "_fprop")
185
186
187
                out = ub_obj_projout.get_ubuf_output(1)
                dim_size = list(inputmat_total.size())
                dim_size[0] = dim_size[0] // tp_world_size
188
                dim_size[1] = out_features
189
                rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device)
190
191
                if ub_obj_projout.is_p2p_overlap():
                    if ub_obj_projout.is_atomic_gemm():
192
                        ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P
193
194
195
196
197
198
199
                    else:
                        ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P
                else:
                    if ub_obj_projout.is_atomic_gemm():
                        ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_RS
                    else:
                        ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS
200
201
202
203
204
205
                if ub_obj_projout.is_fp8_ubuf():
                    proj_out_index = tex.FP8FwdTensors.GEMM1_OUTPUT
                    meta_tensor = fp8_meta["scaling_fwd"]
                    proj_out_tetype = fp8_dtype_forward
                    proj_out_pttype = torch.uint8
                    ub_obj_projout.set_ubuf_scale_inv(meta_tensor.scale_inv[proj_out_index])
206
207
            else:
                dim_size = list(inputmat_total.size())
208
                dim_size[1] = out_features
209
                out = torch.empty(dim_size, dtype=proj_out_pttype, device=inputmat_total.device)
210
211

            _ = fp8_gemm(
212
                weight_fp8._data,
213
214
215
                weight_fp8._scale_inv,
                0,
                weight_fp8._fp8_dtype,
216
217
218
219
220
                (
                    inputmat_total._data
                    if isinstance(inputmat_total, Float8Tensor)
                    else inputmat_total
                ),
221
222
                inputmat_scale_inv,
                0,
223
                fp8_dtype_forward,
224
                proj_out_pttype,
225
226
227
228
229
                get_workspace(),
                bias=bias,
                use_bias=use_bias,
                use_split_accumulator=_2X_ACC_FPROP,
                out=out,
230
231
232
                ub_algo=ub_algo if ub_overlap_rs else None,
                ub=ub_obj_projout if ub_overlap_rs else None,
                extra_output_tensor=rs_out if ub_overlap_rs else None,
233
                out_index=proj_out_index,
234
235
                fp8_meta_tensor=meta_tensor,
                D_dtype=proj_out_tetype,
236
            )
237
            if fp8_output:
238
239
                out = Float8Tensor(
                    data=out,
240
241
242
243
244
245
                    fp8_meta=fp8_meta,
                    fp8_meta_forward=True,
                    fp8_meta_index=tex.FP8FwdTensors.GEMM1_OUTPUT,
                    fp8_dtype=fp8_dtype_forward,
                    dtype=activation_dtype,
                )
246
247
248
249
250
251
252
        else:
            # Cast for native AMP
            weight = cast_if_needed(weight, activation_dtype)
            bias = cast_if_needed(bias, activation_dtype) if use_bias else bias

            if fp8_calibration:
                # amax of input
253
                amin, amax = inputmat_total.aminmax()
254
255
256
                fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_INPUT] = torch.max(
                    -amin, amax
                ).float()
257
                # amax of weight
258
                amin, amax = weight.aminmax()
259
260
261
                fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_WEIGHT] = torch.max(
                    -amin, amax
                ).float()
262

263
            if ub_overlap_rs:
264
                ub_obj_projout = get_ub(ub_name + "_fprop")
265
266
                out = ub_obj_projout.get_ubuf_output(1)
                dim_size = list(inputmat_total.size())
267
                dim_size[0] = dim_size[0] // get_distributed_world_size(tp_group)
268
                dim_size[1] = out_features
269
                rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device)
270
271
272
273
                if ub_obj_projout.is_p2p_overlap():
                    ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P
                else:
                    ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS
274
275
            else:
                dim_size = list(inputmat_total.size())
276
                dim_size[1] = out_features
277
278
                out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device)

279
            _ = gemm(
280
281
282
283
284
285
286
                weight,
                inputmat_total,
                activation_dtype,
                get_workspace(),
                bias=bias,
                use_bias=use_bias,
                out=out,
287
288
289
                ub_algo=ub_algo if ub_overlap_rs else None,
                ub=ub_obj_projout if ub_overlap_rs else None,
                extra_output_tensor=rs_out if ub_overlap_rs else None,
290
291
292
            )

        if is_grad_enabled:
293
294
295
296
297
298
299
300
            saved_inputmat = None
            saved_inputmat_t = None
            if weight.requires_grad:
                if fp8 and not fp8_meta["recipe"].override_linear_precision.wgrad:
                    if inputmat_t is None:
                        saved_inputmat = inputmat
                    else:
                        saved_inputmat_t = inputmat_t
301
302
                        if cpu_offloading:
                            saved_inputmat_t.activation_offloading = True
303
304
                else:
                    saved_inputmat = inputmat_no_fp8
305
306

                if cpu_offloading:
307
308
                    if fp8 and weight_fp8 is not None:
                        weight_fp8.weight_offloading = True
309
310
311
312
313
                    weight.weight_offloading = True

                    if saved_inputmat is not None:
                        saved_inputmat.activation_offloading = True

314
315
316
317
318
            # Scatter intermediate/activation tensors saved for the backward pass
            # NOTE: FSDP sharding is not valid for models initialized with primary Fp8 weights
            ctx.fsdp_group = fsdp_group
            ctx.fsdp_shapes = _fsdp_scatter_tensors(
                fsdp_group,
319
320
                saved_inputmat,  # None if fp8 == False
                saved_inputmat_t,  # None if fp8 == False AND not is_grad_enabled
321
322
323
                weight_fp8 if fp8 and not isinstance(weight, Float8Tensor) else None,
            )

324
            ctx.save_for_backward(
325
326
                saved_inputmat,
                saved_inputmat_t,
327
                inputmat_scale_inv,
328
                weight,
329
                weight_fp8,
330
                weight.main_grad if cpu_offloading and fuse_wgrad_accumulation else None,
331
            )
332

333
334
335
336
            ctx.activation_dtype = activation_dtype
            ctx.fp8 = fp8
            ctx.fp8_meta = fp8_meta
            ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
337
            ctx.cpu_offloading = cpu_offloading
338
339
340
341
            ctx.is_first_microbatch = is_first_microbatch
            ctx.use_bias = use_bias
            ctx.sequence_parallel = sequence_parallel
            ctx.tensor_parallel = tensor_parallel
342
            ctx.inp_shape = inp_shape
343
344
            ctx.parallel_mode = parallel_mode
            ctx.tp_group = tp_group
345
            ctx.ub_overlap_ag = ub_overlap_ag
346
            ctx.ub_name = ub_name
347
348
            ctx.tp_size = tp_size
            ctx.requires_dgrad = inp.requires_grad
349
            ctx.is_input_fp8 = is_input_fp8
350
351
352
            ctx.reduce_and_update_bwd_fp8_tensors = False
            if ctx.fp8 and requires_grad(inp, weight, bias):
                ctx.reduce_and_update_bwd_fp8_tensors = (
353
354
355
                    ctx.reduce_and_update_bwd_fp8_tensors
                    or FP8GlobalStateManager.is_first_fp8_module()
                )
356
357

        # Row Parallel Linear
358
        if ub_overlap_rs:
359
360
361
362
363
364
365
            out = rs_out
        elif parallel_mode == "row" and sequence_parallel:
            out, _ = reduce_scatter_along_first_dim(out, tp_group)
        elif parallel_mode == "row" and tensor_parallel:
            out, _ = allreduce(out, tp_group)

        # [*, in_features] -> [*, out_features] except first dimension changes for SP
366
        return out.view(-1, *inp_shape[1:-1], out_features)
367
368

    @staticmethod
369
    def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
370
        # pylint: disable=missing-function-docstring
371
        if isinstance(grad_output, Float8Tensor):
372
            ctx.fp8_meta["scaling_bwd"].scale_inv[
373
374
                tex.FP8BwdTensors.GRAD_OUTPUT1
            ] = grad_output._scale_inv
375

376
        with torch.cuda.nvtx.range("_Linear_backward"):
377
378
379
            (
                inputmat,
                inputmat_t,
380
                inputmat_scale_inv,
381
                weight,
382
                weight_fp8,
383
                main_grad,
384
            ) = ctx.saved_tensors
385

386
387
388
389
390
391
392
393
            # Gather intermediate/activation tensors if needed
            # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already
            #       shards/unshards the base weights so we don't do it ourselves
            _fsdp_gather_tensors(
                ctx.fsdp_group,
                ctx.fsdp_shapes,
                inputmat,
                inputmat_t,
394
395
                weight_fp8 if ctx.fp8 and not isinstance(weight, Float8Tensor) else None,
            )
396

397
            if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation:
398
                weight = torch.nn.Parameter(weight, weight.requires_grad)
399
400
                weight.main_grad = main_grad

401
402
            tp_world_size = get_distributed_world_size(ctx.tp_group)
            ctx.ub_overlap_ag = False if tp_world_size == 1 else ctx.ub_overlap_ag
403
            ub_algo = None
404
            if ctx.ub_overlap_ag:
405
406
                dim_size = list(grad_output.size())
                dim_size[0] = dim_size[0] * tp_world_size
407
                ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad")
408
409
410
411
                if ctx.ub_obj_gradout.is_atomic_gemm():
                    ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P
                else:
                    ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P
412

413
414
415
416
417
418
419
420
421
422
423
            (
                grad_output,
                grad_output_c,
                grad_output_t,
                grad_bias,
            ) = TransformerEngineBaseModule.grad_output_preprocess(
                ctx, grad_output, ctx.parallel_mode == "row"
            )

            # Column Parallel Linear
            # Overlap input AG with dgrad
424
425
426
            inputmat_total = None
            inputmat_t_total = None
            handle = None
427
            if weight.requires_grad and ctx.parallel_mode == "column" and ctx.sequence_parallel:
428
429
430
                inputmat_total, handle = gather_along_first_dim(
                    inputmat, ctx.tp_group, async_op=ctx.requires_dgrad
                )
431
432
            else:
                inputmat_total = inputmat
433
                inputmat_t_total = inputmat_t
434
435
436
437
438
439
440
441
442

            if ctx.is_first_microbatch is not None:
                accumulate_wgrad_into_param_main_grad = (
                    ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch
                )
            else:
                accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation

            if ctx.fp8:
443
444
                fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
                fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False)
445
446
447

            if ctx.requires_dgrad:
                if ctx.fp8:
448
449
450
451
452
                    if ctx.is_input_fp8:
                        out_index, meta_tensor, output_te_dtype, output_dtype = (
                            tex.FP8BwdTensors.GRAD_INPUT1,
                            ctx.fp8_meta["scaling_bwd"],
                            fp8_dtype_backward,
453
454
                            torch.uint8,
                        )
455
456
                    else:
                        out_index, meta_tensor, output_te_dtype, output_dtype = (
457
458
459
460
461
                            None,
                            None,
                            None,
                            ctx.activation_dtype,
                        )
462
                    dgrad, _ = fp8_gemm(
463
464
465
466
                        weight_fp8.transpose_2d(),
                        weight_fp8._scale_inv,
                        0,
                        weight_fp8._fp8_dtype,
467
468
469
470
                        grad_output_c,
                        ctx.fp8_meta["scaling_bwd"].scale_inv,
                        tex.FP8BwdTensors.GRAD_OUTPUT1,
                        fp8_dtype_backward,
471
                        output_dtype,
472
473
                        get_workspace(),
                        use_split_accumulator=_2X_ACC_DGRAD,
474
475
                        ub_algo=ub_algo if ctx.ub_overlap_ag else None,
                        ub=ctx.ub_obj_gradout if ctx.ub_overlap_ag else None,
476
477
478
                        out_index=out_index,
                        fp8_meta_tensor=meta_tensor,
                        D_dtype=output_te_dtype,
479
                    )
480
                    if output_dtype == torch.uint8:
481
482
                        dgrad = Float8Tensor(
                            data=dgrad,
483
484
485
486
487
                            fp8_meta=ctx.fp8_meta,
                            fp8_meta_forward=False,
                            fp8_meta_index=tex.FP8BwdTensors.GRAD_INPUT1,
                            fp8_dtype=fp8_dtype_backward,
                            dtype=ctx.activation_dtype,
488
                        )
489
490
491
492
493
494
495
496
                else:
                    dgrad, _, _ = gemm(
                        weight,
                        grad_output,
                        ctx.activation_dtype,
                        get_workspace(),
                        layout="NN",
                        grad=True,
497
498
499
500
501
                        ub_algo=(
                            tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P
                            if ctx.ub_overlap_ag
                            else None
                        ),
502
                        ub=ctx.ub_obj_gradout if ctx.ub_overlap_ag else None,
503
504
505
506
                    )

                # Overlap dgrad-RS/AR with wgrad
                if ctx.parallel_mode == "column" and ctx.sequence_parallel:
507
508
                    if handle is not None:
                        handle.wait()
509
510
511
512
513
514
                    dgrad, handle = reduce_scatter_along_first_dim(
                        dgrad, ctx.tp_group, async_op=True
                    )
                elif ctx.parallel_mode == "column" and ctx.tensor_parallel:
                    dgrad, handle = allreduce(dgrad, ctx.tp_group, async_op=True)

515
            wgrad = None
516
517
518
519
            if weight.requires_grad:
                if ctx.fp8:
                    # WGRAD
                    if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
520
                        if ctx.ub_overlap_ag:
521
522
523
524
                            if isinstance(grad_output_c, Float8Tensor):
                                grad_output_t = grad_output_c.transpose_2d()
                            else:
                                grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward)
525
                        if inputmat_t_total is None:
526
527
528
529
                            if isinstance(inputmat_total, Float8Tensor):
                                inputmat_t_total = inputmat_total.transpose_2d()
                            else:
                                inputmat_t_total = tex.fp8_transpose(
530
531
                                    inputmat_total, fp8_dtype_backward
                                )
532
                        wgrad, _ = fp8_gemm(
533
534
535
536
537
                            (
                                inputmat_t_total._data
                                if isinstance(inputmat_t_total, Float8Tensor)
                                else inputmat_t_total
                            ),
538
539
                            inputmat_scale_inv,
                            0,
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
                            fp8_dtype_forward,
                            grad_output_t,
                            ctx.fp8_meta["scaling_bwd"].scale_inv,
                            tex.FP8BwdTensors.GRAD_OUTPUT1,
                            fp8_dtype_backward,
                            ctx.activation_dtype,
                            get_workspace(),
                            accumulate=accumulate_wgrad_into_param_main_grad,
                            out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
                            use_split_accumulator=_2X_ACC_WGRAD,
                        )
                    else:
                        wgrad, _, _ = gemm(
                            inputmat_total,
                            grad_output,
                            ctx.activation_dtype,
                            get_workspace(),
                            layout="NT",
                            grad=True,
                            accumulate=accumulate_wgrad_into_param_main_grad,
                            out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
                        )
                else:
                    # WGRAD
                    wgrad, grad_bias, _ = gemm(
                        inputmat_total,
                        grad_output,
                        ctx.activation_dtype,
                        get_workspace(),
                        layout="NT",
                        grad=True,
                        use_bias=ctx.use_bias,
                        accumulate=accumulate_wgrad_into_param_main_grad,
                        out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
                    )
575
576
577
578

                # Deallocate input tensor
                clear_tensor_data(inputmat_total)
                clear_tensor_data(inputmat_t_total)
579
580
581
582
583
584
585
586

            # Column Parallel Linear
            if ctx.parallel_mode == "column" and ctx.tensor_parallel and handle is not None:
                handle.wait()

            if not ctx.use_bias:
                grad_bias = None

587
588
        if weight.requires_grad:
            # Handle custom DDP from mcore.
589
            if ctx.fuse_wgrad_accumulation and hasattr(weight, "grad_added_to_main_grad"):
590
                weight.grad_added_to_main_grad = True
591
592
593
594
595
596
597
                if getattr(weight, "zero_out_wgrad", False):
                    wgrad = torch.zeros(
                        weight.main_grad.shape,
                        dtype=weight.dtype,
                        device=torch.cuda.current_device(),
                        requires_grad=False,
                    )
598
                else:
599
600
601
602
603
604
                    wgrad = torch.empty(
                        weight.main_grad.shape,
                        dtype=weight.dtype,
                        device=torch.cuda.current_device(),
                        requires_grad=False,
                    )
605
606
607
608
            elif ctx.fuse_wgrad_accumulation:
                wgrad = None
        else:
            wgrad = None
609

610
        if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing():
611
612
            FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)

613
614
615
616
        # Scatter fp8 weight buffers
        if ctx.fp8 and not isinstance(weight, Float8Tensor):
            _fsdp_scatter_tensors(ctx.fsdp_group, weight_fp8)

617
        return (
618
            wgrad,
619
            None,  # weight_fp8
620
621
            dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None,
            grad_bias,
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
            None,  # use_bias
            None,  # is_first_microbatch
            None,  # fp8
            None,  # fp8_calibration
            None,  # fp8_meta
            None,  # fuse_wgrad_accumulation
            None,  # cpu_offloading
            None,  # tp_group
            None,  # tp_size
            None,  # sequence_parallel
            None,  # tensor_parallel
            None,  # activation_dtype
            None,  # parallel_mode
            None,  # is_grad_enabled
            None,  # ub_overlap_rs
            None,  # ub_overlap_ag
            None,  # ub_name
639
            None,  # fp8_output
640
            None,  # fsdp_group
641
642
643
644
        )


class Linear(TransformerEngineBaseModule):
645
    """Applies a linear transformation to the incoming data :math:`y = xA^T + b`
646
647
648
649
650
651
652
653
654
655
656
657
658
659

    On NVIDIA GPUs it is a drop-in replacement for `torch.nn.Linear`.

    Parameters
    ----------
    in_features : int
                 size of each input sample.
    out_features : int
                  size of each output sample.
    bias : bool, default = `True`
          if set to `False`, the layer will not learn an additive bias.
    init_method : Callable, default = `None`
                 used for initializing weights in the following way: `init_method(weight)`.
                 When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`.
660
    get_rng_state_tracker : Callable, default = `None`
661
                 used to get the random number generator state tracker for initializing weights.
662
663
    rng_tracker_name : str, default = `None`
                 the param passed to get_rng_state_tracker to get the specific rng tracker.
cyanguwa's avatar
cyanguwa committed
664
    parameters_split : Optional[Union[Tuple[str, ...], Dict[str, int]]], default = None
665
666
667
668
669
670
671
                      Configuration for splitting the weight and bias tensors along dim 0 into
                      multiple PyTorch parameters. If a list or tuple of strings is provided,
                      they are used to make the names of equally-sized parameters. If a dict
                      (preferably an OrderedDict) is provided, the keys are used as names and
                      values as split sizes along dim 0. The resulting parameters will have
                      names that end in `_weight` or `_bias`, so trailing underscores are
                      stripped from any provided names.
672
    device : Union[torch.device, str], default = "cuda"
673
          The device on which the parameters of the model will be allocated. It is the user's
674
675
          responsibility to ensure all parameters are moved to the GPU before running the
          forward pass.
676
677
678
679
680
681
682
683
684
685
686
687
688

    Parallelism parameters
    ----------------------
    sequence_parallel : bool, default = `False`
                       if set to `True`, uses sequence parallelism.
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
    tp_size : int, default = 1
             used as TP (tensor parallel) world size when TP groups are not formed during
             initialization. In this case, users must call the
             `set_tensor_parallel_group(tp_group)` method on the initialized module before the
             forward pass to supply the tensor parallel group needed for tensor and sequence
             parallel collectives.
689
    parallel_mode : {None, 'column', 'row'}, default = `None`
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
                   used to decide whether this Linear layer is Column Parallel Linear or Row
                   Parallel Linear as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
                   When set to `None`, no communication is performed.

    Optimization parameters
    -----------------------
    fuse_wgrad_accumulation : bool, default = 'False'
                             if set to `True`, enables fusing of creation and accumulation of
                             the weight gradient. When enabled, it is assumed that the weights
                             have an additional `main_grad` attribute (used instead of the
                             regular `grad`) which is a pre-allocated buffer of the correct
                             size to accumulate gradients in.
    return_bias : bool, default = `False`
                 when set to `True`, this module will not apply the additive bias itself, but
                 instead return the bias value during the forward pass together with the
                 output of the linear transformation :math:`y = xA^T`. This is useful when
                 the bias addition can be fused to subsequent operations.
707
    params_dtype : torch.dtype, default = `torch.get_default_dtype()`
708
709
710
                  it controls the type used to allocate the initial parameters. Useful when
                  the model is trained with lower precision and the original FP32 parameters
                  would not fit in GPU memory.
711

712
713
714
715
716
717
718
719
720
721
722
    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        sequence_parallel: bool = False,
        fuse_wgrad_accumulation: bool = False,
        tp_group: Optional[dist_group_type] = None,
        tp_size: int = 1,
        get_rng_state_tracker: Optional[Callable] = None,
723
        rng_tracker_name: Optional[str] = None,
724
725
726
        init_method: Optional[Callable] = None,
        bias: bool = True,
        return_bias: bool = False,
727
        params_dtype: Optional[torch.dtype] = None,
728
        parallel_mode: Optional[str] = None,
cyanguwa's avatar
cyanguwa committed
729
        parameters_split: Optional[Union[Tuple[str, ...], Dict[str, int]]] = None,
730
        device: Union[torch.device, str] = "cuda",
731
732
        ub_overlap_rs: bool = False,
        ub_overlap_ag: bool = False,
733
        ub_name: Optional[str] = None,
734
735
    ) -> None:
        super().__init__()
736
737

        params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
738
739
740
741
742
743
        self.in_features = in_features
        self.out_features = out_features
        self.fuse_wgrad_accumulation = fuse_wgrad_accumulation
        self.use_bias = bias
        self.return_bias = return_bias
        self.apply_bias = bias and not return_bias
744
745
746
        self.ub_overlap_rs = ub_overlap_rs
        self.ub_overlap_ag = ub_overlap_ag
        if ub_overlap_rs or ub_overlap_ag:
747
748
            assert ub_name is not None, "Userbuffer name [string] is not set."
        self.ub_name = ub_name
749
        self.get_rng_state_tracker = get_rng_state_tracker
750
751
        self.rng_tracker_name = rng_tracker_name

752
753
        if device == "meta":
            assert parameters_split is None, "Cannot split module parameters on 'meta' device."
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
        if tp_group is None:
            self.tp_size = tp_size
            if tp_size == 1:
                self.set_tensor_parallel_group(tp_group)
        else:
            self.tp_size = get_distributed_world_size(tp_group)
            self.set_tensor_parallel_group(tp_group)
        self.set_nccl_overlap_warning_if_tp()

        self.parallel_mode = parallel_mode
        assert (
            self.parallel_mode in GemmParallelModes
        ), f"parallel_mode {parallel_mode} not supported"

        if self.parallel_mode == "column":
            self.out_features = divide(self.out_features, self.tp_size)
        elif self.parallel_mode == "row":
            self.in_features = divide(self.in_features, self.tp_size)

        self.sequence_parallel = (self.tp_size > 1) and sequence_parallel

775
776
777
        # Initialize params in FP8
        with_fp8_params = FP8GlobalStateManager.with_fp8_parameters()

778
779
780
781
782
783
784
785
        # Contiguous buffers for params
        weight_tensor = torch.empty(
            self.out_features,
            self.in_features,
            device=device,
            dtype=params_dtype,
        )
        bias_tensor = None
786
        if self.use_bias:
787
788
789
790
791
            bias_tensor = torch.empty(
                self.out_features,
                device=device,
                dtype=params_dtype,
            )
792

793
794
795
796
        # Configure parameter splits
        self.weight_names = []
        self.bias_names = []
        self.parameter_split_sizes = []
797
        if parameters_split is None:
798
799
800
801
802
803
            # Split into a single parameter by default
            self.weight_names = ["weight"]
            self.bias_names = ["bias"]
            self.parameter_split_sizes = [out_features]
        elif not parameters_split:
            raise ValueError("Cannot split weight buffer into 0 parameters")
cyanguwa's avatar
cyanguwa committed
804
        elif isinstance(parameters_split, dict):
805
806
807
808
809
810
811
812
813
814
815
816
            # Split parameters with provided sizes
            for name, split_size in parameters_split.items():
                self.weight_names.append(f"{name.rstrip('_')}_weight")
                self.bias_names.append(f"{name.rstrip('_')}_bias")
                self.parameter_split_sizes.append(split_size)
        elif all(isinstance(name, str) for name in parameters_split):
            # Split parameters evenly
            split_size = out_features // len(parameters_split)
            for name in parameters_split:
                self.weight_names.append(f"{name.rstrip('_')}_weight")
                self.bias_names.append(f"{name.rstrip('_')}_bias")
                self.parameter_split_sizes.append(split_size)
cyanguwa's avatar
cyanguwa committed
817
        else:
818
            raise TypeError("Invalid configuration for parameters split")
819

820
821
822
823
824
825
        # Make sure parameter splits are valid
        if sum(self.parameter_split_sizes) != out_features:
            raise ValueError(
                f"Trying to split weight buffer ({out_features=}) "
                f"with split sizes {self.parameter_split_sizes}"
            )
826

827
828
829
830
831
832
833
834
835
836
        # Adjust parameter splits for tensor-parallel distribution
        if self.parallel_mode == "column":
            for i, size in enumerate(self.parameter_split_sizes):
                if size % self.tp_size != 0:
                    raise RuntimeError(
                        f"Attempting to distribute a parameter with out_features={size} "
                        f"between {self.tp_size} tensor-parallel processes"
                    )
                self.parameter_split_sizes[i] = size // self.tp_size

837
838
839
840
841
        # Construct weight parameters
        # Note: Register weights together so that they are adjacent to
        # each other in Linear.parameters(). This makes it more likely
        # that they will stay contiguous if the weights are
        # manipulated externally, e.g. by FSDP.
842
843
844
845
846
847
848
849
        offset = 0
        for i, split_size in enumerate(self.parameter_split_sizes):
            split_start = offset
            offset += split_size
            split_end = offset

            # Check if parameters are subviews of buffers
            is_subview = (split_start, split_end) != (0, self.out_features)
850
            if is_subview and with_fp8_params:
851
                raise RuntimeError("Splitting Float8Tensor into multiple params is not supported")
852

853
            # Construct weight parameter
854
855
856
857
858
859
860
            self.register_parameter(
                self.weight_names[i],
                torch.nn.Parameter(weight_tensor[split_start:split_end]),
                init_fn=init_method,
                get_rng_state_tracker=get_rng_state_tracker,
                fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT,
            )
861

862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
        # Construct bias parameters if needed
        if self.use_bias:
            offset = 0
            for i, split_size in enumerate(self.parameter_split_sizes):
                split_start = offset
                offset += split_size
                split_end = offset
                self.register_parameter(
                    self.bias_names[i],
                    torch.nn.Parameter(bias_tensor[split_start:split_end]),
                    init_fn=init_method_constant(0.0),
                )
        else:
            for name in self.bias_names:
                bias = torch.Tensor().to(dtype=params_dtype, device=device)
                setattr(self, name, bias)
cyanguwa's avatar
cyanguwa committed
878

879
        if with_fp8_params:
880
881
            self.init_fp8_metadata()

882
        self.reset_parameters(defer_init=device == "meta")
883

884
885
886
887
888
889
890
        # For RPL, bias has to be added after TP collectives
        # So it cannot be fused with the GEMM
        if self.parallel_mode == "row" and self.apply_bias:
            self.gemm_bias_unfused_add = True
        else:
            self.gemm_bias_unfused_add = False

891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
    def reset_parameters(self, defer_init=False):
        super().reset_parameters(defer_init=defer_init)

        if not defer_init:
            # Set parallelism attributes for linear weights
            for weight in self.weight_names:
                set_tensor_model_parallel_attributes(
                    tensor=getattr(self, weight),
                    is_parallel=True,
                    dim=1 if self.parallel_mode == "row" else 0,
                    stride=1,
                )

            # Set parallelism attributes for linear biases
            if self.use_bias:
                for bias in self.bias_names:
                    if self.parallel_mode == "row":
                        setattr(getattr(self, bias), "sequence_parallel", self.sequence_parallel)
                    elif self.parallel_mode == "column":
                        set_tensor_model_parallel_attributes(getattr(self, bias), True, 0, 1)

912
    @no_torch_dynamo()
913
914
915
916
    def forward(
        self,
        inp: torch.Tensor,
        is_first_microbatch: Optional[bool] = None,
917
        fp8_output: Optional[bool] = False,
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
        """
        Apply the linear transformation to the input.

        Parameters
        ----------
        inp : torch.Tensor
             Input tensor.
        is_first_microbatch : {True, False, None}, default = None
                             During training using either gradient accumulation or
                             pipeline parallelism a minibatch of data is further split
                             into microbatches. Between the microbatches of the same minibatch
                             the model weights are not updated. Setting this parameter indicates
                             whether the current microbatch is the first in a minibatch or not.
                             When set, this parameter enables additional optimizations:

                             * during FP8 training, it allows caching of the FP8 versions of
                               the weights
                             * it also allows skipping gradient accumulation during the
                               first microbatch (since it is the first gradient being
                               produced)
        """

941
942
943
944
        skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
        if skip_fp8_weight_update is not None:
            is_first_microbatch = False

945
946
        with self.prepare_forward(
            inp,
947
            is_first_microbatch,
948
            allow_non_contiguous=isinstance(inp, QuantizedTensor),
949
        ) as inp:
950
951

            # Get concatenated weight and bias tensors
952
            unfused_weights = [self._fast_get_param(name) for name in self.weight_names]
953
            if any(isinstance(w, QuantizedTensor) for w in unfused_weights):
954
955
956
                if self.fp8:
                    if len(unfused_weights) != 1:
                        raise RuntimeError(
957
                            "Splitting QuantizedTensor into multiple params is not supported"
958
959
                        )
                else:
960
                    unfused_weights = [w.dequantize() for w in unfused_weights]
961
            weight_tensor = _noop_cat(unfused_weights)
962
            if self.use_bias:
963
                bias_tensor = _noop_cat([self._fast_get_param(name) for name in self.bias_names])
964
            else:
965
                bias_tensor = self._fast_get_param(self.bias_names[0])  # Unused
966

967
968
969
970
            # Initialize FP8 weights if needed
            weight_fp8 = None
            if self.fp8:
                if isinstance(weight_tensor, Float8Tensor):
971
972
973
974
                    # Make sure transpose cache is valid, if present
                    # Note: Transpose cache may have been invalidated
                    # externally, e.g. by optimizer.
                    if weight_tensor._transpose is not None:
975
976
977
978
979
980
                        weight_tensor.transpose_2d(
                            fill_cache=True,
                            noop_flag=skip_fp8_weight_update,
                        )
                else:
                    # FP8 cast to workspace buffer
981
                    update_workspace = is_first_microbatch is None or is_first_microbatch
982
983
984
985
986
987
988
                    weight_fp8 = self.get_fp8_workspace(
                        tensor=weight_tensor,
                        fp8_meta_forward=True,
                        fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT,
                        cache_name=(None if is_first_microbatch is None else "weight"),
                        update_workspace=update_workspace,
                        skip_update_flag=skip_fp8_weight_update,
989
                        fsdp_group=self.fsdp_group,
990
                    )
991

992
993
994
995
996
997
998
999
            if torch.is_grad_enabled():
                linear_fn = _Linear.apply
                args = []
            else:
                linear_fn = _Linear.forward
                args = [None]
            args += (
                weight_tensor,
1000
                weight_fp8,
1001
1002
1003
1004
1005
1006
1007
1008
                inp,
                bias_tensor,
                self.apply_bias and not self.gemm_bias_unfused_add,
                is_first_microbatch,
                self.fp8,
                self.fp8_calibration,
                self.fp8_meta,
                self.fuse_wgrad_accumulation,
1009
                is_cpu_offload_enabled(),
1010
1011
1012
1013
1014
1015
1016
                self.tp_group,
                self.tp_size,
                self.sequence_parallel,
                self.tp_size > 1,
                self.activation_dtype,
                self.parallel_mode,
                torch.is_grad_enabled(),
1017
1018
                self.ub_overlap_rs,
                self.ub_overlap_ag,
1019
                self.ub_name,
1020
                fp8_output,
1021
                self.fsdp_group,
1022
1023
1024
1025
1026
1027
1028
1029
1030
            )
            out = linear_fn(*args)

        if self.gemm_bias_unfused_add:
            out = out + cast_if_needed(bias_tensor, self.activation_dtype)

        if self.return_bias:
            return out, cast_if_needed(bias_tensor, self.activation_dtype)
        return out