fused_dense.py 26.5 KB
Newer Older
1
# Copyright (c) 2023, Tri Dao.
2
# Inspired by https://github.com/NVIDIA/apex/blob/master/apex/fused_dense/fused_dense.py
3
# We make it work with pytorch amp and with bfloat16.
4
# The TensorParallel linear modules are inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/layers.py
5
from functools import partial
6
from typing import Optional
7

8
9
# import fused_dense_cuda  # from apex
import fused_dense_lib as fused_dense_cuda
10
11
12
import torch
import torch.nn as nn
import torch.nn.functional as F
13
14
15
16
17
18
19
20
from flash_attn.ops.activations import gelu_bwd, relu_bwd, sqrelu_bwd, sqrelu_fwd
from flash_attn.utils.distributed import (
    all_gather_raw,
    all_reduce,
    all_reduce_raw,
    reduce_scatter,
    reduce_scatter_raw,
)
Tri Dao's avatar
Tri Dao committed
21
from torch import Tensor
22
from torch.cuda.amp import custom_bwd, custom_fwd
23
from torch.distributed import ProcessGroup
24
25


Tri Dao's avatar
Tri Dao committed
26
class FusedDenseFunc(torch.autograd.Function):
27
28
    @staticmethod
    @custom_fwd
29
30
31
    def forward(
        ctx, x, weight, bias, return_residual=False, process_group=None, sequence_parallel=True
    ):
32
        """
33
34
        If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel
        with sequence parallelism: we do an all_gather_raw of x before doing the matmul.
35
        """
36
37
38
        ctx.compute_weight_gradient = weight.requires_grad
        ctx.return_residual = return_residual
        ctx.process_group = process_group
39
        ctx.sequence_parallel = sequence_parallel
40

41
        if torch.is_autocast_enabled():
42
            x = x.to(dtype=torch.get_autocast_gpu_dtype())
43
        x = x.contiguous()
44
        if process_group is not None and sequence_parallel:
45
46
47
48
49
50
51
52
            # We want to kick off the all_gather early, before weight dtype conversion
            total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
        else:
            total_x = x

        if torch.is_autocast_enabled():
            weight = weight.to(dtype=torch.get_autocast_gpu_dtype())
            bias = bias.to(dtype=torch.get_autocast_gpu_dtype()) if bias is not None else None
53
        weight = weight.contiguous()
54
        if process_group is not None and sequence_parallel:
55
            handle_x.wait()
56
        batch_shape, n = total_x.shape[:-1], total_x.shape[-1]
57
        batch_dim = batch_shape.numel()
58
59
        # https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174
        if min(batch_dim, n, *weight.shape) > 65535 * 32:
60
            raise RuntimeError("fused_dense only supports matrix dims <= 2M")
61
        output = F.linear(total_x, weight, bias)
62
63
64
65
        if ctx.compute_weight_gradient:
            ctx.save_for_backward(x, weight)
        else:
            ctx.save_for_backward(weight)
Tri Dao's avatar
Tri Dao committed
66
        return output if not return_residual else (output, x)
67
68
69

    @staticmethod
    @custom_bwd
Tri Dao's avatar
Tri Dao committed
70
    def backward(ctx, grad_output, *args):
71
        grad_output = grad_output.contiguous()
Tri Dao's avatar
Tri Dao committed
72
        if ctx.return_residual:
73
            (grad_input,) = args
Tri Dao's avatar
Tri Dao committed
74
            grad_input = grad_input.contiguous()
75
        process_group = ctx.process_group
76
        sequence_parallel = ctx.sequence_parallel
77
78
        if ctx.compute_weight_gradient:
            x, weight = ctx.saved_tensors
79
            if process_group is not None and sequence_parallel:
80
81
82
83
                total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
            else:
                total_x = x
        else:
84
            (weight,) = ctx.saved_tensors
85
86
            total_x = None
        batch_shape = grad_output.shape[:-1]
87
        batch_dim = batch_shape.numel()
Tri Dao's avatar
Tri Dao committed
88
89
90
91
92
        grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
        if ctx.needs_input_grad[0]:
            if not ctx.return_residual:
                grad_input = F.linear(grad_output, weight.t())
            else:
93
94
95
                grad_input = torch.addmm(
                    grad_input.reshape(batch_dim, grad_input.shape[-1]), grad_output, weight
                )
96
97
            grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
            if process_group is not None:
98
99
                reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw
                grad_input, handle_grad_input = reduce_fn(grad_input, process_group, async_op=True)
100
101
        else:
            grad_input = None
102
103
        if ctx.needs_input_grad[1]:
            assert ctx.compute_weight_gradient
104
            if process_group is not None and sequence_parallel:
105
106
107
108
109
110
111
112
113
                handle_x.wait()
            grad_weight, grad_bias = fused_dense_cuda.linear_bias_wgrad(
                total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2]
            )
        else:
            grad_weight = None
            grad_bias = grad_output if ctx.needs_input_grad[2] else None
        if process_group is not None and ctx.needs_input_grad[0]:
            handle_grad_input.wait()
114
        return grad_input, grad_weight, grad_bias, None, None, None
115
116


117
118
119
120
121
122
123
124
125
126
127
def fused_dense_func(
    x: Tensor,
    weight: Tensor,
    bias: Optional[Tensor] = None,
    return_residual: bool = False,
    process_group: Optional[ProcessGroup] = None,
    sequence_parallel: bool = True,
):
    dtype_eligible = x.dtype in [torch.float16, torch.bfloat16] or (
        x.dtype == torch.float32 and torch.is_autocast_enabled()
    )
128
    if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible:
129
130
131
        return FusedDenseFunc.apply(
            x, weight, bias, return_residual, process_group, sequence_parallel
        )
Tri Dao's avatar
Tri Dao committed
132
    else:
133
        assert process_group is None
Tri Dao's avatar
Tri Dao committed
134
135
        out = F.linear(x, weight, bias)
        return out if not return_residual else (out, x)
136
137


Tri Dao's avatar
Tri Dao committed
138
class FusedDense(nn.Linear):
139
140
141
142
143
144
145
146
147
    def __init__(
        self,
        in_features: int,
        out_features: int,
        bias: bool = True,
        return_residual: bool = False,
        device=None,
        dtype=None,
    ) -> None:
148
        super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype)
Tri Dao's avatar
Tri Dao committed
149
        self.return_residual = return_residual
150

151
152
153
154
155
    def forward(self, x, process_group=None):
        """
        If process_group is not None, we're doing Tensor Parallel with sequence parallelism:
        we do an all_gather of x before doing the matmul.
        """
156
157
158
159
160
161
162
        return fused_dense_func(
            x,
            self.weight,
            self.bias,
            return_residual=self.return_residual,
            process_group=process_group,
        )
163
164
165


class ColumnParallelLinear(nn.Linear):
166
167
168
169
170
171
172
173
174
175
    def __init__(
        self,
        in_features: int,
        out_features: int,
        process_group: ProcessGroup,
        bias: bool = True,
        sequence_parallel=True,
        device=None,
        dtype=None,
    ) -> None:
176
177
        world_size = torch.distributed.get_world_size(process_group)
        if out_features % world_size != 0:
178
179
180
181
182
183
            raise ValueError(
                f"out_features ({out_features}) must be divisible by " f"world_size ({world_size})"
            )
        super().__init__(
            in_features, out_features // world_size, bias=bias, device=device, dtype=dtype
        )
184
        self.process_group = process_group
185
        self.sequence_parallel = sequence_parallel
186

187
    def forward(self, x):
188
189
190
        # If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
        # we do an all_gather of x before doing the matmul.
        # If not, then the input is already gathered.
191
192
193
194
195
196
197
        return fused_dense_func(
            x,
            self.weight,
            self.bias,
            process_group=self.process_group,
            sequence_parallel=self.sequence_parallel,
        )
198
199
200


class RowParallelLinear(nn.Linear):
201
202
203
204
205
206
207
208
209
210
    def __init__(
        self,
        in_features: int,
        out_features: int,
        process_group: ProcessGroup,
        bias: bool = True,
        sequence_parallel=True,
        device=None,
        dtype=None,
    ) -> None:
211
212
213
        world_size = torch.distributed.get_world_size(process_group)
        rank = torch.distributed.get_rank(process_group)
        if in_features % world_size != 0:
214
215
216
            raise ValueError(
                f"in_features ({in_features}) must be divisible by " f"world_size ({world_size})"
            )
217
        # Only rank 0 will have bias
218
219
220
221
222
223
224
        super().__init__(
            in_features // world_size,
            out_features,
            bias=bias and rank == 0,
            device=device,
            dtype=dtype,
        )
225
        self.process_group = process_group
226
        self.sequence_parallel = sequence_parallel
227
228
229
230
231
232
233

    def forward(self, x):
        """
        We're doing Tensor Parallel with sequence parallelism: we do the matmul and then
        a reduce_scatter of the result.
        """
        out = fused_dense_func(x, self.weight, self.bias)
234
235
        reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
        return reduce_fn(out, self.process_group)
236
237


238
class FusedMLPFunc(torch.autograd.Function):
239
240
    @staticmethod
    @custom_fwd
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
    def forward(
        ctx,
        x,
        weight1,
        bias1,
        weight2,
        bias2,
        activation="gelu_approx",
        save_pre_act=True,
        return_residual=False,
        checkpoint_lvl=0,
        heuristic=0,
        process_group=None,
        sequence_parallel=True,
    ):
256
        """
257
258
259
        If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel
        with sequence parallelism: we do an all_gather of x before doing the matmul.
        If sequence_parallel=False, then the input is already gathered.
260
261

        checkpoint_lvl:
262
        0: no recomputation in the bwd
263
264
        1: recompute gelu_out / relu_out in the bwd
        2: recompute pre_act and gelu_out / relu_out in the bwd
265
266
        """
        assert -1 <= heuristic <= 4
267
268
        assert activation in ["gelu_approx", "relu", "sqrelu"]
        if activation == "sqrelu":
269
            assert heuristic == -1
270
        if not save_pre_act:
Tri Dao's avatar
Tri Dao committed
271
            checkpoint_lvl = 2
272
        assert checkpoint_lvl in [0, 1, 2]
Tri Dao's avatar
Tri Dao committed
273
        ctx.return_residual = return_residual
274
        ctx.process_group = process_group
275
        ctx.sequence_parallel = sequence_parallel
276
        ctx.checkpoint_lvl = checkpoint_lvl
277
        ctx.activation = activation
278
279
280
281
        ctx.heuristic = heuristic

        if torch.is_autocast_enabled():
            x = x.to(dtype=torch.get_autocast_gpu_dtype())
282
        x = x.contiguous()
283
        if process_group is not None and sequence_parallel:
284
285
286
287
288
289
290
291
292
293
            # We want to kick off the all_gather early, before weight dtype conversion
            total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
        else:
            total_x = x

        if torch.is_autocast_enabled():
            dtype = torch.get_autocast_gpu_dtype()
            weight1, weight2 = [a.to(dtype=dtype) for a in [weight1, weight2]]
            bias1 = bias1.to(dtype=dtype) if bias1 is not None else None
            bias2 = bias2.to(dtype=dtype) if bias2 is not None else None
294
        weight1 = weight1.contiguous()
Tri Dao's avatar
Tri Dao committed
295
        bias1 = bias1.contiguous() if bias1 is not None else None
296
        weight2 = weight2.contiguous()
Tri Dao's avatar
Tri Dao committed
297
        bias2 = bias2.contiguous() if bias2 is not None else None
298
        if process_group is not None and sequence_parallel:
299
            handle_x.wait()
300
        batch_shape, n = total_x.shape[:-1], total_x.shape[-1]
301
        batch_dim = batch_shape.numel()
302
303
        # https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174
        if min(batch_dim, n, *weight1.shape, *weight2.shape) > 65535 * 32:
304
            raise RuntimeError("fused_dense only supports matrix dims <= 2M")
305
        if heuristic == -1:
306
            pre_act = F.linear(total_x, weight1, bias1)
307
308
309
310
311
312
            activation_fn = (
                partial(F.gelu, approximate="tanh")
                if activation == "gelu_approx"
                else (sqrelu_fwd if activation == "sqrelu" else F.relu)
            )
            with torch.jit.fuser("fuser2"):
313
                output1 = activation_fn(pre_act)
314
            # This is before adding bias1
315
            # pre_act = F.linear(total_x.reshape(batch_dim, n), weight1)
316
            # with torch.jit.fuser('fuser2'):
317
            #     output1 = bias_gelu(pre_act, bias1)
318
        else:
319
            is_gelu = activation == "gelu_approx"
320
321
            output1, *rest = fused_dense_cuda.linear_act_forward(
                total_x.reshape(batch_dim, n), weight1, bias1, is_gelu, save_pre_act, heuristic
322
323
            )
            if save_pre_act:
324
                pre_act = rest[0]
Tri Dao's avatar
Tri Dao committed
325
        output2 = F.linear(output1, weight2, bias2)
326
        if checkpoint_lvl == 0 or (checkpoint_lvl == 1 and activation == "relu"):
327
328
            # For RELU the pre_act is very small (just a bit-mask) so we just save it
            ctx.save_for_backward(x, weight1, weight2, pre_act, output1)
329
        elif checkpoint_lvl == 1:
330
            ctx.save_for_backward(x, weight1, weight2, pre_act)
331
        elif checkpoint_lvl == 2:
Tri Dao's avatar
Tri Dao committed
332
333
334
            ctx.save_for_backward(x, weight1, weight2, bias1)
        output2 = output2.reshape(*batch_shape, output2.shape[-1])
        return output2 if not return_residual else (output2, x)
335
336
337

    @staticmethod
    @custom_bwd
Tri Dao's avatar
Tri Dao committed
338
    def backward(ctx, grad_output, *args):
339
340
        grad_output = grad_output.contiguous()
        checkpoint_lvl = ctx.checkpoint_lvl
341
        activation = ctx.activation
342
343
344
345
346
        activation_fn = (
            partial(F.gelu, approximate="tanh")
            if activation == "gelu_approx"
            else (sqrelu_fwd if activation == "sqrelu" else F.relu)
        )
Tri Dao's avatar
Tri Dao committed
347
        if ctx.return_residual:
348
            (grad_input,) = args
Tri Dao's avatar
Tri Dao committed
349
            grad_input = grad_input.contiguous()
350
        process_group = ctx.process_group
351
        sequence_parallel = ctx.sequence_parallel
Tri Dao's avatar
Tri Dao committed
352
        x, weight1, weight2, *rest = ctx.saved_tensors
353
        if process_group is None or not sequence_parallel:
354
355
            total_x = x
        batch_shape = grad_output.shape[:-1]
356
        batch_dim = batch_shape.numel()
357
        if checkpoint_lvl in [0, 1]:
358
            if process_group is not None and sequence_parallel:
359
                total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
360
            if checkpoint_lvl == 0 or (checkpoint_lvl == 1 and activation == "relu"):
361
                pre_act, output1 = rest
362
            elif checkpoint_lvl == 1:
363
364
                (pre_act,) = rest
                with torch.jit.fuser("fuser2"):
365
                    output1 = activation_fn(pre_act)
366
        elif checkpoint_lvl == 2:
367
            (bias1,) = rest
368
            if process_group is not None and sequence_parallel:
369
                total_x, _ = all_gather_raw(x, process_group)
370
            if ctx.heuristic == -1:
371
                pre_act = F.linear(total_x, weight1, bias1)
372
                with torch.jit.fuser("fuser2"):
373
                    output1 = activation_fn(pre_act)
374
            else:
375
                output1, pre_act = fused_dense_cuda.linear_act_forward(
376
377
378
379
380
381
                    total_x.reshape(batch_dim, total_x.shape[-1]),
                    weight1,
                    bias1,
                    activation == "gelu_approx",
                    True,
                    ctx.heuristic,
Tri Dao's avatar
Tri Dao committed
382
383
384
385
                )

        grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
        output1 = output1.reshape(batch_dim, output1.shape[-1])
386
        pre_act = pre_act.reshape(batch_dim, pre_act.shape[-1])
Tri Dao's avatar
Tri Dao committed
387
388
389
390
391
392
393
        if ctx.needs_input_grad[3]:
            grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad(
                output1, grad_output, ctx.needs_input_grad[4]
            )
        else:
            grad_weight2 = None
            grad_bias2 = grad_output if ctx.needs_input_grad[4] else None
394
        if ctx.heuristic == -1:
395
            # grad_pre_act = matmul_dgelu(grad_output, weight2, pre_act)
Tri Dao's avatar
Tri Dao committed
396
            grad_output1 = F.linear(grad_output, weight2.t())
397
398
399
400
401
402
            activation_grad_fn = (
                gelu_bwd
                if activation == "gelu_approx"
                else (sqrelu_bwd if activation == "sqrelu" else relu_bwd)
            )
            with torch.jit.fuser("fuser2"):
403
                grad_pre_act = activation_grad_fn(grad_output1, pre_act)
404
        else:
405
406
407
            # The cublasLt epilogue has to compute both gelu/relu grad and bias grad, we can't
            # just compute gelu/relu grad
            grad_pre_act, grad_bias1 = fused_dense_cuda.bias_act_linear_dgrad_bgrad(
408
                weight2, grad_output, pre_act, activation == "gelu_approx", ctx.heuristic
409
            )
Tri Dao's avatar
Tri Dao committed
410
411
412
413
            if not ctx.needs_input_grad[2]:
                grad_bias1 = None
        if ctx.needs_input_grad[0]:
            if not ctx.return_residual:
414
                grad_input = F.linear(grad_pre_act, weight1.t())
Tri Dao's avatar
Tri Dao committed
415
            else:
416
417
418
                grad_input = torch.addmm(
                    grad_input.reshape(batch_dim, grad_input.shape[-1]), grad_pre_act, weight1
                )
419
420
            grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
            if process_group is not None:
421
422
                reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw
                grad_input, handle_grad_input = reduce_fn(grad_input, process_group, async_op=True)
Tri Dao's avatar
Tri Dao committed
423
424
        else:
            grad_input = None
425
426
        if ctx.heuristic == -1:
            if ctx.needs_input_grad[1]:
427
                if process_group is not None and sequence_parallel:
428
429
                    handle_x.wait()
                grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_wgrad(
430
431
432
                    total_x.reshape(batch_dim, total_x.shape[-1]),
                    grad_pre_act,
                    ctx.needs_input_grad[2],
433
434
435
                )
            else:
                grad_weight1 = None
436
                grad_bias1 = grad_pre_act if ctx.needs_input_grad[2] else None
437
438
        else:
            if ctx.needs_input_grad[1]:
439
                if process_group is not None and sequence_parallel:
440
                    handle_x.wait()
441
442
443
                grad_weight1 = F.linear(
                    grad_pre_act.t(), total_x.reshape(batch_dim, total_x.shape[-1]).t()
                )
444
445
446
447
            else:
                grad_weight1 = None
        if process_group is not None and ctx.needs_input_grad[0]:
            handle_grad_input.wait()
448
449
450
451
452
453
454
455
456
457
458
459
460
461
        return (
            grad_input,
            grad_weight1,
            grad_bias1,
            grad_weight2,
            grad_bias2,
            None,
            None,
            None,
            None,
            None,
            None,
            None,
        )
Tri Dao's avatar
Tri Dao committed
462
463


464
def fused_mlp_func(
465
466
467
468
469
470
471
472
473
474
    x: Tensor,
    weight1: Tensor,
    weight2: Tensor,
    bias1: Optional[Tensor] = None,
    bias2: Optional[Tensor] = None,
    activation: str = "gelu_approx",
    save_pre_act: bool = True,
    return_residual: bool = False,
    checkpoint_lvl: int = 0,
    heuristic: int = 0,
475
    process_group: Optional[ProcessGroup] = None,
476
    sequence_parallel: bool = True,
Tri Dao's avatar
Tri Dao committed
477
):
478
479
480
481
    assert activation in ["gelu_approx", "relu", "sqrelu"]
    dtype_eligible = x.dtype in [torch.float16, torch.bfloat16] or (
        x.dtype == torch.float32 and torch.is_autocast_enabled()
    )
482
    # If we save pre-activation, dimension must be divisible by 128 (relu) or 8 (gelu)
483
484
485
486
487
488
489
490
491
492
    dim_eligible = not save_pre_act or (x.shape[-1] % (128 if activation == "relu" else 8) == 0)
    if (
        x.is_cuda
        and weight1.is_cuda
        and weight2.is_cuda
        and (bias1 is None or bias1.is_cuda)
        and (bias2 is None or bias2.is_cuda)
        and dtype_eligible
        and dim_eligible
    ):
493
        return FusedMLPFunc.apply(
494
495
496
497
498
499
500
501
502
503
504
505
            x,
            weight1,
            bias1,
            weight2,
            bias2,
            activation,
            save_pre_act,
            return_residual,
            checkpoint_lvl,
            heuristic,
            process_group,
            sequence_parallel,
Tri Dao's avatar
Tri Dao committed
506
507
        )
    else:
508
        assert process_group is None
509
        pre_act = F.linear(x, weight1, bias1)
510
511
512
513
514
        activation_fn = (
            partial(F.gelu, approximate="tanh")
            if activation == "gelu_approx"
            else partial(F.relu, inplace=True)
        )
515
        output1 = activation_fn(pre_act)
Tri Dao's avatar
Tri Dao committed
516
517
        output2 = F.linear(output1, weight2, bias2)
        return output2 if not return_residual else (output2, x)
518
519


520
class FusedMLP(nn.Module):
521
522
523
524
525
526
527
528
529
530
531
532
533
534
    def __init__(
        self,
        in_features,
        hidden_features=None,
        out_features=None,
        bias1=True,
        bias2=True,
        activation="gelu_approx",
        return_residual=False,
        checkpoint_lvl=0,
        heuristic="auto",
        device=None,
        dtype=None,
    ):
535
        """
536
537
538
539
        If process_group is not None, we're doing Tensor Parallel with sequence parallelism:
        we do an all_gather of x before doing the matmul, gelu, then matmul.
        Finally we do a reduce_scatter of the output.

540
541
542
        checkpoint_lvl (increasing lvl means slower but more memory saving):
            0: no recomputation in the bwd
            1: recompute gelu_out in the bwd
543
            2: recompute pre_act and gelu_out in the bwd
544
545
546
        heuristic:
            -1: don't fuse gemm + gelu (separate kernel)
            0..4: use this heuristic for the algo section in the fused gemm + gelu
547
548
549
            'auto': heuristic will be picked automatically:
                For CUDA >= 11.8, we set heuristic=0 for both fp16 and bf16 for best perf.
                For CUDA <= 11.7, we set heuristic=1 for fp16 and heuristic=-1 for bf16.
550
551
                For H100, we set heuristic=-1 for both fp16 and bf16 as the fused cuBlasLt implementation
                is slower than the unfused version.
Tri Dao's avatar
Tri Dao committed
552
553
554
        return_residual: whether to return the input x along with the output. This is for
            performance reason: for post-norm architecture, returning the input allows us
            to fuse the backward of nn.Linear with the residual connection.
555
556
        """
        assert checkpoint_lvl in [0, 1, 2]
557
558
        assert activation in ["gelu_approx", "relu", "sqrelu"]
        factory_kwargs = {"device": device, "dtype": dtype}
559
        super().__init__()
Tri Dao's avatar
Tri Dao committed
560
561
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features * 4
562
        self.activation = activation
Tri Dao's avatar
Tri Dao committed
563
        self.return_residual = return_residual
564
        self.checkpoint_lvl = checkpoint_lvl
565
        self.heuristic = heuristic if activation != "sqrelu" else -1
Tri Dao's avatar
Tri Dao committed
566
567
        self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs)
        self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
568

569
    def forward(self, x, process_group=None):
570
        dtype = x.dtype if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype()
571
572
573
        if self.heuristic == "auto":
            if self.activation == "gelu_approx":
                if torch.cuda.get_device_capability("cuda") == (9, 0):
574
575
                    heuristic = -1
                else:
576
                    cuda_ver = tuple(map(int, torch.version.cuda.split(".")))
577
                    heuristic = 0 if cuda_ver >= (11, 8) else (1 if dtype == torch.float16 else -1)
578
579
580
581
582
            else:
                heuristic = 0
        else:
            heuristic = self.heuristic
        out = fused_mlp_func(
583
584
585
586
587
588
589
590
591
592
593
            x,
            self.fc1.weight,
            self.fc2.weight,
            self.fc1.bias,
            self.fc2.bias,
            activation=self.activation,
            save_pre_act=self.training,
            return_residual=self.return_residual,
            checkpoint_lvl=self.checkpoint_lvl,
            heuristic=heuristic,
            process_group=process_group,
594
595
596
597
598
599
600
601
        )
        if self.return_residual:
            out, x = out
        if process_group is not None:
            out = reduce_scatter(out, process_group)
        return out if not self.return_residual else (out, x)


602
class ParallelFusedMLP(nn.Module):
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
    def __init__(
        self,
        in_features,
        hidden_features=None,
        out_features=None,
        activation="gelu_approx",
        process_group: ProcessGroup = None,
        bias1=True,
        bias2=True,
        sequence_parallel=True,
        checkpoint_lvl=0,
        heuristic="auto",
        device=None,
        dtype=None,
    ):
618
619
620
621
622
623
624
625
        """
        process_group is required. We're doing Tensor Parallel with sequence parallelism:
        we do an all_gather of x before doing the matmul, gelu, then matmul.
        Finally we do a reduce_scatter of the output.

        checkpoint_lvl (increasing lvl means slower but more memory saving):
            0: no recomputation in the bwd
            1: recompute gelu_out in the bwd
626
            2: recompute pre_act and gelu_out in the bwd
627
628
629
        heuristic:
            -1: don't fuse gemm + gelu (separate kernel)
            0..4: use this heuristic for the algo section in the fused gemm + gelu
630
631
632
            'auto': heuristic will be picked automatically:
                For CUDA >= 11.8, we set heuristic=0 for both fp16 and bf16 for best perf.
                For CUDA <= 11.7, we set heuristic=1 for fp16 and heuristic=-1 for bf16.
633
634
        """
        assert checkpoint_lvl in [0, 1, 2]
635
        assert activation in ["gelu_approx", "relu", "sqrelu"]
636
        assert process_group is not None
637
        factory_kwargs = {"device": device, "dtype": dtype}
638
        super().__init__()
Tri Dao's avatar
Tri Dao committed
639
640
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features * 4
641
        self.activation = activation
642
        self.process_group = process_group
643
        self.sequence_parallel = sequence_parallel
644
        self.checkpoint_lvl = checkpoint_lvl
645
646
647
648
649
650
651
        self.heuristic = heuristic if activation != "sqrelu" else -1
        self.fc1 = ColumnParallelLinear(
            in_features, hidden_features, process_group, bias=bias1, **factory_kwargs
        )
        self.fc2 = RowParallelLinear(
            hidden_features, out_features, process_group, bias=bias2, **factory_kwargs
        )
652

653
    def forward(self, x):
654
        dtype = x.dtype if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype()
655
656
657
        if self.heuristic == "auto":
            if self.activation == "gelu_approx":
                cuda_ver = tuple(map(int, torch.version.cuda.split(".")))
658
659
660
661
662
663
                heuristic = 0 if cuda_ver >= (11, 8) else (1 if dtype == torch.float16 else -1)
            else:
                heuristic = 0
        else:
            heuristic = self.heuristic
        out = fused_mlp_func(
664
665
666
667
668
669
670
671
672
            x,
            self.fc1.weight,
            self.fc2.weight,
            self.fc1.bias,
            self.fc2.bias,
            activation=self.activation,
            save_pre_act=self.training,
            checkpoint_lvl=self.checkpoint_lvl,
            heuristic=heuristic,
673
            process_group=self.process_group,
674
            sequence_parallel=self.sequence_parallel,
675
        )
676
677
        reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
        return reduce_fn(out, self.process_group)