fused_dense.py 22.1 KB
Newer Older
1
2
# Copyright (c) 2022, Tri Dao.
# Inspired by https://github.com/NVIDIA/apex/blob/master/apex/fused_dense/fused_dense.py
3
# We make it work with pytorch amp and with bfloat16.
4
# The TensorParallel linear modules are inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/layers.py
Tri Dao's avatar
Tri Dao committed
5
from typing import Optional
6
7
8
9

import torch
import torch.nn as nn
import torch.nn.functional as F
Tri Dao's avatar
Tri Dao committed
10
from torch import Tensor
11
from torch.distributed import ProcessGroup
12
13
14
15
from torch.cuda.amp import custom_bwd, custom_fwd

# import fused_dense_cuda  # from apex
import fused_dense_lib as fused_dense_cuda
16

17
from flash_attn.ops.gelu_activation import gelu_bwd
18
19
from flash_attn.utils.distributed import all_gather_raw, reduce_scatter_raw, all_reduce_raw
from flash_attn.utils.distributed import reduce_scatter, all_reduce
20
21


Tri Dao's avatar
Tri Dao committed
22
class FusedDenseFunc(torch.autograd.Function):
23
24
25

    @staticmethod
    @custom_fwd
26
27
    def forward(ctx, x, weight, bias, return_residual=False, process_group=None,
                sequence_parallel=True):
28
        """
29
30
        If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel
        with sequence parallelism: we do an all_gather_raw of x before doing the matmul.
31
        """
32
33
34
        ctx.compute_weight_gradient = weight.requires_grad
        ctx.return_residual = return_residual
        ctx.process_group = process_group
35
        ctx.sequence_parallel = sequence_parallel
36

37
        if torch.is_autocast_enabled():
38
            x = x.to(dtype=torch.get_autocast_gpu_dtype())
39
        x = x.contiguous()
40
        if process_group is not None and sequence_parallel:
41
42
43
44
45
46
47
48
            # We want to kick off the all_gather early, before weight dtype conversion
            total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
        else:
            total_x = x

        if torch.is_autocast_enabled():
            weight = weight.to(dtype=torch.get_autocast_gpu_dtype())
            bias = bias.to(dtype=torch.get_autocast_gpu_dtype()) if bias is not None else None
49
        weight = weight.contiguous()
50
        if process_group is not None and sequence_parallel:
51
            handle_x.wait()
52
        batch_shape, n = total_x.shape[:-1], total_x.shape[-1]
53
        batch_dim = batch_shape.numel()
54
55
56
        # https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174
        if min(batch_dim, n, *weight.shape) > 65535 * 32:
            raise RuntimeError('fused_dense only supports matrix dims <= 2M')
57
        output = F.linear(total_x, weight, bias)
58
59
60
61
        if ctx.compute_weight_gradient:
            ctx.save_for_backward(x, weight)
        else:
            ctx.save_for_backward(weight)
Tri Dao's avatar
Tri Dao committed
62
        return output if not return_residual else (output, x)
63
64
65

    @staticmethod
    @custom_bwd
Tri Dao's avatar
Tri Dao committed
66
    def backward(ctx, grad_output, *args):
67
        grad_output = grad_output.contiguous()
Tri Dao's avatar
Tri Dao committed
68
69
70
        if ctx.return_residual:
            grad_input, = args
            grad_input = grad_input.contiguous()
71
        process_group = ctx.process_group
72
        sequence_parallel = ctx.sequence_parallel
73
74
        if ctx.compute_weight_gradient:
            x, weight = ctx.saved_tensors
75
            if process_group is not None and sequence_parallel:
76
77
78
79
80
81
82
                total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
            else:
                total_x = x
        else:
            weight, = ctx.saved_tensors
            total_x = None
        batch_shape = grad_output.shape[:-1]
83
        batch_dim = batch_shape.numel()
Tri Dao's avatar
Tri Dao committed
84
85
86
87
88
        grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
        if ctx.needs_input_grad[0]:
            if not ctx.return_residual:
                grad_input = F.linear(grad_output, weight.t())
            else:
89
90
91
92
                grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]),
                                         grad_output, weight)
            grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
            if process_group is not None:
93
94
                reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw
                grad_input, handle_grad_input = reduce_fn(grad_input, process_group, async_op=True)
95
96
        else:
            grad_input = None
97
98
        if ctx.needs_input_grad[1]:
            assert ctx.compute_weight_gradient
99
            if process_group is not None and sequence_parallel:
100
101
102
103
104
105
106
107
108
                handle_x.wait()
            grad_weight, grad_bias = fused_dense_cuda.linear_bias_wgrad(
                total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2]
            )
        else:
            grad_weight = None
            grad_bias = grad_output if ctx.needs_input_grad[2] else None
        if process_group is not None and ctx.needs_input_grad[0]:
            handle_grad_input.wait()
109
        return grad_input, grad_weight, grad_bias, None, None, None
110
111


Tri Dao's avatar
Tri Dao committed
112
def fused_dense_func(x: Tensor, weight: Tensor, bias: Optional[Tensor] = None,
113
114
                     return_residual: bool = False, process_group: Optional[ProcessGroup] = None,
                     sequence_parallel: bool = True):
Tri Dao's avatar
Tri Dao committed
115
116
    dtype_eligible = (x.dtype in [torch.float16, torch.bfloat16]
                      or (x.dtype == torch.float32 and torch.is_autocast_enabled()))
117
    if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible:
118
119
        return FusedDenseFunc.apply(x, weight, bias, return_residual, process_group,
                                    sequence_parallel)
Tri Dao's avatar
Tri Dao committed
120
    else:
121
        assert process_group is None
Tri Dao's avatar
Tri Dao committed
122
123
        out = F.linear(x, weight, bias)
        return out if not return_residual else (out, x)
124
125


Tri Dao's avatar
Tri Dao committed
126
class FusedDense(nn.Linear):
127
128

    def __init__(self, in_features: int, out_features: int, bias: bool = True,
Tri Dao's avatar
Tri Dao committed
129
                 return_residual: bool = False, device=None, dtype=None) -> None:
130
        super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype)
Tri Dao's avatar
Tri Dao committed
131
        self.return_residual = return_residual
132

133
134
135
136
137
138
139
140
141
142
143
144
    def forward(self, x, process_group=None):
        """
        If process_group is not None, we're doing Tensor Parallel with sequence parallelism:
        we do an all_gather of x before doing the matmul.
        """
        return fused_dense_func(x, self.weight, self.bias, return_residual=self.return_residual,
                                process_group=process_group)


class ColumnParallelLinear(nn.Linear):

    def __init__(self, in_features: int, out_features: int, process_group: ProcessGroup,
145
                 bias: bool = True, sequence_parallel=True, device=None, dtype=None) -> None:
146
147
148
149
150
151
152
        world_size = torch.distributed.get_world_size(process_group)
        if out_features % world_size != 0:
            raise ValueError(f'out_features ({out_features}) must be divisible by '
                             f'world_size ({world_size})')
        super().__init__(in_features, out_features // world_size, bias=bias,
                         device=device, dtype=dtype)
        self.process_group = process_group
153
        self.sequence_parallel = sequence_parallel
154

155
    def forward(self, x):
156
157
158
159
160
        # If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
        # we do an all_gather of x before doing the matmul.
        # If not, then the input is already gathered.
        return fused_dense_func(x, self.weight, self.bias, process_group=self.process_group,
                                sequence_parallel=self.sequence_parallel)
161
162
163
164
165


class RowParallelLinear(nn.Linear):

    def __init__(self, in_features: int, out_features: int, process_group: ProcessGroup,
166
                 bias: bool = True, sequence_parallel=True, device=None, dtype=None) -> None:
167
168
169
170
171
172
173
174
175
        world_size = torch.distributed.get_world_size(process_group)
        rank = torch.distributed.get_rank(process_group)
        if in_features % world_size != 0:
            raise ValueError(f'in_features ({in_features}) must be divisible by '
                             f'world_size ({world_size})')
        # Only rank 0 will have bias
        super().__init__(in_features // world_size, out_features, bias=bias and rank == 0,
                         device=device, dtype=dtype)
        self.process_group = process_group
176
        self.sequence_parallel = sequence_parallel
177
178
179
180
181
182
183

    def forward(self, x):
        """
        We're doing Tensor Parallel with sequence parallelism: we do the matmul and then
        a reduce_scatter of the result.
        """
        out = fused_dense_func(x, self.weight, self.bias)
184
185
        reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
        return reduce_fn(out, self.process_group)
186
187


Tri Dao's avatar
Tri Dao committed
188
class FusedDenseGeluDenseFunc(torch.autograd.Function):
189
190
191

    @staticmethod
    @custom_fwd
192
    def forward(ctx, x, weight1, bias1, weight2, bias2, save_pre_act=True, return_residual=False,
193
                checkpoint_lvl=0, heuristic=0, process_group=None, sequence_parallel=True):
194
        """
195
196
197
        If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel
        with sequence parallelism: we do an all_gather of x before doing the matmul.
        If sequence_parallel=False, then the input is already gathered.
198
199

        checkpoint_lvl:
200
201
202
203
204
        0: no recomputation in the bwd
        1: recompute gelu_out in the bwd
        2: recompute gelu_in and gelu_out in the bwd
        """
        assert -1 <= heuristic <= 4
205
        if not save_pre_act:
Tri Dao's avatar
Tri Dao committed
206
            checkpoint_lvl = 2
207
        assert checkpoint_lvl in [0, 1, 2]
Tri Dao's avatar
Tri Dao committed
208
        ctx.return_residual = return_residual
209
        ctx.process_group = process_group
210
        ctx.sequence_parallel = sequence_parallel
211
212
213
214
215
        ctx.checkpoint_lvl = checkpoint_lvl
        ctx.heuristic = heuristic

        if torch.is_autocast_enabled():
            x = x.to(dtype=torch.get_autocast_gpu_dtype())
216
        x = x.contiguous()
217
        if process_group is not None and sequence_parallel:
218
219
220
221
222
223
224
225
226
227
            # We want to kick off the all_gather early, before weight dtype conversion
            total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
        else:
            total_x = x

        if torch.is_autocast_enabled():
            dtype = torch.get_autocast_gpu_dtype()
            weight1, weight2 = [a.to(dtype=dtype) for a in [weight1, weight2]]
            bias1 = bias1.to(dtype=dtype) if bias1 is not None else None
            bias2 = bias2.to(dtype=dtype) if bias2 is not None else None
228
        weight1 = weight1.contiguous()
Tri Dao's avatar
Tri Dao committed
229
        bias1 = bias1.contiguous() if bias1 is not None else None
230
        weight2 = weight2.contiguous()
Tri Dao's avatar
Tri Dao committed
231
        bias2 = bias2.contiguous() if bias2 is not None else None
232
        if process_group is not None and sequence_parallel:
233
            handle_x.wait()
234
        batch_shape, n = total_x.shape[:-1], total_x.shape[-1]
235
        batch_dim = batch_shape.numel()
236
237
238
        # https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174
        if min(batch_dim, n, *weight1.shape, *weight2.shape) > 65535 * 32:
            raise RuntimeError('fused_dense only supports matrix dims <= 2M')
239
        if heuristic == -1:
240
            gelu_in = F.linear(total_x, weight1, bias1)
241
            output1 = F.gelu(gelu_in, approximate='tanh')
242
243
            # This is before adding bias1
            # gelu_in = F.linear(total_x.reshape(batch_dim, n), weight1)
244
245
246
            # with torch.jit.fuser('fuser2'):
            #     output1 = bias_gelu(gelu_in, bias1)
        else:
247
248
249
250
            output1, *rest = fused_dense_cuda.linear_gelu_forward(
                total_x.reshape(batch_dim, n), weight1, bias1, save_pre_act, heuristic
            )
            if save_pre_act:
251
                gelu_in = rest[0]
Tri Dao's avatar
Tri Dao committed
252
        output2 = F.linear(output1, weight2, bias2)
253
        if checkpoint_lvl == 0:
Tri Dao's avatar
Tri Dao committed
254
            ctx.save_for_backward(x, weight1, weight2, gelu_in, output1)
255
        elif checkpoint_lvl == 1:
Tri Dao's avatar
Tri Dao committed
256
            ctx.save_for_backward(x, weight1, weight2, gelu_in)
257
        elif checkpoint_lvl == 2:
Tri Dao's avatar
Tri Dao committed
258
259
260
            ctx.save_for_backward(x, weight1, weight2, bias1)
        output2 = output2.reshape(*batch_shape, output2.shape[-1])
        return output2 if not return_residual else (output2, x)
261
262
263

    @staticmethod
    @custom_bwd
Tri Dao's avatar
Tri Dao committed
264
    def backward(ctx, grad_output, *args):
265
266
        grad_output = grad_output.contiguous()
        checkpoint_lvl = ctx.checkpoint_lvl
Tri Dao's avatar
Tri Dao committed
267
268
269
        if ctx.return_residual:
            grad_input, = args
            grad_input = grad_input.contiguous()
270
        process_group = ctx.process_group
271
        sequence_parallel = ctx.sequence_parallel
Tri Dao's avatar
Tri Dao committed
272
        x, weight1, weight2, *rest = ctx.saved_tensors
273
        if process_group is None or not sequence_parallel:
274
275
            total_x = x
        batch_shape = grad_output.shape[:-1]
276
        batch_dim = batch_shape.numel()
277
        if checkpoint_lvl in [0, 1]:
278
            if process_group is not None and sequence_parallel:
279
280
281
282
283
284
                total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
            if checkpoint_lvl == 0:
                gelu_in, output1 = rest
            elif checkpoint_lvl == 1:
                gelu_in, = rest
                output1 = F.gelu(gelu_in, approximate='tanh')
285
        elif checkpoint_lvl == 2:
Tri Dao's avatar
Tri Dao committed
286
            bias1, = rest
287
            if process_group is not None and sequence_parallel:
288
                total_x, _ = all_gather_raw(x, process_group)
289
            if ctx.heuristic == -1:
290
                gelu_in = F.linear(total_x, weight1, bias1)
291
292
                output1 = F.gelu(gelu_in, approximate='tanh')
            else:
Tri Dao's avatar
Tri Dao committed
293
                output1, gelu_in = fused_dense_cuda.linear_gelu_forward(
294
295
                    total_x.reshape(batch_dim, total_x.shape[-1]), weight1, bias1, True,
                    ctx.heuristic
Tri Dao's avatar
Tri Dao committed
296
297
298
299
300
301
302
303
304
305
306
307
                )

        grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
        output1 = output1.reshape(batch_dim, output1.shape[-1])
        gelu_in = gelu_in.reshape(batch_dim, gelu_in.shape[-1])
        if ctx.needs_input_grad[3]:
            grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad(
                output1, grad_output, ctx.needs_input_grad[4]
            )
        else:
            grad_weight2 = None
            grad_bias2 = grad_output if ctx.needs_input_grad[4] else None
308
309
        if ctx.heuristic == -1:
            # grad_gelu = matmul_dgelu(grad_output, weight2, gelu_in)
Tri Dao's avatar
Tri Dao committed
310
            grad_output1 = F.linear(grad_output, weight2.t())
311
312
313
            with torch.jit.fuser('fuser2'):
                grad_gelu = gelu_bwd(grad_output1, gelu_in)
        else:
Tri Dao's avatar
Tri Dao committed
314
315
316
317
            # The cublasLt epilogue has to compute both gelu grad and bias grad, we can't
            # just compute gelu grad
            grad_gelu, grad_bias1 = fused_dense_cuda.bias_gelu_linear_dgrad_bgrad(
                weight2, grad_output, gelu_in, ctx.heuristic
318
            )
Tri Dao's avatar
Tri Dao committed
319
320
321
322
323
324
            if not ctx.needs_input_grad[2]:
                grad_bias1 = None
        if ctx.needs_input_grad[0]:
            if not ctx.return_residual:
                grad_input = F.linear(grad_gelu, weight1.t())
            else:
325
326
327
328
                grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]),
                                         grad_gelu, weight1)
            grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
            if process_group is not None:
329
330
                reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw
                grad_input, handle_grad_input = reduce_fn(grad_input, process_group, async_op=True)
Tri Dao's avatar
Tri Dao committed
331
332
        else:
            grad_input = None
333
334
        if ctx.heuristic == -1:
            if ctx.needs_input_grad[1]:
335
                if process_group is not None and sequence_parallel:
336
337
338
339
340
341
342
343
344
345
                    handle_x.wait()
                grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_wgrad(
                    total_x.reshape(batch_dim, total_x.shape[-1]), grad_gelu,
                    ctx.needs_input_grad[2]
                )
            else:
                grad_weight1 = None
                grad_bias1 = grad_gelu if ctx.needs_input_grad[2] else None
        else:
            if ctx.needs_input_grad[1]:
346
                if process_group is not None and sequence_parallel:
347
348
349
350
351
352
353
354
                    handle_x.wait()
                grad_weight1 = F.linear(grad_gelu.t(),
                                        total_x.reshape(batch_dim, total_x.shape[-1]).t())
            else:
                grad_weight1 = None
        if process_group is not None and ctx.needs_input_grad[0]:
            handle_grad_input.wait()
        return (grad_input, grad_weight1, grad_bias1, grad_weight2, grad_bias2,
355
                None, None, None, None, None, None)
Tri Dao's avatar
Tri Dao committed
356
357
358
359
360


def fused_dense_gelu_dense_func(
    x: Tensor, weight1: Tensor, weight2: Tensor, bias1: Optional[Tensor] = None,
    bias2: Optional[Tensor] = None,
361
362
    save_pre_act: bool = True, return_residual: bool = False,
    checkpoint_lvl: int = 0, heuristic: int = 0,
363
364
    process_group: Optional[ProcessGroup] = None,
    sequence_parallel: bool = True
Tri Dao's avatar
Tri Dao committed
365
366
367
368
):
    dtype_eligible = (x.dtype in [torch.float16, torch.bfloat16]
                      or (x.dtype == torch.float32 and torch.is_autocast_enabled()))
    if (x.is_cuda and weight1.is_cuda and weight2.is_cuda and (bias1 is None or bias1.is_cuda)
369
        and (bias2 is None or bias2.is_cuda) and dtype_eligible):
Tri Dao's avatar
Tri Dao committed
370
        return FusedDenseGeluDenseFunc.apply(
371
372
            x, weight1, bias1, weight2, bias2, save_pre_act, return_residual,
            checkpoint_lvl, heuristic, process_group, sequence_parallel
Tri Dao's avatar
Tri Dao committed
373
374
        )
    else:
375
        assert process_group is None
Tri Dao's avatar
Tri Dao committed
376
377
378
379
        gelu_in = F.linear(x, weight1, bias1)
        output1 = F.gelu(gelu_in, approximate='tanh')
        output2 = F.linear(output1, weight2, bias2)
        return output2 if not return_residual else (output2, x)
380
381


Tri Dao's avatar
Tri Dao committed
382
class FusedDenseGeluDense(nn.Module):
383

Tri Dao's avatar
Tri Dao committed
384
385
386
    def __init__(self, in_features, hidden_features, out_features=None, bias1=True,
                 bias2=True, return_residual=False, checkpoint_lvl=0, heuristic=0,
                 device=None, dtype=None):
387
        """
388
389
390
391
        If process_group is not None, we're doing Tensor Parallel with sequence parallelism:
        we do an all_gather of x before doing the matmul, gelu, then matmul.
        Finally we do a reduce_scatter of the output.

392
393
394
395
396
397
398
        checkpoint_lvl (increasing lvl means slower but more memory saving):
            0: no recomputation in the bwd
            1: recompute gelu_out in the bwd
            2: recompute gelu_in and gelu_out in the bwd
        heuristic:
            -1: don't fuse gemm + gelu (separate kernel)
            0..4: use this heuristic for the algo section in the fused gemm + gelu
Tri Dao's avatar
Tri Dao committed
399
400
401
402
403
            For CUDA >= 11.8, you'd want heuristic=0 for both fp16 and bf16 for best perf.
            For CUDA <= 11.7, you'd want heuristic=1 for fp16 and heuristic=-1 for bf16.
        return_residual: whether to return the input x along with the output. This is for
            performance reason: for post-norm architecture, returning the input allows us
            to fuse the backward of nn.Linear with the residual connection.
404
405
406
407
408
409
        """
        assert checkpoint_lvl in [0, 1, 2]
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        if out_features is None:
            out_features = in_features
Tri Dao's avatar
Tri Dao committed
410
        self.return_residual = return_residual
411
412
        self.checkpoint_lvl = checkpoint_lvl
        self.heuristic = heuristic
Tri Dao's avatar
Tri Dao committed
413
414
        self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs)
        self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
415

416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
    def forward(self, x, process_group=None):
        out = fused_dense_gelu_dense_func(
            x, self.fc1.weight, self.fc2.weight, self.fc1.bias, self.fc2.bias,
            save_pre_act=self.training, return_residual=self.return_residual,
            checkpoint_lvl=self.checkpoint_lvl, heuristic=self.heuristic,
            process_group=process_group
        )
        if self.return_residual:
            out, x = out
        if process_group is not None:
            out = reduce_scatter(out, process_group)
        return out if not self.return_residual else (out, x)


class ParallelFusedDenseGeluDense(nn.Module):

    def __init__(self, in_features, hidden_features, out_features=None,
                 process_group: ProcessGroup = None, bias1=True, bias2=True,
434
                 sequence_parallel=True, checkpoint_lvl=0, heuristic=0, device=None, dtype=None):
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
        """
        process_group is required. We're doing Tensor Parallel with sequence parallelism:
        we do an all_gather of x before doing the matmul, gelu, then matmul.
        Finally we do a reduce_scatter of the output.

        checkpoint_lvl (increasing lvl means slower but more memory saving):
            0: no recomputation in the bwd
            1: recompute gelu_out in the bwd
            2: recompute gelu_in and gelu_out in the bwd
        heuristic:
            -1: don't fuse gemm + gelu (separate kernel)
            0..4: use this heuristic for the algo section in the fused gemm + gelu
            For CUDA >= 11.8, you'd want heuristic=0 for both fp16 and bf16 for best perf.
            For CUDA <= 11.7, you'd want heuristic=1 for fp16 and heuristic=-1 for bf16.
        """
        assert checkpoint_lvl in [0, 1, 2]
        assert process_group is not None
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        if out_features is None:
            out_features = in_features
        self.process_group = process_group
457
        self.sequence_parallel = sequence_parallel
458
459
460
461
462
463
464
        self.checkpoint_lvl = checkpoint_lvl
        self.heuristic = heuristic
        self.fc1 = ColumnParallelLinear(in_features, hidden_features, process_group,
                                        bias=bias1, **factory_kwargs)
        self.fc2 = RowParallelLinear(hidden_features, out_features, process_group,
                                     bias=bias2, **factory_kwargs)

465
    def forward(self, x):
466
        out = fused_dense_gelu_dense_func(
Tri Dao's avatar
Tri Dao committed
467
            x, self.fc1.weight, self.fc2.weight, self.fc1.bias, self.fc2.bias,
468
            save_pre_act=self.training, checkpoint_lvl=self.checkpoint_lvl,
469
470
471
            heuristic=self.heuristic,
            process_group=self.process_group,
            sequence_parallel=self.sequence_parallel
472
        )
473
474
        reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
        return reduce_fn(out, self.process_group)