mlp.py 3.18 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
# Copyright (c) 2022, Tri Dao.

import torch
import torch.nn as nn
import torch.nn.functional as F

try:
    from flash_attn.ops.fused_dense import fused_dense_gelu_dense_function_td
    from flash_attn.ops.fused_dense import fused_dense_res_gelu_dense_function_td
except ImportError:
    fused_dense_gelu_dense_function_td = None
    fused_dense_res_gelu_dense_function_td = None


class Mlp(nn.Module):

    def __init__(self, in_features, hidden_features=None, out_features=None, activation=F.gelu,
                 device=None, dtype=None):
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features, **factory_kwargs)
        self.activation = activation
        self.fc2 = nn.Linear(hidden_features, out_features, **factory_kwargs)

    def forward(self, x):
        x = self.fc1(x)
        x = self.activation(x)
        x = self.fc2(x)
        return x


class FusedDenseGeluDense(nn.Module):

    def __init__(self, in_features, hidden_features=None, out_features=None, bias=True,
                 checkpoint_lvl=0, heuristic=0, return_residual=False, device=None, dtype=None):
        """
        checkpoint_lvl (increasing lvl means slower but more memory saving):
            0: no recomputation in the bwd
            1: recompute gelu_out in the bwd
            2: recompute gelu_in and gelu_out in the bwd
        heuristic:
            -1: don't fuse gemm + gelu (separate kernel)
            0..4: use this heuristic for the algo section in the fused gemm + gelu
            For CUDA >= 11.8, you'd want heuristic=0 for both fp16 and bf16 for best perf.
            For CUDA <= 11.7, you'd want heuristic=1 for fp16 and heuristic=-1 for bf16.
        return_residual: whether to return the input x along with the output. This is for
            performance reason: for post-norm architecture, returning the input allows us
            to fuse the backward of nn.Linear with the residual connection.
        """
        assert checkpoint_lvl in [0, 1, 2]
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        assert bias == True, "DenseGeluDense module without bias is currently not supported"
        assert (fused_dense_gelu_dense_function_td is not None
                and fused_dense_res_gelu_dense_function_td is not None), 'fused_dense_lib is not installed'
        self.checkpoint_lvl = checkpoint_lvl
        self.heuristic = heuristic
        self.return_residual = return_residual
        self.fc1 = nn.Linear(in_features, hidden_features, bias=bias, **factory_kwargs)
        self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, **factory_kwargs)

    def forward(self, x):
        assert x.is_cuda
        fn = (fused_dense_gelu_dense_function_td if not self.return_residual
              else fused_dense_res_gelu_dense_function_td)
        return fn(x, self.fc1.weight, self.fc1.bias, self.fc2.weight, self.fc2.bias,
                  self.checkpoint_lvl, self.heuristic)