mlp.py 3.29 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright (c) 2022, Tri Dao.

import torch
import torch.nn as nn
import torch.nn.functional as F

try:
    from flash_attn.ops.fused_dense import fused_dense_gelu_dense_function_td
    from flash_attn.ops.fused_dense import fused_dense_res_gelu_dense_function_td
except ImportError:
    fused_dense_gelu_dense_function_td = None
    fused_dense_res_gelu_dense_function_td = None


class Mlp(nn.Module):

    def __init__(self, in_features, hidden_features=None, out_features=None, activation=F.gelu,
18
                 return_residual=False, device=None, dtype=None):
19
20
21
22
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
23
        self.return_residual = return_residual
24
25
26
27
28
        self.fc1 = nn.Linear(in_features, hidden_features, **factory_kwargs)
        self.activation = activation
        self.fc2 = nn.Linear(hidden_features, out_features, **factory_kwargs)

    def forward(self, x):
29
30
31
32
        y = self.fc1(x)
        y = self.activation(y)
        y = self.fc2(y)
        return y if not self.return_residual else (y, x)
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72


class FusedDenseGeluDense(nn.Module):

    def __init__(self, in_features, hidden_features=None, out_features=None, bias=True,
                 checkpoint_lvl=0, heuristic=0, return_residual=False, device=None, dtype=None):
        """
        checkpoint_lvl (increasing lvl means slower but more memory saving):
            0: no recomputation in the bwd
            1: recompute gelu_out in the bwd
            2: recompute gelu_in and gelu_out in the bwd
        heuristic:
            -1: don't fuse gemm + gelu (separate kernel)
            0..4: use this heuristic for the algo section in the fused gemm + gelu
            For CUDA >= 11.8, you'd want heuristic=0 for both fp16 and bf16 for best perf.
            For CUDA <= 11.7, you'd want heuristic=1 for fp16 and heuristic=-1 for bf16.
        return_residual: whether to return the input x along with the output. This is for
            performance reason: for post-norm architecture, returning the input allows us
            to fuse the backward of nn.Linear with the residual connection.
        """
        assert checkpoint_lvl in [0, 1, 2]
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        assert bias == True, "DenseGeluDense module without bias is currently not supported"
        assert (fused_dense_gelu_dense_function_td is not None
                and fused_dense_res_gelu_dense_function_td is not None), 'fused_dense_lib is not installed'
        self.checkpoint_lvl = checkpoint_lvl
        self.heuristic = heuristic
        self.return_residual = return_residual
        self.fc1 = nn.Linear(in_features, hidden_features, bias=bias, **factory_kwargs)
        self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, **factory_kwargs)

    def forward(self, x):
        assert x.is_cuda
        fn = (fused_dense_gelu_dense_function_td if not self.return_residual
              else fused_dense_res_gelu_dense_function_td)
        return fn(x, self.fc1.weight, self.fc1.bias, self.fc2.weight, self.fc2.bias,
                  self.checkpoint_lvl, self.heuristic)