linear.py 4.92 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Linear API"""

from typing import Union, Tuple

import paddle
import paddle.nn.functional as F
from paddle.nn.initializer import Constant

from .base import TransformerEngineBaseLayer, get_workspace
from ..cpp_extensions import gemm
from ..utils import cast_if_needed

__all__ = ["Linear"]


class _Linear(paddle.autograd.PyLayer):
    """TE implementation of non-FP8 Linear"""

    @staticmethod
    def forward(
        ctx,
        weight: paddle.Tensor,
        inp: paddle.Tensor,
        bias: paddle.Tensor,
        use_bias: bool,
        activation_dtype: paddle.dtype,
    ) -> paddle.Tensor:
        # Make sure input dimensions are compatible
        in_features = weight.shape[-1]
        assert inp.shape[-1] == in_features, "GEMM not possible"
        inputmat = inp.reshape((-1, in_features))

        out, _, _ = gemm(
            weight,
            inputmat,
            activation_dtype,
            get_workspace(),
            bias=bias,
            use_bias=use_bias,
        )

        ctx.save_for_backward(
            inputmat,
            weight,
        )
        ctx.activation_dtype = activation_dtype
        ctx.use_bias = use_bias
        ctx.inp_shape = inp.shape
        ctx.requires_dgrad = not inp.stop_gradient

        return out.reshape((-1, *inp.shape[1:-1], out.shape[-1]))

    @staticmethod
    def backward(ctx, grad_output: paddle.Tensor) -> Tuple[Union[paddle.Tensor, None], ...]:
        inputmat, weight = ctx.saved_tensor()
        if ctx.requires_dgrad:
            dgrad, _, _ = gemm(
                weight,
                grad_output,
                ctx.activation_dtype,
                get_workspace(),
                layout="NN",
                grad=True,
            )

        if not weight.stop_gradient:
            wgrad, grad_bias, _ = gemm(
                inputmat,
                grad_output,
                ctx.activation_dtype,
                get_workspace(),
                layout="NT",
                grad=True,
                use_bias=ctx.use_bias,
            )

        if not ctx.use_bias:
            return (
                wgrad if not weight.stop_gradient else None,
                dgrad.reshape(ctx.inp_shape) if ctx.requires_dgrad else None,
            )

        return (
            wgrad if not weight.stop_gradient else None,
            dgrad.reshape(ctx.inp_shape) if ctx.requires_dgrad else None,
            grad_bias,
        )


class Linear(TransformerEngineBaseLayer):
    """
    Applies a linear transformation to the incoming data :math:`y = xA^T + b`
    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        weight_attr: Union[paddle.ParamAttr, None] = None,
        bias_attr: Union[paddle.ParamAttr, None, bool] = None,
        backend: str = 'transformer_engine',
    ) -> None:
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.backend = backend
        self._weight_attr = weight_attr
        self._bias_attr = bias_attr
        self._dtype = self._helper.get_default_dtype()

        # TE linear weight is in column major
        self.weight = self.create_parameter(
            shape=[out_features, in_features]
            if self.backend == 'transformer_engine' else [in_features, out_features],
            attr=self._weight_attr,
            dtype=self._dtype,
            is_bias=False,
        )

        self.has_bias = self._bias_attr is not False
        if self.has_bias:
            self.bias = self.create_parameter(
                shape=[out_features],
                attr=self._bias_attr if self._bias_attr is not None else paddle.ParamAttr(
                    initializer=Constant(value=0.0)),
                dtype=self._dtype,
                is_bias=True,
            )
        else:
            self.bias = None

    def _te_forward(
        self,
        inp: paddle.Tensor,
    ) -> paddle.Tensor:
        """
        Apply the linear transformation to the input.
        """

        with self.prepare_forward(inp) as inp:
            out = _Linear.apply(
                cast_if_needed(self.weight, self.activation_dtype),
                cast_if_needed(inp, self.activation_dtype),
                cast_if_needed(self.bias, self.activation_dtype),
                self.has_bias,
                self.activation_dtype,
            )

        return out

    def _pd_forward(
        self,
        inp: paddle.Tensor,
    ) -> paddle.Tensor:
        """Calls Paddle OP"""
        return F.linear(inp, self.weight, self.bias)

    def forward(self, *args, **kwargs):
        """forward"""
        if self.backend == 'transformer_engine':
            return self._te_forward(*args, **kwargs)
        if self.backend == 'paddle':
            return self._pd_forward(*args, **kwargs)
        raise AttributeError(f"Backend {self.backend} is not supported.")