linear.py 42.7 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.

"""Linear API"""
6
from typing import Any, Callable, Dict, Optional, Tuple, Union
7
8
9

import torch

10
import transformer_engine_torch as tex
11
12
13
14
15
16
17
18
19

from .base import (
    get_workspace,
    get_ub,
    TransformerEngineBaseModule,
    _2X_ACC_FPROP,
    _2X_ACC_DGRAD,
    _2X_ACC_WGRAD,
)
20
from ._common import _noop_cat
21
from ..fp8 import get_fp8_te_dtype, FP8GlobalStateManager
22
23
24
from ..utils import (
    divide,
    cast_if_needed,
25
    assert_dim_for_fp8_exec,
26
    clear_tensor_data,
27
    init_method_constant,
28
    requires_grad,
29
30
31
32
33
34
35
)
from ..distributed import (
    set_tensor_model_parallel_attributes,
    get_distributed_world_size,
    allreduce,
    reduce_scatter_along_first_dim,
    gather_along_first_dim,
36
37
    _fsdp_scatter_tensors,
    _fsdp_gather_tensors,
38
39
40
41
42
43
44
45
)
from ..cpp_extensions import (
    fp8_gemm,
    gemm,
    fp8_cast_transpose_fused,
    cast_to_fp8,
)
from ..constants import GemmParallelModes, dist_group_type
46
from ..jit import no_torch_dynamo
47
from ..graph import is_graph_capturing
48
from ..float8_tensor import Float8Tensor
49
from ..export import is_in_onnx_export_mode
50
from ..tensor import QuantizedTensor
51

52
53
54
55
56
57
58
59
60
61
62
__all__ = ["Linear"]


class _Linear(torch.autograd.Function):
    """Linear semi-top level module
    Calls custom cuda extensions.
    """

    @staticmethod
    def forward(
        ctx,
63
        weight: Union[Float8Tensor, torch.Tensor],
64
        weight_fp8: Optional[Float8Tensor],
65
66
67
68
69
70
71
72
        inp: torch.Tensor,
        bias: torch.Tensor,
        use_bias: bool,
        is_first_microbatch: Union[bool, None],
        fp8: bool,
        fp8_calibration: bool,
        fp8_meta: Dict[str, Any],
        fuse_wgrad_accumulation: bool,
73
        cpu_offloading: bool,
74
75
76
77
78
79
80
        tp_group: Union[dist_group_type, None],
        tp_size: int,
        sequence_parallel: bool,
        tensor_parallel: bool,
        activation_dtype: torch.dtype,
        parallel_mode: Union[str, None],
        is_grad_enabled: bool,
81
82
        ub_overlap_rs: bool,
        ub_overlap_ag: bool,
83
        ub_name: str,
84
        fp8_output: bool,
85
        fsdp_group: Union[dist_group_type, None],
86
    ) -> torch.Tensor:
87
88
        is_input_fp8 = isinstance(inp, Float8Tensor)

89
90
91
        # Make sure input dimensions are compatible
        in_features = weight.shape[-1]
        assert inp.shape[-1] == in_features, "GEMM not possible"
92
        inputmat = inp.view(-1, in_features)
93
        if fp8:
94
95
            assert_dim_for_fp8_exec(inputmat)
            assert_dim_for_fp8_exec(weight)
96

97
98
        tp_world_size = get_distributed_world_size(tp_group)
        ub_overlap_rs = False if tp_world_size == 1 else ub_overlap_rs
99
100

        # Cast input to expected dtype
101
        inputmat = cast_if_needed(inputmat, activation_dtype)
102
        inputmat_t = None
103
        inputmat_no_fp8 = inputmat
104
        inputmat_scale_inv = None
105

106
107
        if fp8:
            fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
108
            if isinstance(inputmat, Float8Tensor):
109
                inputmat_scale_inv = inputmat._scale_inv
110
            else:
111
                inputmat_scale_inv = torch.empty([1], dtype=torch.float32, device=inputmat.device)
112
113
114
115
116
117
118
119
120
121
122
123
                if (
                    not fp8_meta["recipe"].override_linear_precision.wgrad
                    and is_grad_enabled
                    and weight.requires_grad
                    and not sequence_parallel
                ):
                    # FP8 input for forward, FP8 input transpose for backward wgrad
                    inputmat, inputmat_t = fp8_cast_transpose_fused(
                        inputmat,
                        fp8_meta["scaling_fwd"],
                        tex.FP8FwdTensors.GEMM1_INPUT,
                        fp8_dtype_forward,
124
                        scale_inv=inputmat_scale_inv,
125
126
127
128
129
130
131
132
                    )
                else:
                    # FP8 input for forward
                    inputmat = cast_to_fp8(
                        inputmat,
                        fp8_meta["scaling_fwd"],
                        tex.FP8FwdTensors.GEMM1_INPUT,
                        fp8_dtype_forward,
133
                        scale_inv=inputmat_scale_inv,
134
                    )
135

136
137
138
139
140
141
142
143
144
145
146
147
            # Hack for ONNX export
            # Note: ONNX models are represented as a graph of tensor
            # operations, so the in-place scale-inv update doesn't fit
            # very well. We work around this by making it look like
            # the scale-inv tensor is initialized with a copy.
            # Note: ONNX export expects FP8 scales can be represented
            # with constant ops. However, copying into a buffer
            # involves an expand op for array broadcasting. We work
            # around this by filling the buffer instead.
            if is_in_onnx_export_mode():
                inputmat_scale_inv.fill_(inputmat_scale_inv.item())

148
149
150
151
152
153
        # Column Parallel Linear
        if parallel_mode == "column" and sequence_parallel:
            inputmat_total, _ = gather_along_first_dim(inputmat, tp_group)
        else:
            inputmat_total = inputmat
        if fp8:
154
            bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype
155
156
            bias = cast_if_needed(bias, bias_dtype) if use_bias else bias

157
158
            # Use FP8 weights
            if weight_fp8 is None:
159
                weight_fp8 = weight
160

161
            assert isinstance(weight_fp8, Float8Tensor)
162

163
            if fp8_output:
164
165
166
167
                proj_out_index, meta_tensor, proj_out_tetype, proj_out_pttype = (
                    tex.FP8FwdTensors.GEMM1_OUTPUT,
                    fp8_meta["scaling_fwd"],
                    fp8_dtype_forward,
168
169
                    torch.uint8,
                )
170
171
            else:
                proj_out_index, meta_tensor, proj_out_tetype, proj_out_pttype = (
172
173
174
175
176
                    None,
                    None,
                    None,
                    activation_dtype,
                )
177

178
            if ub_overlap_rs:
179
                ub_obj_projout = get_ub(ub_name + "_fprop")
180
181
182
                out = ub_obj_projout.get_ubuf_output(1)
                dim_size = list(inputmat_total.size())
                dim_size[0] = dim_size[0] // tp_world_size
183
                dim_size[1] = weight_fp8.size(0)
184
                rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device)
185
186
                if ub_obj_projout.is_p2p_overlap():
                    if ub_obj_projout.is_atomic_gemm():
187
                        ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P
188
189
190
191
192
193
194
                    else:
                        ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P
                else:
                    if ub_obj_projout.is_atomic_gemm():
                        ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_RS
                    else:
                        ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS
195
196
197
198
199
200
                if ub_obj_projout.is_fp8_ubuf():
                    proj_out_index = tex.FP8FwdTensors.GEMM1_OUTPUT
                    meta_tensor = fp8_meta["scaling_fwd"]
                    proj_out_tetype = fp8_dtype_forward
                    proj_out_pttype = torch.uint8
                    ub_obj_projout.set_ubuf_scale_inv(meta_tensor.scale_inv[proj_out_index])
201
202
            else:
                dim_size = list(inputmat_total.size())
203
                dim_size[1] = weight_fp8.size(0)
204
                out = torch.empty(dim_size, dtype=proj_out_pttype, device=inputmat_total.device)
205
206

            _ = fp8_gemm(
207
                weight_fp8._data,
208
209
210
                weight_fp8._scale_inv,
                0,
                weight_fp8._fp8_dtype,
211
212
213
214
215
                (
                    inputmat_total._data
                    if isinstance(inputmat_total, Float8Tensor)
                    else inputmat_total
                ),
216
217
                inputmat_scale_inv,
                0,
218
                fp8_dtype_forward,
219
                proj_out_pttype,
220
221
222
223
224
                get_workspace(),
                bias=bias,
                use_bias=use_bias,
                use_split_accumulator=_2X_ACC_FPROP,
                out=out,
225
226
227
                ub_algo=ub_algo if ub_overlap_rs else None,
                ub=ub_obj_projout if ub_overlap_rs else None,
                extra_output_tensor=rs_out if ub_overlap_rs else None,
228
                out_index=proj_out_index,
229
230
                fp8_meta_tensor=meta_tensor,
                D_dtype=proj_out_tetype,
231
            )
232
            if fp8_output:
233
234
                out = Float8Tensor(
                    data=out,
235
236
237
238
239
240
                    fp8_meta=fp8_meta,
                    fp8_meta_forward=True,
                    fp8_meta_index=tex.FP8FwdTensors.GEMM1_OUTPUT,
                    fp8_dtype=fp8_dtype_forward,
                    dtype=activation_dtype,
                )
241
242
243
244
245
246
247
        else:
            # Cast for native AMP
            weight = cast_if_needed(weight, activation_dtype)
            bias = cast_if_needed(bias, activation_dtype) if use_bias else bias

            if fp8_calibration:
                # amax of input
248
                amin, amax = inputmat_total.aminmax()
249
250
251
                fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_INPUT] = torch.max(
                    -amin, amax
                ).float()
252
                # amax of weight
253
                amin, amax = weight.aminmax()
254
255
256
                fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_WEIGHT] = torch.max(
                    -amin, amax
                ).float()
257

258
            if ub_overlap_rs:
259
                ub_obj_projout = get_ub(ub_name + "_fprop")
260
261
                out = ub_obj_projout.get_ubuf_output(1)
                dim_size = list(inputmat_total.size())
262
                dim_size[0] = dim_size[0] // get_distributed_world_size(tp_group)
263
264
                dim_size[1] = weight.size(0)
                rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device)
265
266
267
268
                if ub_obj_projout.is_p2p_overlap():
                    ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P
                else:
                    ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS
269
270
271
272
273
            else:
                dim_size = list(inputmat_total.size())
                dim_size[1] = weight.size(0)
                out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device)

274
            _ = gemm(
275
276
277
278
279
280
281
                weight,
                inputmat_total,
                activation_dtype,
                get_workspace(),
                bias=bias,
                use_bias=use_bias,
                out=out,
282
283
284
                ub_algo=ub_algo if ub_overlap_rs else None,
                ub=ub_obj_projout if ub_overlap_rs else None,
                extra_output_tensor=rs_out if ub_overlap_rs else None,
285
286
287
            )

        if is_grad_enabled:
288
289
290
291
292
293
294
295
            saved_inputmat = None
            saved_inputmat_t = None
            if weight.requires_grad:
                if fp8 and not fp8_meta["recipe"].override_linear_precision.wgrad:
                    if inputmat_t is None:
                        saved_inputmat = inputmat
                    else:
                        saved_inputmat_t = inputmat_t
296
297
                        if cpu_offloading:
                            saved_inputmat_t.activation_offloading = True
298
299
                else:
                    saved_inputmat = inputmat_no_fp8
300
301

                if cpu_offloading:
302
303
                    if fp8 and weight_fp8 is not None:
                        weight_fp8.weight_offloading = True
304
305
306
307
308
                    weight.weight_offloading = True

                    if saved_inputmat is not None:
                        saved_inputmat.activation_offloading = True

309
310
311
312
313
            # Scatter intermediate/activation tensors saved for the backward pass
            # NOTE: FSDP sharding is not valid for models initialized with primary Fp8 weights
            ctx.fsdp_group = fsdp_group
            ctx.fsdp_shapes = _fsdp_scatter_tensors(
                fsdp_group,
314
315
                saved_inputmat,  # None if fp8 == False
                saved_inputmat_t,  # None if fp8 == False AND not is_grad_enabled
316
317
318
                weight_fp8 if fp8 and not isinstance(weight, Float8Tensor) else None,
            )

319
            ctx.save_for_backward(
320
321
                saved_inputmat,
                saved_inputmat_t,
322
                inputmat_scale_inv,
323
                weight,
324
                weight_fp8,
325
                weight.main_grad if cpu_offloading and fuse_wgrad_accumulation else None,
326
            )
327

328
329
330
331
            ctx.activation_dtype = activation_dtype
            ctx.fp8 = fp8
            ctx.fp8_meta = fp8_meta
            ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
332
            ctx.cpu_offloading = cpu_offloading
333
334
335
336
337
338
339
            ctx.is_first_microbatch = is_first_microbatch
            ctx.use_bias = use_bias
            ctx.sequence_parallel = sequence_parallel
            ctx.tensor_parallel = tensor_parallel
            ctx.inp_shape = inp.shape
            ctx.parallel_mode = parallel_mode
            ctx.tp_group = tp_group
340
            ctx.ub_overlap_ag = ub_overlap_ag
341
            ctx.ub_name = ub_name
342
343
            ctx.tp_size = tp_size
            ctx.requires_dgrad = inp.requires_grad
344
            ctx.is_input_fp8 = is_input_fp8
345
346
347
            ctx.reduce_and_update_bwd_fp8_tensors = False
            if ctx.fp8 and requires_grad(inp, weight, bias):
                ctx.reduce_and_update_bwd_fp8_tensors = (
348
349
350
                    ctx.reduce_and_update_bwd_fp8_tensors
                    or FP8GlobalStateManager.is_first_fp8_module()
                )
351
352

        # Row Parallel Linear
353
        if ub_overlap_rs:
354
355
356
357
358
359
360
361
362
363
            out = rs_out
        elif parallel_mode == "row" and sequence_parallel:
            out, _ = reduce_scatter_along_first_dim(out, tp_group)
        elif parallel_mode == "row" and tensor_parallel:
            out, _ = allreduce(out, tp_group)

        # [*, in_features] -> [*, out_features] except first dimension changes for SP
        return out.view(-1, *inp.shape[1:-1], out.shape[-1])

    @staticmethod
364
    def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
365
        if isinstance(grad_output, Float8Tensor):
366
            ctx.fp8_meta["scaling_bwd"].scale_inv[
367
368
                tex.FP8BwdTensors.GRAD_OUTPUT1
            ] = grad_output._scale_inv
369

370
        with torch.cuda.nvtx.range("_Linear_backward"):
371
372
373
            (
                inputmat,
                inputmat_t,
374
                inputmat_scale_inv,
375
                weight,
376
                weight_fp8,
377
                main_grad,
378
            ) = ctx.saved_tensors
379

380
381
382
383
384
385
386
387
            # Gather intermediate/activation tensors if needed
            # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already
            #       shards/unshards the base weights so we don't do it ourselves
            _fsdp_gather_tensors(
                ctx.fsdp_group,
                ctx.fsdp_shapes,
                inputmat,
                inputmat_t,
388
389
                weight_fp8 if ctx.fp8 and not isinstance(weight, Float8Tensor) else None,
            )
390

391
            if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation:
392
                weight = torch.nn.Parameter(weight, weight.requires_grad)
393
394
                weight.main_grad = main_grad

395
396
397
            tp_world_size = get_distributed_world_size(ctx.tp_group)
            ctx.ub_overlap_ag = False if tp_world_size == 1 else ctx.ub_overlap_ag
            if ctx.ub_overlap_ag:
398
399
                dim_size = list(grad_output.size())
                dim_size[0] = dim_size[0] * tp_world_size
400
                ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad")
401
402
403
404
                if ctx.ub_obj_gradout.is_atomic_gemm():
                    ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P
                else:
                    ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P
405

406
407
408
409
410
411
412
413
414
415
416
            (
                grad_output,
                grad_output_c,
                grad_output_t,
                grad_bias,
            ) = TransformerEngineBaseModule.grad_output_preprocess(
                ctx, grad_output, ctx.parallel_mode == "row"
            )

            # Column Parallel Linear
            # Overlap input AG with dgrad
417
418
419
            inputmat_total = None
            inputmat_t_total = None
            handle = None
420
            if weight.requires_grad and ctx.parallel_mode == "column" and ctx.sequence_parallel:
421
422
423
                inputmat_total, handle = gather_along_first_dim(
                    inputmat, ctx.tp_group, async_op=ctx.requires_dgrad
                )
424
425
            else:
                inputmat_total = inputmat
426
                inputmat_t_total = inputmat_t
427
428
429
430
431
432
433
434
435

            if ctx.is_first_microbatch is not None:
                accumulate_wgrad_into_param_main_grad = (
                    ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch
                )
            else:
                accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation

            if ctx.fp8:
436
437
                fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
                fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False)
438
439
440

            if ctx.requires_dgrad:
                if ctx.fp8:
441
442
443
444
445
                    if ctx.is_input_fp8:
                        out_index, meta_tensor, output_te_dtype, output_dtype = (
                            tex.FP8BwdTensors.GRAD_INPUT1,
                            ctx.fp8_meta["scaling_bwd"],
                            fp8_dtype_backward,
446
447
                            torch.uint8,
                        )
448
449
                    else:
                        out_index, meta_tensor, output_te_dtype, output_dtype = (
450
451
452
453
454
                            None,
                            None,
                            None,
                            ctx.activation_dtype,
                        )
455
                    dgrad, _ = fp8_gemm(
456
457
458
459
                        weight_fp8.transpose_2d(),
                        weight_fp8._scale_inv,
                        0,
                        weight_fp8._fp8_dtype,
460
461
462
463
                        grad_output_c,
                        ctx.fp8_meta["scaling_bwd"].scale_inv,
                        tex.FP8BwdTensors.GRAD_OUTPUT1,
                        fp8_dtype_backward,
464
                        output_dtype,
465
466
                        get_workspace(),
                        use_split_accumulator=_2X_ACC_DGRAD,
467
468
                        ub_algo=ub_algo if ctx.ub_overlap_ag else None,
                        ub=ctx.ub_obj_gradout if ctx.ub_overlap_ag else None,
469
470
471
                        out_index=out_index,
                        fp8_meta_tensor=meta_tensor,
                        D_dtype=output_te_dtype,
472
                    )
473
                    if output_dtype == torch.uint8:
474
475
                        dgrad = Float8Tensor(
                            data=dgrad,
476
477
478
479
480
                            fp8_meta=ctx.fp8_meta,
                            fp8_meta_forward=False,
                            fp8_meta_index=tex.FP8BwdTensors.GRAD_INPUT1,
                            fp8_dtype=fp8_dtype_backward,
                            dtype=ctx.activation_dtype,
481
                        )
482
483
484
485
486
487
488
489
                else:
                    dgrad, _, _ = gemm(
                        weight,
                        grad_output,
                        ctx.activation_dtype,
                        get_workspace(),
                        layout="NN",
                        grad=True,
490
491
492
493
494
                        ub_algo=(
                            tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P
                            if ctx.ub_overlap_ag
                            else None
                        ),
495
                        ub=ctx.ub_obj_gradout if ctx.ub_overlap_ag else None,
496
497
498
499
                    )

                # Overlap dgrad-RS/AR with wgrad
                if ctx.parallel_mode == "column" and ctx.sequence_parallel:
500
501
                    if handle is not None:
                        handle.wait()
502
503
504
505
506
507
508
509
510
511
                    dgrad, handle = reduce_scatter_along_first_dim(
                        dgrad, ctx.tp_group, async_op=True
                    )
                elif ctx.parallel_mode == "column" and ctx.tensor_parallel:
                    dgrad, handle = allreduce(dgrad, ctx.tp_group, async_op=True)

            if weight.requires_grad:
                if ctx.fp8:
                    # WGRAD
                    if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad:
512
                        if ctx.ub_overlap_ag:
513
514
515
516
                            if isinstance(grad_output_c, Float8Tensor):
                                grad_output_t = grad_output_c.transpose_2d()
                            else:
                                grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward)
517
                        if inputmat_t_total is None:
518
519
520
521
                            if isinstance(inputmat_total, Float8Tensor):
                                inputmat_t_total = inputmat_total.transpose_2d()
                            else:
                                inputmat_t_total = tex.fp8_transpose(
522
523
                                    inputmat_total, fp8_dtype_backward
                                )
524
                        wgrad, _ = fp8_gemm(
525
526
527
528
529
                            (
                                inputmat_t_total._data
                                if isinstance(inputmat_t_total, Float8Tensor)
                                else inputmat_t_total
                            ),
530
531
                            inputmat_scale_inv,
                            0,
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
                            fp8_dtype_forward,
                            grad_output_t,
                            ctx.fp8_meta["scaling_bwd"].scale_inv,
                            tex.FP8BwdTensors.GRAD_OUTPUT1,
                            fp8_dtype_backward,
                            ctx.activation_dtype,
                            get_workspace(),
                            accumulate=accumulate_wgrad_into_param_main_grad,
                            out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
                            use_split_accumulator=_2X_ACC_WGRAD,
                        )
                    else:
                        wgrad, _, _ = gemm(
                            inputmat_total,
                            grad_output,
                            ctx.activation_dtype,
                            get_workspace(),
                            layout="NT",
                            grad=True,
                            accumulate=accumulate_wgrad_into_param_main_grad,
                            out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
                        )
                else:
                    # WGRAD
                    wgrad, grad_bias, _ = gemm(
                        inputmat_total,
                        grad_output,
                        ctx.activation_dtype,
                        get_workspace(),
                        layout="NT",
                        grad=True,
                        use_bias=ctx.use_bias,
                        accumulate=accumulate_wgrad_into_param_main_grad,
                        out=weight.main_grad if ctx.fuse_wgrad_accumulation else None,
                    )
567
568
569
570

                # Deallocate input tensor
                clear_tensor_data(inputmat_total)
                clear_tensor_data(inputmat_t_total)
571
572
573
574
575
576
577
578

            # Column Parallel Linear
            if ctx.parallel_mode == "column" and ctx.tensor_parallel and handle is not None:
                handle.wait()

            if not ctx.use_bias:
                grad_bias = None

579
580
        if weight.requires_grad:
            # Handle custom DDP from mcore.
581
            if ctx.fuse_wgrad_accumulation and hasattr(weight, "grad_added_to_main_grad"):
582
                weight.grad_added_to_main_grad = True
583
584
585
586
587
588
589
                if getattr(weight, "zero_out_wgrad", False):
                    wgrad = torch.zeros(
                        weight.main_grad.shape,
                        dtype=weight.dtype,
                        device=torch.cuda.current_device(),
                        requires_grad=False,
                    )
590
                else:
591
592
593
594
595
596
                    wgrad = torch.empty(
                        weight.main_grad.shape,
                        dtype=weight.dtype,
                        device=torch.cuda.current_device(),
                        requires_grad=False,
                    )
597
598
599
600
            elif ctx.fuse_wgrad_accumulation:
                wgrad = None
        else:
            wgrad = None
601

602
        if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing():
603
604
            FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)

605
606
607
608
        # Scatter fp8 weight buffers
        if ctx.fp8 and not isinstance(weight, Float8Tensor):
            _fsdp_scatter_tensors(ctx.fsdp_group, weight_fp8)

609
        return (
610
            wgrad,
611
            None,  # weight_fp8
612
613
            dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None,
            grad_bias,
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
            None,  # use_bias
            None,  # is_first_microbatch
            None,  # fp8
            None,  # fp8_calibration
            None,  # fp8_meta
            None,  # fuse_wgrad_accumulation
            None,  # cpu_offloading
            None,  # tp_group
            None,  # tp_size
            None,  # sequence_parallel
            None,  # tensor_parallel
            None,  # activation_dtype
            None,  # parallel_mode
            None,  # is_grad_enabled
            None,  # ub_overlap_rs
            None,  # ub_overlap_ag
            None,  # ub_name
631
            None,  # fp8_output
632
            None,  # fsdp_group
633
634
635
636
        )


class Linear(TransformerEngineBaseModule):
637
    """Applies a linear transformation to the incoming data :math:`y = xA^T + b`
638
639
640
641
642
643
644
645
646
647
648
649
650
651

    On NVIDIA GPUs it is a drop-in replacement for `torch.nn.Linear`.

    Parameters
    ----------
    in_features : int
                 size of each input sample.
    out_features : int
                  size of each output sample.
    bias : bool, default = `True`
          if set to `False`, the layer will not learn an additive bias.
    init_method : Callable, default = `None`
                 used for initializing weights in the following way: `init_method(weight)`.
                 When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`.
652
653
654
655
    get_rng_state_tracker : Callable, default = `None`
                 used to get the random number generator state tracker for initilizeing weights.
    rng_tracker_name : str, default = `None`
                 the param passed to get_rng_state_tracker to get the specific rng tracker.
cyanguwa's avatar
cyanguwa committed
656
    parameters_split : Optional[Union[Tuple[str, ...], Dict[str, int]]], default = None
657
658
659
660
661
662
663
                      Configuration for splitting the weight and bias tensors along dim 0 into
                      multiple PyTorch parameters. If a list or tuple of strings is provided,
                      they are used to make the names of equally-sized parameters. If a dict
                      (preferably an OrderedDict) is provided, the keys are used as names and
                      values as split sizes along dim 0. The resulting parameters will have
                      names that end in `_weight` or `_bias`, so trailing underscores are
                      stripped from any provided names.
664
665
666
667
    device : Union[torch.device, str], default = "cuda"
          The device on which the parameters of the model will allocated. It is the user's
          responsibility to ensure all parameters are moved to the GPU before running the
          forward pass.
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698

    Parallelism parameters
    ----------------------
    sequence_parallel : bool, default = `False`
                       if set to `True`, uses sequence parallelism.
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
    tp_size : int, default = 1
             used as TP (tensor parallel) world size when TP groups are not formed during
             initialization. In this case, users must call the
             `set_tensor_parallel_group(tp_group)` method on the initialized module before the
             forward pass to supply the tensor parallel group needed for tensor and sequence
             parallel collectives.
    parallel_mode : {None, 'Column', 'Row'}, default = `None`
                   used to decide whether this Linear layer is Column Parallel Linear or Row
                   Parallel Linear as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
                   When set to `None`, no communication is performed.

    Optimization parameters
    -----------------------
    fuse_wgrad_accumulation : bool, default = 'False'
                             if set to `True`, enables fusing of creation and accumulation of
                             the weight gradient. When enabled, it is assumed that the weights
                             have an additional `main_grad` attribute (used instead of the
                             regular `grad`) which is a pre-allocated buffer of the correct
                             size to accumulate gradients in.
    return_bias : bool, default = `False`
                 when set to `True`, this module will not apply the additive bias itself, but
                 instead return the bias value during the forward pass together with the
                 output of the linear transformation :math:`y = xA^T`. This is useful when
                 the bias addition can be fused to subsequent operations.
699
    params_dtype : torch.dtype, default = `torch.get_default_dtype()`
700
701
702
                  it controls the type used to allocate the initial parameters. Useful when
                  the model is trained with lower precision and the original FP32 parameters
                  would not fit in GPU memory.
703

704
705
706
707
708
709
710
711
712
713
714
    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        sequence_parallel: bool = False,
        fuse_wgrad_accumulation: bool = False,
        tp_group: Optional[dist_group_type] = None,
        tp_size: int = 1,
        get_rng_state_tracker: Optional[Callable] = None,
715
        rng_tracker_name: Optional[str] = None,
716
717
718
        init_method: Optional[Callable] = None,
        bias: bool = True,
        return_bias: bool = False,
719
        params_dtype: Optional[torch.dtype] = None,
720
        parallel_mode: Optional[str] = None,
cyanguwa's avatar
cyanguwa committed
721
        parameters_split: Optional[Union[Tuple[str, ...], Dict[str, int]]] = None,
722
        device: Union[torch.device, str] = "cuda",
723
724
        ub_overlap_rs: bool = False,
        ub_overlap_ag: bool = False,
725
        ub_name: Optional[str] = None,
726
727
    ) -> None:
        super().__init__()
728
729

        params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
730
731
732
733
734
735
        self.in_features = in_features
        self.out_features = out_features
        self.fuse_wgrad_accumulation = fuse_wgrad_accumulation
        self.use_bias = bias
        self.return_bias = return_bias
        self.apply_bias = bias and not return_bias
736
737
738
        self.ub_overlap_rs = ub_overlap_rs
        self.ub_overlap_ag = ub_overlap_ag
        if ub_overlap_rs or ub_overlap_ag:
739
740
            assert ub_name is not None, "Userbuffer name [string] is not set."
        self.ub_name = ub_name
741
        self.get_rng_state_tracker = get_rng_state_tracker
742
743
        self.rng_tracker_name = rng_tracker_name

744
745
        if device == "meta":
            assert parameters_split is None, "Cannot split module parameters on 'meta' device."
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
        if tp_group is None:
            self.tp_size = tp_size
            if tp_size == 1:
                self.set_tensor_parallel_group(tp_group)
        else:
            self.tp_size = get_distributed_world_size(tp_group)
            self.set_tensor_parallel_group(tp_group)
        self.set_nccl_overlap_warning_if_tp()

        self.parallel_mode = parallel_mode
        assert (
            self.parallel_mode in GemmParallelModes
        ), f"parallel_mode {parallel_mode} not supported"

        if self.parallel_mode == "column":
            self.out_features = divide(self.out_features, self.tp_size)
        elif self.parallel_mode == "row":
            self.in_features = divide(self.in_features, self.tp_size)

        self.sequence_parallel = (self.tp_size > 1) and sequence_parallel

767
768
769
        # Initialize params in FP8
        with_fp8_params = FP8GlobalStateManager.with_fp8_parameters()

770
771
772
773
774
775
776
777
        # Contiguous buffers for params
        weight_tensor = torch.empty(
            self.out_features,
            self.in_features,
            device=device,
            dtype=params_dtype,
        )
        bias_tensor = None
778
        if self.use_bias:
779
780
781
782
783
            bias_tensor = torch.empty(
                self.out_features,
                device=device,
                dtype=params_dtype,
            )
784

785
786
787
788
        # Configure parameter splits
        self.weight_names = []
        self.bias_names = []
        self.parameter_split_sizes = []
789
        if parameters_split is None:
790
791
792
793
794
795
            # Split into a single parameter by default
            self.weight_names = ["weight"]
            self.bias_names = ["bias"]
            self.parameter_split_sizes = [out_features]
        elif not parameters_split:
            raise ValueError("Cannot split weight buffer into 0 parameters")
cyanguwa's avatar
cyanguwa committed
796
        elif isinstance(parameters_split, dict):
797
798
799
800
801
802
803
804
805
806
807
808
            # Split parameters with provided sizes
            for name, split_size in parameters_split.items():
                self.weight_names.append(f"{name.rstrip('_')}_weight")
                self.bias_names.append(f"{name.rstrip('_')}_bias")
                self.parameter_split_sizes.append(split_size)
        elif all(isinstance(name, str) for name in parameters_split):
            # Split parameters evenly
            split_size = out_features // len(parameters_split)
            for name in parameters_split:
                self.weight_names.append(f"{name.rstrip('_')}_weight")
                self.bias_names.append(f"{name.rstrip('_')}_bias")
                self.parameter_split_sizes.append(split_size)
cyanguwa's avatar
cyanguwa committed
809
        else:
810
            raise TypeError("Invalid configuration for parameters split")
811

812
813
814
815
816
817
        # Make sure parameter splits are valid
        if sum(self.parameter_split_sizes) != out_features:
            raise ValueError(
                f"Trying to split weight buffer ({out_features=}) "
                f"with split sizes {self.parameter_split_sizes}"
            )
818

819
820
821
822
823
824
825
826
827
828
        # Adjust parameter splits for tensor-parallel distribution
        if self.parallel_mode == "column":
            for i, size in enumerate(self.parameter_split_sizes):
                if size % self.tp_size != 0:
                    raise RuntimeError(
                        f"Attempting to distribute a parameter with out_features={size} "
                        f"between {self.tp_size} tensor-parallel processes"
                    )
                self.parameter_split_sizes[i] = size // self.tp_size

829
830
831
832
833
        # Construct weight parameters
        # Note: Register weights together so that they are adjacent to
        # each other in Linear.parameters(). This makes it more likely
        # that they will stay contiguous if the weights are
        # manipulated externally, e.g. by FSDP.
834
835
836
837
838
839
840
841
        offset = 0
        for i, split_size in enumerate(self.parameter_split_sizes):
            split_start = offset
            offset += split_size
            split_end = offset

            # Check if parameters are subviews of buffers
            is_subview = (split_start, split_end) != (0, self.out_features)
842
            if is_subview and with_fp8_params:
843
                raise RuntimeError("Splitting Float8Tensor into multiple params is not supported")
844

845
            # Construct weight parameter
846
847
848
849
850
851
852
            self.register_parameter(
                self.weight_names[i],
                torch.nn.Parameter(weight_tensor[split_start:split_end]),
                init_fn=init_method,
                get_rng_state_tracker=get_rng_state_tracker,
                fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT,
            )
853

854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
        # Construct bias parameters if needed
        if self.use_bias:
            offset = 0
            for i, split_size in enumerate(self.parameter_split_sizes):
                split_start = offset
                offset += split_size
                split_end = offset
                self.register_parameter(
                    self.bias_names[i],
                    torch.nn.Parameter(bias_tensor[split_start:split_end]),
                    init_fn=init_method_constant(0.0),
                )
        else:
            for name in self.bias_names:
                bias = torch.Tensor().to(dtype=params_dtype, device=device)
                setattr(self, name, bias)
cyanguwa's avatar
cyanguwa committed
870

871
        if with_fp8_params:
872
873
            self.init_fp8_metadata()

874
        self.reset_parameters(defer_init=(device == "meta"))
875

876
877
878
879
880
881
882
        # For RPL, bias has to be added after TP collectives
        # So it cannot be fused with the GEMM
        if self.parallel_mode == "row" and self.apply_bias:
            self.gemm_bias_unfused_add = True
        else:
            self.gemm_bias_unfused_add = False

883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
    def reset_parameters(self, defer_init=False):
        super().reset_parameters(defer_init=defer_init)

        if not defer_init:
            # Set parallelism attributes for linear weights
            for weight in self.weight_names:
                set_tensor_model_parallel_attributes(
                    tensor=getattr(self, weight),
                    is_parallel=True,
                    dim=1 if self.parallel_mode == "row" else 0,
                    stride=1,
                )

            # Set parallelism attributes for linear biases
            if self.use_bias:
                for bias in self.bias_names:
                    if self.parallel_mode == "row":
                        setattr(getattr(self, bias), "sequence_parallel", self.sequence_parallel)
                    elif self.parallel_mode == "column":
                        set_tensor_model_parallel_attributes(getattr(self, bias), True, 0, 1)

904
    @no_torch_dynamo()
905
906
907
908
    def forward(
        self,
        inp: torch.Tensor,
        is_first_microbatch: Optional[bool] = None,
909
        fp8_output: Optional[bool] = False,
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
        """
        Apply the linear transformation to the input.

        Parameters
        ----------
        inp : torch.Tensor
             Input tensor.
        is_first_microbatch : {True, False, None}, default = None
                             During training using either gradient accumulation or
                             pipeline parallelism a minibatch of data is further split
                             into microbatches. Between the microbatches of the same minibatch
                             the model weights are not updated. Setting this parameter indicates
                             whether the current microbatch is the first in a minibatch or not.
                             When set, this parameter enables additional optimizations:

                             * during FP8 training, it allows caching of the FP8 versions of
                               the weights
                             * it also allows skipping gradient accumulation during the
                               first microbatch (since it is the first gradient being
                               produced)
        """

933
934
935
936
        skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
        if skip_fp8_weight_update is not None:
            is_first_microbatch = False

937
938
        with self.prepare_forward(
            inp,
939
            is_first_microbatch,
940
            allow_non_contiguous=isinstance(inp, QuantizedTensor),
941
        ) as inp:
942
943

            # Get concatenated weight and bias tensors
944
            unfused_weights = [getattr(self, name) for name in self.weight_names]
945
            if any(isinstance(w, QuantizedTensor) for w in unfused_weights):
946
947
948
                if self.fp8:
                    if len(unfused_weights) != 1:
                        raise RuntimeError(
949
                            "Splitting QuantizedTensor into multiple params is not supported"
950
951
                        )
                else:
952
                    unfused_weights = [w.dequantize() for w in unfused_weights]
953
            weight_tensor = _noop_cat(unfused_weights)
954
955
956
            if self.use_bias:
                bias_tensor = _noop_cat(
                    [getattr(self, name) for name in self.bias_names],
957
958
                )
            else:
959
                bias_tensor = getattr(self, self.bias_names[0])  # Unused
960

961
962
963
964
            # Initialize FP8 weights if needed
            weight_fp8 = None
            if self.fp8:
                if isinstance(weight_tensor, Float8Tensor):
965
966
967
968
                    # Make sure transpose cache is valid, if present
                    # Note: Transpose cache may have been invalidated
                    # externally, e.g. by optimizer.
                    if weight_tensor._transpose is not None:
969
970
971
972
973
974
                        weight_tensor.transpose_2d(
                            fill_cache=True,
                            noop_flag=skip_fp8_weight_update,
                        )
                else:
                    # FP8 cast to workspace buffer
975
                    update_workspace = is_first_microbatch is None or is_first_microbatch
976
977
978
979
980
981
982
                    weight_fp8 = self.get_fp8_workspace(
                        tensor=weight_tensor,
                        fp8_meta_forward=True,
                        fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT,
                        cache_name=(None if is_first_microbatch is None else "weight"),
                        update_workspace=update_workspace,
                        skip_update_flag=skip_fp8_weight_update,
983
                        fsdp_group=self.fsdp_group,
984
                    )
985

986
987
            from ..cpu_offload import CPUOffloadEnabled

988
989
990
991
992
993
994
995
            if torch.is_grad_enabled():
                linear_fn = _Linear.apply
                args = []
            else:
                linear_fn = _Linear.forward
                args = [None]
            args += (
                weight_tensor,
996
                weight_fp8,
997
998
999
1000
1001
1002
1003
1004
                inp,
                bias_tensor,
                self.apply_bias and not self.gemm_bias_unfused_add,
                is_first_microbatch,
                self.fp8,
                self.fp8_calibration,
                self.fp8_meta,
                self.fuse_wgrad_accumulation,
1005
                CPUOffloadEnabled,
1006
1007
1008
1009
1010
1011
1012
                self.tp_group,
                self.tp_size,
                self.sequence_parallel,
                self.tp_size > 1,
                self.activation_dtype,
                self.parallel_mode,
                torch.is_grad_enabled(),
1013
1014
                self.ub_overlap_rs,
                self.ub_overlap_ag,
1015
                self.ub_name,
1016
                fp8_output,
1017
                self.fsdp_group,
1018
1019
1020
1021
1022
1023
1024
1025
1026
            )
            out = linear_fn(*args)

        if self.gemm_bias_unfused_add:
            out = out + cast_if_needed(bias_tensor, self.activation_dtype)

        if self.return_bias:
            return out, cast_if_needed(bias_tensor, self.activation_dtype)
        return out