layers.py 8.48 KB
Newer Older
Sengxian's avatar
Sengxian committed
1
r"""
Rick Ho's avatar
Rick Ho committed
2
Layers that FMoE provides to users
Sengxian's avatar
Sengxian committed
3
"""
Rick Ho's avatar
Rick Ho committed
4
import torch
Rick Ho's avatar
Rick Ho committed
5
import torch.nn as nn
6
import math
Rick Ho's avatar
Rick Ho committed
7

8
from .functions import prepare_forward, ensure_comm
Rick Ho's avatar
Rick Ho committed
9
from .functions import MOEScatter, MOEGather, MOELinear
Sengxian's avatar
Sengxian committed
10
from .functions import AllGather, Slice
Rick Ho's avatar
Rick Ho committed
11
from .gates import NaiveGate
Rick Ho's avatar
Rick Ho committed
12

Rick Ho's avatar
Rick Ho committed
13
14

class FMoELinear(nn.Module):
Sengxian's avatar
Sengxian committed
15
    r"""
Rick Ho's avatar
Rick Ho committed
16
17
18
19
    A linear layer that contains multiple experts.
    As multiple experts can be placed on the same worker, the computation can be
    performed in parallel to increase the performance.
    The FMoELinear module provides such function.
Sengxian's avatar
Sengxian committed
20
21
22
23
24
25
26
27
28
29
    """

    def __init__(
        self,
        num_expert: int,
        in_feat: int,
        out_feat: int,
        bias: bool = True,
        rank: int = 0,
    ):
Rick Ho's avatar
Rick Ho committed
30
        super().__init__()
Rick Ho's avatar
Rick Ho committed
31
32
33
        self.num_expert = num_expert
        self.in_feat = in_feat
        self.out_feat = out_feat
34
        self.rank = rank
35
        self.weight = nn.Parameter(torch.Tensor(num_expert, out_feat, in_feat))
36
        if bias:
37
            self.bias = nn.Parameter(torch.zeros(num_expert, out_feat))
38
        else:
Sengxian's avatar
Sengxian committed
39
            self.register_parameter("bias", None)
Rick Ho's avatar
Rick Ho committed
40

41
42
        self.reset_parameters()

Rick Ho's avatar
Rick Ho committed
43
    def forward(self, inp, fwd_expert_count):
Sengxian's avatar
Sengxian committed
44
        r"""
Rick Ho's avatar
Rick Ho committed
45
        Call MOE function
Sengxian's avatar
Sengxian committed
46
        """
47
        x = MOELinear.apply(inp, fwd_expert_count, self.weight, self.bias)
48
        return x
Rick Ho's avatar
Rick Ho committed
49

Jiezhong Qiu's avatar
Jiezhong Qiu committed
50
    def extra_repr(self) -> str:
Sengxian's avatar
Sengxian committed
51
52
53
54
55
56
57
        return "num_expert={}, in_features={}, \
        out_features={}, bias={}, rank={}".format(
            self.num_expert,
            self.in_feat,
            self.out_feat,
            self.bias is not None,
            self.rank,
Jiezhong Qiu's avatar
Jiezhong Qiu committed
58
59
        )

60
61
62
63
64
65
66
    def reset_parameters(self):
        # Approach is the same as in torch.nn.Linear
        # https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/linear.py#L88
        # bias is left to zero, similar as megatron

        torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))

Rick Ho's avatar
Rick Ho committed
67

Rick Ho's avatar
Rick Ho committed
68
def mark_module_parallel_comm(module, comm):
Sengxian's avatar
Sengxian committed
69
    r"""
Rick Ho's avatar
Rick Ho committed
70
71
    Mark all parameters in `module` as doing data parallel in `comm`, where
    `comm` may be one of `'world', 'dp', 'none'`.
Sengxian's avatar
Sengxian committed
72
    """
Rick Ho's avatar
Rick Ho committed
73
    for p in module.parameters():
Sengxian's avatar
Sengxian committed
74
        setattr(p, "dp_comm", comm)
Rick Ho's avatar
Rick Ho committed
75
76


77
def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size):
Sengxian's avatar
Sengxian committed
78
    r"""
Rick Ho's avatar
Rick Ho committed
79
80
81
82
83
    A private function that performs the following steps to complete the MoE
    computation.
    * Count the number of tokens from each worker to each expert.
    * Send the features to their target position so that input features to each
    expert are contiguous in memory.
Rick Ho's avatar
Rick Ho committed
84
    * Perform the forward computation of the experts using `expert_fn`
Rick Ho's avatar
Rick Ho committed
85
86
87
    * Gather the output features of experts back, and reorder them as sentences.
    Intermediate results like expert counts are hidden from users by this
    function.
Sengxian's avatar
Sengxian committed
88
    """
89
    (
Sengxian's avatar
Sengxian committed
90
91
92
93
94
        pos,
        local_expert_count,
        global_expert_count,
        fwd_expert_count,
        fwd_batch_size,
95
    ) = prepare_forward(gate, num_expert, world_size)
Rick Ho's avatar
Rick Ho committed
96
97
98
    topk = 1
    if len(gate.shape) == 2:
        topk = gate.shape[1]
99
    x = MOEScatter.apply(
Rick Ho's avatar
Rick Ho committed
100
        inp, pos // topk,
Rick Ho's avatar
Rick Ho committed
101
        local_expert_count, global_expert_count, fwd_batch_size, world_size
102
    )
Rick Ho's avatar
Rick Ho committed
103
    x = expert_fn(x, fwd_expert_count)
104
105
106
107
108

    out_batch_size = inp.shape[0]
    if len(gate.shape) == 2:
        out_batch_size *= gate.shape[1]

109
    x = MOEGather.apply(
110
111
112
        x, pos,
        local_expert_count, global_expert_count,
        out_batch_size, world_size
113
    )
Rick Ho's avatar
Rick Ho committed
114
115
116
    return x


Rick Ho's avatar
Rick Ho committed
117
class FMoE(nn.Module):
Sengxian's avatar
Sengxian committed
118
    r"""
Rick Ho's avatar
Rick Ho committed
119
120
    A general moe implementation that supports an arbitrary module as the
    expert.
Rick Ho's avatar
Rick Ho committed
121
122
123
124
125
126
127
128
129
    * `num_expert` stands for the number of experts on **each** worker.
    * `world_size` stands for the total number of workers that contains
    different experts.
    * `mp_group` can be a torch's communication group, indicating that model
    parallel is applied across the group, which means that workers in the group
    hold the same copy of the input feature, and demands the same copy of the
    output. FMoE saves computation by slicing the input in the mp group and
    performing all-gather after the MLP computation.
    * `top_k` stands for the number of experts each token is going to.
Rick Ho's avatar
Rick Ho committed
130
131
132
    * `gate` is a gate class which can found in `fmoe.gates`.
    * `expert` can be specified as a module class, it is used to generate
    `num_expert` expert modules.
Sengxian's avatar
Sengxian committed
133
134
135
136
137
138
139
140
    """

    def __init__(
        self,
        num_expert=32,
        d_model=1024,
        world_size=1,
        mp_group=None,
Rick Ho's avatar
Rick Ho committed
141
        moe_group=None,
Sengxian's avatar
Sengxian committed
142
143
144
        top_k=2,
        gate=NaiveGate,
        expert=None,
145
        gate_hook=None,
Colin's avatar
Colin committed
146
147
        mask=None,
        mask_dict=None,
Sengxian's avatar
Sengxian committed
148
    ):
Rick Ho's avatar
Rick Ho committed
149
        super().__init__()
Rick Ho's avatar
Rick Ho committed
150
151
152
        self.num_expert = num_expert
        self.d_model = d_model
        self.world_size = world_size
Rick Ho's avatar
fmoefy  
Rick Ho committed
153
        self.mp_group = mp_group
Rick Ho's avatar
Rick Ho committed
154
155
156
157
158
159
        if mp_group is None:
            self.mp_size = 1
            self.mp_rank = 0
        else:
            self.mp_size = mp_group.size()
            self.mp_rank = mp_group.rank()
Rick Ho's avatar
Rick Ho committed
160
        self.top_k = top_k
Colin's avatar
Colin committed
161
162
163
164
165
        if type(expert) is list:
            self.experts = nn.ModuleList([e(d_model) for e in expert])
            self.experts_fused = False
            self.num_expert = num_expert = len(expert)
        elif expert is not None:
Rick Ho's avatar
Rick Ho committed
166
167
            self.experts = nn.ModuleList([expert(d_model)
                for _ in range(num_expert)])
168
169
170
            self.experts_fused = False
        else:
            self.experts_fused = True
Colin's avatar
Colin committed
171
        self.gate = gate(d_model, num_expert, world_size, top_k)
172
        self.gate_hook = gate_hook
Colin's avatar
Colin committed
173
174
        self.mask = mask
        self.mask_dict = mask_dict
Rick Ho's avatar
Rick Ho committed
175
        self.moe_group = moe_group
Rick Ho's avatar
Rick Ho committed
176
177

    def expert_fn(self, inp, fwd_expert_count):
Sengxian's avatar
Sengxian committed
178
        r"""
Rick Ho's avatar
Rick Ho committed
179
180
        The default expert function which either calls the experts as a whole
        or as separate experts.
Sengxian's avatar
Sengxian committed
181
        """
182
        if self.experts_fused:
Rick Ho's avatar
Rick Ho committed
183
184
185
186
187
            return self.experts(inp, fwd_expert_count)
        outputs = []
        base_idx = 0
        for i in range(self.num_expert):
            batch_size = fwd_expert_count[i].item()
Sengxian's avatar
Sengxian committed
188
            inp_slice = inp[base_idx : base_idx + batch_size]
Rick Ho's avatar
Rick Ho committed
189
190
191
            outputs.append(self.experts[i](inp_slice))
            base_idx += batch_size
        return torch.cat(outputs, dim=0)
Rick Ho's avatar
Rick Ho committed
192

Sengxian's avatar
Sengxian committed
193
194
    def mark_parallel_comm(self, expert_dp_comm="none"):
        r"""
Rick Ho's avatar
Rick Ho committed
195
196
197
        Automatically mark the data parallel comms of the parameters within the
        module. This can be typically called at the end of the __init__ function
        in child classes.
Sengxian's avatar
Sengxian committed
198
        """
Rick Ho's avatar
Rick Ho committed
199
        if self.experts is not None:
200
            comm = expert_dp_comm
Rick Ho's avatar
Rick Ho committed
201
202
203
204
205
            if isinstance(self.experts, list):
                for e in self.experts:
                    mark_module_parallel_comm(e, comm)
            else:
                mark_module_parallel_comm(self.experts, comm)
Rick Ho's avatar
Rick Ho committed
206
        mark_module_parallel_comm(self.gate, "moe")
Rick Ho's avatar
Rick Ho committed
207
208

    def forward(self, inp):
Sengxian's avatar
Sengxian committed
209
        r"""
Rick Ho's avatar
Rick Ho committed
210
211
212
        The FMoE module first computes gate output, and then conduct MoE forward
        according to the gate.  The score of the selected gate given by the
        expert is multiplied to the experts' output tensors as a weight.
Sengxian's avatar
Sengxian committed
213
        """
214
215
        if self.world_size > 1:
            ensure_comm(inp, self.moe_group)
Rick Ho's avatar
Rick Ho committed
216
        if self.mp_size > 1:
Sengxian's avatar
Sengxian committed
217
            inp = Slice.apply(inp, self.mp_rank, self.mp_size, self.mp_group)
Sengxian's avatar
Sengxian committed
218

Rick Ho's avatar
Rick Ho committed
219
220
        gate_top_k_idx, gate_score = self.gate(inp)

221
222
223
        if self.gate_hook is not None:
            self.gate_hook(gate_top_k_idx, gate_score, None)

Colin's avatar
Colin committed
224
        # delete masked tensors
Colin's avatar
Colin committed
225
        if self.mask is not None and self.mask_dict is not None:
Colin's avatar
Colin committed
226
227
228
229
230
231
            mask = self.mask.view(-1)
            # to: (BxL') x d_model
            inp = inp[mask == 0, :]
            gate_top_k_idx = gate_top_k_idx[mask == 0, :]

        fwd = _fmoe_general_global_forward(
232
233
            inp, gate_top_k_idx,
            self.expert_fn, self.num_expert, self.world_size
Sengxian's avatar
Sengxian committed
234
        )
235

Colin's avatar
Colin committed
236
        # recover deleted tensors
Colin's avatar
Colin committed
237
        if self.mask is not None and self.mask_dict is not None:
Colin's avatar
Colin committed
238
239
240
            # to: (BxL') x top_k x d_model
            fwd = fwd.view(-1, self.top_k, self.d_model)
            # to: (BxL) x top_k x d_model
Colin's avatar
Colin committed
241
            x = torch.zeros(mask.shape[0], self.top_k, self.d_model, device=fwd.device, dtype=fwd.dtype)
Colin's avatar
Colin committed
242
243
244
245
246
247
248
249
            # recover
            x[mask == 0] = fwd
            for k, v in self.mask_dict.items():
                x[mask == k] = v
        else:
            x = fwd.view(-1, self.top_k, self.d_model)

        gate_score = gate_score.view(x.shape[0], 1, self.top_k)
Rick Ho's avatar
Rick Ho committed
250
        x = torch.bmm(gate_score, x).reshape(-1, self.d_model)
Sengxian's avatar
Sengxian committed
251

Rick Ho's avatar
Rick Ho committed
252
        if self.mp_size > 1:
Sengxian's avatar
Sengxian committed
253
            x = AllGather.apply(x, self.mp_rank, self.mp_size, self.mp_group)
Rick Ho's avatar
Rick Ho committed
254
        return x