mlp.py 3.51 KB
Newer Older
1
2
3
4
5
# Copyright (c) 2022, Tri Dao.

import torch
import torch.nn as nn
import torch.nn.functional as F
Tri Dao's avatar
Tri Dao committed
6
7
8
9
10
11
from torch.distributed import ProcessGroup

try:
    from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear
except ImportError:
    ColumnParallelLinear, RowParallelLinear = None, None
12
13

try:
14
    from flash_attn.ops.fused_dense import FusedMLP, ParallelFusedMLP
15
except ImportError:
16
    FusedMLP, ParallelFusedMLP = None, None
17
18
19
20
21


class Mlp(nn.Module):

    def __init__(self, in_features, hidden_features=None, out_features=None, activation=F.gelu,
Tri Dao's avatar
Tri Dao committed
22
                 bias1=True, bias2=True, return_residual=False, device=None, dtype=None):
23
24
25
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        out_features = out_features or in_features
26
        hidden_features = hidden_features or in_features * 4
27
        self.return_residual = return_residual
Tri Dao's avatar
Tri Dao committed
28
        self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs)
29
        self.activation = activation
Tri Dao's avatar
Tri Dao committed
30
        self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
31
32

    def forward(self, x):
33
34
35
36
        y = self.fc1(x)
        y = self.activation(y)
        y = self.fc2(y)
        return y if not self.return_residual else (y, x)
Tri Dao's avatar
Tri Dao committed
37
38


Tri Dao's avatar
Tri Dao committed
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
class ParallelMLP(nn.Module):

    def __init__(self, in_features, hidden_features=None, out_features=None, activation=F.gelu,
                 process_group: ProcessGroup = None, sequence_parallel=True,
                 bias1=True, bias2=True, device=None, dtype=None):
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        assert ColumnParallelLinear is not None, "Need to install fused_dense"
        assert RowParallelLinear is not None, "Need to install fused_dense"
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features * 4
        self.fc1 = ColumnParallelLinear(in_features, hidden_features, process_group, bias=bias1,
                                        sequence_parallel=sequence_parallel, **factory_kwargs)
        self.activation = activation
        self.fc2 = RowParallelLinear(hidden_features, out_features, process_group, bias=bias2,
                                     sequence_parallel=sequence_parallel, **factory_kwargs)

    def forward(self, x):
        y = self.fc1(x)
        y = self.activation(y)
        y = self.fc2(y)
        return y


Tri Dao's avatar
Tri Dao committed
63
64
65
class GatedMlp(nn.Module):

    def __init__(self, in_features, hidden_features=None, out_features=None, activation=F.sigmoid,
Tri Dao's avatar
Tri Dao committed
66
67
                 bias1=True, bias2=True, multiple_of=256, return_residual=False,
                 device=None, dtype=None):
Tri Dao's avatar
Tri Dao committed
68
69
70
71
72
73
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or int(8 * in_features / 3)
        hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
        self.return_residual = return_residual
Tri Dao's avatar
Tri Dao committed
74
        self.fc1 = nn.Linear(in_features, 2 * hidden_features, bias=bias1, **factory_kwargs)
Tri Dao's avatar
Tri Dao committed
75
        self.activation = activation
Tri Dao's avatar
Tri Dao committed
76
        self.fc2 = nn.Linear(hidden_features, out_features, bias=bias1, **factory_kwargs)
Tri Dao's avatar
Tri Dao committed
77
78
79
80
81
82
83
84
85
86

    def forward(self, x):
        y = self.fc1(x)
        if self.activation == F.sigmoid:  # Special case for GLU
            y = F.glu(y, dim=-1)
        else:
            y, gate = y.chunk(2, dim=-1)
            y = y * self.activation(gate)
        y = self.fc2(y)
        return y if not self.return_residual else (y, x)