layers.py 7.92 KB
Newer Older
Sengxian's avatar
Sengxian committed
1
r"""
Rick Ho's avatar
Rick Ho committed
2
Layers that FMoE provides to users
Sengxian's avatar
Sengxian committed
3
"""
Rick Ho's avatar
Rick Ho committed
4
import torch
Rick Ho's avatar
Rick Ho committed
5
6
import torch.nn as nn

Rick Ho's avatar
Rick Ho committed
7
from .functions import prepare_forward
Rick Ho's avatar
Rick Ho committed
8
from .functions import MOEScatter, MOEGather, MOELinear
Sengxian's avatar
Sengxian committed
9
from .functions import AllGather, Slice
Rick Ho's avatar
Rick Ho committed
10
from .gates import NaiveGate
Rick Ho's avatar
Rick Ho committed
11

Rick Ho's avatar
Rick Ho committed
12
13

class FMoELinear(nn.Module):
Sengxian's avatar
Sengxian committed
14
    r"""
Rick Ho's avatar
Rick Ho committed
15
16
17
18
    A linear layer that contains multiple experts.
    As multiple experts can be placed on the same worker, the computation can be
    performed in parallel to increase the performance.
    The FMoELinear module provides such function.
Sengxian's avatar
Sengxian committed
19
20
21
22
23
24
25
26
27
28
    """

    def __init__(
        self,
        num_expert: int,
        in_feat: int,
        out_feat: int,
        bias: bool = True,
        rank: int = 0,
    ):
Rick Ho's avatar
Rick Ho committed
29
        super().__init__()
Rick Ho's avatar
Rick Ho committed
30
31
32
        self.num_expert = num_expert
        self.in_feat = in_feat
        self.out_feat = out_feat
33
        self.rank = rank
34
        self.weight = nn.Parameter(torch.Tensor(num_expert, out_feat, in_feat))
35
36
37
        if bias:
            self.bias = nn.Parameter(torch.Tensor(num_expert, out_feat))
        else:
Sengxian's avatar
Sengxian committed
38
            self.register_parameter("bias", None)
Rick Ho's avatar
Rick Ho committed
39
40

    def forward(self, inp, fwd_expert_count):
Sengxian's avatar
Sengxian committed
41
        r"""
Rick Ho's avatar
Rick Ho committed
42
        Call MOE function
Sengxian's avatar
Sengxian committed
43
        """
44
        x = MOELinear.apply(inp, fwd_expert_count, self.weight, self.bias)
45
        return x
Rick Ho's avatar
Rick Ho committed
46

Jiezhong Qiu's avatar
Jiezhong Qiu committed
47
    def extra_repr(self) -> str:
Sengxian's avatar
Sengxian committed
48
49
50
51
52
53
54
        return "num_expert={}, in_features={}, \
        out_features={}, bias={}, rank={}".format(
            self.num_expert,
            self.in_feat,
            self.out_feat,
            self.bias is not None,
            self.rank,
Jiezhong Qiu's avatar
Jiezhong Qiu committed
55
56
        )

Rick Ho's avatar
Rick Ho committed
57

Rick Ho's avatar
Rick Ho committed
58
def mark_module_parallel_comm(module, comm):
Sengxian's avatar
Sengxian committed
59
    r"""
Rick Ho's avatar
Rick Ho committed
60
61
    Mark all parameters in `module` as doing data parallel in `comm`, where
    `comm` may be one of `'world', 'dp', 'none'`.
Sengxian's avatar
Sengxian committed
62
    """
Rick Ho's avatar
Rick Ho committed
63
    for p in module.parameters():
Sengxian's avatar
Sengxian committed
64
        setattr(p, "dp_comm", comm)
Rick Ho's avatar
Rick Ho committed
65
66
67


def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size):
Sengxian's avatar
Sengxian committed
68
    r"""
Rick Ho's avatar
Rick Ho committed
69
70
71
72
73
    A private function that performs the following steps to complete the MoE
    computation.
    * Count the number of tokens from each worker to each expert.
    * Send the features to their target position so that input features to each
    expert are contiguous in memory.
Rick Ho's avatar
Rick Ho committed
74
    * Perform the forward computation of the experts using `expert_fn`
Rick Ho's avatar
Rick Ho committed
75
76
77
    * Gather the output features of experts back, and reorder them as sentences.
    Intermediate results like expert counts are hidden from users by this
    function.
Sengxian's avatar
Sengxian committed
78
    """
79
    (
Sengxian's avatar
Sengxian committed
80
81
82
83
84
        pos,
        local_expert_count,
        global_expert_count,
        fwd_expert_count,
        fwd_batch_size,
Rick Ho's avatar
Rick Ho committed
85
    ) = prepare_forward(gate, num_expert, world_size)
Rick Ho's avatar
Rick Ho committed
86
87
88
    topk = 1
    if len(gate.shape) == 2:
        topk = gate.shape[1]
89
    x = MOEScatter.apply(
Rick Ho's avatar
Rick Ho committed
90
        inp, pos // topk,
Rick Ho's avatar
Rick Ho committed
91
        local_expert_count, global_expert_count, fwd_batch_size, world_size
92
    )
Rick Ho's avatar
Rick Ho committed
93
    x = expert_fn(x, fwd_expert_count)
94
95
96
97
98

    out_batch_size = inp.shape[0]
    if len(gate.shape) == 2:
        out_batch_size *= gate.shape[1]

99
    x = MOEGather.apply(
100
101
102
        x, pos,
        local_expert_count, global_expert_count,
        out_batch_size, world_size
103
    )
Rick Ho's avatar
Rick Ho committed
104
105
106
    return x


Rick Ho's avatar
Rick Ho committed
107
class FMoE(nn.Module):
Sengxian's avatar
Sengxian committed
108
    r"""
Rick Ho's avatar
Rick Ho committed
109
110
    A general moe implementation that supports an arbitrary module as the
    expert.
Rick Ho's avatar
Rick Ho committed
111
112
113
114
115
116
117
118
119
    * `num_expert` stands for the number of experts on **each** worker.
    * `world_size` stands for the total number of workers that contains
    different experts.
    * `mp_group` can be a torch's communication group, indicating that model
    parallel is applied across the group, which means that workers in the group
    hold the same copy of the input feature, and demands the same copy of the
    output. FMoE saves computation by slicing the input in the mp group and
    performing all-gather after the MLP computation.
    * `top_k` stands for the number of experts each token is going to.
Rick Ho's avatar
Rick Ho committed
120
121
122
    * `gate` is a gate class which can found in `fmoe.gates`.
    * `expert` can be specified as a module class, it is used to generate
    `num_expert` expert modules.
Sengxian's avatar
Sengxian committed
123
124
125
126
127
128
129
130
131
132
133
    """

    def __init__(
        self,
        num_expert=32,
        d_model=1024,
        world_size=1,
        mp_group=None,
        top_k=2,
        gate=NaiveGate,
        expert=None,
134
        gate_hook=None,
Colin's avatar
Colin committed
135
136
        mask=None,
        mask_dict=None,
Sengxian's avatar
Sengxian committed
137
    ):
Rick Ho's avatar
Rick Ho committed
138
        super().__init__()
Rick Ho's avatar
Rick Ho committed
139
140
141
        self.num_expert = num_expert
        self.d_model = d_model
        self.world_size = world_size
Rick Ho's avatar
fmoefy  
Rick Ho committed
142
        self.mp_group = mp_group
Rick Ho's avatar
Rick Ho committed
143
144
145
146
147
148
        if mp_group is None:
            self.mp_size = 1
            self.mp_rank = 0
        else:
            self.mp_size = mp_group.size()
            self.mp_rank = mp_group.rank()
Rick Ho's avatar
Rick Ho committed
149
        self.top_k = top_k
Colin's avatar
Colin committed
150
151
152
153
154
        if type(expert) is list:
            self.experts = nn.ModuleList([e(d_model) for e in expert])
            self.experts_fused = False
            self.num_expert = num_expert = len(expert)
        elif expert is not None:
Rick Ho's avatar
Rick Ho committed
155
156
            self.experts = nn.ModuleList([expert(d_model)
                for _ in range(num_expert)])
157
158
159
            self.experts_fused = False
        else:
            self.experts_fused = True
Colin's avatar
Colin committed
160
        self.gate = gate(d_model, num_expert, world_size, top_k)
161
        self.gate_hook = gate_hook
Colin's avatar
Colin committed
162
163
        self.mask = mask
        self.mask_dict = mask_dict
Rick Ho's avatar
Rick Ho committed
164
165

    def expert_fn(self, inp, fwd_expert_count):
Sengxian's avatar
Sengxian committed
166
        r"""
Rick Ho's avatar
Rick Ho committed
167
168
        The default expert function which either calls the experts as a whole
        or as separate experts.
Sengxian's avatar
Sengxian committed
169
        """
170
        if self.experts_fused:
Rick Ho's avatar
Rick Ho committed
171
172
173
174
175
            return self.experts(inp, fwd_expert_count)
        outputs = []
        base_idx = 0
        for i in range(self.num_expert):
            batch_size = fwd_expert_count[i].item()
Sengxian's avatar
Sengxian committed
176
            inp_slice = inp[base_idx : base_idx + batch_size]
Rick Ho's avatar
Rick Ho committed
177
178
179
            outputs.append(self.experts[i](inp_slice))
            base_idx += batch_size
        return torch.cat(outputs, dim=0)
Rick Ho's avatar
Rick Ho committed
180

Sengxian's avatar
Sengxian committed
181
182
    def mark_parallel_comm(self, expert_dp_comm="none"):
        r"""
Rick Ho's avatar
Rick Ho committed
183
184
185
        Automatically mark the data parallel comms of the parameters within the
        module. This can be typically called at the end of the __init__ function
        in child classes.
Sengxian's avatar
Sengxian committed
186
        """
Rick Ho's avatar
Rick Ho committed
187
        if self.experts is not None:
188
            comm = expert_dp_comm
Rick Ho's avatar
Rick Ho committed
189
190
191
192
193
            if isinstance(self.experts, list):
                for e in self.experts:
                    mark_module_parallel_comm(e, comm)
            else:
                mark_module_parallel_comm(self.experts, comm)
Sengxian's avatar
Sengxian committed
194
        mark_module_parallel_comm(self.gate, "world")
Rick Ho's avatar
Rick Ho committed
195
196

    def forward(self, inp):
Sengxian's avatar
Sengxian committed
197
        r"""
Rick Ho's avatar
Rick Ho committed
198
199
200
        The FMoE module first computes gate output, and then conduct MoE forward
        according to the gate.  The score of the selected gate given by the
        expert is multiplied to the experts' output tensors as a weight.
Sengxian's avatar
Sengxian committed
201
        """
Rick Ho's avatar
Rick Ho committed
202
        if self.mp_size > 1:
Sengxian's avatar
Sengxian committed
203
            inp = Slice.apply(inp, self.mp_rank, self.mp_size, self.mp_group)
Sengxian's avatar
Sengxian committed
204

Rick Ho's avatar
Rick Ho committed
205
206
        gate_top_k_idx, gate_score = self.gate(inp)

Colin's avatar
Colin committed
207
        # delete masked tensors
Colin's avatar
Colin committed
208
        if self.mask is not None and self.mask_dict is not None:
Colin's avatar
Colin committed
209
210
211
212
213
214
215
            mask = self.mask.view(-1)
            # to: (BxL') x d_model
            inp = inp[mask == 0, :]
            gate_top_k_idx = gate_top_k_idx[mask == 0, :]

        fwd = _fmoe_general_global_forward(
            inp,
216
217
            gate_top_k_idx,
            self.expert_fn, self.num_expert, self.world_size
Sengxian's avatar
Sengxian committed
218
        )
219

Colin's avatar
Colin committed
220
        # recover deleted tensors
Colin's avatar
Colin committed
221
        if self.mask is not None and self.mask_dict is not None:
Colin's avatar
Colin committed
222
223
224
            # to: (BxL') x top_k x d_model
            fwd = fwd.view(-1, self.top_k, self.d_model)
            # to: (BxL) x top_k x d_model
Colin's avatar
Colin committed
225
            x = torch.zeros(mask.shape[0], self.top_k, self.d_model, device=fwd.device, dtype=fwd.dtype)
Colin's avatar
Colin committed
226
227
228
229
230
231
232
233
            # recover
            x[mask == 0] = fwd
            for k, v in self.mask_dict.items():
                x[mask == k] = v
        else:
            x = fwd.view(-1, self.top_k, self.d_model)

        gate_score = gate_score.view(x.shape[0], 1, self.top_k)
Rick Ho's avatar
Rick Ho committed
234
        x = torch.bmm(gate_score, x).reshape(-1, self.d_model)
Sengxian's avatar
Sengxian committed
235

Rick Ho's avatar
Rick Ho committed
236
        if self.mp_size > 1:
Sengxian's avatar
Sengxian committed
237
            x = AllGather.apply(x, self.mp_rank, self.mp_size, self.mp_group)
Rick Ho's avatar
Rick Ho committed
238
        return x