linear.py 16.4 KB
Newer Older
1
2
3
4
5
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Linear API"""

6
from typing import Union, Tuple, Dict, Any
7
8
9
10
11

import paddle
import paddle.nn.functional as F
from paddle.nn.initializer import Constant

12
13
14
15
16
17
18
from .base import (
    TransformerEngineBaseLayer,
    get_workspace,
    _2X_ACC_FPROP,
    _2X_ACC_DGRAD,
    _2X_ACC_WGRAD,
)
19

20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
from ..fp8 import get_fp8_te_dtype
from ..constants import FP8FwdTensors, FP8BwdTensors
from ..cpp_extensions import gemm, fp8_gemm, cast_to_fp8, cast_transpose
from ..utils import (
    cast_if_needed,
    cast_if_needed_inplace,
    assert_dim_for_fp8_forward_exec,
    get_bias_dtype,
)

__all__ = ["Linear", "_linear_fwd", "_linear_fwd_fp8", "_linear_bwd", "_linear_fwd_non_fp8"]


def _linear_fwd_fp8(
    inputmat: paddle.Tensor,
    inputmat_fp8_index: FP8FwdTensors,
    weight: paddle.Tensor,
    weight_fp8_index: FP8FwdTensors,
    bias: paddle.Tensor,
    use_bias: bool,
    fp8_meta: Dict[str, Any],
    activation_dtype: paddle.dtype,
    is_grad_enabled: bool,
):
    """FP8 path of Linear Fwd"""
    fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
    bias_dtype = get_bias_dtype(activation_dtype)
    bias = cast_if_needed_inplace(bias, bias_dtype)

    if is_grad_enabled:
        weight_fp8, weight_t_fp8 = cast_transpose(
            weight,
            fp8_meta["scaling_fwd"],
            weight_fp8_index,
            fp8_dtype_forward,
        )
    else:
        weight_t_fp8 = None
        weight_fp8 = cast_to_fp8(
            weight,
            fp8_meta["scaling_fwd"],
            weight_fp8_index,
            fp8_dtype_forward,
        )

    out = fp8_gemm(
        weight_fp8,
        fp8_meta["scaling_fwd"].scale_inv,
        weight_fp8_index,
        fp8_dtype_forward,
        inputmat,
        fp8_meta["scaling_fwd"].scale_inv,
        inputmat_fp8_index,
        fp8_dtype_forward,
        activation_dtype,
        get_workspace(),
        bias=bias,
        use_bias=use_bias,
        use_split_accumulator=_2X_ACC_FPROP,
    )

    return out, weight_t_fp8


def _linear_fwd_non_fp8(
    inputmat: paddle.Tensor,
    inputmat_fp8_index: FP8FwdTensors,
    weight: paddle.Tensor,
    weight_fp8_index: FP8FwdTensors,
    bias: paddle.Tensor,
    use_bias: bool,
    fp8_calibration: bool,
    fp8_meta: Dict[str, Any],
    activation_dtype: paddle.dtype,
    activation: str = "",
):
    """Non-FP8 path of Linear Fwd"""

    # Layer parameters are initialized as float32 dtype by default.
    # Cast the parameters to activation_dtype if the current dtype
    # does not match activation_dtype. The casting is inplace, so it
    # only needs to performed once throughout the traing process.
    weight = cast_if_needed_inplace(weight, activation_dtype)
    bias = cast_if_needed_inplace(bias, activation_dtype)

    if fp8_calibration:
        # amax of input
        fp8_meta["scaling_fwd"].amax_history[0, inputmat_fp8_index.value] = \
            paddle.max(paddle.abs(inputmat)).item()
        # amax of weight
        fp8_meta["scaling_fwd"].amax_history[0, weight_fp8_index.value] = \
            paddle.max(paddle.abs(weight)).item()

    outputs = gemm(weight,
                   inputmat,
                   activation_dtype,
                   get_workspace(),
                   bias=bias,
                   use_bias=use_bias,
                   gelu=(activation == 'gelu'))

    if activation == 'gelu':
        gelu_out, _, out = outputs
        return out, gelu_out

    out, _, _ = outputs
    return out


def _linear_fwd(
    inputmat: paddle.Tensor,
    inputmat_fp8_index: FP8FwdTensors,
    weight: paddle.Tensor,
    weight_fp8_index: FP8FwdTensors,
    bias: paddle.Tensor,
    use_bias: bool,
    fp8_enabled: bool,
    fp8_calibration: bool,
    fp8_meta: Dict[str, Any],
    activation_dtype: paddle.dtype,
    is_grad_enabled: bool,
):
    if fp8_enabled:
        out, weight_t_fp8 = _linear_fwd_fp8(
            inputmat,
            inputmat_fp8_index,
            weight,
            weight_fp8_index,
            bias,
            use_bias,
            fp8_meta,
            activation_dtype,
            is_grad_enabled,
        )
    else:
        out = _linear_fwd_non_fp8(
            inputmat,
            inputmat_fp8_index,
            weight,
            weight_fp8_index,
            bias,
            use_bias,
            fp8_calibration,
            fp8_meta,
            activation_dtype,
        )
    return (
        out,
        weight_t_fp8 if fp8_enabled else None,
    )


def _linear_bwd_fp8(
    inputmat: paddle.Tensor,
    inputmat_t: paddle.Tensor,
    inputmat_fp8_index: FP8FwdTensors,
    weight_t_fp8: paddle.Tensor,
    weight_fp8_index: FP8FwdTensors,
    grad_output: paddle.Tensor,
    grad_output_c: paddle.Tensor,
    grad_output_t: paddle.Tensor,
    grad_output_fp8_index: FP8BwdTensors,
    fwd_scale_inverses: paddle.Tensor,
    fp8_meta: Dict[str, Any],
    requires_dgrad: bool,
    requires_wgrad: bool,
    activation_dtype: paddle.dtype,
):
    dgrad, wgrad = None, None
    fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
    fp8_dtype_backward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=False)
    if requires_dgrad:
        dgrad = fp8_gemm(
            weight_t_fp8,
            fwd_scale_inverses,
            weight_fp8_index,
            fp8_dtype_forward,
            grad_output_c,
            fp8_meta["scaling_bwd"].scale_inv,
            grad_output_fp8_index,
            fp8_dtype_backward,
            activation_dtype,
            get_workspace(),
            use_split_accumulator=_2X_ACC_DGRAD,
        )
    if requires_wgrad:
        if not fp8_meta["recipe"].override_linear_precision.wgrad:
            wgrad = fp8_gemm(
                inputmat_t,
                fwd_scale_inverses,
                inputmat_fp8_index,
                fp8_dtype_forward,
                grad_output_t,
                fp8_meta["scaling_bwd"].scale_inv,
                grad_output_fp8_index,
                fp8_dtype_backward,
                activation_dtype,
                get_workspace(),
                use_split_accumulator=_2X_ACC_WGRAD,
            )
        else:
            wgrad, _, _ = gemm(
                inputmat,
                grad_output,
                activation_dtype,
                get_workspace(),
                layout="NT",
                grad=True,
            )
    return dgrad, wgrad


def _linear_bwd_non_fp8(
    inputmat: paddle.Tensor,
    weight: paddle.Tensor,
    grad_output: paddle.Tensor,
    requires_bgrad: bool,
    requires_dgrad: bool,
    activation_dtype: paddle.dtype,
    gelu_input: Union[paddle.Tensor, None] = None,
    activation: str = "",
):
    """
    Performs Linear Backward. Optionally, fuses GELU backward and dbias.
    """
    dgrad, wgrad, bgrad = None, None, None
    requires_wgrad = not weight.stop_gradient
    if requires_dgrad:
        dgrad, _, _ = gemm(
            weight,
            grad_output,
            activation_dtype,
            get_workspace(),
            layout="NN",
            gelu=(activation == 'gelu'),
            gelu_input=gelu_input,
            grad=True,
        )
    if requires_wgrad:
        wgrad, bgrad, _ = gemm(
            inputmat,
            grad_output,
            activation_dtype,
            get_workspace(),
            layout="NT",
            grad=True,
            use_bias=requires_bgrad,
        )
    elif requires_bgrad:
        bgrad = grad_output.sum(axis=0)

    return dgrad, wgrad, bgrad


def _linear_bwd(
    inputmat: paddle.Tensor,
    inputmat_t: paddle.Tensor,
    inputmat_fp8_index: FP8FwdTensors,
    weight: paddle.Tensor,
    weight_t_fp8: paddle.Tensor,
    weight_fp8_index: FP8FwdTensors,
    grad_output: paddle.Tensor,
    grad_output_c: paddle.Tensor,
    grad_output_t: paddle.Tensor,
    grad_output_fp8_index: FP8BwdTensors,
    fwd_scale_inverses: paddle.Tensor,
    requires_bgrad: bool,
    fp8_enabled: bool,
    fp8_meta: Dict[str, Any],
    requires_dgrad: bool,
    activation_dtype: paddle.dtype,
):
    dgrad, wgrad, bgrad = None, None, None
    requires_wgrad = not weight.stop_gradient
    if fp8_enabled:
        dgrad, wgrad = _linear_bwd_fp8(
            inputmat,
            inputmat_t,
            inputmat_fp8_index,
            weight_t_fp8,
            weight_fp8_index,
            grad_output,
            grad_output_c,
            grad_output_t,
            grad_output_fp8_index,
            fwd_scale_inverses,
            fp8_meta,
            requires_dgrad,
            requires_wgrad,
            activation_dtype,
        )
    else:
        dgrad, wgrad, bgrad = _linear_bwd_non_fp8(
            inputmat,
            weight,
            grad_output,
            requires_bgrad,
            requires_dgrad,
            activation_dtype,
        )
    return dgrad, wgrad, bgrad
321
322
323


class _Linear(paddle.autograd.PyLayer):
324
    """TE implementation of Linear"""
325
326
327
328
329
330
331
332

    @staticmethod
    def forward(
        ctx,
        weight: paddle.Tensor,
        inp: paddle.Tensor,
        bias: paddle.Tensor,
        use_bias: bool,
333
334
335
        fp8_enabled: bool,
        fp8_calibration: bool,
        fp8_meta: Dict[str, Any],
336
        activation_dtype: paddle.dtype,
337
        is_grad_enabled: bool,
338
339
340
341
342
    ) -> paddle.Tensor:
        # Make sure input dimensions are compatible
        in_features = weight.shape[-1]
        assert inp.shape[-1] == in_features, "GEMM not possible"
        inputmat = inp.reshape((-1, in_features))
343
344
345
        if fp8_enabled:
            assert_dim_for_fp8_forward_exec(inputmat)
            assert_dim_for_fp8_forward_exec(weight)
346

347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
        inputmat_no_fp8 = inputmat

        # FP8 casting
        if fp8_enabled:
            fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)

            if not fp8_meta["recipe"].override_linear_precision.wgrad:
                if is_grad_enabled:
                    inputmat, inputmat_t = cast_transpose(
                        inputmat,
                        fp8_meta["scaling_fwd"],
                        FP8FwdTensors.GEMM1_INPUT,
                        fp8_dtype_forward,
                    )
                else:
                    inputmat = cast_to_fp8(
                        inputmat,
                        fp8_meta["scaling_fwd"],
                        FP8FwdTensors.GEMM1_INPUT,
                        fp8_dtype_forward,
                    )
            else:
                inputmat, inputmat_t = cast_to_fp8(
                    inputmat,
                    fp8_meta["scaling_fwd"],
                    FP8FwdTensors.GEMM1_INPUT,
                    fp8_dtype_forward,
                ), None
375

376
377
        # GEMM Fwd
        out, weight_t_fp8 = _linear_fwd(
378
            inputmat,
379
            FP8FwdTensors.GEMM1_INPUT,
380
            weight,
381
382
383
384
385
386
387
388
            FP8FwdTensors.GEMM1_WEIGHT,
            bias,
            use_bias,
            fp8_enabled,
            fp8_calibration,
            fp8_meta,
            activation_dtype,
            is_grad_enabled,
389
        )
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406

        if is_grad_enabled:
            fp8_wgrad = fp8_enabled and not fp8_meta["recipe"].override_linear_precision.wgrad
            ctx.save_for_backward(
                inputmat_no_fp8 if not weight.stop_gradient and not fp8_wgrad else None,
                inputmat_t if not weight.stop_gradient and fp8_wgrad else None,
                weight,
                weight_t_fp8 if fp8_enabled else None,
                fp8_meta["scaling_fwd"].scale_inv.clone() if fp8_enabled else None,
            )
            ctx.activation_dtype = activation_dtype
            ctx.fp8_enabled = fp8_enabled
            ctx.fp8_meta = fp8_meta
            ctx.use_bias = use_bias
            ctx.inp_shape = inp.shape
            ctx.requires_dgrad = not inp.stop_gradient
            ctx.requires_bgrad = use_bias and not bias.stop_gradient
407
408
409
410
411

        return out.reshape((-1, *inp.shape[1:-1], out.shape[-1]))

    @staticmethod
    def backward(ctx, grad_output: paddle.Tensor) -> Tuple[Union[paddle.Tensor, None], ...]:
412
413
414
415
416
417
        with TransformerEngineBaseLayer.prepare_backward(ctx.fp8_enabled,
                                                         ctx.fp8_meta,
                                                         name="_Linear"):
            (
                inputmat,
                inputmat_t,
418
                weight,
419
420
421
422
423
                weight_t_fp8,
                fwd_scale_inverses,
            ) = ctx.saved_tensor()

            (
424
                grad_output,
425
426
427
428
                grad_output_c,
                grad_output_t,
                bgrad,
            ) = TransformerEngineBaseLayer.grad_output_preprocess(ctx, grad_output)
429

430
            dgrad, wgrad, bgrad_ = _linear_bwd(
431
                inputmat,
432
433
434
435
436
                inputmat_t,
                FP8FwdTensors.GEMM1_INPUT,
                weight,
                weight_t_fp8,
                FP8FwdTensors.GEMM1_WEIGHT,
437
                grad_output,
438
439
440
441
442
443
444
445
                grad_output_c,
                grad_output_t,
                FP8BwdTensors.GRAD_OUTPUT1,
                fwd_scale_inverses,
                ctx.requires_bgrad,
                ctx.fp8_enabled,
                ctx.fp8_meta,
                ctx.requires_dgrad,
446
447
448
                ctx.activation_dtype,
            )

449
450
451
452
453
454
455
456
457
458
            if not ctx.fp8_enabled:
                # bgrad is fused with gemm for non-FP8 path
                bgrad = bgrad_

            if not ctx.use_bias:
                return (
                    wgrad if not weight.stop_gradient else None,
                    dgrad.reshape(ctx.inp_shape) if ctx.requires_dgrad else None,
                )

459
460
461
            return (
                wgrad if not weight.stop_gradient else None,
                dgrad.reshape(ctx.inp_shape) if ctx.requires_dgrad else None,
462
                bgrad if ctx.requires_bgrad else None,
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
            )


class Linear(TransformerEngineBaseLayer):
    """
    Applies a linear transformation to the incoming data :math:`y = xA^T + b`
    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        weight_attr: Union[paddle.ParamAttr, None] = None,
        bias_attr: Union[paddle.ParamAttr, None, bool] = None,
        backend: str = 'transformer_engine',
    ) -> None:
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.backend = backend
        self._weight_attr = weight_attr
        self._bias_attr = bias_attr
        self._dtype = self._helper.get_default_dtype()

        # TE linear weight is in column major
        self.weight = self.create_parameter(
            shape=[out_features, in_features]
            if self.backend == 'transformer_engine' else [in_features, out_features],
            attr=self._weight_attr,
            dtype=self._dtype,
            is_bias=False,
        )

        self.has_bias = self._bias_attr is not False
497
        use_default_bias = self._bias_attr is None or self._bias_attr is True
498
499
500
        if self.has_bias:
            self.bias = self.create_parameter(
                shape=[out_features],
501
                attr=self._bias_attr if not use_default_bias else paddle.ParamAttr(
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
                    initializer=Constant(value=0.0)),
                dtype=self._dtype,
                is_bias=True,
            )
        else:
            self.bias = None

    def _te_forward(
        self,
        inp: paddle.Tensor,
    ) -> paddle.Tensor:
        """
        Apply the linear transformation to the input.
        """
        with self.prepare_forward(inp) as inp:
517
518
519
520
            # Layer input should be casted outside PyLayer, as performing
            # inplace cast to input tensors may cause problems when used
            # together with Paddle native layers.
            inp = cast_if_needed(inp, self.activation_dtype)
521
            out = _Linear.apply(
522
523
524
                self.weight,
                inp,
                self.bias,
525
                self.has_bias,
526
527
528
                self.fp8_enabled,
                self.fp8_calibration,
                self.fp8_meta,
529
                self.activation_dtype,
530
                paddle.is_grad_enabled(),
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
            )

        return out

    def _pd_forward(
        self,
        inp: paddle.Tensor,
    ) -> paddle.Tensor:
        """Calls Paddle OP"""
        return F.linear(inp, self.weight, self.bias)

    def forward(self, *args, **kwargs):
        """forward"""
        if self.backend == 'transformer_engine':
            return self._te_forward(*args, **kwargs)
        if self.backend == 'paddle':
            return self._pd_forward(*args, **kwargs)
        raise AttributeError(f"Backend {self.backend} is not supported.")