mlp.py 5.64 KB
Newer Older
1
2
3
4
5
# Copyright (c) 2022, Tri Dao.

import torch
import torch.nn as nn
import torch.nn.functional as F
Tri Dao's avatar
Tri Dao committed
6
7
8
9
10
11
from torch.distributed import ProcessGroup

try:
    from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear
except ImportError:
    ColumnParallelLinear, RowParallelLinear = None, None
12
13

try:
Tri Dao's avatar
Tri Dao committed
14
    from flash_attn.ops.fused_dense import FusedMLP, ParallelFusedMLP
15
except ImportError:
16
    FusedMLP, ParallelFusedMLP = None, None
17
18
19


class Mlp(nn.Module):
Tri Dao's avatar
Tri Dao committed
20
21
22
23
24
25
26
27
28
29
30
31
32
    def __init__(
        self,
        in_features,
        hidden_features=None,
        out_features=None,
        activation=F.gelu,
        bias1=True,
        bias2=True,
        return_residual=False,
        device=None,
        dtype=None,
    ):
        factory_kwargs = {"device": device, "dtype": dtype}
33
        super().__init__()
34
35
        out_features = out_features if out_features is not None else in_features
        hidden_features = hidden_features if hidden_features is not None else in_features * 4
36
        self.return_residual = return_residual
Tri Dao's avatar
Tri Dao committed
37
        self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs)
38
        self.activation = activation
Tri Dao's avatar
Tri Dao committed
39
        self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
40
41

    def forward(self, x):
42
43
44
45
        y = self.fc1(x)
        y = self.activation(y)
        y = self.fc2(y)
        return y if not self.return_residual else (y, x)
Tri Dao's avatar
Tri Dao committed
46
47


Tri Dao's avatar
Tri Dao committed
48
class ParallelMLP(nn.Module):
Tri Dao's avatar
Tri Dao committed
49
50
51
52
53
54
55
56
57
58
59
60
61
62
    def __init__(
        self,
        in_features,
        hidden_features=None,
        out_features=None,
        activation=F.gelu,
        process_group: ProcessGroup = None,
        sequence_parallel=True,
        bias1=True,
        bias2=True,
        device=None,
        dtype=None,
    ):
        factory_kwargs = {"device": device, "dtype": dtype}
Tri Dao's avatar
Tri Dao committed
63
64
65
        super().__init__()
        assert ColumnParallelLinear is not None, "Need to install fused_dense"
        assert RowParallelLinear is not None, "Need to install fused_dense"
66
67
        out_features = out_features if out_features is not None else in_features
        hidden_features = hidden_features if hidden_features is not None else in_features * 4
Tri Dao's avatar
Tri Dao committed
68
69
70
71
72
73
74
75
        self.fc1 = ColumnParallelLinear(
            in_features,
            hidden_features,
            process_group,
            bias=bias1,
            sequence_parallel=sequence_parallel,
            **factory_kwargs,
        )
Tri Dao's avatar
Tri Dao committed
76
        self.activation = activation
Tri Dao's avatar
Tri Dao committed
77
78
79
80
81
82
83
84
        self.fc2 = RowParallelLinear(
            hidden_features,
            out_features,
            process_group,
            bias=bias2,
            sequence_parallel=sequence_parallel,
            **factory_kwargs,
        )
Tri Dao's avatar
Tri Dao committed
85
86
87
88
89
90
91
92

    def forward(self, x):
        y = self.fc1(x)
        y = self.activation(y)
        y = self.fc2(y)
        return y


Tri Dao's avatar
Tri Dao committed
93
class GatedMlp(nn.Module):
Tri Dao's avatar
Tri Dao committed
94
95
96
97
98
99
100
101
102
103
104
105
106
107
    def __init__(
        self,
        in_features,
        hidden_features=None,
        out_features=None,
        activation=F.sigmoid,
        bias1=True,
        bias2=True,
        multiple_of=256,
        return_residual=False,
        device=None,
        dtype=None,
    ):
        factory_kwargs = {"device": device, "dtype": dtype}
Tri Dao's avatar
Tri Dao committed
108
        super().__init__()
109
        out_features = out_features if out_features is not None else in_features
Tri Dao's avatar
Tri Dao committed
110
111
112
        hidden_features = (
            hidden_features if hidden_features is not None else int(8 * in_features / 3)
        )
Tri Dao's avatar
Tri Dao committed
113
114
        hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
        self.return_residual = return_residual
Tri Dao's avatar
Tri Dao committed
115
        self.fc1 = nn.Linear(in_features, 2 * hidden_features, bias=bias1, **factory_kwargs)
Tri Dao's avatar
Tri Dao committed
116
        self.activation = activation
117
        self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
Tri Dao's avatar
Tri Dao committed
118
119
120
121
122
123
124
125
126
127

    def forward(self, x):
        y = self.fc1(x)
        if self.activation == F.sigmoid:  # Special case for GLU
            y = F.glu(y, dim=-1)
        else:
            y, gate = y.chunk(2, dim=-1)
            y = y * self.activation(gate)
        y = self.fc2(y)
        return y if not self.return_residual else (y, x)
128
129


Tri Dao's avatar
Tri Dao committed
130
class ParallelGatedMlp(nn.Module):
Tri Dao's avatar
Tri Dao committed
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
    """Parallel GatedMlp"""

    def __init__(
        self,
        in_features,
        process_group,
        hidden_features=None,
        out_features=None,
        activation=F.sigmoid,
        bias1=True,
        bias2=True,
        multiple_of=256,
        sequence_parallel=True,
        device=None,
        dtype=None,
    ):
        factory_kwargs = {"device": device, "dtype": dtype}
Tri Dao's avatar
Tri Dao committed
148
        super().__init__()
149
        out_features = out_features if out_features is not None else in_features
Tri Dao's avatar
Tri Dao committed
150
151
152
        hidden_features = (
            hidden_features if hidden_features is not None else int(8 * in_features / 3)
        )
153
154
        hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
        if ColumnParallelLinear is None or RowParallelLinear is None:
Tri Dao's avatar
Tri Dao committed
155
156
157
158
159
160
161
162
163
            raise ImportError("fused_dense is not installed")
        self.fc1 = ColumnParallelLinear(
            in_features,
            2 * hidden_features,
            process_group,
            bias=bias1,
            sequence_parallel=sequence_parallel,
            **factory_kwargs,
        )
Tri Dao's avatar
Tri Dao committed
164
        self.activation = activation
Tri Dao's avatar
Tri Dao committed
165
166
167
168
169
170
171
172
        self.fc2 = RowParallelLinear(
            hidden_features,
            out_features,
            process_group,
            bias=bias2,
            sequence_parallel=sequence_parallel,
            **factory_kwargs,
        )
Tri Dao's avatar
Tri Dao committed
173
174
175
176
177
178
179
180
181
182

    def forward(self, x):
        y = self.fc1(x)
        if self.activation == F.sigmoid:  # Special case for GLU
            y = F.glu(y, dim=-1)
        else:
            y, gate = y.chunk(2, dim=-1)
            y = y * self.activation(gate)
        y = self.fc2(y)
        return y