linear.py 30.4 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.
"""Linear API"""

6
import warnings
7
from typing import Union, Tuple, Dict, Any, Optional
8
9
10
11
12

import paddle
import paddle.nn.functional as F
from paddle.nn.initializer import Constant

13
14
15
16
17
18
19
from .base import (
    TransformerEngineBaseLayer,
    get_workspace,
    _2X_ACC_FPROP,
    _2X_ACC_DGRAD,
    _2X_ACC_WGRAD,
)
20

21
from ..constants import FP8FwdTensors, FP8BwdTensors, GemmParallelModes, dist_group_type
22
from ..cpp_extensions import gemm, fp8_gemm, cast_to_fp8, cast_transpose, transpose
23
from ..distributed import (
24
    allgather,
25
26
27
    allreduce,
    get_tp_group_and_world_size,
    identity,
28
    reduce_scatter,
29
30
31
    track_rng_state,
    set_tensor_dist_attr,
    set_weight_tensor_dist_attr,
32
    mark_as_sequence_parallel_parameter,
33
)
34
from ..fp8 import get_fp8_te_dtype, get_global_fp8_state
35
from ..utils import (
36
    assert_dim_for_fp8_forward_exec,
37
38
    cast_if_needed,
    cast_if_needed_inplace,
39
    divide,
40
    get_bias_dtype,
Tian Zheng's avatar
Tian Zheng committed
41
42
    save_for_backward_allow_none,
    saved_tensor_allow_none,
43
    clear_tensor_data,
44
45
)

46
__all__ = ["Linear"]
47
48
49
50
51
52


def _linear_fwd_fp8(
    inputmat: paddle.Tensor,
    inputmat_fp8_index: FP8FwdTensors,
    weight: paddle.Tensor,
53
54
    weight_fp8: Optional[paddle.Tensor],
    weight_t_fp8: Optional[paddle.Tensor],
55
56
57
58
59
    weight_fp8_index: FP8FwdTensors,
    bias: paddle.Tensor,
    use_bias: bool,
    fp8_meta: Dict[str, Any],
    activation_dtype: paddle.dtype,
60
61
    parallel_mode: Union[str, None],
    tensor_parallel: bool,
62
    sequence_parallel: bool,
63
    tp_group: Union[dist_group_type, None],
64
    is_grad_enabled: bool,
65
    is_first_microbatch: bool = None,
66
67
68
69
):
    """FP8 path of Linear Fwd"""
    fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
    bias_dtype = get_bias_dtype(activation_dtype)
70
    bias = cast_if_needed(bias, bias_dtype)
71

72
73
74
75
76
    if parallel_mode == "column" and sequence_parallel:
        inputmat_total, _ = allgather(inputmat, tp_group)
    else:
        inputmat_total = inputmat

77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
    if not get_global_fp8_state().is_cudagraph_enabled():
        # if cuda graph is not enabled, we cast the weight here
        update_fp8_weights = is_first_microbatch is None or is_first_microbatch
        if is_grad_enabled:
            if update_fp8_weights:
                weight_fp8, weight_t_fp8 = cast_transpose(
                    weight,
                    fp8_meta["scaling_fwd"],
                    weight_fp8_index,
                    fp8_dtype_forward,
                    cast_out=weight_fp8,
                    transpose_out=weight_t_fp8,
                )
        else:
            weight_t_fp8 = None
            if update_fp8_weights:
                weight_fp8 = cast_to_fp8(
                    weight,
                    fp8_meta["scaling_fwd"],
                    weight_fp8_index,
                    fp8_dtype_forward,
                    out=weight_fp8,
                )
100

Shijie's avatar
Shijie committed
101
    out, _ = fp8_gemm(
102
103
104
105
        weight_fp8,
        fp8_meta["scaling_fwd"].scale_inv,
        weight_fp8_index,
        fp8_dtype_forward,
106
        inputmat_total,
107
108
109
110
111
112
113
114
115
116
        fp8_meta["scaling_fwd"].scale_inv,
        inputmat_fp8_index,
        fp8_dtype_forward,
        activation_dtype,
        get_workspace(),
        bias=bias,
        use_bias=use_bias,
        use_split_accumulator=_2X_ACC_FPROP,
    )

117
118
119
    if parallel_mode == "row" and sequence_parallel:
        out, _ = reduce_scatter(out, tp_group)
    elif parallel_mode == "row" and tensor_parallel:
Tian Zheng's avatar
Tian Zheng committed
120
        out, _ = allreduce(out, tp_group)
121

122
123
124
125
126
127
128
129
130
131
132
133
134
    return out, weight_t_fp8


def _linear_fwd_non_fp8(
    inputmat: paddle.Tensor,
    inputmat_fp8_index: FP8FwdTensors,
    weight: paddle.Tensor,
    weight_fp8_index: FP8FwdTensors,
    bias: paddle.Tensor,
    use_bias: bool,
    fp8_calibration: bool,
    fp8_meta: Dict[str, Any],
    activation_dtype: paddle.dtype,
135
136
    parallel_mode: Union[str, None],
    tensor_parallel: bool,
137
    sequence_parallel: bool,
138
    tp_group: Union[dist_group_type, None],
139
140
141
142
    activation: str = "",
):
    """Non-FP8 path of Linear Fwd"""

143
144
145
146
147
    if parallel_mode == "column" and sequence_parallel:
        inputmat_total, _ = allgather(inputmat, tp_group)
    else:
        inputmat_total = inputmat

148
149
150
151
152
153
154
155
156
    # Layer parameters are initialized as float32 dtype by default.
    # Cast the parameters to activation_dtype if the current dtype
    # does not match activation_dtype. The casting is inplace, so it
    # only needs to performed once throughout the traing process.
    weight = cast_if_needed_inplace(weight, activation_dtype)
    bias = cast_if_needed_inplace(bias, activation_dtype)

    if fp8_calibration:
        # amax of input
157
158
159
        fp8_meta["scaling_fwd"].amax_history[0, inputmat_fp8_index.value] = paddle.max(
            paddle.abs(inputmat_total)
        ).item()
160
        # amax of weight
161
162
163
        fp8_meta["scaling_fwd"].amax_history[0, weight_fp8_index.value] = paddle.max(
            paddle.abs(weight)
        ).item()
164
        fp8_meta["update_amax_and_scale_fwd"] = True
165

166
167
168
169
170
171
172
173
174
    outputs = gemm(
        weight,
        inputmat_total,
        activation_dtype,
        get_workspace(),
        bias=bias,
        use_bias=use_bias,
        gelu=(activation == "gelu"),
    )
175

176
    if activation == "gelu":
177
178
179
180
        gelu_out, _, out = outputs
        return out, gelu_out

    out, _, _ = outputs
181
182
183
184

    if parallel_mode == "row" and sequence_parallel:
        out, _ = reduce_scatter(out, tp_group)
    elif parallel_mode == "row" and tensor_parallel:
Tian Zheng's avatar
Tian Zheng committed
185
        out, _ = allreduce(out, tp_group)
186
187
188
189
190
191
192
    return out


def _linear_fwd(
    inputmat: paddle.Tensor,
    inputmat_fp8_index: FP8FwdTensors,
    weight: paddle.Tensor,
193
194
    weight_fp8: Optional[paddle.Tensor],
    weight_t_fp8: Optional[paddle.Tensor],
195
196
197
198
199
200
201
    weight_fp8_index: FP8FwdTensors,
    bias: paddle.Tensor,
    use_bias: bool,
    fp8_enabled: bool,
    fp8_calibration: bool,
    fp8_meta: Dict[str, Any],
    activation_dtype: paddle.dtype,
202
203
    parallel_mode: Union[str, None],
    tensor_parallel: bool,
204
    sequence_parallel: bool,
205
    tp_group: Union[dist_group_type, None],
206
    is_grad_enabled: bool,
207
    is_first_microbatch: bool = None,
208
209
210
211
212
213
):
    if fp8_enabled:
        out, weight_t_fp8 = _linear_fwd_fp8(
            inputmat,
            inputmat_fp8_index,
            weight,
214
215
            weight_fp8,
            weight_t_fp8,
216
217
218
219
220
            weight_fp8_index,
            bias,
            use_bias,
            fp8_meta,
            activation_dtype,
221
222
            parallel_mode,
            tensor_parallel,
223
            sequence_parallel,
224
            tp_group,
225
            is_grad_enabled,
226
            is_first_microbatch,
227
228
229
230
231
232
233
234
235
236
237
238
        )
    else:
        out = _linear_fwd_non_fp8(
            inputmat,
            inputmat_fp8_index,
            weight,
            weight_fp8_index,
            bias,
            use_bias,
            fp8_calibration,
            fp8_meta,
            activation_dtype,
239
240
            parallel_mode,
            tensor_parallel,
241
            sequence_parallel,
242
            tp_group,
243
244
245
246
247
248
249
250
251
252
253
        )
    return (
        out,
        weight_t_fp8 if fp8_enabled else None,
    )


def _linear_bwd_fp8(
    inputmat: paddle.Tensor,
    inputmat_t: paddle.Tensor,
    inputmat_fp8_index: FP8FwdTensors,
Shijie's avatar
Shijie committed
254
    weight: paddle.Tensor,
255
256
257
258
259
260
261
262
263
264
265
    weight_t_fp8: paddle.Tensor,
    weight_fp8_index: FP8FwdTensors,
    grad_output: paddle.Tensor,
    grad_output_c: paddle.Tensor,
    grad_output_t: paddle.Tensor,
    grad_output_fp8_index: FP8BwdTensors,
    fwd_scale_inverses: paddle.Tensor,
    fp8_meta: Dict[str, Any],
    requires_dgrad: bool,
    requires_wgrad: bool,
    activation_dtype: paddle.dtype,
266
267
    parallel_mode: Union[str, None],
    tensor_parallel: bool,
268
    sequence_parallel: bool,
269
    tp_group: Union[dist_group_type, None],
Shijie's avatar
Shijie committed
270
271
    fuse_wgrad_accumulation: bool,
    accumulate_wgrad_into_param_main_grad: bool,
272
):
Tian Zheng's avatar
Tian Zheng committed
273
    dgrad, wgrad, handle = None, None, None
274
275
276
277
278
279
280
281
282
283

    # Overlap input AG with dgrad
    inputmat_total = None
    inputmat_t_total = None
    if requires_wgrad and parallel_mode == "column" and sequence_parallel:
        inputmat_total, handle = allgather(inputmat, tp_group, sync_op=not requires_dgrad)
    else:
        inputmat_total = inputmat
        inputmat_t_total = inputmat_t

284
285
286
    fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
    fp8_dtype_backward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=False)
    if requires_dgrad:
Shijie's avatar
Shijie committed
287
        dgrad, _ = fp8_gemm(
288
289
290
291
292
293
294
295
296
297
298
299
            weight_t_fp8,
            fwd_scale_inverses,
            weight_fp8_index,
            fp8_dtype_forward,
            grad_output_c,
            fp8_meta["scaling_bwd"].scale_inv,
            grad_output_fp8_index,
            fp8_dtype_backward,
            activation_dtype,
            get_workspace(),
            use_split_accumulator=_2X_ACC_DGRAD,
        )
300
        clear_tensor_data(grad_output_c)
301
302
303
304
305
306
307

        # Overlap dgrad-RS/AR with wgrad
        if parallel_mode == "column" and sequence_parallel:
            if handle is not None:
                handle.wait()
            dgrad, handle = reduce_scatter(dgrad, tp_group, sync_op=False)
        elif parallel_mode == "column" and tensor_parallel:
Tian Zheng's avatar
Tian Zheng committed
308
            dgrad, handle = allreduce(dgrad, tp_group, sync_op=False)
309

310
311
    if requires_wgrad:
        if not fp8_meta["recipe"].override_linear_precision.wgrad:
312
313
            if inputmat_t_total is None:
                inputmat_t_total = transpose(inputmat_total, fp8_dtype_backward)
314
                clear_tensor_data(inputmat_total)
Shijie's avatar
Shijie committed
315
316

            wgrad, _ = fp8_gemm(
317
                inputmat_t_total,
318
319
320
321
322
323
324
                fwd_scale_inverses,
                inputmat_fp8_index,
                fp8_dtype_forward,
                grad_output_t,
                fp8_meta["scaling_bwd"].scale_inv,
                grad_output_fp8_index,
                fp8_dtype_backward,
Shijie's avatar
Shijie committed
325
                "float32" if fuse_wgrad_accumulation else activation_dtype,
326
                get_workspace(),
Shijie's avatar
Shijie committed
327
328
                accumulate=accumulate_wgrad_into_param_main_grad,
                out=weight.main_grad if fuse_wgrad_accumulation else None,
329
330
                use_split_accumulator=_2X_ACC_WGRAD,
            )
331
            clear_tensor_data(inputmat_t_total, grad_output_t)
332
333
        else:
            wgrad, _, _ = gemm(
334
                inputmat_total,
335
336
337
338
                grad_output,
                activation_dtype,
                get_workspace(),
                grad=True,
Shijie's avatar
Shijie committed
339
340
341
342
                accumulate=accumulate_wgrad_into_param_main_grad,
                layout="NT",
                out=weight.main_grad if fuse_wgrad_accumulation else None,
                out_dtype="float32" if fuse_wgrad_accumulation else None,
343
            )
344
            clear_tensor_data(inputmat_total)
Tian Zheng's avatar
Tian Zheng committed
345

Shijie's avatar
Shijie committed
346
347
348
        if fuse_wgrad_accumulation:
            weight.main_grad = wgrad

Tian Zheng's avatar
Tian Zheng committed
349
350
    if parallel_mode == "column" and tensor_parallel and handle is not None:
        handle.wait()
351
352
    if parallel_mode == "column" and sequence_parallel:
        handle.wait()
Tian Zheng's avatar
Tian Zheng committed
353

354
355
356
357
358
359
360
361
362
    return dgrad, wgrad


def _linear_bwd_non_fp8(
    inputmat: paddle.Tensor,
    weight: paddle.Tensor,
    grad_output: paddle.Tensor,
    requires_bgrad: bool,
    requires_dgrad: bool,
Tian Zheng's avatar
Tian Zheng committed
363
    requires_wgrad: bool,
364
    activation_dtype: paddle.dtype,
365
366
    parallel_mode: Union[str, None],
    tensor_parallel: bool,
367
    sequence_parallel: bool,
368
    tp_group: Union[dist_group_type, None],
Shijie's avatar
Shijie committed
369
370
    fuse_wgrad_accumulation: bool,
    accumulate_wgrad_into_param_main_grad: bool,
371
372
373
374
375
376
    gelu_input: Union[paddle.Tensor, None] = None,
    activation: str = "",
):
    """
    Performs Linear Backward. Optionally, fuses GELU backward and dbias.
    """
Tian Zheng's avatar
Tian Zheng committed
377
    dgrad, wgrad, bgrad, handle = None, None, None, None
378
379
380
381
382
383
384
385

    # Overlap input AG with dgrad
    inputmat_total = None
    if requires_wgrad and parallel_mode == "column" and sequence_parallel:
        inputmat_total, handle = allgather(inputmat, tp_group, sync_op=not requires_dgrad)
    else:
        inputmat_total = inputmat

386
387
388
389
390
391
392
    if requires_dgrad:
        dgrad, _, _ = gemm(
            weight,
            grad_output,
            activation_dtype,
            get_workspace(),
            layout="NN",
393
            gelu=(activation == "gelu"),
394
395
396
            gelu_input=gelu_input,
            grad=True,
        )
397
398
399
400
401
402
        # Overlap dgrad-RS/AR with wgrad
        if parallel_mode == "column" and sequence_parallel:
            if handle is not None:
                handle.wait()
            dgrad, handle = reduce_scatter(dgrad, tp_group, sync_op=False)
        elif parallel_mode == "column" and tensor_parallel:
Tian Zheng's avatar
Tian Zheng committed
403
            dgrad, handle = allreduce(dgrad, tp_group, sync_op=False)
404

405
406
    if requires_wgrad:
        wgrad, bgrad, _ = gemm(
407
            inputmat_total,
408
409
410
411
            grad_output,
            activation_dtype,
            get_workspace(),
            grad=True,
Shijie's avatar
Shijie committed
412
413
414
415
            accumulate=accumulate_wgrad_into_param_main_grad,
            layout="NT",
            out=weight.main_grad if fuse_wgrad_accumulation else None,
            out_dtype="float32" if fuse_wgrad_accumulation else None,
416
417
            use_bias=requires_bgrad,
        )
Shijie's avatar
Shijie committed
418
419
420
        if fuse_wgrad_accumulation:
            weight.main_grad = wgrad

421
422
    elif requires_bgrad:
        bgrad = grad_output.sum(axis=0)
Tian Zheng's avatar
Tian Zheng committed
423
424
    if parallel_mode == "column" and tensor_parallel and handle is not None:
        handle.wait()
425
426
    if parallel_mode == "column" and sequence_parallel and handle is not None:
        handle.wait()
Tian Zheng's avatar
Tian Zheng committed
427

428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
    return dgrad, wgrad, bgrad


def _linear_bwd(
    inputmat: paddle.Tensor,
    inputmat_t: paddle.Tensor,
    inputmat_fp8_index: FP8FwdTensors,
    weight: paddle.Tensor,
    weight_t_fp8: paddle.Tensor,
    weight_fp8_index: FP8FwdTensors,
    grad_output: paddle.Tensor,
    grad_output_c: paddle.Tensor,
    grad_output_t: paddle.Tensor,
    grad_output_fp8_index: FP8BwdTensors,
    fwd_scale_inverses: paddle.Tensor,
    requires_bgrad: bool,
    fp8_enabled: bool,
    fp8_meta: Dict[str, Any],
    requires_dgrad: bool,
Tian Zheng's avatar
Tian Zheng committed
447
    requires_wgrad: bool,
448
    activation_dtype: paddle.dtype,
449
450
    parallel_mode: Union[str, None],
    tensor_parallel: bool,
451
    sequence_parallel: bool,
452
    tp_group: Union[dist_group_type, None],
Shijie's avatar
Shijie committed
453
454
    fuse_wgrad_accumulation: bool,
    accumulate_wgrad_into_param_main_grad: bool,
455
456
457
458
459
460
461
):
    dgrad, wgrad, bgrad = None, None, None
    if fp8_enabled:
        dgrad, wgrad = _linear_bwd_fp8(
            inputmat,
            inputmat_t,
            inputmat_fp8_index,
Shijie's avatar
Shijie committed
462
            weight,
463
464
465
466
467
468
469
470
471
472
473
            weight_t_fp8,
            weight_fp8_index,
            grad_output,
            grad_output_c,
            grad_output_t,
            grad_output_fp8_index,
            fwd_scale_inverses,
            fp8_meta,
            requires_dgrad,
            requires_wgrad,
            activation_dtype,
474
475
            parallel_mode,
            tensor_parallel,
476
            sequence_parallel,
477
            tp_group,
Shijie's avatar
Shijie committed
478
479
            fuse_wgrad_accumulation=fuse_wgrad_accumulation,
            accumulate_wgrad_into_param_main_grad=accumulate_wgrad_into_param_main_grad,
480
481
482
483
484
485
486
487
        )
    else:
        dgrad, wgrad, bgrad = _linear_bwd_non_fp8(
            inputmat,
            weight,
            grad_output,
            requires_bgrad,
            requires_dgrad,
Tian Zheng's avatar
Tian Zheng committed
488
            requires_wgrad,
489
            activation_dtype,
490
491
            parallel_mode,
            tensor_parallel,
492
            sequence_parallel,
493
            tp_group,
Shijie's avatar
Shijie committed
494
495
            fuse_wgrad_accumulation=fuse_wgrad_accumulation,
            accumulate_wgrad_into_param_main_grad=accumulate_wgrad_into_param_main_grad,
496
497
        )
    return dgrad, wgrad, bgrad
498
499
500


class _Linear(paddle.autograd.PyLayer):
501
    """TE implementation of Linear"""
502
503
504
505
506

    @staticmethod
    def forward(
        ctx,
        weight: paddle.Tensor,
507
508
        weight_fp8: Optional[paddle.Tensor],
        weight_t_fp8: Optional[paddle.Tensor],
509
510
511
        inp: paddle.Tensor,
        bias: paddle.Tensor,
        use_bias: bool,
512
513
514
        fp8_enabled: bool,
        fp8_calibration: bool,
        fp8_meta: Dict[str, Any],
515
        activation_dtype: paddle.dtype,
516
        is_grad_enabled: bool,
517
518
        parallel_mode: Union[str, None],
        tensor_parallel: bool,
519
        sequence_parallel: bool,
520
521
        tp_group: Union[dist_group_type, None],
        tp_size: int,
Shijie's avatar
Shijie committed
522
        fuse_wgrad_accumulation: bool,
523
        is_first_microbatch: bool,
524
525
526
527
528
    ) -> paddle.Tensor:
        # Make sure input dimensions are compatible
        in_features = weight.shape[-1]
        assert inp.shape[-1] == in_features, "GEMM not possible"
        inputmat = inp.reshape((-1, in_features))
529
530
531
        if fp8_enabled:
            assert_dim_for_fp8_forward_exec(inputmat)
            assert_dim_for_fp8_forward_exec(weight)
532

533
534
535
        inputmat_no_fp8 = inputmat

        # FP8 casting
536
        inputmat_t = None
537
538
        if fp8_enabled:
            fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
539
540
541
542
543
            if (
                not fp8_meta["recipe"].override_linear_precision.wgrad
                and is_grad_enabled
                and not sequence_parallel
            ):
544
545
546
547
548
549
                inputmat, inputmat_t = cast_transpose(
                    inputmat,
                    fp8_meta["scaling_fwd"],
                    FP8FwdTensors.GEMM1_INPUT,
                    fp8_dtype_forward,
                )
550
            else:
551
                inputmat = cast_to_fp8(
552
553
554
555
                    inputmat,
                    fp8_meta["scaling_fwd"],
                    FP8FwdTensors.GEMM1_INPUT,
                    fp8_dtype_forward,
556
                )
557

558
559
        # GEMM Fwd
        out, weight_t_fp8 = _linear_fwd(
560
            inputmat,
561
            FP8FwdTensors.GEMM1_INPUT,
562
            weight,
563
564
            weight_fp8,
            weight_t_fp8,
565
566
567
568
569
570
571
            FP8FwdTensors.GEMM1_WEIGHT,
            bias,
            use_bias,
            fp8_enabled,
            fp8_calibration,
            fp8_meta,
            activation_dtype,
572
573
            parallel_mode,
            tensor_parallel,
574
            sequence_parallel,
575
            tp_group,
576
            is_grad_enabled,
577
            is_first_microbatch,
578
        )
579
580

        if is_grad_enabled:
581
582
583
584
585
            saved_inputmat = None
            if fp8_enabled and sequence_parallel:
                saved_inputmat = inputmat
            else:
                saved_inputmat = inputmat_no_fp8
Tian Zheng's avatar
Tian Zheng committed
586
587
            save_for_backward_allow_none(
                ctx,
588
589
                saved_inputmat,
                inputmat_t,
590
591
592
593
594
595
596
597
598
                weight,
                weight_t_fp8 if fp8_enabled else None,
                fp8_meta["scaling_fwd"].scale_inv.clone() if fp8_enabled else None,
            )
            ctx.activation_dtype = activation_dtype
            ctx.fp8_enabled = fp8_enabled
            ctx.fp8_meta = fp8_meta
            ctx.use_bias = use_bias
            ctx.inp_shape = inp.shape
599
600
            ctx.parallel_mode = parallel_mode
            ctx.tensor_parallel = tensor_parallel
601
            ctx.sequence_parallel = sequence_parallel
602
603
            ctx.tp_group = tp_group
            ctx.tp_size = tp_size
Shijie's avatar
Shijie committed
604
            ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
605
            ctx.requires_dgrad = not inp.stop_gradient
Tian Zheng's avatar
Tian Zheng committed
606
            ctx.requires_wgrad = not weight.stop_gradient
607
            ctx.requires_bgrad = use_bias and not bias.stop_gradient
608
            ctx.is_first_microbatch = is_first_microbatch
609
610
611
612
613

        return out.reshape((-1, *inp.shape[1:-1], out.shape[-1]))

    @staticmethod
    def backward(ctx, grad_output: paddle.Tensor) -> Tuple[Union[paddle.Tensor, None], ...]:
614
615
616
        with TransformerEngineBaseLayer.prepare_backward(
            ctx.fp8_enabled, ctx.fp8_meta, ctx.tp_group, ctx.tp_size, name="_Linear"
        ):
Tian Zheng's avatar
Tian Zheng committed
617

618
            (  # pylint: disable=unbalanced-tuple-unpacking
619
620
                inputmat,
                inputmat_t,
621
                weight,
622
623
                weight_t_fp8,
                fwd_scale_inverses,
Tian Zheng's avatar
Tian Zheng committed
624
            ) = saved_tensor_allow_none(ctx)
625
626

            (
627
                grad_output,
628
629
630
                grad_output_c,
                grad_output_t,
                bgrad,
631
632
633
            ) = TransformerEngineBaseLayer.grad_output_preprocess(
                ctx, grad_output, ctx.parallel_mode == "row"
            )
Shijie's avatar
Shijie committed
634
            if ctx.is_first_microbatch is not None:
635
636
637
                accumulate_wgrad_into_param_main_grad = (
                    ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch
                )
Shijie's avatar
Shijie committed
638
639
            else:
                accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation
640

641
            dgrad, wgrad, bgrad_ = _linear_bwd(
642
                inputmat,
643
644
645
646
647
                inputmat_t,
                FP8FwdTensors.GEMM1_INPUT,
                weight,
                weight_t_fp8,
                FP8FwdTensors.GEMM1_WEIGHT,
648
                grad_output,
649
650
651
652
653
654
655
656
                grad_output_c,
                grad_output_t,
                FP8BwdTensors.GRAD_OUTPUT1,
                fwd_scale_inverses,
                ctx.requires_bgrad,
                ctx.fp8_enabled,
                ctx.fp8_meta,
                ctx.requires_dgrad,
Tian Zheng's avatar
Tian Zheng committed
657
                ctx.requires_wgrad,
658
                ctx.activation_dtype,
659
660
                ctx.parallel_mode,
                ctx.tensor_parallel,
661
                ctx.sequence_parallel,
662
                ctx.tp_group,
Shijie's avatar
Shijie committed
663
664
                ctx.fuse_wgrad_accumulation,
                accumulate_wgrad_into_param_main_grad,
665
666
            )

667
668
669
670
            if not ctx.fp8_enabled:
                # bgrad is fused with gemm for non-FP8 path
                bgrad = bgrad_

671
672
673
674
675
676
            if not ctx.fp8_enabled or ctx.is_first_microbatch is None:
                weight_cache_grad = ()
            else:
                # weight_fp8 and weight_t_fp8 are stop_gradient tensors
                weight_cache_grad = (None, None)

Shijie's avatar
Shijie committed
677
            dgrad_return = dgrad.reshape(ctx.inp_shape) if ctx.requires_dgrad else None
678
            if not ctx.use_bias:
Shijie's avatar
Shijie committed
679
680
681
682
683
                bgrad_return = ()
            elif ctx.requires_bgrad:
                bgrad_return = (bgrad,)
            else:
                bgrad_return = (None,)
684

Shijie's avatar
Shijie committed
685
686
687
688
689
690
691
692
693
        if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation:
            wgrad = None

        return (
            wgrad if ctx.requires_wgrad else None,
            *weight_cache_grad,
            dgrad_return,
            *bgrad_return,
        )
694
695
696
697
698


class Linear(TransformerEngineBaseLayer):
    """
    Applies a linear transformation to the incoming data :math:`y = xA^T + b`
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720

    Parameters
    ----------
    in_features : int
                 size of each input sample.
    out_features : int
                  size of each output sample.
    weight_attr: Union[paddle.ParamAttr, None], default = None
                optional `paddle.ParamAttr` for weight.
    bias_attr: Union[paddle.ParamAttr, None, bool], default = None
              optional `paddle.ParamAttr` for bias.
    backend: {'transformer_engine', 'paddle'}, default = 'transformer_engine'
             if set to 'paddle', a framework only no-FP8 path is executed with limited optimization.

    Parallelism parameters
    ----------------------
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
    parallel_mode : {None, 'Column', 'Row'}, default = `None`
                   used to decide whether this Linear layer is Column Parallel Linear or Row
                   Parallel Linear as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
                   When set to `None`, no communication is performed.
721
722
    sequence_parallel : bool, default = `False`
                       if set to `True`, uses sequence parallelism.
Shijie's avatar
Shijie committed
723
724
725
726
727
728
729
730
731
732

    Optimization parameters
    -----------------------
    fuse_wgrad_accumulation : bool, default = 'False'
                             if set to `True`, enables fusing of creation and accumulation of
                             the weight gradient. When enabled, it is assumed that the weights
                             have an additional `main_grad` attribute (used instead of the
                             regular `grad`) which is a pre-allocated buffer of the correct
                             size to accumulate gradients in.

733
734
735
736
737
738
739
740
    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        weight_attr: Union[paddle.ParamAttr, None] = None,
        bias_attr: Union[paddle.ParamAttr, None, bool] = None,
741
        parallel_mode: Optional[str] = None,
742
        sequence_parallel: bool = False,
743
        tp_group: Union[dist_group_type, None] = None,
Shijie's avatar
Shijie committed
744
        fuse_wgrad_accumulation: bool = False,
745
        backend: str = "transformer_engine",
746
747
748
749
750
751
752
753
754
    ) -> None:
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.backend = backend
        self._weight_attr = weight_attr
        self._bias_attr = bias_attr
        self._dtype = self._helper.get_default_dtype()

755
        # Set parallel configs
756
757
758
        self.tp_group, self.tp_size = get_tp_group_and_world_size(
            tp_group, enable_tp=parallel_mode is not None
        )
759
760
        self.tensor_parallel = self.tp_size > 1
        self.parallel_mode = parallel_mode
761
762
763
        assert (
            self.parallel_mode in GemmParallelModes
        ), f"parallel_mode {parallel_mode} not supported"
764
765
766
767
768
769

        if self.parallel_mode == "column":
            self.out_features = divide(self.out_features, self.tp_size)
        elif self.parallel_mode == "row":
            self.in_features = divide(self.in_features, self.tp_size)

770
771
        self.sequence_parallel = self.tensor_parallel and sequence_parallel

Shijie's avatar
Shijie committed
772
773
        self.fuse_wgrad_accumulation = fuse_wgrad_accumulation

774
775
776
777
        # Initialize weight parameter
        with track_rng_state(enable=self.tensor_parallel):
            # TE linear weight is in column major
            self.weight = self.create_parameter(
778
779
780
781
782
                shape=(
                    [self.out_features, self.in_features]
                    if self.backend == "transformer_engine"
                    else [self.in_features, self.out_features]
                ),
783
784
785
786
                attr=self._weight_attr,
                dtype=self._dtype,
                is_bias=False,
            )
787
788
789
        set_weight_tensor_dist_attr(
            self.weight, self.tensor_parallel, self.parallel_mode, self.backend
        )
790

791
        # Initialize bias parameter
792
        self.has_bias = self._bias_attr is not False
793
        use_default_bias = self._bias_attr is None or self._bias_attr is True
794
795
        if self.has_bias:
            self.bias = self.create_parameter(
796
                shape=[self.out_features],
797
798
799
800
801
                attr=(
                    self._bias_attr
                    if not use_default_bias
                    else paddle.ParamAttr(initializer=Constant(value=0.0))
                ),
802
803
804
                dtype=self._dtype,
                is_bias=True,
            )
805
806
            if parallel_mode == "column":
                set_tensor_dist_attr(self.bias, self.tensor_parallel, axis=0)
807
808
            if parallel_mode == "row" and self.sequence_parallel:
                mark_as_sequence_parallel_parameter(self.bias)
809
810
811
        else:
            self.bias = None

812
        self.fp8_weights.append(self.weight)
813

814
815
816
817
818
819
820
        # For RPL, bias has to be added after TP collectives
        # So it cannot be fused with the GEMM
        if self.parallel_mode == "row" and self.tensor_parallel and self.has_bias:
            self.gemm_bias_fused_add = False
        else:
            self.gemm_bias_fused_add = True

821
822
823
    def _te_forward(
        self,
        inp: paddle.Tensor,
824
        is_first_microbatch: Optional[bool] = None,
825
826
827
828
    ) -> paddle.Tensor:
        """
        Apply the linear transformation to the input.
        """
829
        with self.prepare_forward(inp, is_first_microbatch=is_first_microbatch) as inp:
830
831
832
833
            # Layer input should be casted outside PyLayer, as performing
            # inplace cast to input tensors may cause problems when used
            # together with Paddle native layers.
            inp = cast_if_needed(inp, self.activation_dtype)
834
835

            # Get persistent fp8 weight buffer. None if buffer does not exist.
836
            weight_fp8, weight_t_fp8 = self.get_fp8_weights_scratchpad_and_cast(is_first_microbatch)
837

838
            out = _Linear.apply(
839
                self.weight,
840
841
                weight_fp8,
                weight_t_fp8,
842
                inp,
843
844
                self.bias if self.gemm_bias_fused_add else None,
                self.has_bias and self.gemm_bias_fused_add,
845
846
847
                self.fp8_enabled,
                self.fp8_calibration,
                self.fp8_meta,
848
                self.activation_dtype,
849
                paddle.is_grad_enabled(),
850
851
                self.parallel_mode,
                self.tensor_parallel,
852
                self.sequence_parallel,
853
854
                self.tp_group,
                self.tp_size,
Shijie's avatar
Shijie committed
855
                self.fuse_wgrad_accumulation,
856
                is_first_microbatch,
857
858
            )

859
860
861
        if not self.gemm_bias_fused_add:
            out = out + cast_if_needed_inplace(self.bias, self.activation_dtype)

862
863
864
865
866
        return out

    def _pd_forward(
        self,
        inp: paddle.Tensor,
867
        is_first_microbatch: Optional[bool] = None,
868
869
    ) -> paddle.Tensor:
        """Calls Paddle OP"""
870
871
        if is_first_microbatch is not None:
            warnings.warn(
872
873
874
                "`is_first_microbatch` is not supported for paddle backend and is ignored."
            )
        if self.parallel_mode == "column" and self.tensor_parallel:
875
876
            inp = identity(inp, self.tp_group)
        out = F.linear(inp, self.weight, self.bias if self.gemm_bias_fused_add else None)
877
        if self.parallel_mode == "row" and self.tensor_parallel:
Tian Zheng's avatar
Tian Zheng committed
878
            out, _ = allreduce(out, self.tp_group)
879
880
            out = out + self.bias if self.bias is not None else out
        return out
881
882

    def forward(self, *args, **kwargs):
883
884
885
886
887
        """
        Apply the linear transformation to the input.

        Parameters
        ----------
888
        inp : paddle.Tensor
889
             Input tensor.
890
891
892
893
894
895
896
897
898
899
        is_first_microbatch : {True, False, None}, default = None
                             During training using either gradient accumulation or
                             pipeline parallelism a minibatch of data is further split
                             into microbatches. Between the microbatches of the same minibatch
                             the model weights are not updated. Setting this parameter indicates
                             whether the current microbatch is the first in a minibatch or not.
                             When set, this parameter enables additional optimizations:

                             * during FP8 training, it allows caching of the FP8 versions of
                               the weights
900
        """
901
        if self.backend == "transformer_engine":
902
            return self._te_forward(*args, **kwargs)
903
        if self.backend == "paddle":
904
905
            return self._pd_forward(*args, **kwargs)
        raise AttributeError(f"Backend {self.backend} is not supported.")