"tasks/vision/finetune_utils.py" did not exist on "8d7f508a51585ab1f14827be00ee4afd1d5a748f"
mlp.py 5.05 KB
Newer Older
1
2
3
4
5
# Copyright (c) 2022, Tri Dao.

import torch
import torch.nn as nn
import torch.nn.functional as F
Tri Dao's avatar
Tri Dao committed
6
7
8
9
10
11
from torch.distributed import ProcessGroup

try:
    from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear
except ImportError:
    ColumnParallelLinear, RowParallelLinear = None, None
12
13

try:
14
    from flash_attn.ops.fused_dense import FusedMLP, ParallelFusedMLP, ColumnParallelLinear, RowParallelLinear
15
except ImportError:
16
    FusedMLP, ParallelFusedMLP = None, None
17
    ColumnParallelLinear, RowParallelLinear = None, None
18
19
20
21
22


class Mlp(nn.Module):

    def __init__(self, in_features, hidden_features=None, out_features=None, activation=F.gelu,
Tri Dao's avatar
Tri Dao committed
23
                 bias1=True, bias2=True, return_residual=False, device=None, dtype=None):
24
25
26
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        out_features = out_features or in_features
27
        hidden_features = hidden_features or in_features * 4
28
        self.return_residual = return_residual
Tri Dao's avatar
Tri Dao committed
29
        self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs)
30
        self.activation = activation
Tri Dao's avatar
Tri Dao committed
31
        self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
32
33

    def forward(self, x):
34
35
36
37
        y = self.fc1(x)
        y = self.activation(y)
        y = self.fc2(y)
        return y if not self.return_residual else (y, x)
Tri Dao's avatar
Tri Dao committed
38
39


Tri Dao's avatar
Tri Dao committed
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
class ParallelMLP(nn.Module):

    def __init__(self, in_features, hidden_features=None, out_features=None, activation=F.gelu,
                 process_group: ProcessGroup = None, sequence_parallel=True,
                 bias1=True, bias2=True, device=None, dtype=None):
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        assert ColumnParallelLinear is not None, "Need to install fused_dense"
        assert RowParallelLinear is not None, "Need to install fused_dense"
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features * 4
        self.fc1 = ColumnParallelLinear(in_features, hidden_features, process_group, bias=bias1,
                                        sequence_parallel=sequence_parallel, **factory_kwargs)
        self.activation = activation
        self.fc2 = RowParallelLinear(hidden_features, out_features, process_group, bias=bias2,
                                     sequence_parallel=sequence_parallel, **factory_kwargs)

    def forward(self, x):
        y = self.fc1(x)
        y = self.activation(y)
        y = self.fc2(y)
        return y


Tri Dao's avatar
Tri Dao committed
64
65
66
class GatedMlp(nn.Module):

    def __init__(self, in_features, hidden_features=None, out_features=None, activation=F.sigmoid,
Tri Dao's avatar
Tri Dao committed
67
68
                 bias1=True, bias2=True, multiple_of=256, return_residual=False,
                 device=None, dtype=None):
Tri Dao's avatar
Tri Dao committed
69
70
71
72
73
74
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or int(8 * in_features / 3)
        hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
        self.return_residual = return_residual
Tri Dao's avatar
Tri Dao committed
75
        self.fc1 = nn.Linear(in_features, 2 * hidden_features, bias=bias1, **factory_kwargs)
Tri Dao's avatar
Tri Dao committed
76
        self.activation = activation
77
        self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
Tri Dao's avatar
Tri Dao committed
78
79
80
81
82
83
84
85
86
87

    def forward(self, x):
        y = self.fc1(x)
        if self.activation == F.sigmoid:  # Special case for GLU
            y = F.glu(y, dim=-1)
        else:
            y, gate = y.chunk(2, dim=-1)
            y = y * self.activation(gate)
        y = self.fc2(y)
        return y if not self.return_residual else (y, x)
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111


class ParallelGatedMlp(GatedMlp):
    """ Parallel GatedMlp """

    def __init__(self, in_features, process_group, hidden_features=None, out_features=None, activation=F.sigmoid,
                 bias1=True, bias2=True, multiple_of=256, return_residual=False,
                 sequence_parallel=True, device=None, dtype=None):
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__(in_features, hidden_features=hidden_features, out_features=out_features, activation=activation,
                         bias1=bias1, bias2=bias2, multiple_of=multiple_of, return_residual=return_residual,
                         device=device, dtype=dtype)
        out_features = out_features or in_features
        hidden_features = hidden_features or int(8 * in_features / 3)
        hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of

        if ColumnParallelLinear is None or RowParallelLinear is None:
            raise ImportError('fused_dense is not installed')
        self.fc1 = ColumnParallelLinear(in_features, 2 * hidden_features, process_group,
                                        bias=bias1,
                                        sequence_parallel=sequence_parallel, **factory_kwargs)
        self.fc2 = RowParallelLinear(hidden_features, out_features, process_group,
                                     bias=bias2,
                                     sequence_parallel=sequence_parallel, **factory_kwargs)