linear.py 30.9 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.
"""Linear API"""

6
import warnings
7
from typing import Union, Tuple, Dict, Any, Optional
8
9
10
11
12

import paddle
import paddle.nn.functional as F
from paddle.nn.initializer import Constant

13
14
15
16
17
18
19
from .base import (
    TransformerEngineBaseLayer,
    get_workspace,
    _2X_ACC_FPROP,
    _2X_ACC_DGRAD,
    _2X_ACC_WGRAD,
)
20

21
from ..constants import FP8FwdTensors, FP8BwdTensors, GemmParallelModes, dist_group_type
22
from ..cpp_extensions import gemm, fp8_gemm, cast_to_fp8, cast_transpose, transpose
23
from ..distributed import (
24
    allgather,
25
26
27
    allreduce,
    get_tp_group_and_world_size,
    identity,
28
    reduce_scatter,
29
30
31
    track_rng_state,
    set_tensor_dist_attr,
    set_weight_tensor_dist_attr,
32
    mark_as_sequence_parallel_parameter,
33
)
34
from ..fp8 import get_fp8_te_dtype, get_global_fp8_state
35
from ..utils import (
36
    assert_dim_for_fp8_forward_exec,
37
38
    cast_if_needed,
    cast_if_needed_inplace,
39
    divide,
40
    get_bias_dtype,
Tian Zheng's avatar
Tian Zheng committed
41
42
    save_for_backward_allow_none,
    saved_tensor_allow_none,
43
    clear_tensor_data,
44
45
)

46
__all__ = ["Linear"]
47
48
49
50
51
52


def _linear_fwd_fp8(
    inputmat: paddle.Tensor,
    inputmat_fp8_index: FP8FwdTensors,
    weight: paddle.Tensor,
53
54
    weight_fp8: Optional[paddle.Tensor],
    weight_t_fp8: Optional[paddle.Tensor],
55
56
57
58
59
    weight_fp8_index: FP8FwdTensors,
    bias: paddle.Tensor,
    use_bias: bool,
    fp8_meta: Dict[str, Any],
    activation_dtype: paddle.dtype,
60
61
    parallel_mode: Union[str, None],
    tensor_parallel: bool,
62
    sequence_parallel: bool,
63
    tp_group: Union[dist_group_type, None],
64
    is_grad_enabled: bool,
65
    is_first_microbatch: bool = None,
66
67
68
69
):
    """FP8 path of Linear Fwd"""
    fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
    bias_dtype = get_bias_dtype(activation_dtype)
70
    bias = cast_if_needed(bias, bias_dtype)
71

72
73
74
75
76
    if parallel_mode == "column" and sequence_parallel:
        inputmat_total, _ = allgather(inputmat, tp_group)
    else:
        inputmat_total = inputmat

77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
    if not get_global_fp8_state().is_cudagraph_enabled():
        # if cuda graph is not enabled, we cast the weight here
        update_fp8_weights = is_first_microbatch is None or is_first_microbatch
        if is_grad_enabled:
            if update_fp8_weights:
                weight_fp8, weight_t_fp8 = cast_transpose(
                    weight,
                    fp8_meta["scaling_fwd"],
                    weight_fp8_index,
                    fp8_dtype_forward,
                    cast_out=weight_fp8,
                    transpose_out=weight_t_fp8,
                )
        else:
            weight_t_fp8 = None
            if update_fp8_weights:
                weight_fp8 = cast_to_fp8(
                    weight,
                    fp8_meta["scaling_fwd"],
                    weight_fp8_index,
                    fp8_dtype_forward,
                    out=weight_fp8,
                )
100

Shijie's avatar
Shijie committed
101
    out, _ = fp8_gemm(
102
103
104
105
        weight_fp8,
        fp8_meta["scaling_fwd"].scale_inv,
        weight_fp8_index,
        fp8_dtype_forward,
106
        inputmat_total,
107
108
109
110
111
112
113
114
115
116
        fp8_meta["scaling_fwd"].scale_inv,
        inputmat_fp8_index,
        fp8_dtype_forward,
        activation_dtype,
        get_workspace(),
        bias=bias,
        use_bias=use_bias,
        use_split_accumulator=_2X_ACC_FPROP,
    )

117
118
119
    if parallel_mode == "row" and sequence_parallel:
        out, _ = reduce_scatter(out, tp_group)
    elif parallel_mode == "row" and tensor_parallel:
Tian Zheng's avatar
Tian Zheng committed
120
        out, _ = allreduce(out, tp_group)
121

122
123
124
125
126
127
128
129
130
131
132
133
134
    return out, weight_t_fp8


def _linear_fwd_non_fp8(
    inputmat: paddle.Tensor,
    inputmat_fp8_index: FP8FwdTensors,
    weight: paddle.Tensor,
    weight_fp8_index: FP8FwdTensors,
    bias: paddle.Tensor,
    use_bias: bool,
    fp8_calibration: bool,
    fp8_meta: Dict[str, Any],
    activation_dtype: paddle.dtype,
135
136
    parallel_mode: Union[str, None],
    tensor_parallel: bool,
137
    sequence_parallel: bool,
138
    tp_group: Union[dist_group_type, None],
139
140
141
142
    activation: str = "",
):
    """Non-FP8 path of Linear Fwd"""

143
144
145
146
147
    if parallel_mode == "column" and sequence_parallel:
        inputmat_total, _ = allgather(inputmat, tp_group)
    else:
        inputmat_total = inputmat

148
149
150
151
152
153
154
155
156
    # Layer parameters are initialized as float32 dtype by default.
    # Cast the parameters to activation_dtype if the current dtype
    # does not match activation_dtype. The casting is inplace, so it
    # only needs to performed once throughout the traing process.
    weight = cast_if_needed_inplace(weight, activation_dtype)
    bias = cast_if_needed_inplace(bias, activation_dtype)

    if fp8_calibration:
        # amax of input
157
158
159
        fp8_meta["scaling_fwd"].amax_history[0, inputmat_fp8_index.value] = paddle.max(
            paddle.abs(inputmat_total)
        ).item()
160
        # amax of weight
161
162
163
        fp8_meta["scaling_fwd"].amax_history[0, weight_fp8_index.value] = paddle.max(
            paddle.abs(weight)
        ).item()
164
        fp8_meta["update_amax_and_scale_fwd"] = True
165

166
167
168
169
170
171
172
173
174
    outputs = gemm(
        weight,
        inputmat_total,
        activation_dtype,
        get_workspace(),
        bias=bias,
        use_bias=use_bias,
        gelu=(activation == "gelu"),
    )
175

176
    if activation == "gelu":
177
178
179
180
        gelu_out, _, out = outputs
        return out, gelu_out

    out, _, _ = outputs
181
182
183
184

    if parallel_mode == "row" and sequence_parallel:
        out, _ = reduce_scatter(out, tp_group)
    elif parallel_mode == "row" and tensor_parallel:
Tian Zheng's avatar
Tian Zheng committed
185
        out, _ = allreduce(out, tp_group)
186
187
188
189
190
191
192
    return out


def _linear_fwd(
    inputmat: paddle.Tensor,
    inputmat_fp8_index: FP8FwdTensors,
    weight: paddle.Tensor,
193
194
    weight_fp8: Optional[paddle.Tensor],
    weight_t_fp8: Optional[paddle.Tensor],
195
196
197
198
199
200
201
    weight_fp8_index: FP8FwdTensors,
    bias: paddle.Tensor,
    use_bias: bool,
    fp8_enabled: bool,
    fp8_calibration: bool,
    fp8_meta: Dict[str, Any],
    activation_dtype: paddle.dtype,
202
203
    parallel_mode: Union[str, None],
    tensor_parallel: bool,
204
    sequence_parallel: bool,
205
    tp_group: Union[dist_group_type, None],
206
    is_grad_enabled: bool,
207
    is_first_microbatch: bool = None,
208
    gather_output: bool = False,
209
210
211
212
213
214
):
    if fp8_enabled:
        out, weight_t_fp8 = _linear_fwd_fp8(
            inputmat,
            inputmat_fp8_index,
            weight,
215
216
            weight_fp8,
            weight_t_fp8,
217
218
219
220
221
            weight_fp8_index,
            bias,
            use_bias,
            fp8_meta,
            activation_dtype,
222
223
            parallel_mode,
            tensor_parallel,
224
            sequence_parallel,
225
            tp_group,
226
            is_grad_enabled,
227
            is_first_microbatch,
228
229
230
231
232
233
234
235
236
237
238
239
        )
    else:
        out = _linear_fwd_non_fp8(
            inputmat,
            inputmat_fp8_index,
            weight,
            weight_fp8_index,
            bias,
            use_bias,
            fp8_calibration,
            fp8_meta,
            activation_dtype,
240
241
            parallel_mode,
            tensor_parallel,
242
            sequence_parallel,
243
            tp_group,
244
        )
245
246
247
    if gather_output and tensor_parallel and parallel_mode == "column":
        out, _ = allgather(out, tp_group, axis=-1)

248
249
250
251
252
253
254
255
256
257
    return (
        out,
        weight_t_fp8 if fp8_enabled else None,
    )


def _linear_bwd_fp8(
    inputmat: paddle.Tensor,
    inputmat_t: paddle.Tensor,
    inputmat_fp8_index: FP8FwdTensors,
Shijie's avatar
Shijie committed
258
    weight: paddle.Tensor,
259
260
261
262
263
264
265
266
267
268
269
    weight_t_fp8: paddle.Tensor,
    weight_fp8_index: FP8FwdTensors,
    grad_output: paddle.Tensor,
    grad_output_c: paddle.Tensor,
    grad_output_t: paddle.Tensor,
    grad_output_fp8_index: FP8BwdTensors,
    fwd_scale_inverses: paddle.Tensor,
    fp8_meta: Dict[str, Any],
    requires_dgrad: bool,
    requires_wgrad: bool,
    activation_dtype: paddle.dtype,
270
271
    parallel_mode: Union[str, None],
    tensor_parallel: bool,
272
    sequence_parallel: bool,
273
    tp_group: Union[dist_group_type, None],
Shijie's avatar
Shijie committed
274
275
    fuse_wgrad_accumulation: bool,
    accumulate_wgrad_into_param_main_grad: bool,
276
):
Tian Zheng's avatar
Tian Zheng committed
277
    dgrad, wgrad, handle = None, None, None
278
279
280
281
282
283
284
285
286
287

    # Overlap input AG with dgrad
    inputmat_total = None
    inputmat_t_total = None
    if requires_wgrad and parallel_mode == "column" and sequence_parallel:
        inputmat_total, handle = allgather(inputmat, tp_group, sync_op=not requires_dgrad)
    else:
        inputmat_total = inputmat
        inputmat_t_total = inputmat_t

288
289
290
    fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
    fp8_dtype_backward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=False)
    if requires_dgrad:
Shijie's avatar
Shijie committed
291
        dgrad, _ = fp8_gemm(
292
293
294
295
296
297
298
299
300
301
302
303
            weight_t_fp8,
            fwd_scale_inverses,
            weight_fp8_index,
            fp8_dtype_forward,
            grad_output_c,
            fp8_meta["scaling_bwd"].scale_inv,
            grad_output_fp8_index,
            fp8_dtype_backward,
            activation_dtype,
            get_workspace(),
            use_split_accumulator=_2X_ACC_DGRAD,
        )
304
        clear_tensor_data(grad_output_c)
305
306
307
308
309
310
311

        # Overlap dgrad-RS/AR with wgrad
        if parallel_mode == "column" and sequence_parallel:
            if handle is not None:
                handle.wait()
            dgrad, handle = reduce_scatter(dgrad, tp_group, sync_op=False)
        elif parallel_mode == "column" and tensor_parallel:
Tian Zheng's avatar
Tian Zheng committed
312
            dgrad, handle = allreduce(dgrad, tp_group, sync_op=False)
313

314
315
    if requires_wgrad:
        if not fp8_meta["recipe"].override_linear_precision.wgrad:
316
317
            if inputmat_t_total is None:
                inputmat_t_total = transpose(inputmat_total, fp8_dtype_backward)
318
                clear_tensor_data(inputmat_total)
Shijie's avatar
Shijie committed
319
320

            wgrad, _ = fp8_gemm(
321
                inputmat_t_total,
322
323
324
325
326
327
328
                fwd_scale_inverses,
                inputmat_fp8_index,
                fp8_dtype_forward,
                grad_output_t,
                fp8_meta["scaling_bwd"].scale_inv,
                grad_output_fp8_index,
                fp8_dtype_backward,
Shijie's avatar
Shijie committed
329
                "float32" if fuse_wgrad_accumulation else activation_dtype,
330
                get_workspace(),
Shijie's avatar
Shijie committed
331
332
                accumulate=accumulate_wgrad_into_param_main_grad,
                out=weight.main_grad if fuse_wgrad_accumulation else None,
333
334
                use_split_accumulator=_2X_ACC_WGRAD,
            )
335
            clear_tensor_data(inputmat_t_total, grad_output_t)
336
337
        else:
            wgrad, _, _ = gemm(
338
                inputmat_total,
339
340
341
342
                grad_output,
                activation_dtype,
                get_workspace(),
                grad=True,
Shijie's avatar
Shijie committed
343
344
345
346
                accumulate=accumulate_wgrad_into_param_main_grad,
                layout="NT",
                out=weight.main_grad if fuse_wgrad_accumulation else None,
                out_dtype="float32" if fuse_wgrad_accumulation else None,
347
            )
348
            clear_tensor_data(inputmat_total)
Tian Zheng's avatar
Tian Zheng committed
349

Shijie's avatar
Shijie committed
350
351
352
        if fuse_wgrad_accumulation:
            weight.main_grad = wgrad

Tian Zheng's avatar
Tian Zheng committed
353
354
    if parallel_mode == "column" and tensor_parallel and handle is not None:
        handle.wait()
355
356
    if parallel_mode == "column" and sequence_parallel:
        handle.wait()
Tian Zheng's avatar
Tian Zheng committed
357

358
359
360
361
362
363
364
365
366
    return dgrad, wgrad


def _linear_bwd_non_fp8(
    inputmat: paddle.Tensor,
    weight: paddle.Tensor,
    grad_output: paddle.Tensor,
    requires_bgrad: bool,
    requires_dgrad: bool,
Tian Zheng's avatar
Tian Zheng committed
367
    requires_wgrad: bool,
368
    activation_dtype: paddle.dtype,
369
370
    parallel_mode: Union[str, None],
    tensor_parallel: bool,
371
    sequence_parallel: bool,
372
    tp_group: Union[dist_group_type, None],
Shijie's avatar
Shijie committed
373
374
    fuse_wgrad_accumulation: bool,
    accumulate_wgrad_into_param_main_grad: bool,
375
376
377
378
379
380
    gelu_input: Union[paddle.Tensor, None] = None,
    activation: str = "",
):
    """
    Performs Linear Backward. Optionally, fuses GELU backward and dbias.
    """
Tian Zheng's avatar
Tian Zheng committed
381
    dgrad, wgrad, bgrad, handle = None, None, None, None
382
383
384
385
386
387
388
389

    # Overlap input AG with dgrad
    inputmat_total = None
    if requires_wgrad and parallel_mode == "column" and sequence_parallel:
        inputmat_total, handle = allgather(inputmat, tp_group, sync_op=not requires_dgrad)
    else:
        inputmat_total = inputmat

390
391
392
393
394
395
396
    if requires_dgrad:
        dgrad, _, _ = gemm(
            weight,
            grad_output,
            activation_dtype,
            get_workspace(),
            layout="NN",
397
            gelu=(activation == "gelu"),
398
399
400
            gelu_input=gelu_input,
            grad=True,
        )
401
402
403
404
405
406
        # Overlap dgrad-RS/AR with wgrad
        if parallel_mode == "column" and sequence_parallel:
            if handle is not None:
                handle.wait()
            dgrad, handle = reduce_scatter(dgrad, tp_group, sync_op=False)
        elif parallel_mode == "column" and tensor_parallel:
Tian Zheng's avatar
Tian Zheng committed
407
            dgrad, handle = allreduce(dgrad, tp_group, sync_op=False)
408

409
410
    if requires_wgrad:
        wgrad, bgrad, _ = gemm(
411
            inputmat_total,
412
413
414
415
            grad_output,
            activation_dtype,
            get_workspace(),
            grad=True,
Shijie's avatar
Shijie committed
416
417
418
419
            accumulate=accumulate_wgrad_into_param_main_grad,
            layout="NT",
            out=weight.main_grad if fuse_wgrad_accumulation else None,
            out_dtype="float32" if fuse_wgrad_accumulation else None,
420
421
            use_bias=requires_bgrad,
        )
Shijie's avatar
Shijie committed
422
423
424
        if fuse_wgrad_accumulation:
            weight.main_grad = wgrad

425
426
    elif requires_bgrad:
        bgrad = grad_output.sum(axis=0)
Tian Zheng's avatar
Tian Zheng committed
427
428
    if parallel_mode == "column" and tensor_parallel and handle is not None:
        handle.wait()
429
430
    if parallel_mode == "column" and sequence_parallel and handle is not None:
        handle.wait()
Tian Zheng's avatar
Tian Zheng committed
431

432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
    return dgrad, wgrad, bgrad


def _linear_bwd(
    inputmat: paddle.Tensor,
    inputmat_t: paddle.Tensor,
    inputmat_fp8_index: FP8FwdTensors,
    weight: paddle.Tensor,
    weight_t_fp8: paddle.Tensor,
    weight_fp8_index: FP8FwdTensors,
    grad_output: paddle.Tensor,
    grad_output_c: paddle.Tensor,
    grad_output_t: paddle.Tensor,
    grad_output_fp8_index: FP8BwdTensors,
    fwd_scale_inverses: paddle.Tensor,
    requires_bgrad: bool,
    fp8_enabled: bool,
    fp8_meta: Dict[str, Any],
    requires_dgrad: bool,
Tian Zheng's avatar
Tian Zheng committed
451
    requires_wgrad: bool,
452
    activation_dtype: paddle.dtype,
453
454
    parallel_mode: Union[str, None],
    tensor_parallel: bool,
455
    sequence_parallel: bool,
456
    tp_group: Union[dist_group_type, None],
Shijie's avatar
Shijie committed
457
458
    fuse_wgrad_accumulation: bool,
    accumulate_wgrad_into_param_main_grad: bool,
459
460
461
462
463
464
465
):
    dgrad, wgrad, bgrad = None, None, None
    if fp8_enabled:
        dgrad, wgrad = _linear_bwd_fp8(
            inputmat,
            inputmat_t,
            inputmat_fp8_index,
Shijie's avatar
Shijie committed
466
            weight,
467
468
469
470
471
472
473
474
475
476
477
            weight_t_fp8,
            weight_fp8_index,
            grad_output,
            grad_output_c,
            grad_output_t,
            grad_output_fp8_index,
            fwd_scale_inverses,
            fp8_meta,
            requires_dgrad,
            requires_wgrad,
            activation_dtype,
478
479
            parallel_mode,
            tensor_parallel,
480
            sequence_parallel,
481
            tp_group,
Shijie's avatar
Shijie committed
482
483
            fuse_wgrad_accumulation=fuse_wgrad_accumulation,
            accumulate_wgrad_into_param_main_grad=accumulate_wgrad_into_param_main_grad,
484
485
486
487
488
489
490
491
        )
    else:
        dgrad, wgrad, bgrad = _linear_bwd_non_fp8(
            inputmat,
            weight,
            grad_output,
            requires_bgrad,
            requires_dgrad,
Tian Zheng's avatar
Tian Zheng committed
492
            requires_wgrad,
493
            activation_dtype,
494
495
            parallel_mode,
            tensor_parallel,
496
            sequence_parallel,
497
            tp_group,
Shijie's avatar
Shijie committed
498
499
            fuse_wgrad_accumulation=fuse_wgrad_accumulation,
            accumulate_wgrad_into_param_main_grad=accumulate_wgrad_into_param_main_grad,
500
501
        )
    return dgrad, wgrad, bgrad
502
503
504


class _Linear(paddle.autograd.PyLayer):
505
    """TE implementation of Linear"""
506
507
508
509
510

    @staticmethod
    def forward(
        ctx,
        weight: paddle.Tensor,
511
512
        weight_fp8: Optional[paddle.Tensor],
        weight_t_fp8: Optional[paddle.Tensor],
513
514
515
        inp: paddle.Tensor,
        bias: paddle.Tensor,
        use_bias: bool,
516
517
518
        fp8_enabled: bool,
        fp8_calibration: bool,
        fp8_meta: Dict[str, Any],
519
        activation_dtype: paddle.dtype,
520
        is_grad_enabled: bool,
521
522
        parallel_mode: Union[str, None],
        tensor_parallel: bool,
523
        sequence_parallel: bool,
524
525
        tp_group: Union[dist_group_type, None],
        tp_size: int,
Shijie's avatar
Shijie committed
526
        fuse_wgrad_accumulation: bool,
527
        is_first_microbatch: bool,
528
        gather_output: bool,
529
530
531
532
533
    ) -> paddle.Tensor:
        # Make sure input dimensions are compatible
        in_features = weight.shape[-1]
        assert inp.shape[-1] == in_features, "GEMM not possible"
        inputmat = inp.reshape((-1, in_features))
534
535
536
        if fp8_enabled:
            assert_dim_for_fp8_forward_exec(inputmat)
            assert_dim_for_fp8_forward_exec(weight)
537

538
539
540
        inputmat_no_fp8 = inputmat

        # FP8 casting
541
        inputmat_t = None
542
543
        if fp8_enabled:
            fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
544
545
546
547
548
            if (
                not fp8_meta["recipe"].override_linear_precision.wgrad
                and is_grad_enabled
                and not sequence_parallel
            ):
549
550
551
552
553
554
                inputmat, inputmat_t = cast_transpose(
                    inputmat,
                    fp8_meta["scaling_fwd"],
                    FP8FwdTensors.GEMM1_INPUT,
                    fp8_dtype_forward,
                )
555
            else:
556
                inputmat = cast_to_fp8(
557
558
559
560
                    inputmat,
                    fp8_meta["scaling_fwd"],
                    FP8FwdTensors.GEMM1_INPUT,
                    fp8_dtype_forward,
561
                )
562

563
564
        # GEMM Fwd
        out, weight_t_fp8 = _linear_fwd(
565
            inputmat,
566
            FP8FwdTensors.GEMM1_INPUT,
567
            weight,
568
569
            weight_fp8,
            weight_t_fp8,
570
571
572
573
574
575
576
            FP8FwdTensors.GEMM1_WEIGHT,
            bias,
            use_bias,
            fp8_enabled,
            fp8_calibration,
            fp8_meta,
            activation_dtype,
577
578
            parallel_mode,
            tensor_parallel,
579
            sequence_parallel,
580
            tp_group,
581
            is_grad_enabled,
582
            is_first_microbatch,
583
            gather_output,
584
        )
585
586

        if is_grad_enabled:
587
588
589
590
591
            saved_inputmat = None
            if fp8_enabled and sequence_parallel:
                saved_inputmat = inputmat
            else:
                saved_inputmat = inputmat_no_fp8
Tian Zheng's avatar
Tian Zheng committed
592
593
            save_for_backward_allow_none(
                ctx,
594
595
                saved_inputmat,
                inputmat_t,
596
597
598
599
600
601
602
603
604
                weight,
                weight_t_fp8 if fp8_enabled else None,
                fp8_meta["scaling_fwd"].scale_inv.clone() if fp8_enabled else None,
            )
            ctx.activation_dtype = activation_dtype
            ctx.fp8_enabled = fp8_enabled
            ctx.fp8_meta = fp8_meta
            ctx.use_bias = use_bias
            ctx.inp_shape = inp.shape
605
606
            ctx.parallel_mode = parallel_mode
            ctx.tensor_parallel = tensor_parallel
607
            ctx.sequence_parallel = sequence_parallel
608
609
            ctx.tp_group = tp_group
            ctx.tp_size = tp_size
Shijie's avatar
Shijie committed
610
            ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
611
            ctx.requires_dgrad = not inp.stop_gradient
Tian Zheng's avatar
Tian Zheng committed
612
            ctx.requires_wgrad = not weight.stop_gradient
613
            ctx.requires_bgrad = use_bias and not bias.stop_gradient
614
            ctx.is_first_microbatch = is_first_microbatch
615
            ctx.reduce_scatter_output = gather_output
616
617
618
619
620

        return out.reshape((-1, *inp.shape[1:-1], out.shape[-1]))

    @staticmethod
    def backward(ctx, grad_output: paddle.Tensor) -> Tuple[Union[paddle.Tensor, None], ...]:
621
622
623
        with TransformerEngineBaseLayer.prepare_backward(
            ctx.fp8_enabled, ctx.fp8_meta, ctx.tp_group, ctx.tp_size, name="_Linear"
        ):
Tian Zheng's avatar
Tian Zheng committed
624

625
            (  # pylint: disable=unbalanced-tuple-unpacking
626
627
                inputmat,
                inputmat_t,
628
                weight,
629
630
                weight_t_fp8,
                fwd_scale_inverses,
Tian Zheng's avatar
Tian Zheng committed
631
            ) = saved_tensor_allow_none(ctx)
632
633

            (
634
                grad_output,
635
636
637
                grad_output_c,
                grad_output_t,
                bgrad,
638
639
640
            ) = TransformerEngineBaseLayer.grad_output_preprocess(
                ctx, grad_output, ctx.parallel_mode == "row"
            )
Shijie's avatar
Shijie committed
641
            if ctx.is_first_microbatch is not None:
642
643
644
                accumulate_wgrad_into_param_main_grad = (
                    ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch
                )
Shijie's avatar
Shijie committed
645
646
            else:
                accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation
647

648
            dgrad, wgrad, bgrad_ = _linear_bwd(
649
                inputmat,
650
651
652
653
654
                inputmat_t,
                FP8FwdTensors.GEMM1_INPUT,
                weight,
                weight_t_fp8,
                FP8FwdTensors.GEMM1_WEIGHT,
655
                grad_output,
656
657
658
659
660
661
662
663
                grad_output_c,
                grad_output_t,
                FP8BwdTensors.GRAD_OUTPUT1,
                fwd_scale_inverses,
                ctx.requires_bgrad,
                ctx.fp8_enabled,
                ctx.fp8_meta,
                ctx.requires_dgrad,
Tian Zheng's avatar
Tian Zheng committed
664
                ctx.requires_wgrad,
665
                ctx.activation_dtype,
666
667
                ctx.parallel_mode,
                ctx.tensor_parallel,
668
                ctx.sequence_parallel,
669
                ctx.tp_group,
Shijie's avatar
Shijie committed
670
671
                ctx.fuse_wgrad_accumulation,
                accumulate_wgrad_into_param_main_grad,
672
673
            )

674
675
676
677
            if not ctx.fp8_enabled:
                # bgrad is fused with gemm for non-FP8 path
                bgrad = bgrad_

678
679
680
681
            if ctx.reduce_scatter_output:
                wgrad, _ = reduce_scatter(wgrad, ctx.tp_group)
                bgrad, _ = reduce_scatter(bgrad, ctx.tp_group)

682
683
684
685
686
687
            if not ctx.fp8_enabled or ctx.is_first_microbatch is None:
                weight_cache_grad = ()
            else:
                # weight_fp8 and weight_t_fp8 are stop_gradient tensors
                weight_cache_grad = (None, None)

Shijie's avatar
Shijie committed
688
            dgrad_return = dgrad.reshape(ctx.inp_shape) if ctx.requires_dgrad else None
689
            if not ctx.use_bias:
Shijie's avatar
Shijie committed
690
691
692
693
694
                bgrad_return = ()
            elif ctx.requires_bgrad:
                bgrad_return = (bgrad,)
            else:
                bgrad_return = (None,)
695

Shijie's avatar
Shijie committed
696
697
698
699
700
701
702
703
704
        if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation:
            wgrad = None

        return (
            wgrad if ctx.requires_wgrad else None,
            *weight_cache_grad,
            dgrad_return,
            *bgrad_return,
        )
705
706
707
708
709


class Linear(TransformerEngineBaseLayer):
    """
    Applies a linear transformation to the incoming data :math:`y = xA^T + b`
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731

    Parameters
    ----------
    in_features : int
                 size of each input sample.
    out_features : int
                  size of each output sample.
    weight_attr: Union[paddle.ParamAttr, None], default = None
                optional `paddle.ParamAttr` for weight.
    bias_attr: Union[paddle.ParamAttr, None, bool], default = None
              optional `paddle.ParamAttr` for bias.
    backend: {'transformer_engine', 'paddle'}, default = 'transformer_engine'
             if set to 'paddle', a framework only no-FP8 path is executed with limited optimization.

    Parallelism parameters
    ----------------------
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
    parallel_mode : {None, 'Column', 'Row'}, default = `None`
                   used to decide whether this Linear layer is Column Parallel Linear or Row
                   Parallel Linear as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
                   When set to `None`, no communication is performed.
732
733
    sequence_parallel : bool, default = `False`
                       if set to `True`, uses sequence parallelism.
Shijie's avatar
Shijie committed
734
735
736
737
738
739
740
741
742
743

    Optimization parameters
    -----------------------
    fuse_wgrad_accumulation : bool, default = 'False'
                             if set to `True`, enables fusing of creation and accumulation of
                             the weight gradient. When enabled, it is assumed that the weights
                             have an additional `main_grad` attribute (used instead of the
                             regular `grad`) which is a pre-allocated buffer of the correct
                             size to accumulate gradients in.

744
745
746
747
748
749
750
751
    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        weight_attr: Union[paddle.ParamAttr, None] = None,
        bias_attr: Union[paddle.ParamAttr, None, bool] = None,
752
        parallel_mode: Optional[str] = None,
753
        sequence_parallel: bool = False,
754
        tp_group: Union[dist_group_type, None] = None,
Shijie's avatar
Shijie committed
755
        fuse_wgrad_accumulation: bool = False,
756
        gather_output: bool = False,
757
        backend: str = "transformer_engine",
758
759
760
761
762
763
764
765
    ) -> None:
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.backend = backend
        self._weight_attr = weight_attr
        self._bias_attr = bias_attr
        self._dtype = self._helper.get_default_dtype()
766
        self.gather_output = gather_output
767

768
        # Set parallel configs
769
770
771
        self.tp_group, self.tp_size = get_tp_group_and_world_size(
            tp_group, enable_tp=parallel_mode is not None
        )
772
773
        self.tensor_parallel = self.tp_size > 1
        self.parallel_mode = parallel_mode
774
775
776
        assert (
            self.parallel_mode in GemmParallelModes
        ), f"parallel_mode {parallel_mode} not supported"
777
778
779
780
781
782

        if self.parallel_mode == "column":
            self.out_features = divide(self.out_features, self.tp_size)
        elif self.parallel_mode == "row":
            self.in_features = divide(self.in_features, self.tp_size)

783
784
        self.sequence_parallel = self.tensor_parallel and sequence_parallel

Shijie's avatar
Shijie committed
785
786
        self.fuse_wgrad_accumulation = fuse_wgrad_accumulation

787
788
789
790
        # Initialize weight parameter
        with track_rng_state(enable=self.tensor_parallel):
            # TE linear weight is in column major
            self.weight = self.create_parameter(
791
792
793
794
795
                shape=(
                    [self.out_features, self.in_features]
                    if self.backend == "transformer_engine"
                    else [self.in_features, self.out_features]
                ),
796
797
798
799
                attr=self._weight_attr,
                dtype=self._dtype,
                is_bias=False,
            )
800
801
802
        set_weight_tensor_dist_attr(
            self.weight, self.tensor_parallel, self.parallel_mode, self.backend
        )
803

804
        # Initialize bias parameter
805
        self.has_bias = self._bias_attr is not False
806
        use_default_bias = self._bias_attr is None or self._bias_attr is True
807
808
        if self.has_bias:
            self.bias = self.create_parameter(
809
                shape=[self.out_features],
810
811
812
813
814
                attr=(
                    self._bias_attr
                    if not use_default_bias
                    else paddle.ParamAttr(initializer=Constant(value=0.0))
                ),
815
816
817
                dtype=self._dtype,
                is_bias=True,
            )
818
819
            if parallel_mode == "column":
                set_tensor_dist_attr(self.bias, self.tensor_parallel, axis=0)
820
821
            if parallel_mode == "row" and self.sequence_parallel:
                mark_as_sequence_parallel_parameter(self.bias)
822
823
824
        else:
            self.bias = None

825
        self.fp8_weights.append(self.weight)
826

827
828
829
830
831
832
833
        # For RPL, bias has to be added after TP collectives
        # So it cannot be fused with the GEMM
        if self.parallel_mode == "row" and self.tensor_parallel and self.has_bias:
            self.gemm_bias_fused_add = False
        else:
            self.gemm_bias_fused_add = True

834
835
836
    def _te_forward(
        self,
        inp: paddle.Tensor,
837
        is_first_microbatch: Optional[bool] = None,
838
839
840
841
    ) -> paddle.Tensor:
        """
        Apply the linear transformation to the input.
        """
842
        with self.prepare_forward(inp, is_first_microbatch=is_first_microbatch) as inp:
843
844
845
846
            # Layer input should be casted outside PyLayer, as performing
            # inplace cast to input tensors may cause problems when used
            # together with Paddle native layers.
            inp = cast_if_needed(inp, self.activation_dtype)
847
848

            # Get persistent fp8 weight buffer. None if buffer does not exist.
849
            weight_fp8, weight_t_fp8 = self.get_fp8_weights_scratchpad_and_cast(is_first_microbatch)
850

851
            out = _Linear.apply(
852
                self.weight,
853
854
                weight_fp8,
                weight_t_fp8,
855
                inp,
856
857
                self.bias if self.gemm_bias_fused_add else None,
                self.has_bias and self.gemm_bias_fused_add,
858
859
860
                self.fp8_enabled,
                self.fp8_calibration,
                self.fp8_meta,
861
                self.activation_dtype,
862
                paddle.is_grad_enabled(),
863
864
                self.parallel_mode,
                self.tensor_parallel,
865
                self.sequence_parallel,
866
867
                self.tp_group,
                self.tp_size,
Shijie's avatar
Shijie committed
868
                self.fuse_wgrad_accumulation,
869
                is_first_microbatch,
870
                self.gather_output,
871
872
            )

873
874
875
        if not self.gemm_bias_fused_add:
            out = out + cast_if_needed_inplace(self.bias, self.activation_dtype)

876
877
878
879
880
        return out

    def _pd_forward(
        self,
        inp: paddle.Tensor,
881
        is_first_microbatch: Optional[bool] = None,
882
883
    ) -> paddle.Tensor:
        """Calls Paddle OP"""
884
885
        if is_first_microbatch is not None:
            warnings.warn(
886
887
888
                "`is_first_microbatch` is not supported for paddle backend and is ignored."
            )
        if self.parallel_mode == "column" and self.tensor_parallel:
889
890
            inp = identity(inp, self.tp_group)
        out = F.linear(inp, self.weight, self.bias if self.gemm_bias_fused_add else None)
891
        if self.parallel_mode == "row" and self.tensor_parallel:
Tian Zheng's avatar
Tian Zheng committed
892
            out, _ = allreduce(out, self.tp_group)
893
894
            out = out + self.bias if self.bias is not None else out
        return out
895
896

    def forward(self, *args, **kwargs):
897
898
899
900
901
        """
        Apply the linear transformation to the input.

        Parameters
        ----------
902
        inp : paddle.Tensor
903
             Input tensor.
904
905
906
907
908
909
910
911
912
913
        is_first_microbatch : {True, False, None}, default = None
                             During training using either gradient accumulation or
                             pipeline parallelism a minibatch of data is further split
                             into microbatches. Between the microbatches of the same minibatch
                             the model weights are not updated. Setting this parameter indicates
                             whether the current microbatch is the first in a minibatch or not.
                             When set, this parameter enables additional optimizations:

                             * during FP8 training, it allows caching of the FP8 versions of
                               the weights
914
        """
915
        if self.backend == "transformer_engine":
916
            return self._te_forward(*args, **kwargs)
917
        if self.backend == "paddle":
918
919
            return self._pd_forward(*args, **kwargs)
        raise AttributeError(f"Backend {self.backend} is not supported.")