linear.py 30 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.
"""Linear API"""

6
import warnings
7
from typing import Union, Tuple, Dict, Any, Optional
8
9
10
11
12

import paddle
import paddle.nn.functional as F
from paddle.nn.initializer import Constant

13
14
15
16
17
18
19
from .base import (
    TransformerEngineBaseLayer,
    get_workspace,
    _2X_ACC_FPROP,
    _2X_ACC_DGRAD,
    _2X_ACC_WGRAD,
)
20

21
from ..constants import FP8FwdTensors, FP8BwdTensors, GemmParallelModes, dist_group_type
22
from ..cpp_extensions import gemm, fp8_gemm, cast_to_fp8, cast_transpose, transpose
23
from ..distributed import (
24
    allgather,
25
26
27
    allreduce,
    get_tp_group_and_world_size,
    identity,
28
    reduce_scatter,
29
30
31
    track_rng_state,
    set_tensor_dist_attr,
    set_weight_tensor_dist_attr,
32
    mark_as_sequence_parallel_parameter,
33
34
)
from ..fp8 import get_fp8_te_dtype
35
from ..utils import (
36
    assert_dim_for_fp8_forward_exec,
37
38
    cast_if_needed,
    cast_if_needed_inplace,
39
    divide,
40
    get_bias_dtype,
Tian Zheng's avatar
Tian Zheng committed
41
42
    save_for_backward_allow_none,
    saved_tensor_allow_none,
43
    clear_tensor_data,
44
45
)

46
__all__ = ["Linear"]
47
48
49
50
51
52


def _linear_fwd_fp8(
    inputmat: paddle.Tensor,
    inputmat_fp8_index: FP8FwdTensors,
    weight: paddle.Tensor,
53
54
    weight_fp8: Optional[paddle.Tensor],
    weight_t_fp8: Optional[paddle.Tensor],
55
56
57
58
59
    weight_fp8_index: FP8FwdTensors,
    bias: paddle.Tensor,
    use_bias: bool,
    fp8_meta: Dict[str, Any],
    activation_dtype: paddle.dtype,
60
61
    parallel_mode: Union[str, None],
    tensor_parallel: bool,
62
    sequence_parallel: bool,
63
    tp_group: Union[dist_group_type, None],
64
    is_grad_enabled: bool,
65
    is_first_microbatch: bool = None,
66
67
68
69
):
    """FP8 path of Linear Fwd"""
    fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
    bias_dtype = get_bias_dtype(activation_dtype)
70
    bias = cast_if_needed(bias, bias_dtype)
71

72
73
74
75
76
    if parallel_mode == "column" and sequence_parallel:
        inputmat_total, _ = allgather(inputmat, tp_group)
    else:
        inputmat_total = inputmat

77
    update_fp8_weights = is_first_microbatch is None or is_first_microbatch
78
    if is_grad_enabled:
79
80
81
82
83
84
85
86
87
        if update_fp8_weights:
            weight_fp8, weight_t_fp8 = cast_transpose(
                weight,
                fp8_meta["scaling_fwd"],
                weight_fp8_index,
                fp8_dtype_forward,
                cast_out=weight_fp8,
                transpose_out=weight_t_fp8,
            )
88
89
    else:
        weight_t_fp8 = None
90
91
92
93
94
95
96
97
        if update_fp8_weights:
            weight_fp8 = cast_to_fp8(
                weight,
                fp8_meta["scaling_fwd"],
                weight_fp8_index,
                fp8_dtype_forward,
                out=weight_fp8,
            )
98

Shijie's avatar
Shijie committed
99
    out, _ = fp8_gemm(
100
101
102
103
        weight_fp8,
        fp8_meta["scaling_fwd"].scale_inv,
        weight_fp8_index,
        fp8_dtype_forward,
104
        inputmat_total,
105
106
107
108
109
110
111
112
113
114
        fp8_meta["scaling_fwd"].scale_inv,
        inputmat_fp8_index,
        fp8_dtype_forward,
        activation_dtype,
        get_workspace(),
        bias=bias,
        use_bias=use_bias,
        use_split_accumulator=_2X_ACC_FPROP,
    )

115
116
117
    if parallel_mode == "row" and sequence_parallel:
        out, _ = reduce_scatter(out, tp_group)
    elif parallel_mode == "row" and tensor_parallel:
Tian Zheng's avatar
Tian Zheng committed
118
        out, _ = allreduce(out, tp_group)
119

120
121
122
123
124
125
126
127
128
129
130
131
132
    return out, weight_t_fp8


def _linear_fwd_non_fp8(
    inputmat: paddle.Tensor,
    inputmat_fp8_index: FP8FwdTensors,
    weight: paddle.Tensor,
    weight_fp8_index: FP8FwdTensors,
    bias: paddle.Tensor,
    use_bias: bool,
    fp8_calibration: bool,
    fp8_meta: Dict[str, Any],
    activation_dtype: paddle.dtype,
133
134
    parallel_mode: Union[str, None],
    tensor_parallel: bool,
135
    sequence_parallel: bool,
136
    tp_group: Union[dist_group_type, None],
137
138
139
140
    activation: str = "",
):
    """Non-FP8 path of Linear Fwd"""

141
142
143
144
145
    if parallel_mode == "column" and sequence_parallel:
        inputmat_total, _ = allgather(inputmat, tp_group)
    else:
        inputmat_total = inputmat

146
147
148
149
150
151
152
153
154
    # Layer parameters are initialized as float32 dtype by default.
    # Cast the parameters to activation_dtype if the current dtype
    # does not match activation_dtype. The casting is inplace, so it
    # only needs to performed once throughout the traing process.
    weight = cast_if_needed_inplace(weight, activation_dtype)
    bias = cast_if_needed_inplace(bias, activation_dtype)

    if fp8_calibration:
        # amax of input
155
156
157
        fp8_meta["scaling_fwd"].amax_history[0, inputmat_fp8_index.value] = paddle.max(
            paddle.abs(inputmat_total)
        ).item()
158
        # amax of weight
159
160
161
        fp8_meta["scaling_fwd"].amax_history[0, weight_fp8_index.value] = paddle.max(
            paddle.abs(weight)
        ).item()
162
        fp8_meta["update_amax_and_scale_fwd"] = True
163

164
165
166
167
168
169
170
171
172
    outputs = gemm(
        weight,
        inputmat_total,
        activation_dtype,
        get_workspace(),
        bias=bias,
        use_bias=use_bias,
        gelu=(activation == "gelu"),
    )
173

174
    if activation == "gelu":
175
176
177
178
        gelu_out, _, out = outputs
        return out, gelu_out

    out, _, _ = outputs
179
180
181
182

    if parallel_mode == "row" and sequence_parallel:
        out, _ = reduce_scatter(out, tp_group)
    elif parallel_mode == "row" and tensor_parallel:
Tian Zheng's avatar
Tian Zheng committed
183
        out, _ = allreduce(out, tp_group)
184
185
186
187
188
189
190
    return out


def _linear_fwd(
    inputmat: paddle.Tensor,
    inputmat_fp8_index: FP8FwdTensors,
    weight: paddle.Tensor,
191
192
    weight_fp8: Optional[paddle.Tensor],
    weight_t_fp8: Optional[paddle.Tensor],
193
194
195
196
197
198
199
    weight_fp8_index: FP8FwdTensors,
    bias: paddle.Tensor,
    use_bias: bool,
    fp8_enabled: bool,
    fp8_calibration: bool,
    fp8_meta: Dict[str, Any],
    activation_dtype: paddle.dtype,
200
201
    parallel_mode: Union[str, None],
    tensor_parallel: bool,
202
    sequence_parallel: bool,
203
    tp_group: Union[dist_group_type, None],
204
    is_grad_enabled: bool,
205
    is_first_microbatch: bool = None,
206
207
208
209
210
211
):
    if fp8_enabled:
        out, weight_t_fp8 = _linear_fwd_fp8(
            inputmat,
            inputmat_fp8_index,
            weight,
212
213
            weight_fp8,
            weight_t_fp8,
214
215
216
217
218
            weight_fp8_index,
            bias,
            use_bias,
            fp8_meta,
            activation_dtype,
219
220
            parallel_mode,
            tensor_parallel,
221
            sequence_parallel,
222
            tp_group,
223
            is_grad_enabled,
224
            is_first_microbatch,
225
226
227
228
229
230
231
232
233
234
235
236
        )
    else:
        out = _linear_fwd_non_fp8(
            inputmat,
            inputmat_fp8_index,
            weight,
            weight_fp8_index,
            bias,
            use_bias,
            fp8_calibration,
            fp8_meta,
            activation_dtype,
237
238
            parallel_mode,
            tensor_parallel,
239
            sequence_parallel,
240
            tp_group,
241
242
243
244
245
246
247
248
249
250
251
        )
    return (
        out,
        weight_t_fp8 if fp8_enabled else None,
    )


def _linear_bwd_fp8(
    inputmat: paddle.Tensor,
    inputmat_t: paddle.Tensor,
    inputmat_fp8_index: FP8FwdTensors,
Shijie's avatar
Shijie committed
252
    weight: paddle.Tensor,
253
254
255
256
257
258
259
260
261
262
263
    weight_t_fp8: paddle.Tensor,
    weight_fp8_index: FP8FwdTensors,
    grad_output: paddle.Tensor,
    grad_output_c: paddle.Tensor,
    grad_output_t: paddle.Tensor,
    grad_output_fp8_index: FP8BwdTensors,
    fwd_scale_inverses: paddle.Tensor,
    fp8_meta: Dict[str, Any],
    requires_dgrad: bool,
    requires_wgrad: bool,
    activation_dtype: paddle.dtype,
264
265
    parallel_mode: Union[str, None],
    tensor_parallel: bool,
266
    sequence_parallel: bool,
267
    tp_group: Union[dist_group_type, None],
Shijie's avatar
Shijie committed
268
269
    fuse_wgrad_accumulation: bool,
    accumulate_wgrad_into_param_main_grad: bool,
270
):
Tian Zheng's avatar
Tian Zheng committed
271
    dgrad, wgrad, handle = None, None, None
272
273
274
275
276
277
278
279
280
281

    # Overlap input AG with dgrad
    inputmat_total = None
    inputmat_t_total = None
    if requires_wgrad and parallel_mode == "column" and sequence_parallel:
        inputmat_total, handle = allgather(inputmat, tp_group, sync_op=not requires_dgrad)
    else:
        inputmat_total = inputmat
        inputmat_t_total = inputmat_t

282
283
284
    fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
    fp8_dtype_backward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=False)
    if requires_dgrad:
Shijie's avatar
Shijie committed
285
        dgrad, _ = fp8_gemm(
286
287
288
289
290
291
292
293
294
295
296
297
            weight_t_fp8,
            fwd_scale_inverses,
            weight_fp8_index,
            fp8_dtype_forward,
            grad_output_c,
            fp8_meta["scaling_bwd"].scale_inv,
            grad_output_fp8_index,
            fp8_dtype_backward,
            activation_dtype,
            get_workspace(),
            use_split_accumulator=_2X_ACC_DGRAD,
        )
298
        clear_tensor_data(grad_output_c)
299
300
301
302
303
304
305

        # Overlap dgrad-RS/AR with wgrad
        if parallel_mode == "column" and sequence_parallel:
            if handle is not None:
                handle.wait()
            dgrad, handle = reduce_scatter(dgrad, tp_group, sync_op=False)
        elif parallel_mode == "column" and tensor_parallel:
Tian Zheng's avatar
Tian Zheng committed
306
            dgrad, handle = allreduce(dgrad, tp_group, sync_op=False)
307

308
309
    if requires_wgrad:
        if not fp8_meta["recipe"].override_linear_precision.wgrad:
310
311
            if inputmat_t_total is None:
                inputmat_t_total = transpose(inputmat_total, fp8_dtype_backward)
312
                clear_tensor_data(inputmat_total)
Shijie's avatar
Shijie committed
313
314

            wgrad, _ = fp8_gemm(
315
                inputmat_t_total,
316
317
318
319
320
321
322
                fwd_scale_inverses,
                inputmat_fp8_index,
                fp8_dtype_forward,
                grad_output_t,
                fp8_meta["scaling_bwd"].scale_inv,
                grad_output_fp8_index,
                fp8_dtype_backward,
Shijie's avatar
Shijie committed
323
                "float32" if fuse_wgrad_accumulation else activation_dtype,
324
                get_workspace(),
Shijie's avatar
Shijie committed
325
326
                accumulate=accumulate_wgrad_into_param_main_grad,
                out=weight.main_grad if fuse_wgrad_accumulation else None,
327
328
                use_split_accumulator=_2X_ACC_WGRAD,
            )
329
            clear_tensor_data(inputmat_t_total, grad_output_t)
330
331
        else:
            wgrad, _, _ = gemm(
332
                inputmat_total,
333
334
335
336
                grad_output,
                activation_dtype,
                get_workspace(),
                grad=True,
Shijie's avatar
Shijie committed
337
338
339
340
                accumulate=accumulate_wgrad_into_param_main_grad,
                layout="NT",
                out=weight.main_grad if fuse_wgrad_accumulation else None,
                out_dtype="float32" if fuse_wgrad_accumulation else None,
341
            )
342
            clear_tensor_data(inputmat_total)
Tian Zheng's avatar
Tian Zheng committed
343

Shijie's avatar
Shijie committed
344
345
346
        if fuse_wgrad_accumulation:
            weight.main_grad = wgrad

Tian Zheng's avatar
Tian Zheng committed
347
348
349
    if parallel_mode == "column" and tensor_parallel and handle is not None:
        handle.wait()

350
351
352
353
354
355
356
357
358
    return dgrad, wgrad


def _linear_bwd_non_fp8(
    inputmat: paddle.Tensor,
    weight: paddle.Tensor,
    grad_output: paddle.Tensor,
    requires_bgrad: bool,
    requires_dgrad: bool,
Tian Zheng's avatar
Tian Zheng committed
359
    requires_wgrad: bool,
360
    activation_dtype: paddle.dtype,
361
362
    parallel_mode: Union[str, None],
    tensor_parallel: bool,
363
    sequence_parallel: bool,
364
    tp_group: Union[dist_group_type, None],
Shijie's avatar
Shijie committed
365
366
    fuse_wgrad_accumulation: bool,
    accumulate_wgrad_into_param_main_grad: bool,
367
368
369
370
371
372
    gelu_input: Union[paddle.Tensor, None] = None,
    activation: str = "",
):
    """
    Performs Linear Backward. Optionally, fuses GELU backward and dbias.
    """
Tian Zheng's avatar
Tian Zheng committed
373
    dgrad, wgrad, bgrad, handle = None, None, None, None
374
375
376
377
378
379
380
381

    # Overlap input AG with dgrad
    inputmat_total = None
    if requires_wgrad and parallel_mode == "column" and sequence_parallel:
        inputmat_total, handle = allgather(inputmat, tp_group, sync_op=not requires_dgrad)
    else:
        inputmat_total = inputmat

382
383
384
385
386
387
388
    if requires_dgrad:
        dgrad, _, _ = gemm(
            weight,
            grad_output,
            activation_dtype,
            get_workspace(),
            layout="NN",
389
            gelu=(activation == "gelu"),
390
391
392
            gelu_input=gelu_input,
            grad=True,
        )
393
394
395
396
397
398
        # Overlap dgrad-RS/AR with wgrad
        if parallel_mode == "column" and sequence_parallel:
            if handle is not None:
                handle.wait()
            dgrad, handle = reduce_scatter(dgrad, tp_group, sync_op=False)
        elif parallel_mode == "column" and tensor_parallel:
Tian Zheng's avatar
Tian Zheng committed
399
            dgrad, handle = allreduce(dgrad, tp_group, sync_op=False)
400

401
402
    if requires_wgrad:
        wgrad, bgrad, _ = gemm(
403
            inputmat_total,
404
405
406
407
            grad_output,
            activation_dtype,
            get_workspace(),
            grad=True,
Shijie's avatar
Shijie committed
408
409
410
411
            accumulate=accumulate_wgrad_into_param_main_grad,
            layout="NT",
            out=weight.main_grad if fuse_wgrad_accumulation else None,
            out_dtype="float32" if fuse_wgrad_accumulation else None,
412
413
            use_bias=requires_bgrad,
        )
Shijie's avatar
Shijie committed
414
415
416
        if fuse_wgrad_accumulation:
            weight.main_grad = wgrad

417
418
419
    elif requires_bgrad:
        bgrad = grad_output.sum(axis=0)

Tian Zheng's avatar
Tian Zheng committed
420
421
422
    if parallel_mode == "column" and tensor_parallel and handle is not None:
        handle.wait()

423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
    return dgrad, wgrad, bgrad


def _linear_bwd(
    inputmat: paddle.Tensor,
    inputmat_t: paddle.Tensor,
    inputmat_fp8_index: FP8FwdTensors,
    weight: paddle.Tensor,
    weight_t_fp8: paddle.Tensor,
    weight_fp8_index: FP8FwdTensors,
    grad_output: paddle.Tensor,
    grad_output_c: paddle.Tensor,
    grad_output_t: paddle.Tensor,
    grad_output_fp8_index: FP8BwdTensors,
    fwd_scale_inverses: paddle.Tensor,
    requires_bgrad: bool,
    fp8_enabled: bool,
    fp8_meta: Dict[str, Any],
    requires_dgrad: bool,
Tian Zheng's avatar
Tian Zheng committed
442
    requires_wgrad: bool,
443
    activation_dtype: paddle.dtype,
444
445
    parallel_mode: Union[str, None],
    tensor_parallel: bool,
446
    sequence_parallel: bool,
447
    tp_group: Union[dist_group_type, None],
Shijie's avatar
Shijie committed
448
449
    fuse_wgrad_accumulation: bool,
    accumulate_wgrad_into_param_main_grad: bool,
450
451
452
453
454
455
456
):
    dgrad, wgrad, bgrad = None, None, None
    if fp8_enabled:
        dgrad, wgrad = _linear_bwd_fp8(
            inputmat,
            inputmat_t,
            inputmat_fp8_index,
Shijie's avatar
Shijie committed
457
            weight,
458
459
460
461
462
463
464
465
466
467
468
            weight_t_fp8,
            weight_fp8_index,
            grad_output,
            grad_output_c,
            grad_output_t,
            grad_output_fp8_index,
            fwd_scale_inverses,
            fp8_meta,
            requires_dgrad,
            requires_wgrad,
            activation_dtype,
469
470
            parallel_mode,
            tensor_parallel,
471
            sequence_parallel,
472
            tp_group,
Shijie's avatar
Shijie committed
473
474
            fuse_wgrad_accumulation=fuse_wgrad_accumulation,
            accumulate_wgrad_into_param_main_grad=accumulate_wgrad_into_param_main_grad,
475
476
477
478
479
480
481
482
        )
    else:
        dgrad, wgrad, bgrad = _linear_bwd_non_fp8(
            inputmat,
            weight,
            grad_output,
            requires_bgrad,
            requires_dgrad,
Tian Zheng's avatar
Tian Zheng committed
483
            requires_wgrad,
484
            activation_dtype,
485
486
            parallel_mode,
            tensor_parallel,
487
            sequence_parallel,
488
            tp_group,
Shijie's avatar
Shijie committed
489
490
            fuse_wgrad_accumulation=fuse_wgrad_accumulation,
            accumulate_wgrad_into_param_main_grad=accumulate_wgrad_into_param_main_grad,
491
492
        )
    return dgrad, wgrad, bgrad
493
494
495


class _Linear(paddle.autograd.PyLayer):
496
    """TE implementation of Linear"""
497
498
499
500
501

    @staticmethod
    def forward(
        ctx,
        weight: paddle.Tensor,
502
503
        weight_fp8: Optional[paddle.Tensor],
        weight_t_fp8: Optional[paddle.Tensor],
504
505
506
        inp: paddle.Tensor,
        bias: paddle.Tensor,
        use_bias: bool,
507
508
509
        fp8_enabled: bool,
        fp8_calibration: bool,
        fp8_meta: Dict[str, Any],
510
        activation_dtype: paddle.dtype,
511
        is_grad_enabled: bool,
512
513
        parallel_mode: Union[str, None],
        tensor_parallel: bool,
514
        sequence_parallel: bool,
515
516
        tp_group: Union[dist_group_type, None],
        tp_size: int,
Shijie's avatar
Shijie committed
517
        fuse_wgrad_accumulation: bool,
518
        is_first_microbatch: bool,
519
520
521
522
523
    ) -> paddle.Tensor:
        # Make sure input dimensions are compatible
        in_features = weight.shape[-1]
        assert inp.shape[-1] == in_features, "GEMM not possible"
        inputmat = inp.reshape((-1, in_features))
524
525
526
        if fp8_enabled:
            assert_dim_for_fp8_forward_exec(inputmat)
            assert_dim_for_fp8_forward_exec(weight)
527

528
529
530
        inputmat_no_fp8 = inputmat

        # FP8 casting
531
        inputmat_t = None
532
533
        if fp8_enabled:
            fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
534
535
536
537
538
            if (
                not fp8_meta["recipe"].override_linear_precision.wgrad
                and is_grad_enabled
                and not sequence_parallel
            ):
539
540
541
542
543
544
                inputmat, inputmat_t = cast_transpose(
                    inputmat,
                    fp8_meta["scaling_fwd"],
                    FP8FwdTensors.GEMM1_INPUT,
                    fp8_dtype_forward,
                )
545
            else:
546
                inputmat = cast_to_fp8(
547
548
549
550
                    inputmat,
                    fp8_meta["scaling_fwd"],
                    FP8FwdTensors.GEMM1_INPUT,
                    fp8_dtype_forward,
551
                )
552

553
554
        # GEMM Fwd
        out, weight_t_fp8 = _linear_fwd(
555
            inputmat,
556
            FP8FwdTensors.GEMM1_INPUT,
557
            weight,
558
559
            weight_fp8,
            weight_t_fp8,
560
561
562
563
564
565
566
            FP8FwdTensors.GEMM1_WEIGHT,
            bias,
            use_bias,
            fp8_enabled,
            fp8_calibration,
            fp8_meta,
            activation_dtype,
567
568
            parallel_mode,
            tensor_parallel,
569
            sequence_parallel,
570
            tp_group,
571
            is_grad_enabled,
572
            is_first_microbatch,
573
        )
574
575

        if is_grad_enabled:
576
577
578
579
580
            saved_inputmat = None
            if fp8_enabled and sequence_parallel:
                saved_inputmat = inputmat
            else:
                saved_inputmat = inputmat_no_fp8
Tian Zheng's avatar
Tian Zheng committed
581
582
            save_for_backward_allow_none(
                ctx,
583
584
                saved_inputmat,
                inputmat_t,
585
586
587
588
589
590
591
592
593
                weight,
                weight_t_fp8 if fp8_enabled else None,
                fp8_meta["scaling_fwd"].scale_inv.clone() if fp8_enabled else None,
            )
            ctx.activation_dtype = activation_dtype
            ctx.fp8_enabled = fp8_enabled
            ctx.fp8_meta = fp8_meta
            ctx.use_bias = use_bias
            ctx.inp_shape = inp.shape
594
595
            ctx.parallel_mode = parallel_mode
            ctx.tensor_parallel = tensor_parallel
596
            ctx.sequence_parallel = sequence_parallel
597
598
            ctx.tp_group = tp_group
            ctx.tp_size = tp_size
Shijie's avatar
Shijie committed
599
            ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
600
            ctx.requires_dgrad = not inp.stop_gradient
Tian Zheng's avatar
Tian Zheng committed
601
            ctx.requires_wgrad = not weight.stop_gradient
602
            ctx.requires_bgrad = use_bias and not bias.stop_gradient
603
            ctx.is_first_microbatch = is_first_microbatch
604
605
606
607
608

        return out.reshape((-1, *inp.shape[1:-1], out.shape[-1]))

    @staticmethod
    def backward(ctx, grad_output: paddle.Tensor) -> Tuple[Union[paddle.Tensor, None], ...]:
609
610
611
        with TransformerEngineBaseLayer.prepare_backward(
            ctx.fp8_enabled, ctx.fp8_meta, ctx.tp_group, ctx.tp_size, name="_Linear"
        ):
Tian Zheng's avatar
Tian Zheng committed
612

613
            (  # pylint: disable=unbalanced-tuple-unpacking
614
615
                inputmat,
                inputmat_t,
616
                weight,
617
618
                weight_t_fp8,
                fwd_scale_inverses,
Tian Zheng's avatar
Tian Zheng committed
619
            ) = saved_tensor_allow_none(ctx)
620
621

            (
622
                grad_output,
623
624
625
                grad_output_c,
                grad_output_t,
                bgrad,
626
627
628
            ) = TransformerEngineBaseLayer.grad_output_preprocess(
                ctx, grad_output, ctx.parallel_mode == "row"
            )
Shijie's avatar
Shijie committed
629
            if ctx.is_first_microbatch is not None:
630
631
632
                accumulate_wgrad_into_param_main_grad = (
                    ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch
                )
Shijie's avatar
Shijie committed
633
634
            else:
                accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation
635

636
            dgrad, wgrad, bgrad_ = _linear_bwd(
637
                inputmat,
638
639
640
641
642
                inputmat_t,
                FP8FwdTensors.GEMM1_INPUT,
                weight,
                weight_t_fp8,
                FP8FwdTensors.GEMM1_WEIGHT,
643
                grad_output,
644
645
646
647
648
649
650
651
                grad_output_c,
                grad_output_t,
                FP8BwdTensors.GRAD_OUTPUT1,
                fwd_scale_inverses,
                ctx.requires_bgrad,
                ctx.fp8_enabled,
                ctx.fp8_meta,
                ctx.requires_dgrad,
Tian Zheng's avatar
Tian Zheng committed
652
                ctx.requires_wgrad,
653
                ctx.activation_dtype,
654
655
                ctx.parallel_mode,
                ctx.tensor_parallel,
656
                ctx.sequence_parallel,
657
                ctx.tp_group,
Shijie's avatar
Shijie committed
658
659
                ctx.fuse_wgrad_accumulation,
                accumulate_wgrad_into_param_main_grad,
660
661
            )

662
663
664
665
            if not ctx.fp8_enabled:
                # bgrad is fused with gemm for non-FP8 path
                bgrad = bgrad_

666
667
668
669
670
671
            if not ctx.fp8_enabled or ctx.is_first_microbatch is None:
                weight_cache_grad = ()
            else:
                # weight_fp8 and weight_t_fp8 are stop_gradient tensors
                weight_cache_grad = (None, None)

Shijie's avatar
Shijie committed
672
            dgrad_return = dgrad.reshape(ctx.inp_shape) if ctx.requires_dgrad else None
673
            if not ctx.use_bias:
Shijie's avatar
Shijie committed
674
675
676
677
678
                bgrad_return = ()
            elif ctx.requires_bgrad:
                bgrad_return = (bgrad,)
            else:
                bgrad_return = (None,)
679

Shijie's avatar
Shijie committed
680
681
682
683
684
685
686
687
688
        if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation:
            wgrad = None

        return (
            wgrad if ctx.requires_wgrad else None,
            *weight_cache_grad,
            dgrad_return,
            *bgrad_return,
        )
689
690
691
692
693


class Linear(TransformerEngineBaseLayer):
    """
    Applies a linear transformation to the incoming data :math:`y = xA^T + b`
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715

    Parameters
    ----------
    in_features : int
                 size of each input sample.
    out_features : int
                  size of each output sample.
    weight_attr: Union[paddle.ParamAttr, None], default = None
                optional `paddle.ParamAttr` for weight.
    bias_attr: Union[paddle.ParamAttr, None, bool], default = None
              optional `paddle.ParamAttr` for bias.
    backend: {'transformer_engine', 'paddle'}, default = 'transformer_engine'
             if set to 'paddle', a framework only no-FP8 path is executed with limited optimization.

    Parallelism parameters
    ----------------------
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
    parallel_mode : {None, 'Column', 'Row'}, default = `None`
                   used to decide whether this Linear layer is Column Parallel Linear or Row
                   Parallel Linear as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
                   When set to `None`, no communication is performed.
716
717
    sequence_parallel : bool, default = `False`
                       if set to `True`, uses sequence parallelism.
Shijie's avatar
Shijie committed
718
719
720
721
722
723
724
725
726
727

    Optimization parameters
    -----------------------
    fuse_wgrad_accumulation : bool, default = 'False'
                             if set to `True`, enables fusing of creation and accumulation of
                             the weight gradient. When enabled, it is assumed that the weights
                             have an additional `main_grad` attribute (used instead of the
                             regular `grad`) which is a pre-allocated buffer of the correct
                             size to accumulate gradients in.

728
729
730
731
732
733
734
735
    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        weight_attr: Union[paddle.ParamAttr, None] = None,
        bias_attr: Union[paddle.ParamAttr, None, bool] = None,
736
        parallel_mode: Optional[str] = None,
737
        sequence_parallel: bool = False,
738
        tp_group: Union[dist_group_type, None] = None,
Shijie's avatar
Shijie committed
739
        fuse_wgrad_accumulation: bool = False,
740
        backend: str = "transformer_engine",
741
742
743
744
745
746
747
748
749
    ) -> None:
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.backend = backend
        self._weight_attr = weight_attr
        self._bias_attr = bias_attr
        self._dtype = self._helper.get_default_dtype()

750
        # Set parallel configs
751
752
753
        self.tp_group, self.tp_size = get_tp_group_and_world_size(
            tp_group, enable_tp=parallel_mode is not None
        )
754
755
        self.tensor_parallel = self.tp_size > 1
        self.parallel_mode = parallel_mode
756
757
758
        assert (
            self.parallel_mode in GemmParallelModes
        ), f"parallel_mode {parallel_mode} not supported"
759
760
761
762
763
764

        if self.parallel_mode == "column":
            self.out_features = divide(self.out_features, self.tp_size)
        elif self.parallel_mode == "row":
            self.in_features = divide(self.in_features, self.tp_size)

765
766
        self.sequence_parallel = self.tensor_parallel and sequence_parallel

Shijie's avatar
Shijie committed
767
768
        self.fuse_wgrad_accumulation = fuse_wgrad_accumulation

769
770
771
772
        # Initialize weight parameter
        with track_rng_state(enable=self.tensor_parallel):
            # TE linear weight is in column major
            self.weight = self.create_parameter(
773
774
775
776
777
                shape=(
                    [self.out_features, self.in_features]
                    if self.backend == "transformer_engine"
                    else [self.in_features, self.out_features]
                ),
778
779
780
781
                attr=self._weight_attr,
                dtype=self._dtype,
                is_bias=False,
            )
782
783
784
        set_weight_tensor_dist_attr(
            self.weight, self.tensor_parallel, self.parallel_mode, self.backend
        )
785

786
        # Initialize bias parameter
787
        self.has_bias = self._bias_attr is not False
788
        use_default_bias = self._bias_attr is None or self._bias_attr is True
789
790
        if self.has_bias:
            self.bias = self.create_parameter(
791
                shape=[self.out_features],
792
793
794
795
796
                attr=(
                    self._bias_attr
                    if not use_default_bias
                    else paddle.ParamAttr(initializer=Constant(value=0.0))
                ),
797
798
799
                dtype=self._dtype,
                is_bias=True,
            )
800
801
            if parallel_mode == "column":
                set_tensor_dist_attr(self.bias, self.tensor_parallel, axis=0)
802
803
            if parallel_mode == "row" and self.sequence_parallel:
                mark_as_sequence_parallel_parameter(self.bias)
804
805
806
        else:
            self.bias = None

807
808
        self.fp8_weight_shapes.append(self.weight.shape)

809
810
811
812
813
814
815
        # For RPL, bias has to be added after TP collectives
        # So it cannot be fused with the GEMM
        if self.parallel_mode == "row" and self.tensor_parallel and self.has_bias:
            self.gemm_bias_fused_add = False
        else:
            self.gemm_bias_fused_add = True

816
817
818
    def _te_forward(
        self,
        inp: paddle.Tensor,
819
        is_first_microbatch: Optional[bool] = None,
820
821
822
823
    ) -> paddle.Tensor:
        """
        Apply the linear transformation to the input.
        """
824
        with self.prepare_forward(inp, is_first_microbatch=is_first_microbatch) as inp:
825
826
827
828
            # Layer input should be casted outside PyLayer, as performing
            # inplace cast to input tensors may cause problems when used
            # together with Paddle native layers.
            inp = cast_if_needed(inp, self.activation_dtype)
829
830
831
832

            # Get persistent fp8 weight buffer. None if buffer does not exist.
            weight_fp8, weight_t_fp8 = self.get_fp8_weights_scratchpad(is_first_microbatch)

833
            out = _Linear.apply(
834
                self.weight,
835
836
                weight_fp8,
                weight_t_fp8,
837
                inp,
838
839
                self.bias if self.gemm_bias_fused_add else None,
                self.has_bias and self.gemm_bias_fused_add,
840
841
842
                self.fp8_enabled,
                self.fp8_calibration,
                self.fp8_meta,
843
                self.activation_dtype,
844
                paddle.is_grad_enabled(),
845
846
                self.parallel_mode,
                self.tensor_parallel,
847
                self.sequence_parallel,
848
849
                self.tp_group,
                self.tp_size,
Shijie's avatar
Shijie committed
850
                self.fuse_wgrad_accumulation,
851
                is_first_microbatch,
852
853
            )

854
855
856
        if not self.gemm_bias_fused_add:
            out = out + cast_if_needed_inplace(self.bias, self.activation_dtype)

857
858
859
860
861
        return out

    def _pd_forward(
        self,
        inp: paddle.Tensor,
862
        is_first_microbatch: Optional[bool] = None,
863
864
    ) -> paddle.Tensor:
        """Calls Paddle OP"""
865
866
        if is_first_microbatch is not None:
            warnings.warn(
867
868
869
                "`is_first_microbatch` is not supported for paddle backend and is ignored."
            )
        if self.parallel_mode == "column" and self.tensor_parallel:
870
871
            inp = identity(inp, self.tp_group)
        out = F.linear(inp, self.weight, self.bias if self.gemm_bias_fused_add else None)
872
        if self.parallel_mode == "row" and self.tensor_parallel:
Tian Zheng's avatar
Tian Zheng committed
873
            out, _ = allreduce(out, self.tp_group)
874
875
            out = out + self.bias if self.bias is not None else out
        return out
876
877

    def forward(self, *args, **kwargs):
878
879
880
881
882
        """
        Apply the linear transformation to the input.

        Parameters
        ----------
883
        inp : paddle.Tensor
884
             Input tensor.
885
886
887
888
889
890
891
892
893
894
        is_first_microbatch : {True, False, None}, default = None
                             During training using either gradient accumulation or
                             pipeline parallelism a minibatch of data is further split
                             into microbatches. Between the microbatches of the same minibatch
                             the model weights are not updated. Setting this parameter indicates
                             whether the current microbatch is the first in a minibatch or not.
                             When set, this parameter enables additional optimizations:

                             * during FP8 training, it allows caching of the FP8 versions of
                               the weights
895
        """
896
        if self.backend == "transformer_engine":
897
            return self._te_forward(*args, **kwargs)
898
        if self.backend == "paddle":
899
900
            return self._pd_forward(*args, **kwargs)
        raise AttributeError(f"Backend {self.backend} is not supported.")