layers.py 7.71 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
4
r'''
Layers that FMoE provides to users
'''
import torch
Rick Ho's avatar
Rick Ho committed
5
6
import torch.nn as nn

Rick Ho's avatar
Rick Ho committed
7
8
from .functions import moe_prepare_forward
from .functions import MOEScatter, MOEGather, MOELinear
Sengxian's avatar
Sengxian committed
9
from .functions import AllGather, Slice
Rick Ho's avatar
Rick Ho committed
10
from .gates import NaiveGate
Rick Ho's avatar
Rick Ho committed
11

Rick Ho's avatar
Rick Ho committed
12
13

class FMoELinear(nn.Module):
Rick Ho's avatar
Rick Ho committed
14
15
16
17
18
19
    r'''
    A linear layer that contains multiple experts.
    As multiple experts can be placed on the same worker, the computation can be
    performed in parallel to increase the performance.
    The FMoELinear module provides such function.
    '''
20
21
    def __init__(self, num_expert: int, in_feat: int, out_feat: int,
            bias: bool = True, rank: int = 0):
Rick Ho's avatar
Rick Ho committed
22
        super().__init__()
Rick Ho's avatar
Rick Ho committed
23
24
25
        self.num_expert = num_expert
        self.in_feat = in_feat
        self.out_feat = out_feat
26
        self.rank = rank
27
        self.weight = nn.Parameter(torch.Tensor(num_expert, out_feat, in_feat))
28
29
30
31
        if bias:
            self.bias = nn.Parameter(torch.Tensor(num_expert, out_feat))
        else:
            self.register_parameter('bias', None)
Rick Ho's avatar
Rick Ho committed
32
33

    def forward(self, inp, fwd_expert_count):
Rick Ho's avatar
Rick Ho committed
34
35
36
        r'''
        Call MOE function
        '''
37
        x = MOELinear.apply(inp, self.weight, fwd_expert_count)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
38
        if self.bias is not None:
39
40
41
42
43
44
45
46
47
            # TODO: torch.repeat_interleave seems have numerical
            # instability in backward, leading to incorrect
            # gradient computation for solution 1 and 2.
            # Solution 3 uses a for-loop to expand the bias,
            # but is 50% slower.
            # This part should finally goes to MOELinear.apply,
            # like MOELinear.apply(x, weight, bias, count)

            # Solution 1
Jiezhong Qiu's avatar
Jiezhong Qiu committed
48
            bias = torch.repeat_interleave(self.bias,
49
                fwd_expert_count.to(self.bias.device), dim=0)
50
51

            # Solution 2
52
53
54
            # bias_idx = torch.arange(self.num_expert)\
            #     .repeat_interleave(fwd_expert_count)
            # bias = self.bias[bias_idx]
55
56
57
58
59
60
61
62
63
64
65

            # Solution 3
            # bias = []
            # for i in range(self.num_expert):
            #    if fwd_expert_count[i] > 0:
            #        bias.append(
            #            self.bias[i].unsqueeze(0).expand(
            #                fwd_expert_count[i], -1
            #            )
            #        )
            # bias = torch.cat(bias, dim=0)
66
67
            x = x + bias
        return x
Rick Ho's avatar
Rick Ho committed
68

Jiezhong Qiu's avatar
Jiezhong Qiu committed
69
70
    def extra_repr(self) -> str:
        return 'num_expert={}, in_features={}, \
71
        out_features={}, bias={}, rank={}'.format(
Jiezhong Qiu's avatar
Jiezhong Qiu committed
72
73
74
75
                    self.num_expert, self.in_feat,
                    self.out_feat, self.bias is not None, self.rank
        )

Rick Ho's avatar
Rick Ho committed
76

Rick Ho's avatar
Rick Ho committed
77
78
79
80
81
82
83
84
85
86
def mark_module_parallel_comm(module, comm):
    r'''
    Mark all parameters in `module` as doing data parallel in `comm`, where
    `comm` may be one of `'world', 'dp', 'none'`.
    '''
    for p in module.parameters():
        setattr(p, 'dp_comm', comm)


def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size):
Rick Ho's avatar
Rick Ho committed
87
88
89
90
91
92
    r'''
    A private function that performs the following steps to complete the MoE
    computation.
    * Count the number of tokens from each worker to each expert.
    * Send the features to their target position so that input features to each
    expert are contiguous in memory.
Rick Ho's avatar
Rick Ho committed
93
    * Perform the forward computation of the experts using `expert_fn`
Rick Ho's avatar
Rick Ho committed
94
95
96
97
    * Gather the output features of experts back, and reorder them as sentences.
    Intermediate results like expert counts are hidden from users by this
    function.
    '''
98
    (
Rick Ho's avatar
Rick Ho committed
99
100
        pos, local_expert_count, global_expert_count, fwd_expert_count,
        fwd_batch_size
101
102
    ) = moe_prepare_forward(gate, num_expert, world_size)
    x = MOEScatter.apply(
Rick Ho's avatar
Rick Ho committed
103
104
        inp, pos, local_expert_count, global_expert_count, fwd_batch_size,
        world_size
105
    )
Rick Ho's avatar
Rick Ho committed
106
    x = expert_fn(x, fwd_expert_count)
107
108
109
    x = MOEGather.apply(
        x, pos, local_expert_count, global_expert_count, inp.shape[0], world_size
    )
Rick Ho's avatar
Rick Ho committed
110
111
112
    return x


Rick Ho's avatar
Rick Ho committed
113
class FMoE(nn.Module):
Rick Ho's avatar
Rick Ho committed
114
    r'''
Rick Ho's avatar
Rick Ho committed
115
116
    A general moe implementation that supports an arbitrary module as the
    expert.
Rick Ho's avatar
Rick Ho committed
117
118
119
120
121
122
123
124
125
    * `num_expert` stands for the number of experts on **each** worker.
    * `world_size` stands for the total number of workers that contains
    different experts.
    * `mp_group` can be a torch's communication group, indicating that model
    parallel is applied across the group, which means that workers in the group
    hold the same copy of the input feature, and demands the same copy of the
    output. FMoE saves computation by slicing the input in the mp group and
    performing all-gather after the MLP computation.
    * `top_k` stands for the number of experts each token is going to.
Rick Ho's avatar
Rick Ho committed
126
127
128
    * `gate` is a gate class which can found in `fmoe.gates`.
    * `expert` can be specified as a module class, it is used to generate
    `num_expert` expert modules.
Rick Ho's avatar
Rick Ho committed
129
    '''
Rick Ho's avatar
Rick Ho committed
130
    def __init__(self, num_expert=32, d_model=1024, world_size=1, mp_group=None,
Rick Ho's avatar
Rick Ho committed
131
            top_k=2, gate=NaiveGate, expert=None):
Rick Ho's avatar
Rick Ho committed
132
        super().__init__()
Rick Ho's avatar
Rick Ho committed
133
134
135
        self.num_expert = num_expert
        self.d_model = d_model
        self.world_size = world_size
Rick Ho's avatar
fmoefy  
Rick Ho committed
136
        self.mp_group = mp_group
Rick Ho's avatar
Rick Ho committed
137
138
139
140
141
142
        if mp_group is None:
            self.mp_size = 1
            self.mp_rank = 0
        else:
            self.mp_size = mp_group.size()
            self.mp_rank = mp_group.rank()
Rick Ho's avatar
Rick Ho committed
143
        self.top_k = top_k
Rick Ho's avatar
Rick Ho committed
144
        self.gate = gate(d_model, num_expert, world_size, top_k)
Rick Ho's avatar
Rick Ho committed
145
        if expert is not None:
146
            self.experts = nn.ModuleList([expert(d_model)
147
148
149
150
                for _ in range(num_expert)])
            self.experts_fused = False
        else:
            self.experts_fused = True
Rick Ho's avatar
Rick Ho committed
151
152

    def expert_fn(self, inp, fwd_expert_count):
Rick Ho's avatar
Rick Ho committed
153
154
155
156
        r'''
        The default expert function which either calls the experts as a whole
        or as separate experts.
        '''
157
        if self.experts_fused:
Rick Ho's avatar
Rick Ho committed
158
159
160
161
162
163
164
165
166
            return self.experts(inp, fwd_expert_count)
        outputs = []
        base_idx = 0
        for i in range(self.num_expert):
            batch_size = fwd_expert_count[i].item()
            inp_slice = inp[base_idx:base_idx + batch_size]
            outputs.append(self.experts[i](inp_slice))
            base_idx += batch_size
        return torch.cat(outputs, dim=0)
Rick Ho's avatar
Rick Ho committed
167

168
    def mark_parallel_comm(self, expert_dp_comm='none'):
Rick Ho's avatar
Rick Ho committed
169
        r'''
Rick Ho's avatar
Rick Ho committed
170
171
172
173
174
        Automatically mark the data parallel comms of the parameters within the
        module. This can be typically called at the end of the __init__ function
        in child classes.
        '''
        if self.experts is not None:
175
            comm = expert_dp_comm
Rick Ho's avatar
Rick Ho committed
176
177
178
179
180
181
182
183
184
185
186
187
            if isinstance(self.experts, list):
                for e in self.experts:
                    mark_module_parallel_comm(e, comm)
            else:
                mark_module_parallel_comm(self.experts, comm)
        mark_module_parallel_comm(self.gate, 'world')

    def forward(self, inp):
        r'''
        The FMoE module first computes gate output, and then conduct MoE forward
        according to the gate.  The score of the selected gate given by the
        expert is multiplied to the experts' output tensors as a weight.
Rick Ho's avatar
Rick Ho committed
188
        '''
Rick Ho's avatar
Rick Ho committed
189
        if self.mp_size > 1:
Sengxian's avatar
Sengxian committed
190
191
            inp = Slice.apply(inp,
                    self.mp_rank, self.mp_size, self.mp_group)
Sengxian's avatar
Sengxian committed
192

Rick Ho's avatar
Rick Ho committed
193
        gate_top_k_idx, gate_score = self.gate(inp)
194
195
        # to: (BxLxtop_k) x d_model
        inp = inp.repeat_interleave(repeats=self.top_k, dim=0)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
196
        x = _fmoe_general_global_forward(inp, gate_top_k_idx, self.expert_fn,
Rick Ho's avatar
Rick Ho committed
197
                self.num_expert, self.world_size)
198
        # to: (BxL) x top_k x d_model
Rick Ho's avatar
Rick Ho committed
199
200
201
        x = x.view(-1, self.top_k, self.d_model)
        # to: (BxL) x d_model
        x = torch.bmm(gate_score, x).reshape(-1, self.d_model)
Sengxian's avatar
Sengxian committed
202

Rick Ho's avatar
Rick Ho committed
203
        if self.mp_size > 1:
Rick Ho's avatar
Rick Ho committed
204
            x = AllGather.apply(x,
Rick Ho's avatar
Rick Ho committed
205
                    self.mp_rank, self.mp_size, self.mp_group)
Rick Ho's avatar
Rick Ho committed
206
        return x