mlp.py 4.92 KB
Newer Older
1
2
3
4
5
# Copyright (c) 2022, Tri Dao.

import torch
import torch.nn as nn
import torch.nn.functional as F
Tri Dao's avatar
Tri Dao committed
6
7
8
9
10
11
from torch.distributed import ProcessGroup

try:
    from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear
except ImportError:
    ColumnParallelLinear, RowParallelLinear = None, None
12
13

try:
Tri Dao's avatar
Tri Dao committed
14
    from flash_attn.ops.fused_dense import FusedMLP, ParallelFusedMLP
15
except ImportError:
16
    FusedMLP, ParallelFusedMLP = None, None
17
18
19
20
21


class Mlp(nn.Module):

    def __init__(self, in_features, hidden_features=None, out_features=None, activation=F.gelu,
Tri Dao's avatar
Tri Dao committed
22
                 bias1=True, bias2=True, return_residual=False, device=None, dtype=None):
23
24
25
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        out_features = out_features or in_features
26
        hidden_features = hidden_features or in_features * 4
27
        self.return_residual = return_residual
Tri Dao's avatar
Tri Dao committed
28
        self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs)
29
        self.activation = activation
Tri Dao's avatar
Tri Dao committed
30
        self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
31
32

    def forward(self, x):
33
34
35
36
        y = self.fc1(x)
        y = self.activation(y)
        y = self.fc2(y)
        return y if not self.return_residual else (y, x)
Tri Dao's avatar
Tri Dao committed
37
38


Tri Dao's avatar
Tri Dao committed
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
class ParallelMLP(nn.Module):

    def __init__(self, in_features, hidden_features=None, out_features=None, activation=F.gelu,
                 process_group: ProcessGroup = None, sequence_parallel=True,
                 bias1=True, bias2=True, device=None, dtype=None):
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        assert ColumnParallelLinear is not None, "Need to install fused_dense"
        assert RowParallelLinear is not None, "Need to install fused_dense"
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features * 4
        self.fc1 = ColumnParallelLinear(in_features, hidden_features, process_group, bias=bias1,
                                        sequence_parallel=sequence_parallel, **factory_kwargs)
        self.activation = activation
        self.fc2 = RowParallelLinear(hidden_features, out_features, process_group, bias=bias2,
                                     sequence_parallel=sequence_parallel, **factory_kwargs)

    def forward(self, x):
        y = self.fc1(x)
        y = self.activation(y)
        y = self.fc2(y)
        return y


Tri Dao's avatar
Tri Dao committed
63
64
65
class GatedMlp(nn.Module):

    def __init__(self, in_features, hidden_features=None, out_features=None, activation=F.sigmoid,
Tri Dao's avatar
Tri Dao committed
66
67
                 bias1=True, bias2=True, multiple_of=256, return_residual=False,
                 device=None, dtype=None):
Tri Dao's avatar
Tri Dao committed
68
69
70
71
72
73
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or int(8 * in_features / 3)
        hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
        self.return_residual = return_residual
Tri Dao's avatar
Tri Dao committed
74
        self.fc1 = nn.Linear(in_features, 2 * hidden_features, bias=bias1, **factory_kwargs)
Tri Dao's avatar
Tri Dao committed
75
        self.activation = activation
76
        self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
Tri Dao's avatar
Tri Dao committed
77
78
79
80
81
82
83
84
85
86

    def forward(self, x):
        y = self.fc1(x)
        if self.activation == F.sigmoid:  # Special case for GLU
            y = F.glu(y, dim=-1)
        else:
            y, gate = y.chunk(2, dim=-1)
            y = y * self.activation(gate)
        y = self.fc2(y)
        return y if not self.return_residual else (y, x)
87
88


Tri Dao's avatar
Tri Dao committed
89
class ParallelGatedMlp(nn.Module):
90
91
    """ Parallel GatedMlp """

Tri Dao's avatar
Tri Dao committed
92
93
    def __init__(self, in_features, process_group, hidden_features=None, out_features=None,
                 activation=F.sigmoid, bias1=True, bias2=True, multiple_of=256,
94
95
                 sequence_parallel=True, device=None, dtype=None):
        factory_kwargs = {'device': device, 'dtype': dtype}
Tri Dao's avatar
Tri Dao committed
96
        super().__init__()
97
98
99
100
101
        out_features = out_features or in_features
        hidden_features = hidden_features or int(8 * in_features / 3)
        hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
        if ColumnParallelLinear is None or RowParallelLinear is None:
            raise ImportError('fused_dense is not installed')
Tri Dao's avatar
Tri Dao committed
102
        self.fc1 = ColumnParallelLinear(in_features, 2 * hidden_features, process_group, bias=bias1,
103
                                        sequence_parallel=sequence_parallel, **factory_kwargs)
Tri Dao's avatar
Tri Dao committed
104
105
        self.activation = activation
        self.fc2 = RowParallelLinear(hidden_features, out_features, process_group, bias=bias2,
106
                                     sequence_parallel=sequence_parallel, **factory_kwargs)
Tri Dao's avatar
Tri Dao committed
107
108
109
110
111
112
113
114
115
116

    def forward(self, x):
        y = self.fc1(x)
        if self.activation == F.sigmoid:  # Special case for GLU
            y = F.glu(y, dim=-1)
        else:
            y, gate = y.chunk(2, dim=-1)
            y = y * self.activation(gate)
        y = self.fc2(y)
        return y