fused_dense.py 20.6 KB
Newer Older
1
2
# Copyright (c) 2022, Tri Dao.
# Inspired by https://github.com/NVIDIA/apex/blob/master/apex/fused_dense/fused_dense.py
3
# We make it work with pytorch amp and with bfloat16.
4
# The TensorParallel linear modules are inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/layers.py
Tri Dao's avatar
Tri Dao committed
5
from typing import Optional
6
7
8
9

import torch
import torch.nn as nn
import torch.nn.functional as F
Tri Dao's avatar
Tri Dao committed
10
from torch import Tensor
11
from torch.distributed import ProcessGroup
12
13
14
15
from torch.cuda.amp import custom_bwd, custom_fwd

# import fused_dense_cuda  # from apex
import fused_dense_lib as fused_dense_cuda
16

17
from flash_attn.ops.gelu_activation import gelu_bwd
18
from flash_attn.utils.distributed import all_gather_raw, reduce_scatter_raw, reduce_scatter
19
20


Tri Dao's avatar
Tri Dao committed
21
class FusedDenseFunc(torch.autograd.Function):
22
23
24

    @staticmethod
    @custom_fwd
25
26
27
28
29
    def forward(ctx, x, weight, bias, return_residual=False, process_group=None):
        """
        If process_group is not None, we're doing Tensor Parallel with sequence parallelism:
        we do an all_gather_raw of x before doing the matmul.
        """
30
31
32
33
        ctx.compute_weight_gradient = weight.requires_grad
        ctx.return_residual = return_residual
        ctx.process_group = process_group

34
        if torch.is_autocast_enabled():
35
            x = x.to(dtype=torch.get_autocast_gpu_dtype())
36
        x = x.contiguous()
37
38
39
40
41
42
43
44
45
        if process_group is not None:
            # We want to kick off the all_gather early, before weight dtype conversion
            total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
        else:
            total_x = x

        if torch.is_autocast_enabled():
            weight = weight.to(dtype=torch.get_autocast_gpu_dtype())
            bias = bias.to(dtype=torch.get_autocast_gpu_dtype()) if bias is not None else None
46
        weight = weight.contiguous()
47
48
        if process_group is not None:
            handle_x.wait()
49
        batch_shape, n = total_x.shape[:-1], total_x.shape[-1]
50
        batch_dim = batch_shape.numel()
51
52
53
        # https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174
        if min(batch_dim, n, *weight.shape) > 65535 * 32:
            raise RuntimeError('fused_dense only supports matrix dims <= 2M')
54
        output = F.linear(total_x, weight, bias)
55
56
57
58
        if ctx.compute_weight_gradient:
            ctx.save_for_backward(x, weight)
        else:
            ctx.save_for_backward(weight)
Tri Dao's avatar
Tri Dao committed
59
        return output if not return_residual else (output, x)
60
61
62

    @staticmethod
    @custom_bwd
Tri Dao's avatar
Tri Dao committed
63
    def backward(ctx, grad_output, *args):
64
        grad_output = grad_output.contiguous()
Tri Dao's avatar
Tri Dao committed
65
66
67
        if ctx.return_residual:
            grad_input, = args
            grad_input = grad_input.contiguous()
68
69
70
71
72
73
74
75
76
77
78
        process_group = ctx.process_group
        if ctx.compute_weight_gradient:
            x, weight = ctx.saved_tensors
            if process_group is not None:
                total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
            else:
                total_x = x
        else:
            weight, = ctx.saved_tensors
            total_x = None
        batch_shape = grad_output.shape[:-1]
79
        batch_dim = batch_shape.numel()
Tri Dao's avatar
Tri Dao committed
80
81
82
83
84
        grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
        if ctx.needs_input_grad[0]:
            if not ctx.return_residual:
                grad_input = F.linear(grad_output, weight.t())
            else:
85
86
87
88
89
90
                grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]),
                                         grad_output, weight)
            grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
            if process_group is not None:
                grad_input, handle_grad_input = reduce_scatter_raw(grad_input, process_group,
                                                                   async_op=True)
91
92
        else:
            grad_input = None
93
94
95
96
97
98
99
100
101
102
103
104
105
        if ctx.needs_input_grad[1]:
            assert ctx.compute_weight_gradient
            if process_group is not None:
                handle_x.wait()
            grad_weight, grad_bias = fused_dense_cuda.linear_bias_wgrad(
                total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2]
            )
        else:
            grad_weight = None
            grad_bias = grad_output if ctx.needs_input_grad[2] else None
        if process_group is not None and ctx.needs_input_grad[0]:
            handle_grad_input.wait()
        return grad_input, grad_weight, grad_bias, None, None
106
107


Tri Dao's avatar
Tri Dao committed
108
def fused_dense_func(x: Tensor, weight: Tensor, bias: Optional[Tensor] = None,
109
                     return_residual: bool = False, process_group: Optional[ProcessGroup] = None):
Tri Dao's avatar
Tri Dao committed
110
111
    dtype_eligible = (x.dtype in [torch.float16, torch.bfloat16]
                      or (x.dtype == torch.float32 and torch.is_autocast_enabled()))
112
    if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible:
113
        return FusedDenseFunc.apply(x, weight, bias, return_residual, process_group)
Tri Dao's avatar
Tri Dao committed
114
    else:
115
        assert process_group is None
Tri Dao's avatar
Tri Dao committed
116
117
        out = F.linear(x, weight, bias)
        return out if not return_residual else (out, x)
118
119


Tri Dao's avatar
Tri Dao committed
120
class FusedDense(nn.Linear):
121
122

    def __init__(self, in_features: int, out_features: int, bias: bool = True,
Tri Dao's avatar
Tri Dao committed
123
                 return_residual: bool = False, device=None, dtype=None) -> None:
124
        super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype)
Tri Dao's avatar
Tri Dao committed
125
        self.return_residual = return_residual
126

127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
    def forward(self, x, process_group=None):
        """
        If process_group is not None, we're doing Tensor Parallel with sequence parallelism:
        we do an all_gather of x before doing the matmul.
        """
        return fused_dense_func(x, self.weight, self.bias, return_residual=self.return_residual,
                                process_group=process_group)


class ColumnParallelLinear(nn.Linear):

    def __init__(self, in_features: int, out_features: int, process_group: ProcessGroup,
                 bias: bool = True, device=None, dtype=None) -> None:
        world_size = torch.distributed.get_world_size(process_group)
        if out_features % world_size != 0:
            raise ValueError(f'out_features ({out_features}) must be divisible by '
                             f'world_size ({world_size})')
        super().__init__(in_features, out_features // world_size, bias=bias,
                         device=device, dtype=dtype)
        self.process_group = process_group

148
    def forward(self, x):
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
        """
        We're doing Tensor Parallel with sequence parallelism: we do an all_gather of
        x before doing the matmul.
        """
        return fused_dense_func(x, self.weight, self.bias, process_group=self.process_group)


class RowParallelLinear(nn.Linear):

    def __init__(self, in_features: int, out_features: int, process_group: ProcessGroup,
                 bias: bool = True, device=None, dtype=None) -> None:
        world_size = torch.distributed.get_world_size(process_group)
        rank = torch.distributed.get_rank(process_group)
        if in_features % world_size != 0:
            raise ValueError(f'in_features ({in_features}) must be divisible by '
                             f'world_size ({world_size})')
        # Only rank 0 will have bias
        super().__init__(in_features // world_size, out_features, bias=bias and rank == 0,
                         device=device, dtype=dtype)
        self.process_group = process_group

    def forward(self, x):
        """
        We're doing Tensor Parallel with sequence parallelism: we do the matmul and then
        a reduce_scatter of the result.
        """
        out = fused_dense_func(x, self.weight, self.bias)
        return reduce_scatter(out, self.process_group)
177
178


Tri Dao's avatar
Tri Dao committed
179
class FusedDenseGeluDenseFunc(torch.autograd.Function):
180
181
182

    @staticmethod
    @custom_fwd
183
184
185
186
187
188
189
    def forward(ctx, x, weight1, bias1, weight2, bias2, save_pre_act=True, return_residual=False,
                checkpoint_lvl=0, heuristic=0, process_group=None):
        """
        If process_group is not None, we're doing Tensor Parallel with sequence parallelism:
        we do an all_gather of x before doing the matmul.

        checkpoint_lvl:
190
191
192
193
194
        0: no recomputation in the bwd
        1: recompute gelu_out in the bwd
        2: recompute gelu_in and gelu_out in the bwd
        """
        assert -1 <= heuristic <= 4
195
        if not save_pre_act:
Tri Dao's avatar
Tri Dao committed
196
            checkpoint_lvl = 2
197
        assert checkpoint_lvl in [0, 1, 2]
Tri Dao's avatar
Tri Dao committed
198
        ctx.return_residual = return_residual
199
        ctx.process_group = process_group
200
201
202
203
204
        ctx.checkpoint_lvl = checkpoint_lvl
        ctx.heuristic = heuristic

        if torch.is_autocast_enabled():
            x = x.to(dtype=torch.get_autocast_gpu_dtype())
205
        x = x.contiguous()
206
207
208
209
210
211
212
213
214
215
216
        if process_group is not None:
            # We want to kick off the all_gather early, before weight dtype conversion
            total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
        else:
            total_x = x

        if torch.is_autocast_enabled():
            dtype = torch.get_autocast_gpu_dtype()
            weight1, weight2 = [a.to(dtype=dtype) for a in [weight1, weight2]]
            bias1 = bias1.to(dtype=dtype) if bias1 is not None else None
            bias2 = bias2.to(dtype=dtype) if bias2 is not None else None
217
        weight1 = weight1.contiguous()
Tri Dao's avatar
Tri Dao committed
218
        bias1 = bias1.contiguous() if bias1 is not None else None
219
        weight2 = weight2.contiguous()
Tri Dao's avatar
Tri Dao committed
220
        bias2 = bias2.contiguous() if bias2 is not None else None
221
        if process_group is not None:
222
            handle_x.wait()
223
        batch_shape, n = total_x.shape[:-1], total_x.shape[-1]
224
        batch_dim = batch_shape.numel()
225
226
227
        # https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174
        if min(batch_dim, n, *weight1.shape, *weight2.shape) > 65535 * 32:
            raise RuntimeError('fused_dense only supports matrix dims <= 2M')
228
        if heuristic == -1:
229
            gelu_in = F.linear(total_x, weight1, bias1)
230
            output1 = F.gelu(gelu_in, approximate='tanh')
231
232
            # This is before adding bias1
            # gelu_in = F.linear(total_x.reshape(batch_dim, n), weight1)
233
234
235
            # with torch.jit.fuser('fuser2'):
            #     output1 = bias_gelu(gelu_in, bias1)
        else:
236
237
238
239
            output1, *rest = fused_dense_cuda.linear_gelu_forward(
                total_x.reshape(batch_dim, n), weight1, bias1, save_pre_act, heuristic
            )
            if save_pre_act:
240
                gelu_in = rest[0]
Tri Dao's avatar
Tri Dao committed
241
        output2 = F.linear(output1, weight2, bias2)
242
        if checkpoint_lvl == 0:
Tri Dao's avatar
Tri Dao committed
243
            ctx.save_for_backward(x, weight1, weight2, gelu_in, output1)
244
        elif checkpoint_lvl == 1:
Tri Dao's avatar
Tri Dao committed
245
            ctx.save_for_backward(x, weight1, weight2, gelu_in)
246
        elif checkpoint_lvl == 2:
Tri Dao's avatar
Tri Dao committed
247
248
249
            ctx.save_for_backward(x, weight1, weight2, bias1)
        output2 = output2.reshape(*batch_shape, output2.shape[-1])
        return output2 if not return_residual else (output2, x)
250
251
252

    @staticmethod
    @custom_bwd
Tri Dao's avatar
Tri Dao committed
253
    def backward(ctx, grad_output, *args):
254
255
        grad_output = grad_output.contiguous()
        checkpoint_lvl = ctx.checkpoint_lvl
Tri Dao's avatar
Tri Dao committed
256
257
258
        if ctx.return_residual:
            grad_input, = args
            grad_input = grad_input.contiguous()
259
        process_group = ctx.process_group
Tri Dao's avatar
Tri Dao committed
260
        x, weight1, weight2, *rest = ctx.saved_tensors
261
262
263
        if process_group is None:
            total_x = x
        batch_shape = grad_output.shape[:-1]
264
        batch_dim = batch_shape.numel()
265
266
267
268
269
270
271
272
        if checkpoint_lvl in [0, 1]:
            if process_group is not None:
                total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
            if checkpoint_lvl == 0:
                gelu_in, output1 = rest
            elif checkpoint_lvl == 1:
                gelu_in, = rest
                output1 = F.gelu(gelu_in, approximate='tanh')
273
        elif checkpoint_lvl == 2:
Tri Dao's avatar
Tri Dao committed
274
            bias1, = rest
275
276
            if process_group is not None:
                total_x, _ = all_gather_raw(x, process_group)
277
            if ctx.heuristic == -1:
278
                gelu_in = F.linear(total_x, weight1, bias1)
279
280
                output1 = F.gelu(gelu_in, approximate='tanh')
            else:
Tri Dao's avatar
Tri Dao committed
281
                output1, gelu_in = fused_dense_cuda.linear_gelu_forward(
282
283
                    total_x.reshape(batch_dim, total_x.shape[-1]), weight1, bias1, True,
                    ctx.heuristic
Tri Dao's avatar
Tri Dao committed
284
285
286
287
288
289
290
291
292
293
294
295
                )

        grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
        output1 = output1.reshape(batch_dim, output1.shape[-1])
        gelu_in = gelu_in.reshape(batch_dim, gelu_in.shape[-1])
        if ctx.needs_input_grad[3]:
            grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad(
                output1, grad_output, ctx.needs_input_grad[4]
            )
        else:
            grad_weight2 = None
            grad_bias2 = grad_output if ctx.needs_input_grad[4] else None
296
297
        if ctx.heuristic == -1:
            # grad_gelu = matmul_dgelu(grad_output, weight2, gelu_in)
Tri Dao's avatar
Tri Dao committed
298
            grad_output1 = F.linear(grad_output, weight2.t())
299
300
301
            with torch.jit.fuser('fuser2'):
                grad_gelu = gelu_bwd(grad_output1, gelu_in)
        else:
Tri Dao's avatar
Tri Dao committed
302
303
304
305
            # The cublasLt epilogue has to compute both gelu grad and bias grad, we can't
            # just compute gelu grad
            grad_gelu, grad_bias1 = fused_dense_cuda.bias_gelu_linear_dgrad_bgrad(
                weight2, grad_output, gelu_in, ctx.heuristic
306
            )
Tri Dao's avatar
Tri Dao committed
307
308
309
310
311
312
            if not ctx.needs_input_grad[2]:
                grad_bias1 = None
        if ctx.needs_input_grad[0]:
            if not ctx.return_residual:
                grad_input = F.linear(grad_gelu, weight1.t())
            else:
313
314
315
316
317
318
                grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]),
                                         grad_gelu, weight1)
            grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
            if process_group is not None:
                grad_input, handle_grad_input = reduce_scatter_raw(grad_input, process_group,
                                                                   async_op=True)
Tri Dao's avatar
Tri Dao committed
319
320
        else:
            grad_input = None
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
        if ctx.heuristic == -1:
            if ctx.needs_input_grad[1]:
                if process_group is not None:
                    handle_x.wait()
                grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_wgrad(
                    total_x.reshape(batch_dim, total_x.shape[-1]), grad_gelu,
                    ctx.needs_input_grad[2]
                )
            else:
                grad_weight1 = None
                grad_bias1 = grad_gelu if ctx.needs_input_grad[2] else None
        else:
            if ctx.needs_input_grad[1]:
                if process_group is not None:
                    handle_x.wait()
                grad_weight1 = F.linear(grad_gelu.t(),
                                        total_x.reshape(batch_dim, total_x.shape[-1]).t())
            else:
                grad_weight1 = None
        if process_group is not None and ctx.needs_input_grad[0]:
            handle_grad_input.wait()
        return (grad_input, grad_weight1, grad_bias1, grad_weight2, grad_bias2,
                None, None, None, None, None)
Tri Dao's avatar
Tri Dao committed
344
345
346
347
348


def fused_dense_gelu_dense_func(
    x: Tensor, weight1: Tensor, weight2: Tensor, bias1: Optional[Tensor] = None,
    bias2: Optional[Tensor] = None,
349
350
351
    save_pre_act: bool = True, return_residual: bool = False,
    checkpoint_lvl: int = 0, heuristic: int = 0,
    process_group: Optional[ProcessGroup] = None
Tri Dao's avatar
Tri Dao committed
352
353
354
355
):
    dtype_eligible = (x.dtype in [torch.float16, torch.bfloat16]
                      or (x.dtype == torch.float32 and torch.is_autocast_enabled()))
    if (x.is_cuda and weight1.is_cuda and weight2.is_cuda and (bias1 is None or bias1.is_cuda)
356
        and (bias2 is None or bias2.is_cuda) and dtype_eligible):
Tri Dao's avatar
Tri Dao committed
357
358
        return FusedDenseGeluDenseFunc.apply(
            x, weight1, bias1, weight2, bias2,
359
            save_pre_act, return_residual, checkpoint_lvl, heuristic, process_group
Tri Dao's avatar
Tri Dao committed
360
361
        )
    else:
362
        assert process_group is None
Tri Dao's avatar
Tri Dao committed
363
364
365
366
        gelu_in = F.linear(x, weight1, bias1)
        output1 = F.gelu(gelu_in, approximate='tanh')
        output2 = F.linear(output1, weight2, bias2)
        return output2 if not return_residual else (output2, x)
367
368


Tri Dao's avatar
Tri Dao committed
369
class FusedDenseGeluDense(nn.Module):
370

Tri Dao's avatar
Tri Dao committed
371
372
373
    def __init__(self, in_features, hidden_features, out_features=None, bias1=True,
                 bias2=True, return_residual=False, checkpoint_lvl=0, heuristic=0,
                 device=None, dtype=None):
374
        """
375
376
377
378
        If process_group is not None, we're doing Tensor Parallel with sequence parallelism:
        we do an all_gather of x before doing the matmul, gelu, then matmul.
        Finally we do a reduce_scatter of the output.

379
380
381
382
383
384
385
        checkpoint_lvl (increasing lvl means slower but more memory saving):
            0: no recomputation in the bwd
            1: recompute gelu_out in the bwd
            2: recompute gelu_in and gelu_out in the bwd
        heuristic:
            -1: don't fuse gemm + gelu (separate kernel)
            0..4: use this heuristic for the algo section in the fused gemm + gelu
Tri Dao's avatar
Tri Dao committed
386
387
388
389
390
            For CUDA >= 11.8, you'd want heuristic=0 for both fp16 and bf16 for best perf.
            For CUDA <= 11.7, you'd want heuristic=1 for fp16 and heuristic=-1 for bf16.
        return_residual: whether to return the input x along with the output. This is for
            performance reason: for post-norm architecture, returning the input allows us
            to fuse the backward of nn.Linear with the residual connection.
391
392
393
394
395
396
        """
        assert checkpoint_lvl in [0, 1, 2]
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        if out_features is None:
            out_features = in_features
Tri Dao's avatar
Tri Dao committed
397
        self.return_residual = return_residual
398
399
        self.checkpoint_lvl = checkpoint_lvl
        self.heuristic = heuristic
Tri Dao's avatar
Tri Dao committed
400
401
        self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs)
        self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
402

403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
    def forward(self, x, process_group=None):
        out = fused_dense_gelu_dense_func(
            x, self.fc1.weight, self.fc2.weight, self.fc1.bias, self.fc2.bias,
            save_pre_act=self.training, return_residual=self.return_residual,
            checkpoint_lvl=self.checkpoint_lvl, heuristic=self.heuristic,
            process_group=process_group
        )
        if self.return_residual:
            out, x = out
        if process_group is not None:
            out = reduce_scatter(out, process_group)
        return out if not self.return_residual else (out, x)


class ParallelFusedDenseGeluDense(nn.Module):

    def __init__(self, in_features, hidden_features, out_features=None,
                 process_group: ProcessGroup = None, bias1=True, bias2=True,
                 checkpoint_lvl=0, heuristic=0, device=None, dtype=None):
        """
        process_group is required. We're doing Tensor Parallel with sequence parallelism:
        we do an all_gather of x before doing the matmul, gelu, then matmul.
        Finally we do a reduce_scatter of the output.

        checkpoint_lvl (increasing lvl means slower but more memory saving):
            0: no recomputation in the bwd
            1: recompute gelu_out in the bwd
            2: recompute gelu_in and gelu_out in the bwd
        heuristic:
            -1: don't fuse gemm + gelu (separate kernel)
            0..4: use this heuristic for the algo section in the fused gemm + gelu
            For CUDA >= 11.8, you'd want heuristic=0 for both fp16 and bf16 for best perf.
            For CUDA <= 11.7, you'd want heuristic=1 for fp16 and heuristic=-1 for bf16.
        """
        assert checkpoint_lvl in [0, 1, 2]
        assert process_group is not None
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        if out_features is None:
            out_features = in_features
        self.process_group = process_group
        self.checkpoint_lvl = checkpoint_lvl
        self.heuristic = heuristic
        self.fc1 = ColumnParallelLinear(in_features, hidden_features, process_group,
                                        bias=bias1, **factory_kwargs)
        self.fc2 = RowParallelLinear(hidden_features, out_features, process_group,
                                     bias=bias2, **factory_kwargs)

451
    def forward(self, x):
452
        out = fused_dense_gelu_dense_func(
Tri Dao's avatar
Tri Dao committed
453
            x, self.fc1.weight, self.fc2.weight, self.fc1.bias, self.fc2.bias,
454
455
            save_pre_act=self.training, checkpoint_lvl=self.checkpoint_lvl,
            heuristic=self.heuristic, process_group=self.process_group
456
        )
457
        return reduce_scatter(out, self.process_group)