layers.py 8.58 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
r'''
Layers that FMoE provides to users
'''
Jiezhong Qiu's avatar
Jiezhong Qiu committed
4
import math
Rick Ho's avatar
Rick Ho committed
5
import torch
Rick Ho's avatar
Rick Ho committed
6
import torch.nn as nn
7
import numpy as np
Rick Ho's avatar
Rick Ho committed
8

Rick Ho's avatar
Rick Ho committed
9
10
from .functions import moe_prepare_forward
from .functions import MOEScatter, MOEGather, MOELinear
Sengxian's avatar
Sengxian committed
11
from .functions import AllGather, Slice
Rick Ho's avatar
Rick Ho committed
12
from .gates import NaiveGate
Rick Ho's avatar
Rick Ho committed
13

Rick Ho's avatar
Rick Ho committed
14
15

class FMoELinear(nn.Module):
Rick Ho's avatar
Rick Ho committed
16
17
18
19
20
21
    r'''
    A linear layer that contains multiple experts.
    As multiple experts can be placed on the same worker, the computation can be
    performed in parallel to increase the performance.
    The FMoELinear module provides such function.
    '''
22
23
    def __init__(self, num_expert: int, in_feat: int, out_feat: int,
            bias: bool = True, rank: int = 0):
Rick Ho's avatar
Rick Ho committed
24
        super().__init__()
Rick Ho's avatar
Rick Ho committed
25
26
27
        self.num_expert = num_expert
        self.in_feat = in_feat
        self.out_feat = out_feat
28
        self.rank = rank
29
        self.weight = nn.Parameter(torch.Tensor(num_expert, out_feat, in_feat))
30
31
32
33
        if bias:
            self.bias = nn.Parameter(torch.Tensor(num_expert, out_feat))
        else:
            self.register_parameter('bias', None)
Rick Ho's avatar
Rick Ho committed
34
35
36
        self.reset_parameters()

    def reset_parameters(self):
Rick Ho's avatar
Rick Ho committed
37
38
39
        r'''
        Initialize the weight as linear layers
        '''
40
        rng = np.random.default_rng(np.random.randint(2048) + self.rank)
41

Jiezhong Qiu's avatar
Jiezhong Qiu committed
42
        # copied from torch.nn.init.kaiming_uniform_
43
44
45
        fan = nn.init._calculate_correct_fan(self.weight[0], 'fan_in')
        gain = nn.init.calculate_gain('leaky_relu', math.sqrt(5))
        std = gain / math.sqrt(fan)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
46
        bound = math.sqrt(3.0) * std
47
48
        device = self.weight.device
        dtype = self.weight.dtype
49
        weight = rng.uniform(-bound, bound, size=tuple(self.weight.size()))
Rick Ho's avatar
Rick Ho committed
50
        self.weight.data = torch.tensor(weight, dtype=dtype, device=device)
51
52
53
54
55
56

        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight[0])
            bound = 1 / math.sqrt(fan_in)
            bias = rng.uniform(-bound, bound, size=tuple(self.bias.size()))
            self.bias.data = torch.tensor(bias, dtype=dtype, device=device)
Rick Ho's avatar
Rick Ho committed
57
58

    def forward(self, inp, fwd_expert_count):
Rick Ho's avatar
Rick Ho committed
59
60
61
        r'''
        Call MOE function
        '''
62
        x = MOELinear.apply(inp, self.weight, fwd_expert_count)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
63
        if self.bias is not None:
64
65
66
67
68
69
70
71
72
            # TODO: torch.repeat_interleave seems have numerical
            # instability in backward, leading to incorrect
            # gradient computation for solution 1 and 2.
            # Solution 3 uses a for-loop to expand the bias,
            # but is 50% slower.
            # This part should finally goes to MOELinear.apply,
            # like MOELinear.apply(x, weight, bias, count)

            # Solution 1
73
74
            bias = torch.repeat_interleave(self.bias,
                fwd_expert_count.to(self.bias.device), dim=0)
75
76

            # Solution 2
77
78
79
            # bias_idx = torch.arange(self.num_expert)\
            #     .repeat_interleave(fwd_expert_count)
            # bias = self.bias[bias_idx]
80
81
82
83
84
85
86
87
88
89
90

            # Solution 3
            # bias = []
            # for i in range(self.num_expert):
            #    if fwd_expert_count[i] > 0:
            #        bias.append(
            #            self.bias[i].unsqueeze(0).expand(
            #                fwd_expert_count[i], -1
            #            )
            #        )
            # bias = torch.cat(bias, dim=0)
91
92
            x = x + bias
        return x
Rick Ho's avatar
Rick Ho committed
93

Jiezhong Qiu's avatar
Jiezhong Qiu committed
94
95
    def extra_repr(self) -> str:
        return 'num_expert={}, in_features={}, \
96
        out_features={}, bias={}, rank={}'.format(
Jiezhong Qiu's avatar
Jiezhong Qiu committed
97
98
99
100
                    self.num_expert, self.in_feat,
                    self.out_feat, self.bias is not None, self.rank
        )

Rick Ho's avatar
Rick Ho committed
101

Rick Ho's avatar
Rick Ho committed
102
103
104
105
106
107
108
109
110
111
def mark_module_parallel_comm(module, comm):
    r'''
    Mark all parameters in `module` as doing data parallel in `comm`, where
    `comm` may be one of `'world', 'dp', 'none'`.
    '''
    for p in module.parameters():
        setattr(p, 'dp_comm', comm)


def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size):
Rick Ho's avatar
Rick Ho committed
112
113
114
115
116
117
    r'''
    A private function that performs the following steps to complete the MoE
    computation.
    * Count the number of tokens from each worker to each expert.
    * Send the features to their target position so that input features to each
    expert are contiguous in memory.
Rick Ho's avatar
Rick Ho committed
118
    * Perform the forward computation of the experts using `expert_fn`
Rick Ho's avatar
Rick Ho committed
119
120
121
122
    * Gather the output features of experts back, and reorder them as sentences.
    Intermediate results like expert counts are hidden from users by this
    function.
    '''
123
    (
Rick Ho's avatar
Rick Ho committed
124
125
        pos, local_expert_count, global_expert_count, fwd_expert_count,
        fwd_batch_size
126
127
    ) = moe_prepare_forward(gate, num_expert, world_size)
    x = MOEScatter.apply(
Rick Ho's avatar
Rick Ho committed
128
129
        inp, pos, local_expert_count, global_expert_count, fwd_batch_size,
        world_size
130
    )
Rick Ho's avatar
Rick Ho committed
131
    x = expert_fn(x, fwd_expert_count)
132
133
134
    x = MOEGather.apply(
        x, pos, local_expert_count, global_expert_count, inp.shape[0], world_size
    )
Rick Ho's avatar
Rick Ho committed
135
136
137
    return x


Rick Ho's avatar
Rick Ho committed
138
class FMoE(nn.Module):
Rick Ho's avatar
Rick Ho committed
139
    r'''
Rick Ho's avatar
Rick Ho committed
140
141
    A general moe implementation that supports an arbitrary module as the
    expert.
Rick Ho's avatar
Rick Ho committed
142
143
144
145
146
147
148
149
150
    * `num_expert` stands for the number of experts on **each** worker.
    * `world_size` stands for the total number of workers that contains
    different experts.
    * `mp_group` can be a torch's communication group, indicating that model
    parallel is applied across the group, which means that workers in the group
    hold the same copy of the input feature, and demands the same copy of the
    output. FMoE saves computation by slicing the input in the mp group and
    performing all-gather after the MLP computation.
    * `top_k` stands for the number of experts each token is going to.
Rick Ho's avatar
Rick Ho committed
151
152
153
    * `gate` is a gate class which can found in `fmoe.gates`.
    * `expert` can be specified as a module class, it is used to generate
    `num_expert` expert modules.
Rick Ho's avatar
Rick Ho committed
154
    '''
Rick Ho's avatar
Rick Ho committed
155
    def __init__(self, num_expert=32, d_model=1024, world_size=1, mp_group=None,
Rick Ho's avatar
Rick Ho committed
156
            top_k=2, gate=NaiveGate, expert=None):
Rick Ho's avatar
Rick Ho committed
157
        super().__init__()
Rick Ho's avatar
Rick Ho committed
158
159
160
        self.num_expert = num_expert
        self.d_model = d_model
        self.world_size = world_size
Rick Ho's avatar
fmoefy  
Rick Ho committed
161
        self.mp_group = mp_group
Rick Ho's avatar
Rick Ho committed
162
163
164
165
166
167
        if mp_group is None:
            self.mp_size = 1
            self.mp_rank = 0
        else:
            self.mp_size = mp_group.size()
            self.mp_rank = mp_group.rank()
Rick Ho's avatar
Rick Ho committed
168
        self.top_k = top_k
Rick Ho's avatar
Rick Ho committed
169
        self.gate = gate(d_model, num_expert, world_size, top_k)
Rick Ho's avatar
Rick Ho committed
170
        if expert is not None:
171
            self.experts = nn.ModuleList([expert(d_model)
172
173
174
175
                for _ in range(num_expert)])
            self.experts_fused = False
        else:
            self.experts_fused = True
Rick Ho's avatar
Rick Ho committed
176
177

    def expert_fn(self, inp, fwd_expert_count):
178
        if self.experts_fused:
Rick Ho's avatar
Rick Ho committed
179
180
181
182
183
184
185
186
187
            return self.experts(inp, fwd_expert_count)
        outputs = []
        base_idx = 0
        for i in range(self.num_expert):
            batch_size = fwd_expert_count[i].item()
            inp_slice = inp[base_idx:base_idx + batch_size]
            outputs.append(self.experts[i](inp_slice))
            base_idx += batch_size
        return torch.cat(outputs, dim=0)
Rick Ho's avatar
Rick Ho committed
188

189
    def mark_parallel_comm(self, expert_dp_comm='none'):
Rick Ho's avatar
Rick Ho committed
190
        r'''
Rick Ho's avatar
Rick Ho committed
191
192
193
194
195
        Automatically mark the data parallel comms of the parameters within the
        module. This can be typically called at the end of the __init__ function
        in child classes.
        '''
        if self.experts is not None:
196
            comm = expert_dp_comm
Rick Ho's avatar
Rick Ho committed
197
198
199
200
201
202
203
204
205
206
207
208
            if isinstance(self.experts, list):
                for e in self.experts:
                    mark_module_parallel_comm(e, comm)
            else:
                mark_module_parallel_comm(self.experts, comm)
        mark_module_parallel_comm(self.gate, 'world')

    def forward(self, inp):
        r'''
        The FMoE module first computes gate output, and then conduct MoE forward
        according to the gate.  The score of the selected gate given by the
        expert is multiplied to the experts' output tensors as a weight.
Rick Ho's avatar
Rick Ho committed
209
        '''
Rick Ho's avatar
Rick Ho committed
210
        if self.mp_size > 1:
Sengxian's avatar
Sengxian committed
211
212
            inp = Slice.apply(inp,
                    self.mp_rank, self.mp_size, self.mp_group)
Sengxian's avatar
Sengxian committed
213

Rick Ho's avatar
Rick Ho committed
214
        gate_top_k_idx, gate_score = self.gate(inp)
215
216
        # to: (BxLxtop_k) x d_model
        inp = inp.repeat_interleave(repeats=self.top_k, dim=0)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
217
        x = _fmoe_general_global_forward(inp, gate_top_k_idx, self.expert_fn,
Rick Ho's avatar
Rick Ho committed
218
                self.num_expert, self.world_size)
219
        # to: (BxL) x top_k x d_model
Rick Ho's avatar
Rick Ho committed
220
221
222
        x = x.view(-1, self.top_k, self.d_model)
        # to: (BxL) x d_model
        x = torch.bmm(gate_score, x).reshape(-1, self.d_model)
Sengxian's avatar
Sengxian committed
223

Rick Ho's avatar
Rick Ho committed
224
        if self.mp_size > 1:
Rick Ho's avatar
Rick Ho committed
225
            x = AllGather.apply(x,
Rick Ho's avatar
Rick Ho committed
226
                    self.mp_rank, self.mp_size, self.mp_group)
Rick Ho's avatar
Rick Ho committed
227
        return x