linear.py 27.7 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.
"""Linear API"""

6
import warnings
7
from typing import Union, Tuple, Dict, Any, Optional
8
9
10
11
12

import paddle
import paddle.nn.functional as F
from paddle.nn.initializer import Constant

13
14
15
16
17
18
19
from .base import (
    TransformerEngineBaseLayer,
    get_workspace,
    _2X_ACC_FPROP,
    _2X_ACC_DGRAD,
    _2X_ACC_WGRAD,
)
20

21
from ..constants import FP8FwdTensors, FP8BwdTensors, GemmParallelModes, dist_group_type
22
from ..cpp_extensions import gemm, fp8_gemm, cast_to_fp8, cast_transpose, transpose
23
from ..distributed import (
24
    allgather,
25
26
27
    allreduce,
    get_tp_group_and_world_size,
    identity,
28
    reduce_scatter,
29
30
31
    track_rng_state,
    set_tensor_dist_attr,
    set_weight_tensor_dist_attr,
32
    mark_as_sequence_parallel_parameter,
33
34
)
from ..fp8 import get_fp8_te_dtype
35
from ..utils import (
36
    assert_dim_for_fp8_forward_exec,
37
38
    cast_if_needed,
    cast_if_needed_inplace,
39
    divide,
40
    get_bias_dtype,
Tian Zheng's avatar
Tian Zheng committed
41
42
    save_for_backward_allow_none,
    saved_tensor_allow_none,
43
    clear_tensor_data,
44
45
)

46
__all__ = ["Linear"]
47
48
49
50
51
52


def _linear_fwd_fp8(
    inputmat: paddle.Tensor,
    inputmat_fp8_index: FP8FwdTensors,
    weight: paddle.Tensor,
53
54
    weight_fp8: Optional[paddle.Tensor],
    weight_t_fp8: Optional[paddle.Tensor],
55
56
57
58
59
    weight_fp8_index: FP8FwdTensors,
    bias: paddle.Tensor,
    use_bias: bool,
    fp8_meta: Dict[str, Any],
    activation_dtype: paddle.dtype,
60
61
    parallel_mode: Union[str, None],
    tensor_parallel: bool,
62
    sequence_parallel: bool,
63
    tp_group: Union[dist_group_type, None],
64
    is_grad_enabled: bool,
65
    is_first_microbatch: bool = None,
66
67
68
69
):
    """FP8 path of Linear Fwd"""
    fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
    bias_dtype = get_bias_dtype(activation_dtype)
70
    bias = cast_if_needed(bias, bias_dtype)
71

72
73
74
75
76
    if parallel_mode == "column" and sequence_parallel:
        inputmat_total, _ = allgather(inputmat, tp_group)
    else:
        inputmat_total = inputmat

77
    update_fp8_weights = is_first_microbatch is None or is_first_microbatch
78
    if is_grad_enabled:
79
80
81
82
83
84
85
86
87
        if update_fp8_weights:
            weight_fp8, weight_t_fp8 = cast_transpose(
                weight,
                fp8_meta["scaling_fwd"],
                weight_fp8_index,
                fp8_dtype_forward,
                cast_out=weight_fp8,
                transpose_out=weight_t_fp8,
            )
88
89
    else:
        weight_t_fp8 = None
90
91
92
93
94
95
96
97
        if update_fp8_weights:
            weight_fp8 = cast_to_fp8(
                weight,
                fp8_meta["scaling_fwd"],
                weight_fp8_index,
                fp8_dtype_forward,
                out=weight_fp8,
            )
98
99
100
101
102
103

    out = fp8_gemm(
        weight_fp8,
        fp8_meta["scaling_fwd"].scale_inv,
        weight_fp8_index,
        fp8_dtype_forward,
104
        inputmat_total,
105
106
107
108
109
110
111
112
113
114
        fp8_meta["scaling_fwd"].scale_inv,
        inputmat_fp8_index,
        fp8_dtype_forward,
        activation_dtype,
        get_workspace(),
        bias=bias,
        use_bias=use_bias,
        use_split_accumulator=_2X_ACC_FPROP,
    )

115
116
117
    if parallel_mode == "row" and sequence_parallel:
        out, _ = reduce_scatter(out, tp_group)
    elif parallel_mode == "row" and tensor_parallel:
Tian Zheng's avatar
Tian Zheng committed
118
        out, _ = allreduce(out, tp_group)
119

120
121
122
123
124
125
126
127
128
129
130
131
132
    return out, weight_t_fp8


def _linear_fwd_non_fp8(
    inputmat: paddle.Tensor,
    inputmat_fp8_index: FP8FwdTensors,
    weight: paddle.Tensor,
    weight_fp8_index: FP8FwdTensors,
    bias: paddle.Tensor,
    use_bias: bool,
    fp8_calibration: bool,
    fp8_meta: Dict[str, Any],
    activation_dtype: paddle.dtype,
133
134
    parallel_mode: Union[str, None],
    tensor_parallel: bool,
135
    sequence_parallel: bool,
136
    tp_group: Union[dist_group_type, None],
137
138
139
140
    activation: str = "",
):
    """Non-FP8 path of Linear Fwd"""

141
142
143
144
145
    if parallel_mode == "column" and sequence_parallel:
        inputmat_total, _ = allgather(inputmat, tp_group)
    else:
        inputmat_total = inputmat

146
147
148
149
150
151
152
153
154
155
    # Layer parameters are initialized as float32 dtype by default.
    # Cast the parameters to activation_dtype if the current dtype
    # does not match activation_dtype. The casting is inplace, so it
    # only needs to performed once throughout the traing process.
    weight = cast_if_needed_inplace(weight, activation_dtype)
    bias = cast_if_needed_inplace(bias, activation_dtype)

    if fp8_calibration:
        # amax of input
        fp8_meta["scaling_fwd"].amax_history[0, inputmat_fp8_index.value] = \
156
            paddle.max(paddle.abs(inputmat_total)).item()
157
158
159
        # amax of weight
        fp8_meta["scaling_fwd"].amax_history[0, weight_fp8_index.value] = \
            paddle.max(paddle.abs(weight)).item()
160
        fp8_meta["update_amax_and_scale_fwd"] = True
161
162

    outputs = gemm(weight,
163
                   inputmat_total,
164
165
166
167
168
169
170
171
172
173
174
                   activation_dtype,
                   get_workspace(),
                   bias=bias,
                   use_bias=use_bias,
                   gelu=(activation == 'gelu'))

    if activation == 'gelu':
        gelu_out, _, out = outputs
        return out, gelu_out

    out, _, _ = outputs
175
176
177
178

    if parallel_mode == "row" and sequence_parallel:
        out, _ = reduce_scatter(out, tp_group)
    elif parallel_mode == "row" and tensor_parallel:
Tian Zheng's avatar
Tian Zheng committed
179
        out, _ = allreduce(out, tp_group)
180
181
182
183
184
185
186
    return out


def _linear_fwd(
    inputmat: paddle.Tensor,
    inputmat_fp8_index: FP8FwdTensors,
    weight: paddle.Tensor,
187
188
    weight_fp8: Optional[paddle.Tensor],
    weight_t_fp8: Optional[paddle.Tensor],
189
190
191
192
193
194
195
    weight_fp8_index: FP8FwdTensors,
    bias: paddle.Tensor,
    use_bias: bool,
    fp8_enabled: bool,
    fp8_calibration: bool,
    fp8_meta: Dict[str, Any],
    activation_dtype: paddle.dtype,
196
197
    parallel_mode: Union[str, None],
    tensor_parallel: bool,
198
    sequence_parallel: bool,
199
    tp_group: Union[dist_group_type, None],
200
    is_grad_enabled: bool,
201
    is_first_microbatch: bool = None,
202
203
204
205
206
207
):
    if fp8_enabled:
        out, weight_t_fp8 = _linear_fwd_fp8(
            inputmat,
            inputmat_fp8_index,
            weight,
208
209
            weight_fp8,
            weight_t_fp8,
210
211
212
213
214
            weight_fp8_index,
            bias,
            use_bias,
            fp8_meta,
            activation_dtype,
215
216
            parallel_mode,
            tensor_parallel,
217
            sequence_parallel,
218
            tp_group,
219
            is_grad_enabled,
220
            is_first_microbatch,
221
222
223
224
225
226
227
228
229
230
231
232
        )
    else:
        out = _linear_fwd_non_fp8(
            inputmat,
            inputmat_fp8_index,
            weight,
            weight_fp8_index,
            bias,
            use_bias,
            fp8_calibration,
            fp8_meta,
            activation_dtype,
233
234
            parallel_mode,
            tensor_parallel,
235
            sequence_parallel,
236
            tp_group,
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
        )
    return (
        out,
        weight_t_fp8 if fp8_enabled else None,
    )


def _linear_bwd_fp8(
    inputmat: paddle.Tensor,
    inputmat_t: paddle.Tensor,
    inputmat_fp8_index: FP8FwdTensors,
    weight_t_fp8: paddle.Tensor,
    weight_fp8_index: FP8FwdTensors,
    grad_output: paddle.Tensor,
    grad_output_c: paddle.Tensor,
    grad_output_t: paddle.Tensor,
    grad_output_fp8_index: FP8BwdTensors,
    fwd_scale_inverses: paddle.Tensor,
    fp8_meta: Dict[str, Any],
    requires_dgrad: bool,
    requires_wgrad: bool,
    activation_dtype: paddle.dtype,
259
260
    parallel_mode: Union[str, None],
    tensor_parallel: bool,
261
    sequence_parallel: bool,
262
    tp_group: Union[dist_group_type, None],
263
):
Tian Zheng's avatar
Tian Zheng committed
264
    dgrad, wgrad, handle = None, None, None
265
266
267
268
269
270
271
272
273
274

    # Overlap input AG with dgrad
    inputmat_total = None
    inputmat_t_total = None
    if requires_wgrad and parallel_mode == "column" and sequence_parallel:
        inputmat_total, handle = allgather(inputmat, tp_group, sync_op=not requires_dgrad)
    else:
        inputmat_total = inputmat
        inputmat_t_total = inputmat_t

275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
    fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
    fp8_dtype_backward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=False)
    if requires_dgrad:
        dgrad = fp8_gemm(
            weight_t_fp8,
            fwd_scale_inverses,
            weight_fp8_index,
            fp8_dtype_forward,
            grad_output_c,
            fp8_meta["scaling_bwd"].scale_inv,
            grad_output_fp8_index,
            fp8_dtype_backward,
            activation_dtype,
            get_workspace(),
            use_split_accumulator=_2X_ACC_DGRAD,
        )
291
        clear_tensor_data(grad_output_c)
292
293
294
295
296
297
298

        # Overlap dgrad-RS/AR with wgrad
        if parallel_mode == "column" and sequence_parallel:
            if handle is not None:
                handle.wait()
            dgrad, handle = reduce_scatter(dgrad, tp_group, sync_op=False)
        elif parallel_mode == "column" and tensor_parallel:
Tian Zheng's avatar
Tian Zheng committed
299
            dgrad, handle = allreduce(dgrad, tp_group, sync_op=False)
300

301
302
    if requires_wgrad:
        if not fp8_meta["recipe"].override_linear_precision.wgrad:
303
304
            if inputmat_t_total is None:
                inputmat_t_total = transpose(inputmat_total, fp8_dtype_backward)
305
                clear_tensor_data(inputmat_total)
306
            wgrad = fp8_gemm(
307
                inputmat_t_total,
308
309
310
311
312
313
314
315
316
317
318
                fwd_scale_inverses,
                inputmat_fp8_index,
                fp8_dtype_forward,
                grad_output_t,
                fp8_meta["scaling_bwd"].scale_inv,
                grad_output_fp8_index,
                fp8_dtype_backward,
                activation_dtype,
                get_workspace(),
                use_split_accumulator=_2X_ACC_WGRAD,
            )
319
            clear_tensor_data(inputmat_t_total, grad_output_t)
320
321
        else:
            wgrad, _, _ = gemm(
322
                inputmat_total,
323
324
325
326
327
328
                grad_output,
                activation_dtype,
                get_workspace(),
                layout="NT",
                grad=True,
            )
329
            clear_tensor_data(inputmat_total)
Tian Zheng's avatar
Tian Zheng committed
330
331
332
333

    if parallel_mode == "column" and tensor_parallel and handle is not None:
        handle.wait()

334
335
336
337
338
339
340
341
342
    return dgrad, wgrad


def _linear_bwd_non_fp8(
    inputmat: paddle.Tensor,
    weight: paddle.Tensor,
    grad_output: paddle.Tensor,
    requires_bgrad: bool,
    requires_dgrad: bool,
Tian Zheng's avatar
Tian Zheng committed
343
    requires_wgrad: bool,
344
    activation_dtype: paddle.dtype,
345
346
    parallel_mode: Union[str, None],
    tensor_parallel: bool,
347
    sequence_parallel: bool,
348
    tp_group: Union[dist_group_type, None],
349
350
351
352
353
354
    gelu_input: Union[paddle.Tensor, None] = None,
    activation: str = "",
):
    """
    Performs Linear Backward. Optionally, fuses GELU backward and dbias.
    """
Tian Zheng's avatar
Tian Zheng committed
355
    dgrad, wgrad, bgrad, handle = None, None, None, None
356
357
358
359
360
361
362
363

    # Overlap input AG with dgrad
    inputmat_total = None
    if requires_wgrad and parallel_mode == "column" and sequence_parallel:
        inputmat_total, handle = allgather(inputmat, tp_group, sync_op=not requires_dgrad)
    else:
        inputmat_total = inputmat

364
365
366
367
368
369
370
371
372
373
374
    if requires_dgrad:
        dgrad, _, _ = gemm(
            weight,
            grad_output,
            activation_dtype,
            get_workspace(),
            layout="NN",
            gelu=(activation == 'gelu'),
            gelu_input=gelu_input,
            grad=True,
        )
375
376
377
378
379
380
        # Overlap dgrad-RS/AR with wgrad
        if parallel_mode == "column" and sequence_parallel:
            if handle is not None:
                handle.wait()
            dgrad, handle = reduce_scatter(dgrad, tp_group, sync_op=False)
        elif parallel_mode == "column" and tensor_parallel:
Tian Zheng's avatar
Tian Zheng committed
381
            dgrad, handle = allreduce(dgrad, tp_group, sync_op=False)
382

383
384
    if requires_wgrad:
        wgrad, bgrad, _ = gemm(
385
            inputmat_total,
386
387
388
389
390
391
392
393
394
395
            grad_output,
            activation_dtype,
            get_workspace(),
            layout="NT",
            grad=True,
            use_bias=requires_bgrad,
        )
    elif requires_bgrad:
        bgrad = grad_output.sum(axis=0)

Tian Zheng's avatar
Tian Zheng committed
396
397
398
    if parallel_mode == "column" and tensor_parallel and handle is not None:
        handle.wait()

399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
    return dgrad, wgrad, bgrad


def _linear_bwd(
    inputmat: paddle.Tensor,
    inputmat_t: paddle.Tensor,
    inputmat_fp8_index: FP8FwdTensors,
    weight: paddle.Tensor,
    weight_t_fp8: paddle.Tensor,
    weight_fp8_index: FP8FwdTensors,
    grad_output: paddle.Tensor,
    grad_output_c: paddle.Tensor,
    grad_output_t: paddle.Tensor,
    grad_output_fp8_index: FP8BwdTensors,
    fwd_scale_inverses: paddle.Tensor,
    requires_bgrad: bool,
    fp8_enabled: bool,
    fp8_meta: Dict[str, Any],
    requires_dgrad: bool,
Tian Zheng's avatar
Tian Zheng committed
418
    requires_wgrad: bool,
419
    activation_dtype: paddle.dtype,
420
421
    parallel_mode: Union[str, None],
    tensor_parallel: bool,
422
    sequence_parallel: bool,
423
    tp_group: Union[dist_group_type, None],
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
):
    dgrad, wgrad, bgrad = None, None, None
    if fp8_enabled:
        dgrad, wgrad = _linear_bwd_fp8(
            inputmat,
            inputmat_t,
            inputmat_fp8_index,
            weight_t_fp8,
            weight_fp8_index,
            grad_output,
            grad_output_c,
            grad_output_t,
            grad_output_fp8_index,
            fwd_scale_inverses,
            fp8_meta,
            requires_dgrad,
            requires_wgrad,
            activation_dtype,
442
443
            parallel_mode,
            tensor_parallel,
444
            sequence_parallel,
445
            tp_group,
446
447
448
449
450
451
452
453
        )
    else:
        dgrad, wgrad, bgrad = _linear_bwd_non_fp8(
            inputmat,
            weight,
            grad_output,
            requires_bgrad,
            requires_dgrad,
Tian Zheng's avatar
Tian Zheng committed
454
            requires_wgrad,
455
            activation_dtype,
456
457
            parallel_mode,
            tensor_parallel,
458
            sequence_parallel,
459
            tp_group,
460
461
        )
    return dgrad, wgrad, bgrad
462
463
464


class _Linear(paddle.autograd.PyLayer):
465
    """TE implementation of Linear"""
466
467
468
469
470

    @staticmethod
    def forward(
        ctx,
        weight: paddle.Tensor,
471
472
        weight_fp8: Optional[paddle.Tensor],
        weight_t_fp8: Optional[paddle.Tensor],
473
474
475
        inp: paddle.Tensor,
        bias: paddle.Tensor,
        use_bias: bool,
476
477
478
        fp8_enabled: bool,
        fp8_calibration: bool,
        fp8_meta: Dict[str, Any],
479
        activation_dtype: paddle.dtype,
480
        is_grad_enabled: bool,
481
482
        parallel_mode: Union[str, None],
        tensor_parallel: bool,
483
        sequence_parallel: bool,
484
485
        tp_group: Union[dist_group_type, None],
        tp_size: int,
486
        is_first_microbatch: bool,
487
488
489
490
491
    ) -> paddle.Tensor:
        # Make sure input dimensions are compatible
        in_features = weight.shape[-1]
        assert inp.shape[-1] == in_features, "GEMM not possible"
        inputmat = inp.reshape((-1, in_features))
492
493
494
        if fp8_enabled:
            assert_dim_for_fp8_forward_exec(inputmat)
            assert_dim_for_fp8_forward_exec(weight)
495

496
497
498
        inputmat_no_fp8 = inputmat

        # FP8 casting
499
        inputmat_t = None
500
501
        if fp8_enabled:
            fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
502
503
504
505
506
507
508
509
            if (not fp8_meta["recipe"].override_linear_precision.wgrad and is_grad_enabled
                    and not sequence_parallel):
                inputmat, inputmat_t = cast_transpose(
                    inputmat,
                    fp8_meta["scaling_fwd"],
                    FP8FwdTensors.GEMM1_INPUT,
                    fp8_dtype_forward,
                )
510
            else:
511
                inputmat = cast_to_fp8(
512
513
514
515
                    inputmat,
                    fp8_meta["scaling_fwd"],
                    FP8FwdTensors.GEMM1_INPUT,
                    fp8_dtype_forward,
516
                )
517

518
519
        # GEMM Fwd
        out, weight_t_fp8 = _linear_fwd(
520
            inputmat,
521
            FP8FwdTensors.GEMM1_INPUT,
522
            weight,
523
524
            weight_fp8,
            weight_t_fp8,
525
526
527
528
529
530
531
            FP8FwdTensors.GEMM1_WEIGHT,
            bias,
            use_bias,
            fp8_enabled,
            fp8_calibration,
            fp8_meta,
            activation_dtype,
532
533
            parallel_mode,
            tensor_parallel,
534
            sequence_parallel,
535
            tp_group,
536
            is_grad_enabled,
537
            is_first_microbatch,
538
        )
539
540

        if is_grad_enabled:
541
542
543
544
545
            saved_inputmat = None
            if fp8_enabled and sequence_parallel:
                saved_inputmat = inputmat
            else:
                saved_inputmat = inputmat_no_fp8
Tian Zheng's avatar
Tian Zheng committed
546
547
            save_for_backward_allow_none(
                ctx,
548
549
                saved_inputmat,
                inputmat_t,
550
551
552
553
554
555
556
557
558
                weight,
                weight_t_fp8 if fp8_enabled else None,
                fp8_meta["scaling_fwd"].scale_inv.clone() if fp8_enabled else None,
            )
            ctx.activation_dtype = activation_dtype
            ctx.fp8_enabled = fp8_enabled
            ctx.fp8_meta = fp8_meta
            ctx.use_bias = use_bias
            ctx.inp_shape = inp.shape
559
560
            ctx.parallel_mode = parallel_mode
            ctx.tensor_parallel = tensor_parallel
561
            ctx.sequence_parallel = sequence_parallel
562
563
            ctx.tp_group = tp_group
            ctx.tp_size = tp_size
564
            ctx.requires_dgrad = not inp.stop_gradient
Tian Zheng's avatar
Tian Zheng committed
565
            ctx.requires_wgrad = not weight.stop_gradient
566
            ctx.requires_bgrad = use_bias and not bias.stop_gradient
567
            ctx.is_first_microbatch = is_first_microbatch
568
569
570
571
572

        return out.reshape((-1, *inp.shape[1:-1], out.shape[-1]))

    @staticmethod
    def backward(ctx, grad_output: paddle.Tensor) -> Tuple[Union[paddle.Tensor, None], ...]:
573
574
        with TransformerEngineBaseLayer.prepare_backward(ctx.fp8_enabled,
                                                         ctx.fp8_meta,
575
576
                                                         ctx.tp_group,
                                                         ctx.tp_size,
577
                                                         name="_Linear"):
Tian Zheng's avatar
Tian Zheng committed
578
579

            (    # pylint: disable=unbalanced-tuple-unpacking
580
581
                inputmat,
                inputmat_t,
582
                weight,
583
584
                weight_t_fp8,
                fwd_scale_inverses,
Tian Zheng's avatar
Tian Zheng committed
585
            ) = saved_tensor_allow_none(ctx)
586
587

            (
588
                grad_output,
589
590
591
                grad_output_c,
                grad_output_t,
                bgrad,
592
593
            ) = TransformerEngineBaseLayer.grad_output_preprocess(ctx, grad_output,
                                                                  ctx.parallel_mode == "row")
594

595
            dgrad, wgrad, bgrad_ = _linear_bwd(
596
                inputmat,
597
598
599
600
601
                inputmat_t,
                FP8FwdTensors.GEMM1_INPUT,
                weight,
                weight_t_fp8,
                FP8FwdTensors.GEMM1_WEIGHT,
602
                grad_output,
603
604
605
606
607
608
609
610
                grad_output_c,
                grad_output_t,
                FP8BwdTensors.GRAD_OUTPUT1,
                fwd_scale_inverses,
                ctx.requires_bgrad,
                ctx.fp8_enabled,
                ctx.fp8_meta,
                ctx.requires_dgrad,
Tian Zheng's avatar
Tian Zheng committed
611
                ctx.requires_wgrad,
612
                ctx.activation_dtype,
613
614
                ctx.parallel_mode,
                ctx.tensor_parallel,
615
                ctx.sequence_parallel,
616
                ctx.tp_group,
617
618
            )

619
620
621
622
            if not ctx.fp8_enabled:
                # bgrad is fused with gemm for non-FP8 path
                bgrad = bgrad_

623
624
625
626
627
628
            if not ctx.fp8_enabled or ctx.is_first_microbatch is None:
                weight_cache_grad = ()
            else:
                # weight_fp8 and weight_t_fp8 are stop_gradient tensors
                weight_cache_grad = (None, None)

629
630
            if not ctx.use_bias:
                return (
Tian Zheng's avatar
Tian Zheng committed
631
                    wgrad if ctx.requires_wgrad else None,
632
                    *weight_cache_grad,
633
634
635
                    dgrad.reshape(ctx.inp_shape) if ctx.requires_dgrad else None,
                )

636
            return (
Tian Zheng's avatar
Tian Zheng committed
637
                wgrad if ctx.requires_wgrad else None,
638
                *weight_cache_grad,
639
                dgrad.reshape(ctx.inp_shape) if ctx.requires_dgrad else None,
640
                bgrad if ctx.requires_bgrad else None,
641
642
643
644
645
646
            )


class Linear(TransformerEngineBaseLayer):
    """
    Applies a linear transformation to the incoming data :math:`y = xA^T + b`
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668

    Parameters
    ----------
    in_features : int
                 size of each input sample.
    out_features : int
                  size of each output sample.
    weight_attr: Union[paddle.ParamAttr, None], default = None
                optional `paddle.ParamAttr` for weight.
    bias_attr: Union[paddle.ParamAttr, None, bool], default = None
              optional `paddle.ParamAttr` for bias.
    backend: {'transformer_engine', 'paddle'}, default = 'transformer_engine'
             if set to 'paddle', a framework only no-FP8 path is executed with limited optimization.

    Parallelism parameters
    ----------------------
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
    parallel_mode : {None, 'Column', 'Row'}, default = `None`
                   used to decide whether this Linear layer is Column Parallel Linear or Row
                   Parallel Linear as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
                   When set to `None`, no communication is performed.
669
670
    sequence_parallel : bool, default = `False`
                       if set to `True`, uses sequence parallelism.
671
672
673
674
675
676
677
678
    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        weight_attr: Union[paddle.ParamAttr, None] = None,
        bias_attr: Union[paddle.ParamAttr, None, bool] = None,
679
        parallel_mode: Optional[str] = None,
680
        sequence_parallel: bool = False,
681
        tp_group: Union[dist_group_type, None] = None,
682
683
684
685
686
687
688
689
690
691
        backend: str = 'transformer_engine',
    ) -> None:
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.backend = backend
        self._weight_attr = weight_attr
        self._bias_attr = bias_attr
        self._dtype = self._helper.get_default_dtype()

692
693
694
695
696
697
698
699
700
701
702
703
704
705
        # Set parallel configs
        self.tp_group, self.tp_size = get_tp_group_and_world_size(tp_group,
                                                                  enable_tp=parallel_mode
                                                                  is not None)
        self.tensor_parallel = self.tp_size > 1
        self.parallel_mode = parallel_mode
        assert (self.parallel_mode
                in GemmParallelModes), f"parallel_mode {parallel_mode} not supported"

        if self.parallel_mode == "column":
            self.out_features = divide(self.out_features, self.tp_size)
        elif self.parallel_mode == "row":
            self.in_features = divide(self.in_features, self.tp_size)

706
707
        self.sequence_parallel = self.tensor_parallel and sequence_parallel

708
709
710
711
712
713
714
715
716
717
718
719
        # Initialize weight parameter
        with track_rng_state(enable=self.tensor_parallel):
            # TE linear weight is in column major
            self.weight = self.create_parameter(
                shape=[self.out_features, self.in_features]
                if self.backend == 'transformer_engine' else [self.in_features, self.out_features],
                attr=self._weight_attr,
                dtype=self._dtype,
                is_bias=False,
            )
        set_weight_tensor_dist_attr(self.weight, self.tensor_parallel, self.parallel_mode,
                                    self.backend)
720

721
        # Initialize bias parameter
722
        self.has_bias = self._bias_attr is not False
723
        use_default_bias = self._bias_attr is None or self._bias_attr is True
724
725
        if self.has_bias:
            self.bias = self.create_parameter(
726
                shape=[self.out_features],
727
                attr=self._bias_attr if not use_default_bias else paddle.ParamAttr(
728
729
730
731
                    initializer=Constant(value=0.0)),
                dtype=self._dtype,
                is_bias=True,
            )
732
733
            if parallel_mode == "column":
                set_tensor_dist_attr(self.bias, self.tensor_parallel, axis=0)
734
735
            if parallel_mode == "row" and self.sequence_parallel:
                mark_as_sequence_parallel_parameter(self.bias)
736
737
738
        else:
            self.bias = None

739
740
        self.fp8_weight_shapes.append(self.weight.shape)

741
742
743
744
745
746
747
        # For RPL, bias has to be added after TP collectives
        # So it cannot be fused with the GEMM
        if self.parallel_mode == "row" and self.tensor_parallel and self.has_bias:
            self.gemm_bias_fused_add = False
        else:
            self.gemm_bias_fused_add = True

748
749
750
    def _te_forward(
        self,
        inp: paddle.Tensor,
751
        is_first_microbatch: Optional[bool] = None,
752
753
754
755
    ) -> paddle.Tensor:
        """
        Apply the linear transformation to the input.
        """
756
        with self.prepare_forward(inp, is_first_microbatch=is_first_microbatch) as inp:
757
758
759
760
            # Layer input should be casted outside PyLayer, as performing
            # inplace cast to input tensors may cause problems when used
            # together with Paddle native layers.
            inp = cast_if_needed(inp, self.activation_dtype)
761
762
763
764

            # Get persistent fp8 weight buffer. None if buffer does not exist.
            weight_fp8, weight_t_fp8 = self.get_fp8_weights_scratchpad(is_first_microbatch)

765
            out = _Linear.apply(
766
                self.weight,
767
768
                weight_fp8,
                weight_t_fp8,
769
                inp,
770
771
                self.bias if self.gemm_bias_fused_add else None,
                self.has_bias and self.gemm_bias_fused_add,
772
773
774
                self.fp8_enabled,
                self.fp8_calibration,
                self.fp8_meta,
775
                self.activation_dtype,
776
                paddle.is_grad_enabled(),
777
778
                self.parallel_mode,
                self.tensor_parallel,
779
                self.sequence_parallel,
780
781
                self.tp_group,
                self.tp_size,
782
                is_first_microbatch,
783
784
            )

785
786
787
        if not self.gemm_bias_fused_add:
            out = out + cast_if_needed_inplace(self.bias, self.activation_dtype)

788
789
790
791
792
        return out

    def _pd_forward(
        self,
        inp: paddle.Tensor,
793
        is_first_microbatch: Optional[bool] = None,
794
795
    ) -> paddle.Tensor:
        """Calls Paddle OP"""
796
797
798
        if is_first_microbatch is not None:
            warnings.warn(
                "`is_first_microbatch` is not supported for paddle backend and is ignored.")
799
800
801
802
        if self.parallel_mode == 'column' and self.tensor_parallel:
            inp = identity(inp, self.tp_group)
        out = F.linear(inp, self.weight, self.bias if self.gemm_bias_fused_add else None)
        if self.parallel_mode == 'row' and self.tensor_parallel:
Tian Zheng's avatar
Tian Zheng committed
803
            out, _ = allreduce(out, self.tp_group)
804
805
            out = out + self.bias if self.bias is not None else out
        return out
806
807

    def forward(self, *args, **kwargs):
808
809
810
811
812
        """
        Apply the linear transformation to the input.

        Parameters
        ----------
813
        inp : paddle.Tensor
814
             Input tensor.
815
816
817
818
819
820
821
822
823
824
        is_first_microbatch : {True, False, None}, default = None
                             During training using either gradient accumulation or
                             pipeline parallelism a minibatch of data is further split
                             into microbatches. Between the microbatches of the same minibatch
                             the model weights are not updated. Setting this parameter indicates
                             whether the current microbatch is the first in a minibatch or not.
                             When set, this parameter enables additional optimizations:

                             * during FP8 training, it allows caching of the FP8 versions of
                               the weights
825
        """
826
827
828
829
830
        if self.backend == 'transformer_engine':
            return self._te_forward(*args, **kwargs)
        if self.backend == 'paddle':
            return self._pd_forward(*args, **kwargs)
        raise AttributeError(f"Backend {self.backend} is not supported.")