layers.py 7.03 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
4
r'''
Layers that FMoE provides to users
'''
import torch
Rick Ho's avatar
Rick Ho committed
5
import torch.nn as nn
Rick Ho's avatar
Rick Ho committed
6
import torch.nn.functional as F
Rick Ho's avatar
Rick Ho committed
7

Rick Ho's avatar
Rick Ho committed
8
9
10
11
from .functions import moe_prepare_forward
from .functions import MOEScatter, MOEGather, MOELinear
from .functions import AllGather

Rick Ho's avatar
Rick Ho committed
12
13

class FMoELinear(nn.Module):
Rick Ho's avatar
Rick Ho committed
14
15
16
17
18
19
    r'''
    A linear layer that contains multiple experts.
    As multiple experts can be placed on the same worker, the computation can be
    performed in parallel to increase the performance.
    The FMoELinear module provides such function.
    '''
Rick Ho's avatar
Rick Ho committed
20
    def __init__(self, num_expert=32, in_feat=1024, out_feat=1024):
Rick Ho's avatar
Rick Ho committed
21
        super().__init__()
Rick Ho's avatar
Rick Ho committed
22
23
24
        self.num_expert = num_expert
        self.in_feat = in_feat
        self.out_feat = out_feat
25
        self.weight = nn.Parameter(torch.Tensor(num_expert, out_feat, in_feat))
Rick Ho's avatar
Rick Ho committed
26
27
28
        self.reset_parameters()

    def reset_parameters(self):
Rick Ho's avatar
Rick Ho committed
29
30
31
        r'''
        Initialize the weight as linear layers
        '''
Rick Ho's avatar
Rick Ho committed
32
        for i in range(self.num_expert):
Rick Ho's avatar
Rick Ho committed
33
34
            linear = nn.Linear(in_features=self.in_feat,
                    out_features=self.out_feat)
Rick Ho's avatar
Rick Ho committed
35
36
37
            self.weight.data[i] = linear.weight.data

    def forward(self, inp, fwd_expert_count):
Rick Ho's avatar
Rick Ho committed
38
39
40
        r'''
        Call MOE function
        '''
Rick Ho's avatar
Rick Ho committed
41
42
43
        return MOELinear.apply(inp, self.weight, fwd_expert_count)


Rick Ho's avatar
Rick Ho committed
44
class FMoENaiveGate(nn.Module):
Rick Ho's avatar
Rick Ho committed
45
46
47
48
49
50
51
52
    r'''
    A naive gate implementation that defines the standard behavior of the gate
    which determines which experts the tokens are going to.
    Both the indecies and the score, or confidence, are output to the parent
    module.
    The load-balance strategies are also designed to be implemented within the
    `Gate` module.
    '''
Rick Ho's avatar
Rick Ho committed
53
    def __init__(self, d_model, num_expert, world_size, top_k=2):
Rick Ho's avatar
Rick Ho committed
54
        super().__init__()
Rick Ho's avatar
Rick Ho committed
55
        self.gate = nn.Linear(d_model, num_expert * world_size)
Rick Ho's avatar
Rick Ho committed
56
        self.top_k = top_k
Rick Ho's avatar
Rick Ho committed
57
58

    def forward(self, inp):
Rick Ho's avatar
Rick Ho committed
59
60
61
62
        r'''
        The naive implementation simply calculates the top-k of a linear layer's
        output.
        '''
Rick Ho's avatar
Rick Ho committed
63
        gate = self.gate(inp)
64
65
66
        gate_top_k_val, gate_top_k_idx = torch.topk(
            gate, k=self.top_k, dim=-1, largest=True, sorted=False
        )  # [.. x top_k]
Rick Ho's avatar
Rick Ho committed
67
68
        gate_top_k_val = gate_top_k_val.view(-1, self.top_k)

69
70
71
        # (BxL) x 1 x top_k
        gate_score = F.softmax(gate_top_k_val, dim=-1).unsqueeze(1)
        gate_top_k_idx = gate_top_k_idx.view(-1)  # (BxLxtop_k)
Rick Ho's avatar
Rick Ho committed
72
73
74
75
76

        return gate_top_k_idx, gate_score


def _fmoe_full_forward(inp, gate, linears, activation, num_expert, world_size):
Rick Ho's avatar
Rick Ho committed
77
78
79
80
81
82
83
84
85
86
87
88
    r'''
    A private function that performs the following steps to complete the MoE
    computation.
    * Count the number of tokens from each worker to each expert.
    * Send the features to their target position so that input features to each
    expert are contiguous in memory.
    * Perform the MLP of the experts by applying MoELinear and the activation in
    turns.
    * Gather the output features of experts back, and reorder them as sentences.
    Intermediate results like expert counts are hidden from users by this
    function.
    '''
89
    (
Rick Ho's avatar
Rick Ho committed
90
91
        pos, local_expert_count, global_expert_count, fwd_expert_count,
        fwd_batch_size
92
93
    ) = moe_prepare_forward(gate, num_expert, world_size)
    x = MOEScatter.apply(
Rick Ho's avatar
Rick Ho committed
94
95
        inp, pos, local_expert_count, global_expert_count, fwd_batch_size,
        world_size
96
    )
Rick Ho's avatar
Rick Ho committed
97
98
99
    for i, l in enumerate(linears):
        if i:
            x = activation(x)
100
        x = l(x, fwd_expert_count)
101
102
103
    x = MOEGather.apply(
        x, pos, local_expert_count, global_expert_count, inp.shape[0], world_size
    )
Rick Ho's avatar
Rick Ho committed
104
105
106
    return x


Rick Ho's avatar
Rick Ho committed
107
class FMoETransformerMLP(nn.Module):
Rick Ho's avatar
Rick Ho committed
108
109
110
111
112
113
114
115
116
117
118
119
120
    r'''
    A complete MoE MLP module in a Transformer block.
    * `num_expert` stands for the number of experts on **each** worker.
    * `world_size` stands for the total number of workers that contains
    different experts.
    * `mp_group` can be a torch's communication group, indicating that model
    parallel is applied across the group, which means that workers in the group
    hold the same copy of the input feature, and demands the same copy of the
    output. FMoE saves computation by slicing the input in the mp group and
    performing all-gather after the MLP computation.
    * `activation` is the activation function to be used in MLP in each expert.
    * `top_k` stands for the number of experts each token is going to.
    '''
121
122
123
124
125
126
    def __init__(
        self,
        num_expert=32,
        d_model=1024,
        d_hidden=4096,
        world_size=1,
Rick Ho's avatar
fmoefy  
Rick Ho committed
127
        mp_group=None,
128
129
        activation=torch.nn.functional.gelu,
        top_k=2,
Rick Ho's avatar
Rick Ho committed
130
        pre_lnorm=False
131
    ):
Rick Ho's avatar
Rick Ho committed
132
        super().__init__()
Rick Ho's avatar
Rick Ho committed
133
134
135
136
        self.num_expert = num_expert
        self.d_model = d_model
        self.d_hidden = d_hidden
        self.world_size = world_size
Rick Ho's avatar
fmoefy  
Rick Ho committed
137
        self.mp_group = mp_group
Rick Ho's avatar
Rick Ho committed
138
139
140
141
142
143
        if mp_group is None:
            self.mp_size = 1
            self.mp_rank = 0
        else:
            self.mp_size = mp_group.size()
            self.mp_rank = mp_group.rank()
Rick Ho's avatar
Rick Ho committed
144
145
        self.activation = activation
        self.pre_lnorm = pre_lnorm
Rick Ho's avatar
Rick Ho committed
146
        self.top_k = top_k
Rick Ho's avatar
Rick Ho committed
147
148

        self.htoh4 = FMoELinear(num_expert, d_model, d_hidden)
149
        self.h4toh = FMoELinear(num_expert, d_hidden, d_model)
Rick Ho's avatar
Rick Ho committed
150

Rick Ho's avatar
Rick Ho committed
151
        self.gate = FMoENaiveGate(d_model, num_expert, world_size, top_k)
Rick Ho's avatar
Rick Ho committed
152
153
        for p in self.gate.parameters():
            setattr(p, 'dp_comm', 'world')
Rick Ho's avatar
Rick Ho committed
154
155

        self.layer_norm = nn.LayerNorm(d_model)
156
157
158
        self.bias = torch.nn.parameter.Parameter(
            torch.zeros(d_model, dtype=torch.float32)
        )
Rick Ho's avatar
Rick Ho committed
159

Sengxian's avatar
Sengxian committed
160
    def forward(self, inp: torch.Tensor):
Rick Ho's avatar
Rick Ho committed
161
162
163
164
165
        r'''
        The FMoETransformerMLP module automatically performs reshape and layer
        normalization. The score of the selected gate given by the expert is
        multiplied to the experts' output tensors as a weight.
        '''
166
167
168
        original_shape = inp.shape
        inp = inp.reshape(-1, self.d_model)

Rick Ho's avatar
Rick Ho committed
169
        if self.mp_size > 1:
170
            B: int = inp.shape[0]
Rick Ho's avatar
Rick Ho committed
171
172
            local_batch_size = B // self.mp_size
            batch_start = local_batch_size * self.mp_rank
Sengxian's avatar
Sengxian committed
173
            batch_end = min(batch_start + local_batch_size, B)
174
            inp = inp[batch_start:batch_end]
Sengxian's avatar
Sengxian committed
175

Rick Ho's avatar
Rick Ho committed
176
177
178
179
180
181
        residual = inp
        if self.pre_lnorm:
            inp = self.layer_norm(inp)

        gate_top_k_idx, gate_score = self.gate(inp)

182
183
184
        # to: (BxLxtop_k) x d_model
        inp = inp.repeat_interleave(repeats=self.top_k, dim=0)

185
186
187
188
189
190
191
192
193
        x = _fmoe_full_forward(
            inp,
            gate_top_k_idx,
            [self.htoh4, self.h4toh],
            self.activation,
            self.num_expert,
            self.world_size,
        )

194
195
196
197
        # to: (BxL) x top_k x d_model
        core_out = x.view(-1, self.top_k, self.d_model)
        # to: (BxL) x 1 x d_model
        core_out = torch.bmm(gate_score, core_out)
Rick Ho's avatar
Rick Ho committed
198
        output = core_out.reshape(residual.shape) + residual
Rick Ho's avatar
Rick Ho committed
199
200
201

        if not self.pre_lnorm:
            output = self.layer_norm(output)
Sengxian's avatar
Sengxian committed
202

Rick Ho's avatar
Rick Ho committed
203
        if self.mp_size > 1:
Rick Ho's avatar
Rick Ho committed
204
            output = AllGather.apply(output,
Rick Ho's avatar
Rick Ho committed
205
                    self.mp_rank, self.mp_size, self.mp_group)
Sengxian's avatar
Sengxian committed
206

207
        return output.reshape(original_shape), self.bias