linear.py 24.8 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.
"""Linear API"""

6
from typing import Union, Tuple, Dict, Any, Optional
7
8
9
10
11

import paddle
import paddle.nn.functional as F
from paddle.nn.initializer import Constant

12
13
14
15
16
17
18
from .base import (
    TransformerEngineBaseLayer,
    get_workspace,
    _2X_ACC_FPROP,
    _2X_ACC_DGRAD,
    _2X_ACC_WGRAD,
)
19

20
from ..constants import FP8FwdTensors, FP8BwdTensors, GemmParallelModes, dist_group_type
21
from ..cpp_extensions import gemm, fp8_gemm, cast_to_fp8, cast_transpose, transpose
22
from ..distributed import (
23
    allgather,
24
25
26
    allreduce,
    get_tp_group_and_world_size,
    identity,
27
    reduce_scatter,
28
29
30
    track_rng_state,
    set_tensor_dist_attr,
    set_weight_tensor_dist_attr,
31
    mark_as_sequence_parallel_parameter,
32
33
)
from ..fp8 import get_fp8_te_dtype
34
from ..utils import (
35
    assert_dim_for_fp8_forward_exec,
36
37
    cast_if_needed,
    cast_if_needed_inplace,
38
    divide,
39
    get_bias_dtype,
Tian Zheng's avatar
Tian Zheng committed
40
41
    save_for_backward_allow_none,
    saved_tensor_allow_none,
42
43
)

44
__all__ = ["Linear"]
45
46
47
48
49
50
51
52
53
54
55


def _linear_fwd_fp8(
    inputmat: paddle.Tensor,
    inputmat_fp8_index: FP8FwdTensors,
    weight: paddle.Tensor,
    weight_fp8_index: FP8FwdTensors,
    bias: paddle.Tensor,
    use_bias: bool,
    fp8_meta: Dict[str, Any],
    activation_dtype: paddle.dtype,
56
57
    parallel_mode: Union[str, None],
    tensor_parallel: bool,
58
    sequence_parallel: bool,
59
    tp_group: Union[dist_group_type, None],
60
61
62
63
64
    is_grad_enabled: bool,
):
    """FP8 path of Linear Fwd"""
    fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
    bias_dtype = get_bias_dtype(activation_dtype)
65
    bias = cast_if_needed(bias, bias_dtype)
66

67
68
69
70
71
    if parallel_mode == "column" and sequence_parallel:
        inputmat_total, _ = allgather(inputmat, tp_group)
    else:
        inputmat_total = inputmat

72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
    if is_grad_enabled:
        weight_fp8, weight_t_fp8 = cast_transpose(
            weight,
            fp8_meta["scaling_fwd"],
            weight_fp8_index,
            fp8_dtype_forward,
        )
    else:
        weight_t_fp8 = None
        weight_fp8 = cast_to_fp8(
            weight,
            fp8_meta["scaling_fwd"],
            weight_fp8_index,
            fp8_dtype_forward,
        )

    out = fp8_gemm(
        weight_fp8,
        fp8_meta["scaling_fwd"].scale_inv,
        weight_fp8_index,
        fp8_dtype_forward,
93
        inputmat_total,
94
95
96
97
98
99
100
101
102
103
        fp8_meta["scaling_fwd"].scale_inv,
        inputmat_fp8_index,
        fp8_dtype_forward,
        activation_dtype,
        get_workspace(),
        bias=bias,
        use_bias=use_bias,
        use_split_accumulator=_2X_ACC_FPROP,
    )

104
105
106
    if parallel_mode == "row" and sequence_parallel:
        out, _ = reduce_scatter(out, tp_group)
    elif parallel_mode == "row" and tensor_parallel:
Tian Zheng's avatar
Tian Zheng committed
107
        out, _ = allreduce(out, tp_group)
108

109
110
111
112
113
114
115
116
117
118
119
120
121
    return out, weight_t_fp8


def _linear_fwd_non_fp8(
    inputmat: paddle.Tensor,
    inputmat_fp8_index: FP8FwdTensors,
    weight: paddle.Tensor,
    weight_fp8_index: FP8FwdTensors,
    bias: paddle.Tensor,
    use_bias: bool,
    fp8_calibration: bool,
    fp8_meta: Dict[str, Any],
    activation_dtype: paddle.dtype,
122
123
    parallel_mode: Union[str, None],
    tensor_parallel: bool,
124
    sequence_parallel: bool,
125
    tp_group: Union[dist_group_type, None],
126
127
128
129
    activation: str = "",
):
    """Non-FP8 path of Linear Fwd"""

130
131
132
133
134
    if parallel_mode == "column" and sequence_parallel:
        inputmat_total, _ = allgather(inputmat, tp_group)
    else:
        inputmat_total = inputmat

135
136
137
138
139
140
141
142
143
144
    # Layer parameters are initialized as float32 dtype by default.
    # Cast the parameters to activation_dtype if the current dtype
    # does not match activation_dtype. The casting is inplace, so it
    # only needs to performed once throughout the traing process.
    weight = cast_if_needed_inplace(weight, activation_dtype)
    bias = cast_if_needed_inplace(bias, activation_dtype)

    if fp8_calibration:
        # amax of input
        fp8_meta["scaling_fwd"].amax_history[0, inputmat_fp8_index.value] = \
145
            paddle.max(paddle.abs(inputmat_total)).item()
146
147
148
149
150
        # amax of weight
        fp8_meta["scaling_fwd"].amax_history[0, weight_fp8_index.value] = \
            paddle.max(paddle.abs(weight)).item()

    outputs = gemm(weight,
151
                   inputmat_total,
152
153
154
155
156
157
158
159
160
161
162
                   activation_dtype,
                   get_workspace(),
                   bias=bias,
                   use_bias=use_bias,
                   gelu=(activation == 'gelu'))

    if activation == 'gelu':
        gelu_out, _, out = outputs
        return out, gelu_out

    out, _, _ = outputs
163
164
165
166

    if parallel_mode == "row" and sequence_parallel:
        out, _ = reduce_scatter(out, tp_group)
    elif parallel_mode == "row" and tensor_parallel:
Tian Zheng's avatar
Tian Zheng committed
167
        out, _ = allreduce(out, tp_group)
168
169
170
171
172
173
174
175
176
177
178
179
180
181
    return out


def _linear_fwd(
    inputmat: paddle.Tensor,
    inputmat_fp8_index: FP8FwdTensors,
    weight: paddle.Tensor,
    weight_fp8_index: FP8FwdTensors,
    bias: paddle.Tensor,
    use_bias: bool,
    fp8_enabled: bool,
    fp8_calibration: bool,
    fp8_meta: Dict[str, Any],
    activation_dtype: paddle.dtype,
182
183
    parallel_mode: Union[str, None],
    tensor_parallel: bool,
184
    sequence_parallel: bool,
185
    tp_group: Union[dist_group_type, None],
186
187
188
189
190
191
192
193
194
195
196
197
    is_grad_enabled: bool,
):
    if fp8_enabled:
        out, weight_t_fp8 = _linear_fwd_fp8(
            inputmat,
            inputmat_fp8_index,
            weight,
            weight_fp8_index,
            bias,
            use_bias,
            fp8_meta,
            activation_dtype,
198
199
            parallel_mode,
            tensor_parallel,
200
            sequence_parallel,
201
            tp_group,
202
203
204
205
206
207
208
209
210
211
212
213
214
            is_grad_enabled,
        )
    else:
        out = _linear_fwd_non_fp8(
            inputmat,
            inputmat_fp8_index,
            weight,
            weight_fp8_index,
            bias,
            use_bias,
            fp8_calibration,
            fp8_meta,
            activation_dtype,
215
216
            parallel_mode,
            tensor_parallel,
217
            sequence_parallel,
218
            tp_group,
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
        )
    return (
        out,
        weight_t_fp8 if fp8_enabled else None,
    )


def _linear_bwd_fp8(
    inputmat: paddle.Tensor,
    inputmat_t: paddle.Tensor,
    inputmat_fp8_index: FP8FwdTensors,
    weight_t_fp8: paddle.Tensor,
    weight_fp8_index: FP8FwdTensors,
    grad_output: paddle.Tensor,
    grad_output_c: paddle.Tensor,
    grad_output_t: paddle.Tensor,
    grad_output_fp8_index: FP8BwdTensors,
    fwd_scale_inverses: paddle.Tensor,
    fp8_meta: Dict[str, Any],
    requires_dgrad: bool,
    requires_wgrad: bool,
    activation_dtype: paddle.dtype,
241
242
    parallel_mode: Union[str, None],
    tensor_parallel: bool,
243
    sequence_parallel: bool,
244
    tp_group: Union[dist_group_type, None],
245
):
Tian Zheng's avatar
Tian Zheng committed
246
    dgrad, wgrad, handle = None, None, None
247
248
249
250
251
252
253
254
255
256

    # Overlap input AG with dgrad
    inputmat_total = None
    inputmat_t_total = None
    if requires_wgrad and parallel_mode == "column" and sequence_parallel:
        inputmat_total, handle = allgather(inputmat, tp_group, sync_op=not requires_dgrad)
    else:
        inputmat_total = inputmat
        inputmat_t_total = inputmat_t

257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
    fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
    fp8_dtype_backward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=False)
    if requires_dgrad:
        dgrad = fp8_gemm(
            weight_t_fp8,
            fwd_scale_inverses,
            weight_fp8_index,
            fp8_dtype_forward,
            grad_output_c,
            fp8_meta["scaling_bwd"].scale_inv,
            grad_output_fp8_index,
            fp8_dtype_backward,
            activation_dtype,
            get_workspace(),
            use_split_accumulator=_2X_ACC_DGRAD,
        )
273
274
275
276
277
278
279

        # Overlap dgrad-RS/AR with wgrad
        if parallel_mode == "column" and sequence_parallel:
            if handle is not None:
                handle.wait()
            dgrad, handle = reduce_scatter(dgrad, tp_group, sync_op=False)
        elif parallel_mode == "column" and tensor_parallel:
Tian Zheng's avatar
Tian Zheng committed
280
            dgrad, handle = allreduce(dgrad, tp_group, sync_op=False)
281

282
283
    if requires_wgrad:
        if not fp8_meta["recipe"].override_linear_precision.wgrad:
284
285
            if inputmat_t_total is None:
                inputmat_t_total = transpose(inputmat_total, fp8_dtype_backward)
286
            wgrad = fp8_gemm(
287
                inputmat_t_total,
288
289
290
291
292
293
294
295
296
297
298
299
300
                fwd_scale_inverses,
                inputmat_fp8_index,
                fp8_dtype_forward,
                grad_output_t,
                fp8_meta["scaling_bwd"].scale_inv,
                grad_output_fp8_index,
                fp8_dtype_backward,
                activation_dtype,
                get_workspace(),
                use_split_accumulator=_2X_ACC_WGRAD,
            )
        else:
            wgrad, _, _ = gemm(
301
                inputmat_total,
302
303
304
305
306
307
                grad_output,
                activation_dtype,
                get_workspace(),
                layout="NT",
                grad=True,
            )
Tian Zheng's avatar
Tian Zheng committed
308
309
310
311

    if parallel_mode == "column" and tensor_parallel and handle is not None:
        handle.wait()

312
313
314
315
316
317
318
319
320
    return dgrad, wgrad


def _linear_bwd_non_fp8(
    inputmat: paddle.Tensor,
    weight: paddle.Tensor,
    grad_output: paddle.Tensor,
    requires_bgrad: bool,
    requires_dgrad: bool,
Tian Zheng's avatar
Tian Zheng committed
321
    requires_wgrad: bool,
322
    activation_dtype: paddle.dtype,
323
324
    parallel_mode: Union[str, None],
    tensor_parallel: bool,
325
    sequence_parallel: bool,
326
    tp_group: Union[dist_group_type, None],
327
328
329
330
331
332
    gelu_input: Union[paddle.Tensor, None] = None,
    activation: str = "",
):
    """
    Performs Linear Backward. Optionally, fuses GELU backward and dbias.
    """
Tian Zheng's avatar
Tian Zheng committed
333
    dgrad, wgrad, bgrad, handle = None, None, None, None
334
335
336
337
338
339
340
341

    # Overlap input AG with dgrad
    inputmat_total = None
    if requires_wgrad and parallel_mode == "column" and sequence_parallel:
        inputmat_total, handle = allgather(inputmat, tp_group, sync_op=not requires_dgrad)
    else:
        inputmat_total = inputmat

342
343
344
345
346
347
348
349
350
351
352
    if requires_dgrad:
        dgrad, _, _ = gemm(
            weight,
            grad_output,
            activation_dtype,
            get_workspace(),
            layout="NN",
            gelu=(activation == 'gelu'),
            gelu_input=gelu_input,
            grad=True,
        )
353
354
355
356
357
358
        # Overlap dgrad-RS/AR with wgrad
        if parallel_mode == "column" and sequence_parallel:
            if handle is not None:
                handle.wait()
            dgrad, handle = reduce_scatter(dgrad, tp_group, sync_op=False)
        elif parallel_mode == "column" and tensor_parallel:
Tian Zheng's avatar
Tian Zheng committed
359
            dgrad, handle = allreduce(dgrad, tp_group, sync_op=False)
360

361
362
    if requires_wgrad:
        wgrad, bgrad, _ = gemm(
363
            inputmat_total,
364
365
366
367
368
369
370
371
372
373
            grad_output,
            activation_dtype,
            get_workspace(),
            layout="NT",
            grad=True,
            use_bias=requires_bgrad,
        )
    elif requires_bgrad:
        bgrad = grad_output.sum(axis=0)

Tian Zheng's avatar
Tian Zheng committed
374
375
376
    if parallel_mode == "column" and tensor_parallel and handle is not None:
        handle.wait()

377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
    return dgrad, wgrad, bgrad


def _linear_bwd(
    inputmat: paddle.Tensor,
    inputmat_t: paddle.Tensor,
    inputmat_fp8_index: FP8FwdTensors,
    weight: paddle.Tensor,
    weight_t_fp8: paddle.Tensor,
    weight_fp8_index: FP8FwdTensors,
    grad_output: paddle.Tensor,
    grad_output_c: paddle.Tensor,
    grad_output_t: paddle.Tensor,
    grad_output_fp8_index: FP8BwdTensors,
    fwd_scale_inverses: paddle.Tensor,
    requires_bgrad: bool,
    fp8_enabled: bool,
    fp8_meta: Dict[str, Any],
    requires_dgrad: bool,
Tian Zheng's avatar
Tian Zheng committed
396
    requires_wgrad: bool,
397
    activation_dtype: paddle.dtype,
398
399
    parallel_mode: Union[str, None],
    tensor_parallel: bool,
400
    sequence_parallel: bool,
401
    tp_group: Union[dist_group_type, None],
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
):
    dgrad, wgrad, bgrad = None, None, None
    if fp8_enabled:
        dgrad, wgrad = _linear_bwd_fp8(
            inputmat,
            inputmat_t,
            inputmat_fp8_index,
            weight_t_fp8,
            weight_fp8_index,
            grad_output,
            grad_output_c,
            grad_output_t,
            grad_output_fp8_index,
            fwd_scale_inverses,
            fp8_meta,
            requires_dgrad,
            requires_wgrad,
            activation_dtype,
420
421
            parallel_mode,
            tensor_parallel,
422
            sequence_parallel,
423
            tp_group,
424
425
426
427
428
429
430
431
        )
    else:
        dgrad, wgrad, bgrad = _linear_bwd_non_fp8(
            inputmat,
            weight,
            grad_output,
            requires_bgrad,
            requires_dgrad,
Tian Zheng's avatar
Tian Zheng committed
432
            requires_wgrad,
433
            activation_dtype,
434
435
            parallel_mode,
            tensor_parallel,
436
            sequence_parallel,
437
            tp_group,
438
439
        )
    return dgrad, wgrad, bgrad
440
441
442


class _Linear(paddle.autograd.PyLayer):
443
    """TE implementation of Linear"""
444
445
446
447
448
449
450
451

    @staticmethod
    def forward(
        ctx,
        weight: paddle.Tensor,
        inp: paddle.Tensor,
        bias: paddle.Tensor,
        use_bias: bool,
452
453
454
        fp8_enabled: bool,
        fp8_calibration: bool,
        fp8_meta: Dict[str, Any],
455
        activation_dtype: paddle.dtype,
456
        is_grad_enabled: bool,
457
458
        parallel_mode: Union[str, None],
        tensor_parallel: bool,
459
        sequence_parallel: bool,
460
461
        tp_group: Union[dist_group_type, None],
        tp_size: int,
462
463
464
465
466
    ) -> paddle.Tensor:
        # Make sure input dimensions are compatible
        in_features = weight.shape[-1]
        assert inp.shape[-1] == in_features, "GEMM not possible"
        inputmat = inp.reshape((-1, in_features))
467
468
469
        if fp8_enabled:
            assert_dim_for_fp8_forward_exec(inputmat)
            assert_dim_for_fp8_forward_exec(weight)
470

471
472
473
        inputmat_no_fp8 = inputmat

        # FP8 casting
474
        inputmat_t = None
475
476
        if fp8_enabled:
            fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
477
478
479
480
481
482
483
484
            if (not fp8_meta["recipe"].override_linear_precision.wgrad and is_grad_enabled
                    and not sequence_parallel):
                inputmat, inputmat_t = cast_transpose(
                    inputmat,
                    fp8_meta["scaling_fwd"],
                    FP8FwdTensors.GEMM1_INPUT,
                    fp8_dtype_forward,
                )
485
            else:
486
                inputmat = cast_to_fp8(
487
488
489
490
                    inputmat,
                    fp8_meta["scaling_fwd"],
                    FP8FwdTensors.GEMM1_INPUT,
                    fp8_dtype_forward,
491
                )
492

493
494
        # GEMM Fwd
        out, weight_t_fp8 = _linear_fwd(
495
            inputmat,
496
            FP8FwdTensors.GEMM1_INPUT,
497
            weight,
498
499
500
501
502
503
504
            FP8FwdTensors.GEMM1_WEIGHT,
            bias,
            use_bias,
            fp8_enabled,
            fp8_calibration,
            fp8_meta,
            activation_dtype,
505
506
            parallel_mode,
            tensor_parallel,
507
            sequence_parallel,
508
            tp_group,
509
            is_grad_enabled,
510
        )
511
512

        if is_grad_enabled:
513
514
515
516
517
            saved_inputmat = None
            if fp8_enabled and sequence_parallel:
                saved_inputmat = inputmat
            else:
                saved_inputmat = inputmat_no_fp8
Tian Zheng's avatar
Tian Zheng committed
518
519
            save_for_backward_allow_none(
                ctx,
520
521
                saved_inputmat,
                inputmat_t,
522
523
524
525
526
527
528
529
530
                weight,
                weight_t_fp8 if fp8_enabled else None,
                fp8_meta["scaling_fwd"].scale_inv.clone() if fp8_enabled else None,
            )
            ctx.activation_dtype = activation_dtype
            ctx.fp8_enabled = fp8_enabled
            ctx.fp8_meta = fp8_meta
            ctx.use_bias = use_bias
            ctx.inp_shape = inp.shape
531
532
            ctx.parallel_mode = parallel_mode
            ctx.tensor_parallel = tensor_parallel
533
            ctx.sequence_parallel = sequence_parallel
534
535
            ctx.tp_group = tp_group
            ctx.tp_size = tp_size
536
            ctx.requires_dgrad = not inp.stop_gradient
Tian Zheng's avatar
Tian Zheng committed
537
            ctx.requires_wgrad = not weight.stop_gradient
538
            ctx.requires_bgrad = use_bias and not bias.stop_gradient
539
540
541
542
543

        return out.reshape((-1, *inp.shape[1:-1], out.shape[-1]))

    @staticmethod
    def backward(ctx, grad_output: paddle.Tensor) -> Tuple[Union[paddle.Tensor, None], ...]:
544
545
        with TransformerEngineBaseLayer.prepare_backward(ctx.fp8_enabled,
                                                         ctx.fp8_meta,
546
547
                                                         ctx.tp_group,
                                                         ctx.tp_size,
548
                                                         name="_Linear"):
Tian Zheng's avatar
Tian Zheng committed
549
550

            (    # pylint: disable=unbalanced-tuple-unpacking
551
552
                inputmat,
                inputmat_t,
553
                weight,
554
555
                weight_t_fp8,
                fwd_scale_inverses,
Tian Zheng's avatar
Tian Zheng committed
556
            ) = saved_tensor_allow_none(ctx)
557
558

            (
559
                grad_output,
560
561
562
                grad_output_c,
                grad_output_t,
                bgrad,
563
564
            ) = TransformerEngineBaseLayer.grad_output_preprocess(ctx, grad_output,
                                                                  ctx.parallel_mode == "row")
565

566
            dgrad, wgrad, bgrad_ = _linear_bwd(
567
                inputmat,
568
569
570
571
572
                inputmat_t,
                FP8FwdTensors.GEMM1_INPUT,
                weight,
                weight_t_fp8,
                FP8FwdTensors.GEMM1_WEIGHT,
573
                grad_output,
574
575
576
577
578
579
580
581
                grad_output_c,
                grad_output_t,
                FP8BwdTensors.GRAD_OUTPUT1,
                fwd_scale_inverses,
                ctx.requires_bgrad,
                ctx.fp8_enabled,
                ctx.fp8_meta,
                ctx.requires_dgrad,
Tian Zheng's avatar
Tian Zheng committed
582
                ctx.requires_wgrad,
583
                ctx.activation_dtype,
584
585
                ctx.parallel_mode,
                ctx.tensor_parallel,
586
                ctx.sequence_parallel,
587
                ctx.tp_group,
588
589
            )

590
591
592
593
594
595
            if not ctx.fp8_enabled:
                # bgrad is fused with gemm for non-FP8 path
                bgrad = bgrad_

            if not ctx.use_bias:
                return (
Tian Zheng's avatar
Tian Zheng committed
596
                    wgrad if ctx.requires_wgrad else None,
597
598
599
                    dgrad.reshape(ctx.inp_shape) if ctx.requires_dgrad else None,
                )

600
            return (
Tian Zheng's avatar
Tian Zheng committed
601
                wgrad if ctx.requires_wgrad else None,
602
                dgrad.reshape(ctx.inp_shape) if ctx.requires_dgrad else None,
603
                bgrad if ctx.requires_bgrad else None,
604
605
606
607
608
609
            )


class Linear(TransformerEngineBaseLayer):
    """
    Applies a linear transformation to the incoming data :math:`y = xA^T + b`
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631

    Parameters
    ----------
    in_features : int
                 size of each input sample.
    out_features : int
                  size of each output sample.
    weight_attr: Union[paddle.ParamAttr, None], default = None
                optional `paddle.ParamAttr` for weight.
    bias_attr: Union[paddle.ParamAttr, None, bool], default = None
              optional `paddle.ParamAttr` for bias.
    backend: {'transformer_engine', 'paddle'}, default = 'transformer_engine'
             if set to 'paddle', a framework only no-FP8 path is executed with limited optimization.

    Parallelism parameters
    ----------------------
    tp_group : ProcessGroup, default = `None`
              tensor parallel process group.
    parallel_mode : {None, 'Column', 'Row'}, default = `None`
                   used to decide whether this Linear layer is Column Parallel Linear or Row
                   Parallel Linear as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
                   When set to `None`, no communication is performed.
632
633
    sequence_parallel : bool, default = `False`
                       if set to `True`, uses sequence parallelism.
634
635
636
637
638
639
640
641
    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        weight_attr: Union[paddle.ParamAttr, None] = None,
        bias_attr: Union[paddle.ParamAttr, None, bool] = None,
642
        parallel_mode: Optional[str] = None,
643
        sequence_parallel: bool = False,
644
        tp_group: Union[dist_group_type, None] = None,
645
646
647
648
649
650
651
652
653
654
        backend: str = 'transformer_engine',
    ) -> None:
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.backend = backend
        self._weight_attr = weight_attr
        self._bias_attr = bias_attr
        self._dtype = self._helper.get_default_dtype()

655
656
657
658
659
660
661
662
663
664
665
666
667
668
        # Set parallel configs
        self.tp_group, self.tp_size = get_tp_group_and_world_size(tp_group,
                                                                  enable_tp=parallel_mode
                                                                  is not None)
        self.tensor_parallel = self.tp_size > 1
        self.parallel_mode = parallel_mode
        assert (self.parallel_mode
                in GemmParallelModes), f"parallel_mode {parallel_mode} not supported"

        if self.parallel_mode == "column":
            self.out_features = divide(self.out_features, self.tp_size)
        elif self.parallel_mode == "row":
            self.in_features = divide(self.in_features, self.tp_size)

669
670
        self.sequence_parallel = self.tensor_parallel and sequence_parallel

671
672
673
674
675
676
677
678
679
680
681
682
        # Initialize weight parameter
        with track_rng_state(enable=self.tensor_parallel):
            # TE linear weight is in column major
            self.weight = self.create_parameter(
                shape=[self.out_features, self.in_features]
                if self.backend == 'transformer_engine' else [self.in_features, self.out_features],
                attr=self._weight_attr,
                dtype=self._dtype,
                is_bias=False,
            )
        set_weight_tensor_dist_attr(self.weight, self.tensor_parallel, self.parallel_mode,
                                    self.backend)
683

684
        # Initialize bias parameter
685
        self.has_bias = self._bias_attr is not False
686
        use_default_bias = self._bias_attr is None or self._bias_attr is True
687
688
        if self.has_bias:
            self.bias = self.create_parameter(
689
                shape=[self.out_features],
690
                attr=self._bias_attr if not use_default_bias else paddle.ParamAttr(
691
692
693
694
                    initializer=Constant(value=0.0)),
                dtype=self._dtype,
                is_bias=True,
            )
695
696
            if parallel_mode == "column":
                set_tensor_dist_attr(self.bias, self.tensor_parallel, axis=0)
697
698
            if parallel_mode == "row" and self.sequence_parallel:
                mark_as_sequence_parallel_parameter(self.bias)
699
700
701
        else:
            self.bias = None

702
703
704
705
706
707
708
        # For RPL, bias has to be added after TP collectives
        # So it cannot be fused with the GEMM
        if self.parallel_mode == "row" and self.tensor_parallel and self.has_bias:
            self.gemm_bias_fused_add = False
        else:
            self.gemm_bias_fused_add = True

709
710
711
712
713
714
715
716
    def _te_forward(
        self,
        inp: paddle.Tensor,
    ) -> paddle.Tensor:
        """
        Apply the linear transformation to the input.
        """
        with self.prepare_forward(inp) as inp:
717
718
719
720
            # Layer input should be casted outside PyLayer, as performing
            # inplace cast to input tensors may cause problems when used
            # together with Paddle native layers.
            inp = cast_if_needed(inp, self.activation_dtype)
721
            out = _Linear.apply(
722
723
                self.weight,
                inp,
724
725
                self.bias if self.gemm_bias_fused_add else None,
                self.has_bias and self.gemm_bias_fused_add,
726
727
728
                self.fp8_enabled,
                self.fp8_calibration,
                self.fp8_meta,
729
                self.activation_dtype,
730
                paddle.is_grad_enabled(),
731
732
                self.parallel_mode,
                self.tensor_parallel,
733
                self.sequence_parallel,
734
735
                self.tp_group,
                self.tp_size,
736
737
            )

738
739
740
        if not self.gemm_bias_fused_add:
            out = out + cast_if_needed_inplace(self.bias, self.activation_dtype)

741
742
743
744
745
746
747
        return out

    def _pd_forward(
        self,
        inp: paddle.Tensor,
    ) -> paddle.Tensor:
        """Calls Paddle OP"""
748
749
750
751
        if self.parallel_mode == 'column' and self.tensor_parallel:
            inp = identity(inp, self.tp_group)
        out = F.linear(inp, self.weight, self.bias if self.gemm_bias_fused_add else None)
        if self.parallel_mode == 'row' and self.tensor_parallel:
Tian Zheng's avatar
Tian Zheng committed
752
            out, _ = allreduce(out, self.tp_group)
753
754
            out = out + self.bias if self.bias is not None else out
        return out
755
756

    def forward(self, *args, **kwargs):
757
758
759
760
761
762
763
764
        """
        Apply the linear transformation to the input.

        Parameters
        ----------
        inp : torch.Tensor
             Input tensor.
        """
765
766
767
768
769
        if self.backend == 'transformer_engine':
            return self._te_forward(*args, **kwargs)
        if self.backend == 'paddle':
            return self._pd_forward(*args, **kwargs)
        raise AttributeError(f"Backend {self.backend} is not supported.")