layers.py 8.02 KB
Newer Older
Sengxian's avatar
Sengxian committed
1
r"""
Rick Ho's avatar
Rick Ho committed
2
Layers that FMoE provides to users
Sengxian's avatar
Sengxian committed
3
"""
Rick Ho's avatar
Rick Ho committed
4
import torch
Rick Ho's avatar
Rick Ho committed
5
6
import torch.nn as nn

Rick Ho's avatar
Rick Ho committed
7
from .functions import prepare_forward
Rick Ho's avatar
Rick Ho committed
8
from .functions import MOEScatter, MOEGather, MOELinear
Sengxian's avatar
Sengxian committed
9
from .functions import AllGather, Slice
Rick Ho's avatar
Rick Ho committed
10
from .gates import NaiveGate
Rick Ho's avatar
Rick Ho committed
11

Rick Ho's avatar
Rick Ho committed
12
13

class FMoELinear(nn.Module):
Sengxian's avatar
Sengxian committed
14
    r"""
Rick Ho's avatar
Rick Ho committed
15
16
17
18
    A linear layer that contains multiple experts.
    As multiple experts can be placed on the same worker, the computation can be
    performed in parallel to increase the performance.
    The FMoELinear module provides such function.
Sengxian's avatar
Sengxian committed
19
20
21
22
23
24
25
26
27
28
    """

    def __init__(
        self,
        num_expert: int,
        in_feat: int,
        out_feat: int,
        bias: bool = True,
        rank: int = 0,
    ):
Rick Ho's avatar
Rick Ho committed
29
        super().__init__()
Rick Ho's avatar
Rick Ho committed
30
31
32
        self.num_expert = num_expert
        self.in_feat = in_feat
        self.out_feat = out_feat
33
        self.rank = rank
34
        self.weight = nn.Parameter(torch.Tensor(num_expert, out_feat, in_feat))
35
36
37
        if bias:
            self.bias = nn.Parameter(torch.Tensor(num_expert, out_feat))
        else:
Sengxian's avatar
Sengxian committed
38
            self.register_parameter("bias", None)
Rick Ho's avatar
Rick Ho committed
39
40

    def forward(self, inp, fwd_expert_count):
Sengxian's avatar
Sengxian committed
41
        r"""
Rick Ho's avatar
Rick Ho committed
42
        Call MOE function
Sengxian's avatar
Sengxian committed
43
        """
44
        x = MOELinear.apply(inp, self.weight, fwd_expert_count)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
45
        if self.bias is not None:
46
47
48
49
50
51
52
53
54
            # TODO: torch.repeat_interleave seems have numerical
            # instability in backward, leading to incorrect
            # gradient computation for solution 1 and 2.
            # Solution 3 uses a for-loop to expand the bias,
            # but is 50% slower.
            # This part should finally goes to MOELinear.apply,
            # like MOELinear.apply(x, weight, bias, count)

            # Solution 1
Sengxian's avatar
Sengxian committed
55
56
57
            bias = torch.repeat_interleave(
                self.bias, fwd_expert_count.to(self.bias.device), dim=0
            )
58
59

            # Solution 2
60
61
62
            # bias_idx = torch.arange(self.num_expert)\
            #     .repeat_interleave(fwd_expert_count)
            # bias = self.bias[bias_idx]
63
64
65
66
67
68
69
70
71
72
73

            # Solution 3
            # bias = []
            # for i in range(self.num_expert):
            #    if fwd_expert_count[i] > 0:
            #        bias.append(
            #            self.bias[i].unsqueeze(0).expand(
            #                fwd_expert_count[i], -1
            #            )
            #        )
            # bias = torch.cat(bias, dim=0)
74
75
            x = x + bias
        return x
Rick Ho's avatar
Rick Ho committed
76

Jiezhong Qiu's avatar
Jiezhong Qiu committed
77
    def extra_repr(self) -> str:
Sengxian's avatar
Sengxian committed
78
79
80
81
82
83
84
        return "num_expert={}, in_features={}, \
        out_features={}, bias={}, rank={}".format(
            self.num_expert,
            self.in_feat,
            self.out_feat,
            self.bias is not None,
            self.rank,
Jiezhong Qiu's avatar
Jiezhong Qiu committed
85
86
        )

Rick Ho's avatar
Rick Ho committed
87

Rick Ho's avatar
Rick Ho committed
88
def mark_module_parallel_comm(module, comm):
Sengxian's avatar
Sengxian committed
89
    r"""
Rick Ho's avatar
Rick Ho committed
90
91
    Mark all parameters in `module` as doing data parallel in `comm`, where
    `comm` may be one of `'world', 'dp', 'none'`.
Sengxian's avatar
Sengxian committed
92
    """
Rick Ho's avatar
Rick Ho committed
93
    for p in module.parameters():
Sengxian's avatar
Sengxian committed
94
        setattr(p, "dp_comm", comm)
Rick Ho's avatar
Rick Ho committed
95
96
97


def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size):
Sengxian's avatar
Sengxian committed
98
    r"""
Rick Ho's avatar
Rick Ho committed
99
100
101
102
103
    A private function that performs the following steps to complete the MoE
    computation.
    * Count the number of tokens from each worker to each expert.
    * Send the features to their target position so that input features to each
    expert are contiguous in memory.
Rick Ho's avatar
Rick Ho committed
104
    * Perform the forward computation of the experts using `expert_fn`
Rick Ho's avatar
Rick Ho committed
105
106
107
    * Gather the output features of experts back, and reorder them as sentences.
    Intermediate results like expert counts are hidden from users by this
    function.
Sengxian's avatar
Sengxian committed
108
    """
109
    (
Sengxian's avatar
Sengxian committed
110
111
112
113
114
        pos,
        local_expert_count,
        global_expert_count,
        fwd_expert_count,
        fwd_batch_size,
Rick Ho's avatar
Rick Ho committed
115
    ) = prepare_forward(gate, num_expert, world_size)
Rick Ho's avatar
Rick Ho committed
116
117
118
    topk = 1
    if len(gate.shape) == 2:
        topk = gate.shape[1]
119
    x = MOEScatter.apply(
Rick Ho's avatar
Rick Ho committed
120
        inp, pos // topk,
Rick Ho's avatar
Rick Ho committed
121
        local_expert_count, global_expert_count, fwd_batch_size, world_size
122
    )
Rick Ho's avatar
Rick Ho committed
123
    x = expert_fn(x, fwd_expert_count)
124
125
126
127
128

    out_batch_size = inp.shape[0]
    if len(gate.shape) == 2:
        out_batch_size *= gate.shape[1]

129
    x = MOEGather.apply(
130
131
132
        x, pos,
        local_expert_count, global_expert_count,
        out_batch_size, world_size
133
    )
Rick Ho's avatar
Rick Ho committed
134
135
136
    return x


Rick Ho's avatar
Rick Ho committed
137
class FMoE(nn.Module):
Sengxian's avatar
Sengxian committed
138
    r"""
Rick Ho's avatar
Rick Ho committed
139
140
    A general moe implementation that supports an arbitrary module as the
    expert.
Rick Ho's avatar
Rick Ho committed
141
142
143
144
145
146
147
148
149
    * `num_expert` stands for the number of experts on **each** worker.
    * `world_size` stands for the total number of workers that contains
    different experts.
    * `mp_group` can be a torch's communication group, indicating that model
    parallel is applied across the group, which means that workers in the group
    hold the same copy of the input feature, and demands the same copy of the
    output. FMoE saves computation by slicing the input in the mp group and
    performing all-gather after the MLP computation.
    * `top_k` stands for the number of experts each token is going to.
Rick Ho's avatar
Rick Ho committed
150
151
152
    * `gate` is a gate class which can found in `fmoe.gates`.
    * `expert` can be specified as a module class, it is used to generate
    `num_expert` expert modules.
Sengxian's avatar
Sengxian committed
153
154
155
156
157
158
159
160
161
162
163
    """

    def __init__(
        self,
        num_expert=32,
        d_model=1024,
        world_size=1,
        mp_group=None,
        top_k=2,
        gate=NaiveGate,
        expert=None,
164
        gate_hook=None,
Sengxian's avatar
Sengxian committed
165
    ):
Rick Ho's avatar
Rick Ho committed
166
        super().__init__()
Rick Ho's avatar
Rick Ho committed
167
168
169
        self.num_expert = num_expert
        self.d_model = d_model
        self.world_size = world_size
Rick Ho's avatar
fmoefy  
Rick Ho committed
170
        self.mp_group = mp_group
Rick Ho's avatar
Rick Ho committed
171
172
173
174
175
176
        if mp_group is None:
            self.mp_size = 1
            self.mp_rank = 0
        else:
            self.mp_size = mp_group.size()
            self.mp_rank = mp_group.rank()
Rick Ho's avatar
Rick Ho committed
177
        self.top_k = top_k
Rick Ho's avatar
Rick Ho committed
178
        self.gate = gate(d_model, num_expert, world_size, top_k)
Rick Ho's avatar
Rick Ho committed
179
        if expert is not None:
Rick Ho's avatar
Rick Ho committed
180
181
            self.experts = nn.ModuleList([expert(d_model)
                for _ in range(num_expert)])
182
183
184
            self.experts_fused = False
        else:
            self.experts_fused = True
185
        self.gate_hook = gate_hook
Rick Ho's avatar
Rick Ho committed
186
187

    def expert_fn(self, inp, fwd_expert_count):
Sengxian's avatar
Sengxian committed
188
        r"""
Rick Ho's avatar
Rick Ho committed
189
190
        The default expert function which either calls the experts as a whole
        or as separate experts.
Sengxian's avatar
Sengxian committed
191
        """
192
        if self.experts_fused:
Rick Ho's avatar
Rick Ho committed
193
194
195
196
197
            return self.experts(inp, fwd_expert_count)
        outputs = []
        base_idx = 0
        for i in range(self.num_expert):
            batch_size = fwd_expert_count[i].item()
Sengxian's avatar
Sengxian committed
198
            inp_slice = inp[base_idx : base_idx + batch_size]
Rick Ho's avatar
Rick Ho committed
199
200
201
            outputs.append(self.experts[i](inp_slice))
            base_idx += batch_size
        return torch.cat(outputs, dim=0)
Rick Ho's avatar
Rick Ho committed
202

Sengxian's avatar
Sengxian committed
203
204
    def mark_parallel_comm(self, expert_dp_comm="none"):
        r"""
Rick Ho's avatar
Rick Ho committed
205
206
207
        Automatically mark the data parallel comms of the parameters within the
        module. This can be typically called at the end of the __init__ function
        in child classes.
Sengxian's avatar
Sengxian committed
208
        """
Rick Ho's avatar
Rick Ho committed
209
        if self.experts is not None:
210
            comm = expert_dp_comm
Rick Ho's avatar
Rick Ho committed
211
212
213
214
215
            if isinstance(self.experts, list):
                for e in self.experts:
                    mark_module_parallel_comm(e, comm)
            else:
                mark_module_parallel_comm(self.experts, comm)
Sengxian's avatar
Sengxian committed
216
        mark_module_parallel_comm(self.gate, "world")
Rick Ho's avatar
Rick Ho committed
217
218

    def forward(self, inp):
Sengxian's avatar
Sengxian committed
219
        r"""
Rick Ho's avatar
Rick Ho committed
220
221
222
        The FMoE module first computes gate output, and then conduct MoE forward
        according to the gate.  The score of the selected gate given by the
        expert is multiplied to the experts' output tensors as a weight.
Sengxian's avatar
Sengxian committed
223
        """
Rick Ho's avatar
Rick Ho committed
224
        if self.mp_size > 1:
Sengxian's avatar
Sengxian committed
225
            inp = Slice.apply(inp, self.mp_rank, self.mp_size, self.mp_group)
Sengxian's avatar
Sengxian committed
226

Rick Ho's avatar
Rick Ho committed
227
228
        gate_top_k_idx, gate_score = self.gate(inp)

Sengxian's avatar
Sengxian committed
229
        x = _fmoe_general_global_forward(
230
231
232
            inp, 
            gate_top_k_idx,
            self.expert_fn, self.num_expert, self.world_size
Sengxian's avatar
Sengxian committed
233
        )
234
235
236

        x = x.view(inp.shape[0], self.top_k, self.d_model)
        gate_score = gate_score.view(inp.shape[0], 1, self.top_k)
Rick Ho's avatar
Rick Ho committed
237
        x = torch.bmm(gate_score, x).reshape(-1, self.d_model)
Sengxian's avatar
Sengxian committed
238

Rick Ho's avatar
Rick Ho committed
239
        if self.mp_size > 1:
Sengxian's avatar
Sengxian committed
240
            x = AllGather.apply(x, self.mp_rank, self.mp_size, self.mp_group)
Rick Ho's avatar
Rick Ho committed
241
        return x