mlp.py 5.89 KB
Newer Older
1
# Copyright (c) 2023, Tri Dao.
2
3
4
5

import torch
import torch.nn as nn
import torch.nn.functional as F
Tri Dao's avatar
Tri Dao committed
6
7
from torch.distributed import ProcessGroup

8
9
10
11
12
13

try:
    from flash_attn.ops.activations import swiglu
except ImportError:
    swiglu = None

Tri Dao's avatar
Tri Dao committed
14
15
16
17
try:
    from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear
except ImportError:
    ColumnParallelLinear, RowParallelLinear = None, None
18
19

try:
Tri Dao's avatar
Tri Dao committed
20
    from flash_attn.ops.fused_dense import FusedMLP, ParallelFusedMLP
21
except ImportError:
22
    FusedMLP, ParallelFusedMLP = None, None
23
24
25


class Mlp(nn.Module):
Tri Dao's avatar
Tri Dao committed
26
27
28
29
30
31
32
33
34
35
36
37
38
    def __init__(
        self,
        in_features,
        hidden_features=None,
        out_features=None,
        activation=F.gelu,
        bias1=True,
        bias2=True,
        return_residual=False,
        device=None,
        dtype=None,
    ):
        factory_kwargs = {"device": device, "dtype": dtype}
39
        super().__init__()
40
41
        out_features = out_features if out_features is not None else in_features
        hidden_features = hidden_features if hidden_features is not None else in_features * 4
42
        self.return_residual = return_residual
Tri Dao's avatar
Tri Dao committed
43
        self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs)
44
        self.activation = activation
Tri Dao's avatar
Tri Dao committed
45
        self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
46
47

    def forward(self, x):
48
49
50
51
        y = self.fc1(x)
        y = self.activation(y)
        y = self.fc2(y)
        return y if not self.return_residual else (y, x)
Tri Dao's avatar
Tri Dao committed
52
53


Tri Dao's avatar
Tri Dao committed
54
class ParallelMLP(nn.Module):
Tri Dao's avatar
Tri Dao committed
55
56
57
58
59
60
61
62
63
64
65
66
67
68
    def __init__(
        self,
        in_features,
        hidden_features=None,
        out_features=None,
        activation=F.gelu,
        process_group: ProcessGroup = None,
        sequence_parallel=True,
        bias1=True,
        bias2=True,
        device=None,
        dtype=None,
    ):
        factory_kwargs = {"device": device, "dtype": dtype}
Tri Dao's avatar
Tri Dao committed
69
70
71
        super().__init__()
        assert ColumnParallelLinear is not None, "Need to install fused_dense"
        assert RowParallelLinear is not None, "Need to install fused_dense"
72
73
        out_features = out_features if out_features is not None else in_features
        hidden_features = hidden_features if hidden_features is not None else in_features * 4
Tri Dao's avatar
Tri Dao committed
74
75
76
77
78
79
80
81
        self.fc1 = ColumnParallelLinear(
            in_features,
            hidden_features,
            process_group,
            bias=bias1,
            sequence_parallel=sequence_parallel,
            **factory_kwargs,
        )
Tri Dao's avatar
Tri Dao committed
82
        self.activation = activation
Tri Dao's avatar
Tri Dao committed
83
84
85
86
87
88
89
90
        self.fc2 = RowParallelLinear(
            hidden_features,
            out_features,
            process_group,
            bias=bias2,
            sequence_parallel=sequence_parallel,
            **factory_kwargs,
        )
Tri Dao's avatar
Tri Dao committed
91
92
93
94
95
96
97
98

    def forward(self, x):
        y = self.fc1(x)
        y = self.activation(y)
        y = self.fc2(y)
        return y


Tri Dao's avatar
Tri Dao committed
99
class GatedMlp(nn.Module):
Tri Dao's avatar
Tri Dao committed
100
101
102
103
104
105
106
107
108
109
110
111
112
113
    def __init__(
        self,
        in_features,
        hidden_features=None,
        out_features=None,
        activation=F.sigmoid,
        bias1=True,
        bias2=True,
        multiple_of=256,
        return_residual=False,
        device=None,
        dtype=None,
    ):
        factory_kwargs = {"device": device, "dtype": dtype}
Tri Dao's avatar
Tri Dao committed
114
        super().__init__()
115
        out_features = out_features if out_features is not None else in_features
Tri Dao's avatar
Tri Dao committed
116
117
118
        hidden_features = (
            hidden_features if hidden_features is not None else int(8 * in_features / 3)
        )
Tri Dao's avatar
Tri Dao committed
119
120
        hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
        self.return_residual = return_residual
Tri Dao's avatar
Tri Dao committed
121
        self.fc1 = nn.Linear(in_features, 2 * hidden_features, bias=bias1, **factory_kwargs)
Tri Dao's avatar
Tri Dao committed
122
        self.activation = activation
123
        self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
Tri Dao's avatar
Tri Dao committed
124
125
126
127
128

    def forward(self, x):
        y = self.fc1(x)
        if self.activation == F.sigmoid:  # Special case for GLU
            y = F.glu(y, dim=-1)
129
130
131
        elif self.activation == F.silu and swiglu is not None:  # Special case for SwiGLU
            y, gate = y.chunk(2, dim=-1)
            y = swiglu(gate, y)
Tri Dao's avatar
Tri Dao committed
132
133
134
135
136
        else:
            y, gate = y.chunk(2, dim=-1)
            y = y * self.activation(gate)
        y = self.fc2(y)
        return y if not self.return_residual else (y, x)
137
138


Tri Dao's avatar
Tri Dao committed
139
class ParallelGatedMlp(nn.Module):
Tri Dao's avatar
Tri Dao committed
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
    """Parallel GatedMlp"""

    def __init__(
        self,
        in_features,
        process_group,
        hidden_features=None,
        out_features=None,
        activation=F.sigmoid,
        bias1=True,
        bias2=True,
        multiple_of=256,
        sequence_parallel=True,
        device=None,
        dtype=None,
    ):
        factory_kwargs = {"device": device, "dtype": dtype}
Tri Dao's avatar
Tri Dao committed
157
        super().__init__()
158
        out_features = out_features if out_features is not None else in_features
Tri Dao's avatar
Tri Dao committed
159
160
161
        hidden_features = (
            hidden_features if hidden_features is not None else int(8 * in_features / 3)
        )
162
163
        hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
        if ColumnParallelLinear is None or RowParallelLinear is None:
Tri Dao's avatar
Tri Dao committed
164
165
166
167
168
169
170
171
172
            raise ImportError("fused_dense is not installed")
        self.fc1 = ColumnParallelLinear(
            in_features,
            2 * hidden_features,
            process_group,
            bias=bias1,
            sequence_parallel=sequence_parallel,
            **factory_kwargs,
        )
Tri Dao's avatar
Tri Dao committed
173
        self.activation = activation
Tri Dao's avatar
Tri Dao committed
174
175
176
177
178
179
180
181
        self.fc2 = RowParallelLinear(
            hidden_features,
            out_features,
            process_group,
            bias=bias2,
            sequence_parallel=sequence_parallel,
            **factory_kwargs,
        )
Tri Dao's avatar
Tri Dao committed
182
183
184
185
186
187
188
189
190
191

    def forward(self, x):
        y = self.fc1(x)
        if self.activation == F.sigmoid:  # Special case for GLU
            y = F.glu(y, dim=-1)
        else:
            y, gate = y.chunk(2, dim=-1)
            y = y * self.activation(gate)
        y = self.fc2(y)
        return y