fused_dense.py 11.8 KB
Newer Older
1
2
# Adapted from https://github.com/NVIDIA/apex/blob/master/apex/fused_dense/fused_dense.py
# We make it work with pytorch amp and with bfloat16.
Tri Dao's avatar
Tri Dao committed
3
from typing import Optional
4
5
6
7

import torch
import torch.nn as nn
import torch.nn.functional as F
Tri Dao's avatar
Tri Dao committed
8
from torch import Tensor
9
10
11
12
13
14
15
from torch.cuda.amp import custom_bwd, custom_fwd

# import fused_dense_cuda  # from apex
import fused_dense_lib as fused_dense_cuda
from flash_attn.ops.gelu_activation import gelu_bwd


Tri Dao's avatar
Tri Dao committed
16
class FusedDenseFunc(torch.autograd.Function):
17
18
19

    @staticmethod
    @custom_fwd
Tri Dao's avatar
Tri Dao committed
20
    def forward(ctx, x, weight, bias, return_residual=False):
21
22
        if torch.is_autocast_enabled():
            dtype = torch.get_autocast_gpu_dtype()
Tri Dao's avatar
Tri Dao committed
23
24
25
            x, weight = [a.to(dtype=dtype) for a in [x, weight]]
            bias = bias.to(dtype=dtype) if bias is not None else None
        ctx.return_residual = return_residual
26
27
28
29
30
31
        x = x.contiguous()
        weight = weight.contiguous()
        ctx.save_for_backward(x, weight)
        batch_shape, n = x.shape[:-1], x.shape[-1]
        batch_dim = batch_shape.numel()
        assert batch_dim <= 64 * 1024, 'fused_dense only supports dimension at most 64k'
Tri Dao's avatar
Tri Dao committed
32
33
        output = F.linear(x, weight, bias)
        return output if not return_residual else (output, x)
34
35
36

    @staticmethod
    @custom_bwd
Tri Dao's avatar
Tri Dao committed
37
    def backward(ctx, grad_output, *args):
38
        grad_output = grad_output.contiguous()
Tri Dao's avatar
Tri Dao committed
39
40
41
        if ctx.return_residual:
            grad_input, = args
            grad_input = grad_input.contiguous()
42
43
44
        x, weight = ctx.saved_tensors
        batch_shape, n = x.shape[:-1], x.shape[-1]
        batch_dim = batch_shape.numel()
Tri Dao's avatar
Tri Dao committed
45
46
47
48
        grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
        if ctx.needs_input_grad[1]:
            grad_weight, grad_bias = fused_dense_cuda.linear_bias_wgrad(
                x.reshape(batch_dim, n), grad_output, ctx.needs_input_grad[2]
49
            )
Tri Dao's avatar
Tri Dao committed
50
51
52
53
54
55
56
57
        else:
            grad_weight = None
            grad_bias = grad_output if ctx.needs_input_grad[2] else None
        if ctx.needs_input_grad[0]:
            if not ctx.return_residual:
                grad_input = F.linear(grad_output, weight.t())
            else:
                grad_input = torch.addmm(grad_input.reshape(batch_dim, n), grad_output, weight)
58
59
60
            grad_input = grad_input.reshape_as(x)
        else:
            grad_input = None
Tri Dao's avatar
Tri Dao committed
61
        return grad_input, grad_weight, grad_bias, None
62
63


Tri Dao's avatar
Tri Dao committed
64
65
66
67
68
69
70
71
72
73
74
def fused_dense_func(x: Tensor, weight: Tensor, bias: Optional[Tensor] = None,
                     return_residual: bool = False):
    batch_dim = x.shape[:-1].numel()
    dtype_eligible = (x.dtype in [torch.float16, torch.bfloat16]
                      or (x.dtype == torch.float32 and torch.is_autocast_enabled()))
    if (x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and batch_dim <= 64 * 1024
        and dtype_eligible):
        return FusedDenseFunc.apply(x, weight, bias, return_residual)
    else:
        out = F.linear(x, weight, bias)
        return out if not return_residual else (out, x)
75
76


Tri Dao's avatar
Tri Dao committed
77
class FusedDense(nn.Linear):
78
79

    def __init__(self, in_features: int, out_features: int, bias: bool = True,
Tri Dao's avatar
Tri Dao committed
80
                 return_residual: bool = False, device=None, dtype=None) -> None:
81
        super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype)
Tri Dao's avatar
Tri Dao committed
82
        self.return_residual = return_residual
83
84

    def forward(self, x):
Tri Dao's avatar
Tri Dao committed
85
        return fused_dense_func(x, self.weight, self.bias, return_residual=self.return_residual)
86
87


Tri Dao's avatar
Tri Dao committed
88
class FusedDenseGeluDenseFunc(torch.autograd.Function):
89
90
91

    @staticmethod
    @custom_fwd
Tri Dao's avatar
Tri Dao committed
92
93
    def forward(ctx, x, weight1, bias1, weight2, bias2, save_gelu_in=True, return_residual=False,
                checkpoint_lvl=0, heuristic=0):
94
95
96
97
98
99
100
101
        """checkpoint_lvl:
        0: no recomputation in the bwd
        1: recompute gelu_out in the bwd
        2: recompute gelu_in and gelu_out in the bwd
        """
        assert -1 <= heuristic <= 4
        if torch.is_autocast_enabled():
            dtype = torch.get_autocast_gpu_dtype()
Tri Dao's avatar
Tri Dao committed
102
103
104
105
106
            x, weight1, weight2 = [a.to(dtype=dtype) for a in [x, weight1, weight2]]
            bias1 = bias1.to(dtype=dtype) if bias1 is not None else None
            bias2 = bias2.to(dtype=dtype) if bias2 is not None else None
        if not save_gelu_in:
            checkpoint_lvl = 2
107
        assert checkpoint_lvl in [0, 1, 2]
Tri Dao's avatar
Tri Dao committed
108
        ctx.return_residual = return_residual
109
110
        x = x.contiguous()
        weight1 = weight1.contiguous()
Tri Dao's avatar
Tri Dao committed
111
        bias1 = bias1.contiguous() if bias1 is not None else None
112
        weight2 = weight2.contiguous()
Tri Dao's avatar
Tri Dao committed
113
        bias2 = bias2.contiguous() if bias2 is not None else None
114
115
116
117
        batch_shape, n = x.shape[:-1], x.shape[-1]
        batch_dim = batch_shape.numel()
        assert batch_dim <= 64 * 1024, 'fused_dense only supports dimension at most 64k'
        if heuristic == -1:
Tri Dao's avatar
Tri Dao committed
118
            gelu_in = F.linear(x, weight1, bias1)
119
120
121
122
123
124
125
126
127
            output1 = F.gelu(gelu_in, approximate='tanh')
            # gelu_in = F.linear(x.reshape(batch_dim, n), weight1)  # This is before adding bias1
            # with torch.jit.fuser('fuser2'):
            #     output1 = bias_gelu(gelu_in, bias1)
        else:
            output1, *rest = fused_dense_cuda.linear_gelu_forward(x.reshape(batch_dim, n), weight1,
                                                                  bias1, save_gelu_in, heuristic)
            if save_gelu_in:
                gelu_in = rest[0]
Tri Dao's avatar
Tri Dao committed
128
        output2 = F.linear(output1, weight2, bias2)
129
130
131
        ctx.checkpoint_lvl = checkpoint_lvl
        ctx.heuristic = heuristic
        if checkpoint_lvl == 0:
Tri Dao's avatar
Tri Dao committed
132
            ctx.save_for_backward(x, weight1, weight2, gelu_in, output1)
133
        elif checkpoint_lvl == 1:
Tri Dao's avatar
Tri Dao committed
134
            ctx.save_for_backward(x, weight1, weight2, gelu_in)
135
        elif checkpoint_lvl == 2:
Tri Dao's avatar
Tri Dao committed
136
137
138
            ctx.save_for_backward(x, weight1, weight2, bias1)
        output2 = output2.reshape(*batch_shape, output2.shape[-1])
        return output2 if not return_residual else (output2, x)
139
140
141

    @staticmethod
    @custom_bwd
Tri Dao's avatar
Tri Dao committed
142
    def backward(ctx, grad_output, *args):
143
144
        grad_output = grad_output.contiguous()
        checkpoint_lvl = ctx.checkpoint_lvl
Tri Dao's avatar
Tri Dao committed
145
146
147
148
        if ctx.return_residual:
            grad_input, = args
            grad_input = grad_input.contiguous()
        x, weight1, weight2, *rest = ctx.saved_tensors
149
150
151
152
153
154
155
156
        batch_shape, n = x.shape[:-1], x.shape[-1]
        batch_dim = batch_shape.numel()
        if checkpoint_lvl == 0:
            gelu_in, output1 = rest
        elif checkpoint_lvl == 1:
            gelu_in, = rest
            output1 = F.gelu(gelu_in, approximate='tanh')
        elif checkpoint_lvl == 2:
Tri Dao's avatar
Tri Dao committed
157
            bias1, = rest
158
            if ctx.heuristic == -1:
Tri Dao's avatar
Tri Dao committed
159
                gelu_in = F.linear(x, weight1, bias1)
160
161
                output1 = F.gelu(gelu_in, approximate='tanh')
            else:
Tri Dao's avatar
Tri Dao committed
162
163
164
165
166
167
168
169
170
171
172
173
174
175
                output1, gelu_in = fused_dense_cuda.linear_gelu_forward(
                    x.reshape(batch_dim, n), weight1, bias1, True, ctx.heuristic
                )

        grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
        output1 = output1.reshape(batch_dim, output1.shape[-1])
        gelu_in = gelu_in.reshape(batch_dim, gelu_in.shape[-1])
        if ctx.needs_input_grad[3]:
            grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad(
                output1, grad_output, ctx.needs_input_grad[4]
            )
        else:
            grad_weight2 = None
            grad_bias2 = grad_output if ctx.needs_input_grad[4] else None
176
177
        if ctx.heuristic == -1:
            # grad_gelu = matmul_dgelu(grad_output, weight2, gelu_in)
Tri Dao's avatar
Tri Dao committed
178
            grad_output1 = F.linear(grad_output, weight2.t())
179
180
            with torch.jit.fuser('fuser2'):
                grad_gelu = gelu_bwd(grad_output1, gelu_in)
Tri Dao's avatar
Tri Dao committed
181
182
183
184
185
186
187
            if ctx.needs_input_grad[1]:
                grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_wgrad(
                    x.reshape(batch_dim, n), grad_gelu, ctx.needs_input_grad[2]
                )
            else:
                grad_weight1 = None
                grad_bias1 = grad_gelu if ctx.needs_input_grad[2] else None
188
        else:
Tri Dao's avatar
Tri Dao committed
189
190
191
192
            # The cublasLt epilogue has to compute both gelu grad and bias grad, we can't
            # just compute gelu grad
            grad_gelu, grad_bias1 = fused_dense_cuda.bias_gelu_linear_dgrad_bgrad(
                weight2, grad_output, gelu_in, ctx.heuristic
193
            )
Tri Dao's avatar
Tri Dao committed
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
            if not ctx.needs_input_grad[2]:
                grad_bias1 = None
            if ctx.needs_input_grad[1]:
                grad_weight1 = F.linear(grad_gelu.t(), x.reshape(batch_dim, n).t())
            else:
                grad_weight1 = None
        if ctx.needs_input_grad[0]:
            if not ctx.return_residual:
                grad_input = F.linear(grad_gelu, weight1.t())
            else:
                grad_input = torch.addmm(grad_input.reshape(batch_dim, n), grad_gelu, weight1)
            grad_input = grad_input.reshape_as(x)
        else:
            grad_input = None
        return grad_input, grad_weight1, grad_bias1, grad_weight2, grad_bias2, None, None, None, None


def fused_dense_gelu_dense_func(
    x: Tensor, weight1: Tensor, weight2: Tensor, bias1: Optional[Tensor] = None,
    bias2: Optional[Tensor] = None,
    save_gelu_in: bool = True, return_residual: bool = False,
    checkpoint_lvl: int = 0, heuristic: int = 0
):
    batch_dim = x.shape[:-1].numel()
    dtype_eligible = (x.dtype in [torch.float16, torch.bfloat16]
                      or (x.dtype == torch.float32 and torch.is_autocast_enabled()))
    if (x.is_cuda and weight1.is_cuda and weight2.is_cuda and (bias1 is None or bias1.is_cuda)
        and (bias2 is None or bias2.is_cuda) and batch_dim <= 64 * 1024
        and dtype_eligible):
        return FusedDenseGeluDenseFunc.apply(
            x, weight1, bias1, weight2, bias2,
            save_gelu_in, return_residual, checkpoint_lvl, heuristic
        )
    else:
        gelu_in = F.linear(x, weight1, bias1)
        output1 = F.gelu(gelu_in, approximate='tanh')
        output2 = F.linear(output1, weight2, bias2)
        return output2 if not return_residual else (output2, x)
232
233


Tri Dao's avatar
Tri Dao committed
234
class FusedDenseGeluDense(nn.Module):
235

Tri Dao's avatar
Tri Dao committed
236
237
238
    def __init__(self, in_features, hidden_features, out_features=None, bias1=True,
                 bias2=True, return_residual=False, checkpoint_lvl=0, heuristic=0,
                 device=None, dtype=None):
239
240
241
242
243
244
245
246
        """
        checkpoint_lvl (increasing lvl means slower but more memory saving):
            0: no recomputation in the bwd
            1: recompute gelu_out in the bwd
            2: recompute gelu_in and gelu_out in the bwd
        heuristic:
            -1: don't fuse gemm + gelu (separate kernel)
            0..4: use this heuristic for the algo section in the fused gemm + gelu
Tri Dao's avatar
Tri Dao committed
247
248
249
250
251
            For CUDA >= 11.8, you'd want heuristic=0 for both fp16 and bf16 for best perf.
            For CUDA <= 11.7, you'd want heuristic=1 for fp16 and heuristic=-1 for bf16.
        return_residual: whether to return the input x along with the output. This is for
            performance reason: for post-norm architecture, returning the input allows us
            to fuse the backward of nn.Linear with the residual connection.
252
253
254
255
256
257
        """
        assert checkpoint_lvl in [0, 1, 2]
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        if out_features is None:
            out_features = in_features
Tri Dao's avatar
Tri Dao committed
258
        self.return_residual = return_residual
259
260
        self.checkpoint_lvl = checkpoint_lvl
        self.heuristic = heuristic
Tri Dao's avatar
Tri Dao committed
261
262
        self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs)
        self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
263
264

    def forward(self, x):
Tri Dao's avatar
Tri Dao committed
265
266
267
268
        return fused_dense_gelu_dense_func(
            x, self.fc1.weight, self.fc2.weight, self.fc1.bias, self.fc2.bias,
            save_gelu_in=self.training, return_residual=self.return_residual,
            checkpoint_lvl=self.checkpoint_lvl, heuristic=self.heuristic
269
        )