fused_dense.py 25 KB
Newer Older
1
# Copyright (c) 2023, Tri Dao.
2
# Inspired by https://github.com/NVIDIA/apex/blob/master/apex/fused_dense/fused_dense.py
3
# We make it work with pytorch amp and with bfloat16.
4
# The TensorParallel linear modules are inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/layers.py
Tri Dao's avatar
Tri Dao committed
5
from typing import Optional
6
from functools import partial
7
8
9
10

import torch
import torch.nn as nn
import torch.nn.functional as F
Tri Dao's avatar
Tri Dao committed
11
from torch import Tensor
12
from torch.distributed import ProcessGroup
13
14
15
16
from torch.cuda.amp import custom_bwd, custom_fwd

# import fused_dense_cuda  # from apex
import fused_dense_lib as fused_dense_cuda
17

18
from flash_attn.ops.gelu_activation import gelu_bwd
19
20
from flash_attn.utils.distributed import all_gather_raw, reduce_scatter_raw, all_reduce_raw
from flash_attn.utils.distributed import reduce_scatter, all_reduce
21
22


23
24
25
26
27
@torch.jit.script
def relu_bwd(g, x):
    return torch.where(x >= 0, g, 0.0).to(dtype=x.dtype)


Tri Dao's avatar
Tri Dao committed
28
class FusedDenseFunc(torch.autograd.Function):
29
30
31

    @staticmethod
    @custom_fwd
32
33
    def forward(ctx, x, weight, bias, return_residual=False, process_group=None,
                sequence_parallel=True):
34
        """
35
36
        If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel
        with sequence parallelism: we do an all_gather_raw of x before doing the matmul.
37
        """
38
39
40
        ctx.compute_weight_gradient = weight.requires_grad
        ctx.return_residual = return_residual
        ctx.process_group = process_group
41
        ctx.sequence_parallel = sequence_parallel
42

43
        if torch.is_autocast_enabled():
44
            x = x.to(dtype=torch.get_autocast_gpu_dtype())
45
        x = x.contiguous()
46
        if process_group is not None and sequence_parallel:
47
48
49
50
51
52
53
54
            # We want to kick off the all_gather early, before weight dtype conversion
            total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
        else:
            total_x = x

        if torch.is_autocast_enabled():
            weight = weight.to(dtype=torch.get_autocast_gpu_dtype())
            bias = bias.to(dtype=torch.get_autocast_gpu_dtype()) if bias is not None else None
55
        weight = weight.contiguous()
56
        if process_group is not None and sequence_parallel:
57
            handle_x.wait()
58
        batch_shape, n = total_x.shape[:-1], total_x.shape[-1]
59
        batch_dim = batch_shape.numel()
60
61
62
        # https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174
        if min(batch_dim, n, *weight.shape) > 65535 * 32:
            raise RuntimeError('fused_dense only supports matrix dims <= 2M')
63
        output = F.linear(total_x, weight, bias)
64
65
66
67
        if ctx.compute_weight_gradient:
            ctx.save_for_backward(x, weight)
        else:
            ctx.save_for_backward(weight)
Tri Dao's avatar
Tri Dao committed
68
        return output if not return_residual else (output, x)
69
70
71

    @staticmethod
    @custom_bwd
Tri Dao's avatar
Tri Dao committed
72
    def backward(ctx, grad_output, *args):
73
        grad_output = grad_output.contiguous()
Tri Dao's avatar
Tri Dao committed
74
75
76
        if ctx.return_residual:
            grad_input, = args
            grad_input = grad_input.contiguous()
77
        process_group = ctx.process_group
78
        sequence_parallel = ctx.sequence_parallel
79
80
        if ctx.compute_weight_gradient:
            x, weight = ctx.saved_tensors
81
            if process_group is not None and sequence_parallel:
82
83
84
85
86
87
88
                total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
            else:
                total_x = x
        else:
            weight, = ctx.saved_tensors
            total_x = None
        batch_shape = grad_output.shape[:-1]
89
        batch_dim = batch_shape.numel()
Tri Dao's avatar
Tri Dao committed
90
91
92
93
94
        grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
        if ctx.needs_input_grad[0]:
            if not ctx.return_residual:
                grad_input = F.linear(grad_output, weight.t())
            else:
95
96
97
98
                grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]),
                                         grad_output, weight)
            grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
            if process_group is not None:
99
100
                reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw
                grad_input, handle_grad_input = reduce_fn(grad_input, process_group, async_op=True)
101
102
        else:
            grad_input = None
103
104
        if ctx.needs_input_grad[1]:
            assert ctx.compute_weight_gradient
105
            if process_group is not None and sequence_parallel:
106
107
108
109
110
111
112
113
114
                handle_x.wait()
            grad_weight, grad_bias = fused_dense_cuda.linear_bias_wgrad(
                total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2]
            )
        else:
            grad_weight = None
            grad_bias = grad_output if ctx.needs_input_grad[2] else None
        if process_group is not None and ctx.needs_input_grad[0]:
            handle_grad_input.wait()
115
        return grad_input, grad_weight, grad_bias, None, None, None
116
117


Tri Dao's avatar
Tri Dao committed
118
def fused_dense_func(x: Tensor, weight: Tensor, bias: Optional[Tensor] = None,
119
120
                     return_residual: bool = False, process_group: Optional[ProcessGroup] = None,
                     sequence_parallel: bool = True):
Tri Dao's avatar
Tri Dao committed
121
122
    dtype_eligible = (x.dtype in [torch.float16, torch.bfloat16]
                      or (x.dtype == torch.float32 and torch.is_autocast_enabled()))
123
    if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible:
124
125
        return FusedDenseFunc.apply(x, weight, bias, return_residual, process_group,
                                    sequence_parallel)
Tri Dao's avatar
Tri Dao committed
126
    else:
127
        assert process_group is None
Tri Dao's avatar
Tri Dao committed
128
129
        out = F.linear(x, weight, bias)
        return out if not return_residual else (out, x)
130
131


Tri Dao's avatar
Tri Dao committed
132
class FusedDense(nn.Linear):
133
134

    def __init__(self, in_features: int, out_features: int, bias: bool = True,
Tri Dao's avatar
Tri Dao committed
135
                 return_residual: bool = False, device=None, dtype=None) -> None:
136
        super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype)
Tri Dao's avatar
Tri Dao committed
137
        self.return_residual = return_residual
138

139
140
141
142
143
144
145
146
147
148
149
150
    def forward(self, x, process_group=None):
        """
        If process_group is not None, we're doing Tensor Parallel with sequence parallelism:
        we do an all_gather of x before doing the matmul.
        """
        return fused_dense_func(x, self.weight, self.bias, return_residual=self.return_residual,
                                process_group=process_group)


class ColumnParallelLinear(nn.Linear):

    def __init__(self, in_features: int, out_features: int, process_group: ProcessGroup,
151
                 bias: bool = True, sequence_parallel=True, device=None, dtype=None) -> None:
152
153
154
155
156
157
158
        world_size = torch.distributed.get_world_size(process_group)
        if out_features % world_size != 0:
            raise ValueError(f'out_features ({out_features}) must be divisible by '
                             f'world_size ({world_size})')
        super().__init__(in_features, out_features // world_size, bias=bias,
                         device=device, dtype=dtype)
        self.process_group = process_group
159
        self.sequence_parallel = sequence_parallel
160

161
    def forward(self, x):
162
163
164
165
166
        # If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
        # we do an all_gather of x before doing the matmul.
        # If not, then the input is already gathered.
        return fused_dense_func(x, self.weight, self.bias, process_group=self.process_group,
                                sequence_parallel=self.sequence_parallel)
167
168
169
170
171


class RowParallelLinear(nn.Linear):

    def __init__(self, in_features: int, out_features: int, process_group: ProcessGroup,
172
                 bias: bool = True, sequence_parallel=True, device=None, dtype=None) -> None:
173
174
175
176
177
178
179
180
181
        world_size = torch.distributed.get_world_size(process_group)
        rank = torch.distributed.get_rank(process_group)
        if in_features % world_size != 0:
            raise ValueError(f'in_features ({in_features}) must be divisible by '
                             f'world_size ({world_size})')
        # Only rank 0 will have bias
        super().__init__(in_features // world_size, out_features, bias=bias and rank == 0,
                         device=device, dtype=dtype)
        self.process_group = process_group
182
        self.sequence_parallel = sequence_parallel
183
184
185
186
187
188
189

    def forward(self, x):
        """
        We're doing Tensor Parallel with sequence parallelism: we do the matmul and then
        a reduce_scatter of the result.
        """
        out = fused_dense_func(x, self.weight, self.bias)
190
191
        reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
        return reduce_fn(out, self.process_group)
192
193


194
class FusedMLPFunc(torch.autograd.Function):
195
196
197

    @staticmethod
    @custom_fwd
198
199
200
    def forward(ctx, x, weight1, bias1, weight2, bias2, activation='gelu_approx', save_pre_act=True,
                return_residual=False, checkpoint_lvl=0, heuristic=0, process_group=None,
                sequence_parallel=True):
201
        """
202
203
204
        If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel
        with sequence parallelism: we do an all_gather of x before doing the matmul.
        If sequence_parallel=False, then the input is already gathered.
205
206

        checkpoint_lvl:
207
        0: no recomputation in the bwd
208
209
        1: recompute gelu_out / relu_out in the bwd
        2: recompute pre_act and gelu_out / relu_out in the bwd
210
211
        """
        assert -1 <= heuristic <= 4
212
        assert activation in ['gelu_approx', 'relu']
213
        if not save_pre_act:
Tri Dao's avatar
Tri Dao committed
214
            checkpoint_lvl = 2
215
        assert checkpoint_lvl in [0, 1, 2]
Tri Dao's avatar
Tri Dao committed
216
        ctx.return_residual = return_residual
217
        ctx.process_group = process_group
218
        ctx.sequence_parallel = sequence_parallel
219
        ctx.checkpoint_lvl = checkpoint_lvl
220
        ctx.activation = activation
221
222
223
224
        ctx.heuristic = heuristic

        if torch.is_autocast_enabled():
            x = x.to(dtype=torch.get_autocast_gpu_dtype())
225
        x = x.contiguous()
226
        if process_group is not None and sequence_parallel:
227
228
229
230
231
232
233
234
235
236
            # We want to kick off the all_gather early, before weight dtype conversion
            total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
        else:
            total_x = x

        if torch.is_autocast_enabled():
            dtype = torch.get_autocast_gpu_dtype()
            weight1, weight2 = [a.to(dtype=dtype) for a in [weight1, weight2]]
            bias1 = bias1.to(dtype=dtype) if bias1 is not None else None
            bias2 = bias2.to(dtype=dtype) if bias2 is not None else None
237
        weight1 = weight1.contiguous()
Tri Dao's avatar
Tri Dao committed
238
        bias1 = bias1.contiguous() if bias1 is not None else None
239
        weight2 = weight2.contiguous()
Tri Dao's avatar
Tri Dao committed
240
        bias2 = bias2.contiguous() if bias2 is not None else None
241
        if process_group is not None and sequence_parallel:
242
            handle_x.wait()
243
        batch_shape, n = total_x.shape[:-1], total_x.shape[-1]
244
        batch_dim = batch_shape.numel()
245
246
247
        # https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174
        if min(batch_dim, n, *weight1.shape, *weight2.shape) > 65535 * 32:
            raise RuntimeError('fused_dense only supports matrix dims <= 2M')
248
        if heuristic == -1:
249
250
251
252
            pre_act = F.linear(total_x, weight1, bias1)
            activation_fn = (partial(F.gelu, approximate='tanh') if activation == 'gelu_approx'
                             else F.relu)
            output1 = activation_fn(pre_act)
253
            # This is before adding bias1
254
            # pre_act = F.linear(total_x.reshape(batch_dim, n), weight1)
255
            # with torch.jit.fuser('fuser2'):
256
            #     output1 = bias_gelu(pre_act, bias1)
257
        else:
258
259
260
            is_gelu = activation == 'gelu_approx'
            output1, *rest = fused_dense_cuda.linear_act_forward(
                total_x.reshape(batch_dim, n), weight1, bias1, is_gelu, save_pre_act, heuristic
261
262
            )
            if save_pre_act:
263
                pre_act = rest[0]
Tri Dao's avatar
Tri Dao committed
264
        output2 = F.linear(output1, weight2, bias2)
265
266
267
        if checkpoint_lvl == 0 or (checkpoint_lvl == 1 and activation == 'relu'):
            # For RELU the pre_act is very small (just a bit-mask) so we just save it
            ctx.save_for_backward(x, weight1, weight2, pre_act, output1)
268
        elif checkpoint_lvl == 1:
269
            ctx.save_for_backward(x, weight1, weight2, pre_act)
270
        elif checkpoint_lvl == 2:
Tri Dao's avatar
Tri Dao committed
271
272
273
            ctx.save_for_backward(x, weight1, weight2, bias1)
        output2 = output2.reshape(*batch_shape, output2.shape[-1])
        return output2 if not return_residual else (output2, x)
274
275
276

    @staticmethod
    @custom_bwd
Tri Dao's avatar
Tri Dao committed
277
    def backward(ctx, grad_output, *args):
278
279
        grad_output = grad_output.contiguous()
        checkpoint_lvl = ctx.checkpoint_lvl
280
281
282
        activation = ctx.activation
        activation_fn = (partial(F.gelu, approximate='tanh') if activation == 'gelu_approx'
                            else F.relu)
Tri Dao's avatar
Tri Dao committed
283
284
285
        if ctx.return_residual:
            grad_input, = args
            grad_input = grad_input.contiguous()
286
        process_group = ctx.process_group
287
        sequence_parallel = ctx.sequence_parallel
Tri Dao's avatar
Tri Dao committed
288
        x, weight1, weight2, *rest = ctx.saved_tensors
289
        if process_group is None or not sequence_parallel:
290
291
            total_x = x
        batch_shape = grad_output.shape[:-1]
292
        batch_dim = batch_shape.numel()
293
        if checkpoint_lvl in [0, 1]:
294
            if process_group is not None and sequence_parallel:
295
                total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
296
297
            if checkpoint_lvl == 0 or (checkpoint_lvl == 1 and activation == 'relu'):
                pre_act, output1 = rest
298
            elif checkpoint_lvl == 1:
299
300
                pre_act, = rest
                output1 = activation_fn(pre_act)
301
        elif checkpoint_lvl == 2:
Tri Dao's avatar
Tri Dao committed
302
            bias1, = rest
303
            if process_group is not None and sequence_parallel:
304
                total_x, _ = all_gather_raw(x, process_group)
305
            if ctx.heuristic == -1:
306
307
                pre_act = F.linear(total_x, weight1, bias1)
                output1 = activation_fn(pre_act)
308
            else:
309
310
311
                output1, pre_act = fused_dense_cuda.linear_act_forward(
                    total_x.reshape(batch_dim, total_x.shape[-1]), weight1, bias1,
                    activation == 'gelu_approx', True, ctx.heuristic
Tri Dao's avatar
Tri Dao committed
312
313
314
315
                )

        grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
        output1 = output1.reshape(batch_dim, output1.shape[-1])
316
        pre_act = pre_act.reshape(batch_dim, pre_act.shape[-1])
Tri Dao's avatar
Tri Dao committed
317
318
319
320
321
322
323
        if ctx.needs_input_grad[3]:
            grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad(
                output1, grad_output, ctx.needs_input_grad[4]
            )
        else:
            grad_weight2 = None
            grad_bias2 = grad_output if ctx.needs_input_grad[4] else None
324
        if ctx.heuristic == -1:
325
            # grad_pre_act = matmul_dgelu(grad_output, weight2, pre_act)
Tri Dao's avatar
Tri Dao committed
326
            grad_output1 = F.linear(grad_output, weight2.t())
327
            with torch.jit.fuser('fuser2'):
328
329
                activation_grad_fn = gelu_bwd if activation == 'gelu_approx' else relu_bwd
                grad_pre_act = activation_grad_fn(grad_output1, pre_act)
330
        else:
331
332
333
334
            # The cublasLt epilogue has to compute both gelu/relu grad and bias grad, we can't
            # just compute gelu/relu grad
            grad_pre_act, grad_bias1 = fused_dense_cuda.bias_act_linear_dgrad_bgrad(
                weight2, grad_output, pre_act, activation == 'gelu_approx', ctx.heuristic
335
            )
Tri Dao's avatar
Tri Dao committed
336
337
338
339
            if not ctx.needs_input_grad[2]:
                grad_bias1 = None
        if ctx.needs_input_grad[0]:
            if not ctx.return_residual:
340
                grad_input = F.linear(grad_pre_act, weight1.t())
Tri Dao's avatar
Tri Dao committed
341
            else:
342
                grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]),
343
                                         grad_pre_act, weight1)
344
345
            grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
            if process_group is not None:
346
347
                reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw
                grad_input, handle_grad_input = reduce_fn(grad_input, process_group, async_op=True)
Tri Dao's avatar
Tri Dao committed
348
349
        else:
            grad_input = None
350
351
        if ctx.heuristic == -1:
            if ctx.needs_input_grad[1]:
352
                if process_group is not None and sequence_parallel:
353
354
                    handle_x.wait()
                grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_wgrad(
355
                    total_x.reshape(batch_dim, total_x.shape[-1]), grad_pre_act,
356
357
358
359
                    ctx.needs_input_grad[2]
                )
            else:
                grad_weight1 = None
360
                grad_bias1 = grad_pre_act if ctx.needs_input_grad[2] else None
361
362
        else:
            if ctx.needs_input_grad[1]:
363
                if process_group is not None and sequence_parallel:
364
                    handle_x.wait()
365
                grad_weight1 = F.linear(grad_pre_act.t(),
366
367
368
369
370
371
                                        total_x.reshape(batch_dim, total_x.shape[-1]).t())
            else:
                grad_weight1 = None
        if process_group is not None and ctx.needs_input_grad[0]:
            handle_grad_input.wait()
        return (grad_input, grad_weight1, grad_bias1, grad_weight2, grad_bias2,
372
                None, None, None, None, None, None, None)
Tri Dao's avatar
Tri Dao committed
373
374


375
def fused_mlp_func(
Tri Dao's avatar
Tri Dao committed
376
    x: Tensor, weight1: Tensor, weight2: Tensor, bias1: Optional[Tensor] = None,
377
    bias2: Optional[Tensor] = None, activation: str = 'gelu_approx',
378
379
    save_pre_act: bool = True, return_residual: bool = False,
    checkpoint_lvl: int = 0, heuristic: int = 0,
380
381
    process_group: Optional[ProcessGroup] = None,
    sequence_parallel: bool = True
Tri Dao's avatar
Tri Dao committed
382
):
383
    assert activation in ['gelu_approx', 'relu']
Tri Dao's avatar
Tri Dao committed
384
385
    dtype_eligible = (x.dtype in [torch.float16, torch.bfloat16]
                      or (x.dtype == torch.float32 and torch.is_autocast_enabled()))
386
387
    # If we save pre-activation, dimension must be divisible by 128 (relu) or 8 (gelu)
    dim_eligible = not save_pre_act or (x.shape[-1] % (128 if activation == 'relu' else 8) == 0)
Tri Dao's avatar
Tri Dao committed
388
    if (x.is_cuda and weight1.is_cuda and weight2.is_cuda and (bias1 is None or bias1.is_cuda)
389
390
391
        and (bias2 is None or bias2.is_cuda) and dtype_eligible and dim_eligible):
        return FusedMLPFunc.apply(
            x, weight1, bias1, weight2, bias2, activation, save_pre_act, return_residual,
392
            checkpoint_lvl, heuristic, process_group, sequence_parallel
Tri Dao's avatar
Tri Dao committed
393
394
        )
    else:
395
        assert process_group is None
396
397
398
399
        pre_act = F.linear(x, weight1, bias1)
        activation_fn = (partial(F.gelu, approximate='tanh') if activation == 'gelu_approx'
                         else partial(F.relu, inplace=True))
        output1 = activation_fn(pre_act)
Tri Dao's avatar
Tri Dao committed
400
401
        output2 = F.linear(output1, weight2, bias2)
        return output2 if not return_residual else (output2, x)
402
403


404
class FusedMLP(nn.Module):
405

Tri Dao's avatar
Tri Dao committed
406
    def __init__(self, in_features, hidden_features, out_features=None, bias1=True,
407
408
                 bias2=True, activation='gelu_approx', return_residual=False,
                 checkpoint_lvl=0, heuristic='auto', device=None, dtype=None):
409
        """
410
411
412
413
        If process_group is not None, we're doing Tensor Parallel with sequence parallelism:
        we do an all_gather of x before doing the matmul, gelu, then matmul.
        Finally we do a reduce_scatter of the output.

414
415
416
        checkpoint_lvl (increasing lvl means slower but more memory saving):
            0: no recomputation in the bwd
            1: recompute gelu_out in the bwd
417
            2: recompute pre_act and gelu_out in the bwd
418
419
420
        heuristic:
            -1: don't fuse gemm + gelu (separate kernel)
            0..4: use this heuristic for the algo section in the fused gemm + gelu
421
422
423
            'auto': heuristic will be picked automatically:
                For CUDA >= 11.8, we set heuristic=0 for both fp16 and bf16 for best perf.
                For CUDA <= 11.7, we set heuristic=1 for fp16 and heuristic=-1 for bf16.
424
425
                For H100, we set heuristic=-1 for both fp16 and bf16 as the fused cuBlasLt implementation
                is slower than the unfused version.
Tri Dao's avatar
Tri Dao committed
426
427
428
        return_residual: whether to return the input x along with the output. This is for
            performance reason: for post-norm architecture, returning the input allows us
            to fuse the backward of nn.Linear with the residual connection.
429
430
        """
        assert checkpoint_lvl in [0, 1, 2]
431
        assert activation in ['gelu_approx', 'relu']
432
433
434
435
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        if out_features is None:
            out_features = in_features
436
        self.activation = activation
Tri Dao's avatar
Tri Dao committed
437
        self.return_residual = return_residual
438
439
        self.checkpoint_lvl = checkpoint_lvl
        self.heuristic = heuristic
Tri Dao's avatar
Tri Dao committed
440
441
        self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs)
        self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
442

443
    def forward(self, x, process_group=None):
444
445
446
        dtype = x.dtype if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype()
        if self.heuristic == 'auto':
            if self.activation == 'gelu_approx':
447
448
449
450
451
                if torch.cuda.get_device_capability('cuda') == (9, 0):
                    heuristic = -1
                else:
                    cuda_ver = tuple(map(int, torch.version.cuda.split('.')))
                    heuristic = 0 if cuda_ver >= (11, 8) else (1 if dtype == torch.float16 else -1)
452
453
454
455
456
            else:
                heuristic = 0
        else:
            heuristic = self.heuristic
        out = fused_mlp_func(
457
            x, self.fc1.weight, self.fc2.weight, self.fc1.bias, self.fc2.bias,
458
459
460
            activation=self.activation, save_pre_act=self.training,
            return_residual=self.return_residual, checkpoint_lvl=self.checkpoint_lvl,
            heuristic=heuristic, process_group=process_group
461
462
463
464
465
466
467
468
        )
        if self.return_residual:
            out, x = out
        if process_group is not None:
            out = reduce_scatter(out, process_group)
        return out if not self.return_residual else (out, x)


469
class ParallelFusedMLP(nn.Module):
470

471
    def __init__(self, in_features, hidden_features, out_features=None, activation='gelu_approx',
472
                 process_group: ProcessGroup = None, bias1=True, bias2=True,
473
474
                 sequence_parallel=True, checkpoint_lvl=0, heuristic='auto',
                 device=None, dtype=None):
475
476
477
478
479
480
481
482
        """
        process_group is required. We're doing Tensor Parallel with sequence parallelism:
        we do an all_gather of x before doing the matmul, gelu, then matmul.
        Finally we do a reduce_scatter of the output.

        checkpoint_lvl (increasing lvl means slower but more memory saving):
            0: no recomputation in the bwd
            1: recompute gelu_out in the bwd
483
            2: recompute pre_act and gelu_out in the bwd
484
485
486
        heuristic:
            -1: don't fuse gemm + gelu (separate kernel)
            0..4: use this heuristic for the algo section in the fused gemm + gelu
487
488
489
            'auto': heuristic will be picked automatically:
                For CUDA >= 11.8, we set heuristic=0 for both fp16 and bf16 for best perf.
                For CUDA <= 11.7, we set heuristic=1 for fp16 and heuristic=-1 for bf16.
490
491
        """
        assert checkpoint_lvl in [0, 1, 2]
492
        assert activation in ['gelu_approx', 'relu']
493
494
495
496
497
        assert process_group is not None
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        if out_features is None:
            out_features = in_features
498
        self.activation = activation
499
        self.process_group = process_group
500
        self.sequence_parallel = sequence_parallel
501
502
503
504
505
506
507
        self.checkpoint_lvl = checkpoint_lvl
        self.heuristic = heuristic
        self.fc1 = ColumnParallelLinear(in_features, hidden_features, process_group,
                                        bias=bias1, **factory_kwargs)
        self.fc2 = RowParallelLinear(hidden_features, out_features, process_group,
                                     bias=bias2, **factory_kwargs)

508
    def forward(self, x):
509
510
511
512
513
514
515
516
517
518
        dtype = x.dtype if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype()
        if self.heuristic == 'auto':
            if self.activation == 'gelu_approx':
                cuda_ver = tuple(map(int, torch.version.cuda.split('.')))
                heuristic = 0 if cuda_ver >= (11, 8) else (1 if dtype == torch.float16 else -1)
            else:
                heuristic = 0
        else:
            heuristic = self.heuristic
        out = fused_mlp_func(
Tri Dao's avatar
Tri Dao committed
519
            x, self.fc1.weight, self.fc2.weight, self.fc1.bias, self.fc2.bias,
520
521
            activation=self.activation, save_pre_act=self.training,
            checkpoint_lvl=self.checkpoint_lvl, heuristic=heuristic,
522
523
            process_group=self.process_group,
            sequence_parallel=self.sequence_parallel
524
        )
525
526
        reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
        return reduce_fn(out, self.process_group)