linear.py 30.2 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.
"""Linear API"""

6
import warnings
7
from typing import Union, Tuple, Dict, Any, Optional
8
9
10
11
12

import paddle
import paddle.nn.functional as F
from paddle.nn.initializer import Constant

13
14
15
16
17
18
19
from .base import (
    TransformerEngineBaseLayer,
    get_workspace,
    _2X_ACC_FPROP,
    _2X_ACC_DGRAD,
    _2X_ACC_WGRAD,
)
20

21
from ..constants import FP8FwdTensors, FP8BwdTensors, GemmParallelModes, dist_group_type
22
from ..cpp_extensions import gemm, fp8_gemm, cast_to_fp8, cast_transpose, transpose
23
from ..distributed import (
24
    allgather,
25
26
27
    allreduce,
    get_tp_group_and_world_size,
    identity,
28
    reduce_scatter,
29
30
31
    track_rng_state,
    set_tensor_dist_attr,
    set_weight_tensor_dist_attr,
32
    mark_as_sequence_parallel_parameter,
33
34
)
from ..fp8 import get_fp8_te_dtype
35
from ..utils import (
36
    assert_dim_for_fp8_forward_exec,
37
38
    cast_if_needed,
    cast_if_needed_inplace,
39
    divide,
40
    get_bias_dtype,
Tian Zheng's avatar
Tian Zheng committed
41
42
    save_for_backward_allow_none,
    saved_tensor_allow_none,
43
    clear_tensor_data,
44
45
)

46
__all__ = ["Linear"]
47
48
49
50
51
52


def _linear_fwd_fp8(
    inputmat: paddle.Tensor,
    inputmat_fp8_index: FP8FwdTensors,
    weight: paddle.Tensor,
53
54
    weight_fp8: Optional[paddle.Tensor],
    weight_t_fp8: Optional[paddle.Tensor],
55
56
57
58
59
    weight_fp8_index: FP8FwdTensors,
    bias: paddle.Tensor,
    use_bias: bool,
    fp8_meta: Dict[str, Any],
    activation_dtype: paddle.dtype,
60
61
    parallel_mode: Union[str, None],
    tensor_parallel: bool,
62
    sequence_parallel: bool,
63
    tp_group: Union[dist_group_type, None],
64
    is_grad_enabled: bool,
65
    is_first_microbatch: bool = None,
66
67
68
69
):
    """FP8 path of Linear Fwd"""
    fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
    bias_dtype = get_bias_dtype(activation_dtype)
70
    bias = cast_if_needed(bias, bias_dtype)
71

72
73
74
75
76
    if parallel_mode == "column" and sequence_parallel:
        inputmat_total, _ = allgather(inputmat, tp_group)
    else:
        inputmat_total = inputmat

77
    update_fp8_weights = is_first_microbatch is None or is_first_microbatch
78
    if is_grad_enabled:
79
80
81
82
83
84
85
86
87
        if update_fp8_weights:
            weight_fp8, weight_t_fp8 = cast_transpose(
                weight,
                fp8_meta["scaling_fwd"],
                weight_fp8_index,
                fp8_dtype_forward,
                cast_out=weight_fp8,
                transpose_out=weight_t_fp8,
            )
88
89
    else:
        weight_t_fp8 = None
90
91
92
93
94
95
96
97
        if update_fp8_weights:
            weight_fp8 = cast_to_fp8(
                weight,
                fp8_meta["scaling_fwd"],
                weight_fp8_index,
                fp8_dtype_forward,
                out=weight_fp8,
            )
98

Shijie's avatar
Shijie committed
99
    out, _ = fp8_gemm(
100
101
102
103
        weight_fp8,
        fp8_meta["scaling_fwd"].scale_inv,
        weight_fp8_index,
        fp8_dtype_forward,
104
        inputmat_total,
105
106
107
108
109
110
111
112
113
114
        fp8_meta["scaling_fwd"].scale_inv,
        inputmat_fp8_index,
        fp8_dtype_forward,
        activation_dtype,
        get_workspace(),
        bias=bias,
        use_bias=use_bias,
        use_split_accumulator=_2X_ACC_FPROP,
    )

115
116
117
    if parallel_mode == "row" and sequence_parallel:
        out, _ = reduce_scatter(out, tp_group)
    elif parallel_mode == "row" and tensor_parallel:
Tian Zheng's avatar
Tian Zheng committed
118
        out, _ = allreduce(out, tp_group)
119

120
121
122
123
124
125
126
127
128
129
130
131
132
    return out, weight_t_fp8


def _linear_fwd_non_fp8(
    inputmat: paddle.Tensor,
    inputmat_fp8_index: FP8FwdTensors,
    weight: paddle.Tensor,
    weight_fp8_index: FP8FwdTensors,
    bias: paddle.Tensor,
    use_bias: bool,
    fp8_calibration: bool,
    fp8_meta: Dict[str, Any],
    activation_dtype: paddle.dtype,
133
134
    parallel_mode: Union[str, None],
    tensor_parallel: bool,
135
    sequence_parallel: bool,
136
    tp_group: Union[dist_group_type, None],
137
138
139
140
    activation: str = "",
):
    """Non-FP8 path of Linear Fwd"""

141
142
143
144
145
    if parallel_mode == "column" and sequence_parallel:
        inputmat_total, _ = allgather(inputmat, tp_group)
    else:
        inputmat_total = inputmat

146
147
148
149
150
151
152
153
154
155
    # Layer parameters are initialized as float32 dtype by default.
    # Cast the parameters to activation_dtype if the current dtype
    # does not match activation_dtype. The casting is inplace, so it
    # only needs to performed once throughout the traing process.
    weight = cast_if_needed_inplace(weight, activation_dtype)
    bias = cast_if_needed_inplace(bias, activation_dtype)

    if fp8_calibration:
        # amax of input
        fp8_meta["scaling_fwd"].amax_history[0, inputmat_fp8_index.value] = \
156
            paddle.max(paddle.abs(inputmat_total)).item()
157
158
159
        # amax of weight
        fp8_meta["scaling_fwd"].amax_history[0, weight_fp8_index.value] = \
            paddle.max(paddle.abs(weight)).item()
160
        fp8_meta["update_amax_and_scale_fwd"] = True
161
162

    outputs = gemm(weight,
163
                   inputmat_total,
164
165
166
167
168
169
170
171
172
173
174
                   activation_dtype,
                   get_workspace(),
                   bias=bias,
                   use_bias=use_bias,
                   gelu=(activation == 'gelu'))

    if activation == 'gelu':
        gelu_out, _, out = outputs
        return out, gelu_out

    out, _, _ = outputs
175
176
177
178

    if parallel_mode == "row" and sequence_parallel:
        out, _ = reduce_scatter(out, tp_group)
    elif parallel_mode == "row" and tensor_parallel:
Tian Zheng's avatar
Tian Zheng committed
179
        out, _ = allreduce(out, tp_group)
180
181
182
183
184
185
186
    return out


def _linear_fwd(
    inputmat: paddle.Tensor,
    inputmat_fp8_index: FP8FwdTensors,
    weight: paddle.Tensor,
187
188
    weight_fp8: Optional[paddle.Tensor],
    weight_t_fp8: Optional[paddle.Tensor],
189
190
191
192
193
194
195
    weight_fp8_index: FP8FwdTensors,
    bias: paddle.Tensor,
    use_bias: bool,
    fp8_enabled: bool,
    fp8_calibration: bool,
    fp8_meta: Dict[str, Any],
    activation_dtype: paddle.dtype,
196
197
    parallel_mode: Union[str, None],
    tensor_parallel: bool,
198
    sequence_parallel: bool,
199
    tp_group: Union[dist_group_type, None],
200
    is_grad_enabled: bool,
201
    is_first_microbatch: bool = None,
202
203
204
205
206
207
):
    if fp8_enabled:
        out, weight_t_fp8 = _linear_fwd_fp8(
            inputmat,
            inputmat_fp8_index,
            weight,
208
209
            weight_fp8,
            weight_t_fp8,
210
211
212
213
214
            weight_fp8_index,
            bias,
            use_bias,
            fp8_meta,
            activation_dtype,
215
216
            parallel_mode,
            tensor_parallel,
217
            sequence_parallel,
218
            tp_group,
219
            is_grad_enabled,
220
            is_first_microbatch,
221
222
223
224
225
226
227
228
229
230
231
232
        )
    else:
        out = _linear_fwd_non_fp8(
            inputmat,
            inputmat_fp8_index,
            weight,
            weight_fp8_index,
            bias,
            use_bias,
            fp8_calibration,
            fp8_meta,
            activation_dtype,
233
234
            parallel_mode,
            tensor_parallel,
235
            sequence_parallel,
236
            tp_group,
237
238
239
240
241
242
243
244
245
246
247
        )
    return (
        out,
        weight_t_fp8 if fp8_enabled else None,
    )


def _linear_bwd_fp8(
    inputmat: paddle.Tensor,
    inputmat_t: paddle.Tensor,
    inputmat_fp8_index: FP8FwdTensors,
Shijie's avatar
Shijie committed
248
    weight: paddle.Tensor,
249
250
251
252
253
254
255
256
257
258
259
    weight_t_fp8: paddle.Tensor,
    weight_fp8_index: FP8FwdTensors,
    grad_output: paddle.Tensor,
    grad_output_c: paddle.Tensor,
    grad_output_t: paddle.Tensor,
    grad_output_fp8_index: FP8BwdTensors,
    fwd_scale_inverses: paddle.Tensor,
    fp8_meta: Dict[str, Any],
    requires_dgrad: bool,
    requires_wgrad: bool,
    activation_dtype: paddle.dtype,
260
261
    parallel_mode: Union[str, None],
    tensor_parallel: bool,
262
    sequence_parallel: bool,
263
    tp_group: Union[dist_group_type, None],
Shijie's avatar
Shijie committed
264
265
    fuse_wgrad_accumulation: bool,
    accumulate_wgrad_into_param_main_grad: bool,
266
):
Tian Zheng's avatar
Tian Zheng committed
267
    dgrad, wgrad, handle = None, None, None
268
269
270
271
272
273
274
275
276
277

    # Overlap input AG with dgrad
    inputmat_total = None
    inputmat_t_total = None
    if requires_wgrad and parallel_mode == "column" and sequence_parallel:
        inputmat_total, handle = allgather(inputmat, tp_group, sync_op=not requires_dgrad)
    else:
        inputmat_total = inputmat
        inputmat_t_total = inputmat_t

278
279
280
    fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
    fp8_dtype_backward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=False)
    if requires_dgrad:
Shijie's avatar
Shijie committed
281
        dgrad, _ = fp8_gemm(
282
283
284
285
286
287
288
289
290
291
292
293
            weight_t_fp8,
            fwd_scale_inverses,
            weight_fp8_index,
            fp8_dtype_forward,
            grad_output_c,
            fp8_meta["scaling_bwd"].scale_inv,
            grad_output_fp8_index,
            fp8_dtype_backward,
            activation_dtype,
            get_workspace(),
            use_split_accumulator=_2X_ACC_DGRAD,
        )
294
        clear_tensor_data(grad_output_c)
295
296
297
298
299
300
301

        # Overlap dgrad-RS/AR with wgrad
        if parallel_mode == "column" and sequence_parallel:
            if handle is not None:
                handle.wait()
            dgrad, handle = reduce_scatter(dgrad, tp_group, sync_op=False)
        elif parallel_mode == "column" and tensor_parallel:
Tian Zheng's avatar
Tian Zheng committed
302
            dgrad, handle = allreduce(dgrad, tp_group, sync_op=False)
303

304
305
    if requires_wgrad:
        if not fp8_meta["recipe"].override_linear_precision.wgrad:
306
307
            if inputmat_t_total is None:
                inputmat_t_total = transpose(inputmat_total, fp8_dtype_backward)
308
                clear_tensor_data(inputmat_total)
Shijie's avatar
Shijie committed
309
310

            wgrad, _ = fp8_gemm(
311
                inputmat_t_total,
312
313
314
315
316
317
318
                fwd_scale_inverses,
                inputmat_fp8_index,
                fp8_dtype_forward,
                grad_output_t,
                fp8_meta["scaling_bwd"].scale_inv,
                grad_output_fp8_index,
                fp8_dtype_backward,
Shijie's avatar
Shijie committed
319
                "float32" if fuse_wgrad_accumulation else activation_dtype,
320
                get_workspace(),
Shijie's avatar
Shijie committed
321
322
                accumulate=accumulate_wgrad_into_param_main_grad,
                out=weight.main_grad if fuse_wgrad_accumulation else None,
323
324
                use_split_accumulator=_2X_ACC_WGRAD,
            )
325
            clear_tensor_data(inputmat_t_total, grad_output_t)
326
327
        else:
            wgrad, _, _ = gemm(
328
                inputmat_total,
329
330
331
332
                grad_output,
                activation_dtype,
                get_workspace(),
                grad=True,
Shijie's avatar
Shijie committed
333
334
335
336
                accumulate=accumulate_wgrad_into_param_main_grad,
                layout="NT",
                out=weight.main_grad if fuse_wgrad_accumulation else None,
                out_dtype="float32" if fuse_wgrad_accumulation else None,
337
            )
338
            clear_tensor_data(inputmat_total)
Tian Zheng's avatar
Tian Zheng committed
339

Shijie's avatar
Shijie committed
340
341
342
        if fuse_wgrad_accumulation:
            weight.main_grad = wgrad

Tian Zheng's avatar
Tian Zheng committed
343
344
345
    if parallel_mode == "column" and tensor_parallel and handle is not None:
        handle.wait()

346
347
348
349
350
351
352
353
354
    return dgrad, wgrad


def _linear_bwd_non_fp8(
    inputmat: paddle.Tensor,
    weight: paddle.Tensor,
    grad_output: paddle.Tensor,
    requires_bgrad: bool,
    requires_dgrad: bool,
Tian Zheng's avatar
Tian Zheng committed
355
    requires_wgrad: bool,
356
    activation_dtype: paddle.dtype,
357
358
    parallel_mode: Union[str, None],
    tensor_parallel: bool,
359
    sequence_parallel: bool,
360
    tp_group: Union[dist_group_type, None],
Shijie's avatar
Shijie committed
361
362
    fuse_wgrad_accumulation: bool,
    accumulate_wgrad_into_param_main_grad: bool,
363
364
365
366
367
368
    gelu_input: Union[paddle.Tensor, None] = None,
    activation: str = "",
):
    """
    Performs Linear Backward. Optionally, fuses GELU backward and dbias.
    """
Tian Zheng's avatar
Tian Zheng committed
369
    dgrad, wgrad, bgrad, handle = None, None, None, None
370
371
372
373
374
375
376
377

    # Overlap input AG with dgrad
    inputmat_total = None
    if requires_wgrad and parallel_mode == "column" and sequence_parallel:
        inputmat_total, handle = allgather(inputmat, tp_group, sync_op=not requires_dgrad)
    else:
        inputmat_total = inputmat

378
379
380
381
382
383
384
385
386
387
388
    if requires_dgrad:
        dgrad, _, _ = gemm(
            weight,
            grad_output,
            activation_dtype,
            get_workspace(),
            layout="NN",
            gelu=(activation == 'gelu'),
            gelu_input=gelu_input,
            grad=True,
        )
389
390
391
392
393
394
        # Overlap dgrad-RS/AR with wgrad
        if parallel_mode == "column" and sequence_parallel:
            if handle is not None:
                handle.wait()
            dgrad, handle = reduce_scatter(dgrad, tp_group, sync_op=False)
        elif parallel_mode == "column" and tensor_parallel:
Tian Zheng's avatar
Tian Zheng committed
395
            dgrad, handle = allreduce(dgrad, tp_group, sync_op=False)
396

397
398
    if requires_wgrad:
        wgrad, bgrad, _ = gemm(
399
            inputmat_total,
400
401
402
403
            grad_output,
            activation_dtype,
            get_workspace(),
            grad=True,
Shijie's avatar
Shijie committed
404
405
406
407
            accumulate=accumulate_wgrad_into_param_main_grad,
            layout="NT",
            out=weight.main_grad if fuse_wgrad_accumulation else None,
            out_dtype="float32" if fuse_wgrad_accumulation else None,
408
409
            use_bias=requires_bgrad,
        )
Shijie's avatar
Shijie committed
410
411
412
        if fuse_wgrad_accumulation:
            weight.main_grad = wgrad

413
414
415
    elif requires_bgrad:
        bgrad = grad_output.sum(axis=0)

Tian Zheng's avatar
Tian Zheng committed
416
417
418
    if parallel_mode == "column" and tensor_parallel and handle is not None:
        handle.wait()

419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
    return dgrad, wgrad, bgrad


def _linear_bwd(
    inputmat: paddle.Tensor,
    inputmat_t: paddle.Tensor,
    inputmat_fp8_index: FP8FwdTensors,
    weight: paddle.Tensor,
    weight_t_fp8: paddle.Tensor,
    weight_fp8_index: FP8FwdTensors,
    grad_output: paddle.Tensor,
    grad_output_c: paddle.Tensor,
    grad_output_t: paddle.Tensor,
    grad_output_fp8_index: FP8BwdTensors,
    fwd_scale_inverses: paddle.Tensor,
    requires_bgrad: bool,
    fp8_enabled: bool,
    fp8_meta: Dict[str, Any],
    requires_dgrad: bool,
Tian Zheng's avatar
Tian Zheng committed
438
    requires_wgrad: bool,
439
    activation_dtype: paddle.dtype,
440
441
    parallel_mode: Union[str, None],
    tensor_parallel: bool,
442
    sequence_parallel: bool,
443
    tp_group: Union[dist_group_type, None],
Shijie's avatar
Shijie committed
444
445
    fuse_wgrad_accumulation: bool,
    accumulate_wgrad_into_param_main_grad: bool,
446
447
448
449
450
451
452
):
    dgrad, wgrad, bgrad = None, None, None
    if fp8_enabled:
        dgrad, wgrad = _linear_bwd_fp8(
            inputmat,
            inputmat_t,
            inputmat_fp8_index,
Shijie's avatar
Shijie committed
453
            weight,
454
455
456
457
458
459
460
461
462
463
464
            weight_t_fp8,
            weight_fp8_index,
            grad_output,
            grad_output_c,
            grad_output_t,
            grad_output_fp8_index,
            fwd_scale_inverses,
            fp8_meta,
            requires_dgrad,
            requires_wgrad,
            activation_dtype,
465
466
            parallel_mode,
            tensor_parallel,
467
            sequence_parallel,
468
            tp_group,
Shijie's avatar
Shijie committed
469
470
            fuse_wgrad_accumulation=fuse_wgrad_accumulation,
            accumulate_wgrad_into_param_main_grad=accumulate_wgrad_into_param_main_grad,
471
472
473
474
475
476
477
478
        )
    else:
        dgrad, wgrad, bgrad = _linear_bwd_non_fp8(
            inputmat,
            weight,
            grad_output,
            requires_bgrad,
            requires_dgrad,
Tian Zheng's avatar
Tian Zheng committed
479
            requires_wgrad,
480
            activation_dtype,
481
482
            parallel_mode,
            tensor_parallel,
483
            sequence_parallel,
484
            tp_group,
Shijie's avatar
Shijie committed
485
486
            fuse_wgrad_accumulation=fuse_wgrad_accumulation,
            accumulate_wgrad_into_param_main_grad=accumulate_wgrad_into_param_main_grad,
487
488
        )
    return dgrad, wgrad, bgrad
489
490
491


class _Linear(paddle.autograd.PyLayer):
492
    """TE implementation of Linear"""
493
494
495
496
497

    @staticmethod
    def forward(
        ctx,
        weight: paddle.Tensor,
498
499
        weight_fp8: Optional[paddle.Tensor],
        weight_t_fp8: Optional[paddle.Tensor],
500
501
502
        inp: paddle.Tensor,
        bias: paddle.Tensor,
        use_bias: bool,
503
504
505
        fp8_enabled: bool,
        fp8_calibration: bool,
        fp8_meta: Dict[str, Any],
506
        activation_dtype: paddle.dtype,
507
        is_grad_enabled: bool,
508
509
        parallel_mode: Union[str, None],
        tensor_parallel: bool,
510
        sequence_parallel: bool,
511
512
        tp_group: Union[dist_group_type, None],
        tp_size: int,
Shijie's avatar
Shijie committed
513
        fuse_wgrad_accumulation: bool,
514
        is_first_microbatch: bool,
515
516
517
518
519
    ) -> paddle.Tensor:
        # Make sure input dimensions are compatible
        in_features = weight.shape[-1]
        assert inp.shape[-1] == in_features, "GEMM not possible"
        inputmat = inp.reshape((-1, in_features))
520
521
522
        if fp8_enabled:
            assert_dim_for_fp8_forward_exec(inputmat)
            assert_dim_for_fp8_forward_exec(weight)
523

524
525
526
        inputmat_no_fp8 = inputmat

        # FP8 casting
527
        inputmat_t = None
528
529
        if fp8_enabled:
            fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
530
531
532
533
534
535
536
537
            if (not fp8_meta["recipe"].override_linear_precision.wgrad and is_grad_enabled
                    and not sequence_parallel):
                inputmat, inputmat_t = cast_transpose(
                    inputmat,
                    fp8_meta["scaling_fwd"],
                    FP8FwdTensors.GEMM1_INPUT,
                    fp8_dtype_forward,
                )
538
            else:
539
                inputmat = cast_to_fp8(
540
541
542
543
                    inputmat,
                    fp8_meta["scaling_fwd"],
                    FP8FwdTensors.GEMM1_INPUT,
                    fp8_dtype_forward,
544
                )
545

546
547
        # GEMM Fwd
        out, weight_t_fp8 = _linear_fwd(
548
            inputmat,
549
            FP8FwdTensors.GEMM1_INPUT,
550
            weight,
551
552
            weight_fp8,
            weight_t_fp8,
553
554
555
556
557
558
559
            FP8FwdTensors.GEMM1_WEIGHT,
            bias,
            use_bias,
            fp8_enabled,
            fp8_calibration,
            fp8_meta,
            activation_dtype,
560
561
            parallel_mode,
            tensor_parallel,
562
            sequence_parallel,
563
            tp_group,
564
            is_grad_enabled,
565
            is_first_microbatch,
566
        )
567
568

        if is_grad_enabled:
569
570
571
572
573
            saved_inputmat = None
            if fp8_enabled and sequence_parallel:
                saved_inputmat = inputmat
            else:
                saved_inputmat = inputmat_no_fp8
Tian Zheng's avatar
Tian Zheng committed
574
575
            save_for_backward_allow_none(
                ctx,
576
577
                saved_inputmat,
                inputmat_t,
578
579
580
581
582
583
584
585
586
                weight,
                weight_t_fp8 if fp8_enabled else None,
                fp8_meta["scaling_fwd"].scale_inv.clone() if fp8_enabled else None,
            )
            ctx.activation_dtype = activation_dtype
            ctx.fp8_enabled = fp8_enabled
            ctx.fp8_meta = fp8_meta
            ctx.use_bias = use_bias
            ctx.inp_shape = inp.shape
587
588
            ctx.parallel_mode = parallel_mode
            ctx.tensor_parallel = tensor_parallel
589
            ctx.sequence_parallel = sequence_parallel
590
591
            ctx.tp_group = tp_group
            ctx.tp_size = tp_size
Shijie's avatar
Shijie committed
592
            ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
593
            ctx.requires_dgrad = not inp.stop_gradient
Tian Zheng's avatar
Tian Zheng committed
594
            ctx.requires_wgrad = not weight.stop_gradient
595
            ctx.requires_bgrad = use_bias and not bias.stop_gradient
596
            ctx.is_first_microbatch = is_first_microbatch
597
598
599
600
601

        return out.reshape((-1, *inp.shape[1:-1], out.shape[-1]))

    @staticmethod
    def backward(ctx, grad_output: paddle.Tensor) -> Tuple[Union[paddle.Tensor, None], ...]:
602
603
        with TransformerEngineBaseLayer.prepare_backward(ctx.fp8_enabled,
                                                         ctx.fp8_meta,
604
605
                                                         ctx.tp_group,
                                                         ctx.tp_size,
606
                                                         name="_Linear"):
Tian Zheng's avatar
Tian Zheng committed
607
608

            (    # pylint: disable=unbalanced-tuple-unpacking
609
610
                inputmat,
                inputmat_t,
611
                weight,
612
613
                weight_t_fp8,
                fwd_scale_inverses,
Tian Zheng's avatar
Tian Zheng committed
614
            ) = saved_tensor_allow_none(ctx)
615
616

            (
617
                grad_output,
618
619
620
                grad_output_c,
                grad_output_t,
                bgrad,
621
622
            ) = TransformerEngineBaseLayer.grad_output_preprocess(ctx, grad_output,
                                                                  ctx.parallel_mode == "row")
Shijie's avatar
Shijie committed
623
624
625
626
627
            if ctx.is_first_microbatch is not None:
                accumulate_wgrad_into_param_main_grad = (ctx.fuse_wgrad_accumulation
                                                         and not ctx.is_first_microbatch)
            else:
                accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation
628

629
            dgrad, wgrad, bgrad_ = _linear_bwd(
630
                inputmat,
631
632
633
634
635
                inputmat_t,
                FP8FwdTensors.GEMM1_INPUT,
                weight,
                weight_t_fp8,
                FP8FwdTensors.GEMM1_WEIGHT,
636
                grad_output,
637
638
639
640
641
642
643
644
                grad_output_c,
                grad_output_t,
                FP8BwdTensors.GRAD_OUTPUT1,
                fwd_scale_inverses,
                ctx.requires_bgrad,
                ctx.fp8_enabled,
                ctx.fp8_meta,
                ctx.requires_dgrad,
Tian Zheng's avatar
Tian Zheng committed
645
                ctx.requires_wgrad,
646
                ctx.activation_dtype,
647
648
                ctx.parallel_mode,
                ctx.tensor_parallel,
649
                ctx.sequence_parallel,
650
                ctx.tp_group,
Shijie's avatar
Shijie committed
651
652
                ctx.fuse_wgrad_accumulation,
                accumulate_wgrad_into_param_main_grad,
653
654
            )

655
656
657
658
            if not ctx.fp8_enabled:
                # bgrad is fused with gemm for non-FP8 path
                bgrad = bgrad_

659
660
661
662
663
664
            if not ctx.fp8_enabled or ctx.is_first_microbatch is None:
                weight_cache_grad = ()
            else:
                # weight_fp8 and weight_t_fp8 are stop_gradient tensors
                weight_cache_grad = (None, None)

Shijie's avatar
Shijie committed
665
            dgrad_return = dgrad.reshape(ctx.inp_shape) if ctx.requires_dgrad else None
666
            if not ctx.use_bias:
Shijie's avatar
Shijie committed
667
668
669
670
671
                bgrad_return = ()
            elif ctx.requires_bgrad:
                bgrad_return = (bgrad,)
            else:
                bgrad_return = (None,)
672

Shijie's avatar
Shijie committed
673
674
675
676
677
678
679
680
681
        if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation:
            wgrad = None

        return (
            wgrad if ctx.requires_wgrad else None,
            *weight_cache_grad,
            dgrad_return,
            *bgrad_return,
        )
682
683
684
685
686


class Linear(TransformerEngineBaseLayer):
    """
    Applies a linear transformation to the incoming data :math:`y = xA^T + b`
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708

    Parameters
    ----------
    in_features : int
                 size of each input sample.
    out_features : int
                  size of each output sample.
    weight_attr: Union[paddle.ParamAttr, None], default = None
                optional `paddle.ParamAttr` for weight.
    bias_attr: Union[paddle.ParamAttr, None, bool], default = None
              optional `paddle.ParamAttr` for bias.
    backend: {'transformer_engine', 'paddle'}, default = 'transformer_engine'
             if set to 'paddle', a framework only no-FP8 path is executed with limited optimization.

    Parallelism parameters
    ----------------------
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
    parallel_mode : {None, 'Column', 'Row'}, default = `None`
                   used to decide whether this Linear layer is Column Parallel Linear or Row
                   Parallel Linear as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
                   When set to `None`, no communication is performed.
709
710
    sequence_parallel : bool, default = `False`
                       if set to `True`, uses sequence parallelism.
Shijie's avatar
Shijie committed
711
712
713
714
715
716
717
718
719
720

    Optimization parameters
    -----------------------
    fuse_wgrad_accumulation : bool, default = 'False'
                             if set to `True`, enables fusing of creation and accumulation of
                             the weight gradient. When enabled, it is assumed that the weights
                             have an additional `main_grad` attribute (used instead of the
                             regular `grad`) which is a pre-allocated buffer of the correct
                             size to accumulate gradients in.

721
722
723
724
725
726
727
728
    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        weight_attr: Union[paddle.ParamAttr, None] = None,
        bias_attr: Union[paddle.ParamAttr, None, bool] = None,
729
        parallel_mode: Optional[str] = None,
730
        sequence_parallel: bool = False,
731
        tp_group: Union[dist_group_type, None] = None,
Shijie's avatar
Shijie committed
732
        fuse_wgrad_accumulation: bool = False,
733
734
735
736
737
738
739
740
741
742
        backend: str = 'transformer_engine',
    ) -> None:
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.backend = backend
        self._weight_attr = weight_attr
        self._bias_attr = bias_attr
        self._dtype = self._helper.get_default_dtype()

743
744
745
746
747
748
749
750
751
752
753
754
755
756
        # Set parallel configs
        self.tp_group, self.tp_size = get_tp_group_and_world_size(tp_group,
                                                                  enable_tp=parallel_mode
                                                                  is not None)
        self.tensor_parallel = self.tp_size > 1
        self.parallel_mode = parallel_mode
        assert (self.parallel_mode
                in GemmParallelModes), f"parallel_mode {parallel_mode} not supported"

        if self.parallel_mode == "column":
            self.out_features = divide(self.out_features, self.tp_size)
        elif self.parallel_mode == "row":
            self.in_features = divide(self.in_features, self.tp_size)

757
758
        self.sequence_parallel = self.tensor_parallel and sequence_parallel

Shijie's avatar
Shijie committed
759
760
        self.fuse_wgrad_accumulation = fuse_wgrad_accumulation

761
762
763
764
765
766
767
768
769
770
771
772
        # Initialize weight parameter
        with track_rng_state(enable=self.tensor_parallel):
            # TE linear weight is in column major
            self.weight = self.create_parameter(
                shape=[self.out_features, self.in_features]
                if self.backend == 'transformer_engine' else [self.in_features, self.out_features],
                attr=self._weight_attr,
                dtype=self._dtype,
                is_bias=False,
            )
        set_weight_tensor_dist_attr(self.weight, self.tensor_parallel, self.parallel_mode,
                                    self.backend)
773

774
        # Initialize bias parameter
775
        self.has_bias = self._bias_attr is not False
776
        use_default_bias = self._bias_attr is None or self._bias_attr is True
777
778
        if self.has_bias:
            self.bias = self.create_parameter(
779
                shape=[self.out_features],
780
                attr=self._bias_attr if not use_default_bias else paddle.ParamAttr(
781
782
783
784
                    initializer=Constant(value=0.0)),
                dtype=self._dtype,
                is_bias=True,
            )
785
786
            if parallel_mode == "column":
                set_tensor_dist_attr(self.bias, self.tensor_parallel, axis=0)
787
788
            if parallel_mode == "row" and self.sequence_parallel:
                mark_as_sequence_parallel_parameter(self.bias)
789
790
791
        else:
            self.bias = None

792
793
        self.fp8_weight_shapes.append(self.weight.shape)

794
795
796
797
798
799
800
        # For RPL, bias has to be added after TP collectives
        # So it cannot be fused with the GEMM
        if self.parallel_mode == "row" and self.tensor_parallel and self.has_bias:
            self.gemm_bias_fused_add = False
        else:
            self.gemm_bias_fused_add = True

801
802
803
    def _te_forward(
        self,
        inp: paddle.Tensor,
804
        is_first_microbatch: Optional[bool] = None,
805
806
807
808
    ) -> paddle.Tensor:
        """
        Apply the linear transformation to the input.
        """
809
        with self.prepare_forward(inp, is_first_microbatch=is_first_microbatch) as inp:
810
811
812
813
            # Layer input should be casted outside PyLayer, as performing
            # inplace cast to input tensors may cause problems when used
            # together with Paddle native layers.
            inp = cast_if_needed(inp, self.activation_dtype)
814
815
816
817

            # Get persistent fp8 weight buffer. None if buffer does not exist.
            weight_fp8, weight_t_fp8 = self.get_fp8_weights_scratchpad(is_first_microbatch)

818
            out = _Linear.apply(
819
                self.weight,
820
821
                weight_fp8,
                weight_t_fp8,
822
                inp,
823
824
                self.bias if self.gemm_bias_fused_add else None,
                self.has_bias and self.gemm_bias_fused_add,
825
826
827
                self.fp8_enabled,
                self.fp8_calibration,
                self.fp8_meta,
828
                self.activation_dtype,
829
                paddle.is_grad_enabled(),
830
831
                self.parallel_mode,
                self.tensor_parallel,
832
                self.sequence_parallel,
833
834
                self.tp_group,
                self.tp_size,
Shijie's avatar
Shijie committed
835
                self.fuse_wgrad_accumulation,
836
                is_first_microbatch,
837
838
            )

839
840
841
        if not self.gemm_bias_fused_add:
            out = out + cast_if_needed_inplace(self.bias, self.activation_dtype)

842
843
844
845
846
        return out

    def _pd_forward(
        self,
        inp: paddle.Tensor,
847
        is_first_microbatch: Optional[bool] = None,
848
849
    ) -> paddle.Tensor:
        """Calls Paddle OP"""
850
851
852
        if is_first_microbatch is not None:
            warnings.warn(
                "`is_first_microbatch` is not supported for paddle backend and is ignored.")
853
854
855
856
        if self.parallel_mode == 'column' and self.tensor_parallel:
            inp = identity(inp, self.tp_group)
        out = F.linear(inp, self.weight, self.bias if self.gemm_bias_fused_add else None)
        if self.parallel_mode == 'row' and self.tensor_parallel:
Tian Zheng's avatar
Tian Zheng committed
857
            out, _ = allreduce(out, self.tp_group)
858
859
            out = out + self.bias if self.bias is not None else out
        return out
860
861

    def forward(self, *args, **kwargs):
862
863
864
865
866
        """
        Apply the linear transformation to the input.

        Parameters
        ----------
867
        inp : paddle.Tensor
868
             Input tensor.
869
870
871
872
873
874
875
876
877
878
        is_first_microbatch : {True, False, None}, default = None
                             During training using either gradient accumulation or
                             pipeline parallelism a minibatch of data is further split
                             into microbatches. Between the microbatches of the same minibatch
                             the model weights are not updated. Setting this parameter indicates
                             whether the current microbatch is the first in a minibatch or not.
                             When set, this parameter enables additional optimizations:

                             * during FP8 training, it allows caching of the FP8 versions of
                               the weights
879
        """
880
881
882
883
884
        if self.backend == 'transformer_engine':
            return self._te_forward(*args, **kwargs)
        if self.backend == 'paddle':
            return self._pd_forward(*args, **kwargs)
        raise AttributeError(f"Backend {self.backend} is not supported.")