"vscode:/vscode.git/clone" did not exist on "2d9dd14f27c9041f47cb2bb6e9a7e6374ccd03e6"
layers.py 6.14 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
4
r'''
Layers that FMoE provides to users
'''
import torch
Rick Ho's avatar
Rick Ho committed
5
6
import torch.nn as nn

Rick Ho's avatar
Rick Ho committed
7
8
9
from .functions import moe_prepare_forward
from .functions import MOEScatter, MOEGather, MOELinear
from .functions import AllGather
Rick Ho's avatar
Rick Ho committed
10
from .gates import NaiveGate
Rick Ho's avatar
Rick Ho committed
11

Rick Ho's avatar
Rick Ho committed
12
13

class FMoELinear(nn.Module):
Rick Ho's avatar
Rick Ho committed
14
15
16
17
18
19
    r'''
    A linear layer that contains multiple experts.
    As multiple experts can be placed on the same worker, the computation can be
    performed in parallel to increase the performance.
    The FMoELinear module provides such function.
    '''
Rick Ho's avatar
Rick Ho committed
20
    def __init__(self, num_expert=32, in_feat=1024, out_feat=1024):
Rick Ho's avatar
Rick Ho committed
21
        super().__init__()
Rick Ho's avatar
Rick Ho committed
22
23
24
        self.num_expert = num_expert
        self.in_feat = in_feat
        self.out_feat = out_feat
25
        self.weight = nn.Parameter(torch.Tensor(num_expert, out_feat, in_feat))
Rick Ho's avatar
Rick Ho committed
26
27
28
        self.reset_parameters()

    def reset_parameters(self):
Rick Ho's avatar
Rick Ho committed
29
30
31
        r'''
        Initialize the weight as linear layers
        '''
Rick Ho's avatar
Rick Ho committed
32
        for i in range(self.num_expert):
Rick Ho's avatar
Rick Ho committed
33
34
            linear = nn.Linear(in_features=self.in_feat,
                    out_features=self.out_feat)
Rick Ho's avatar
Rick Ho committed
35
36
37
            self.weight.data[i] = linear.weight.data

    def forward(self, inp, fwd_expert_count):
Rick Ho's avatar
Rick Ho committed
38
39
40
        r'''
        Call MOE function
        '''
Rick Ho's avatar
Rick Ho committed
41
42
43
44
        return MOELinear.apply(inp, self.weight, fwd_expert_count)


def _fmoe_full_forward(inp, gate, linears, activation, num_expert, world_size):
Rick Ho's avatar
Rick Ho committed
45
46
47
48
49
50
51
52
53
54
55
56
    r'''
    A private function that performs the following steps to complete the MoE
    computation.
    * Count the number of tokens from each worker to each expert.
    * Send the features to their target position so that input features to each
    expert are contiguous in memory.
    * Perform the MLP of the experts by applying MoELinear and the activation in
    turns.
    * Gather the output features of experts back, and reorder them as sentences.
    Intermediate results like expert counts are hidden from users by this
    function.
    '''
57
    (
Rick Ho's avatar
Rick Ho committed
58
59
        pos, local_expert_count, global_expert_count, fwd_expert_count,
        fwd_batch_size
60
61
    ) = moe_prepare_forward(gate, num_expert, world_size)
    x = MOEScatter.apply(
Rick Ho's avatar
Rick Ho committed
62
63
        inp, pos, local_expert_count, global_expert_count, fwd_batch_size,
        world_size
64
    )
Rick Ho's avatar
Rick Ho committed
65
66
67
    for i, l in enumerate(linears):
        if i:
            x = activation(x)
68
        x = l(x, fwd_expert_count)
69
70
71
    x = MOEGather.apply(
        x, pos, local_expert_count, global_expert_count, inp.shape[0], world_size
    )
Rick Ho's avatar
Rick Ho committed
72
73
74
    return x


Rick Ho's avatar
Rick Ho committed
75
class FMoETransformerMLP(nn.Module):
Rick Ho's avatar
Rick Ho committed
76
77
78
79
80
81
82
83
84
85
86
87
88
    r'''
    A complete MoE MLP module in a Transformer block.
    * `num_expert` stands for the number of experts on **each** worker.
    * `world_size` stands for the total number of workers that contains
    different experts.
    * `mp_group` can be a torch's communication group, indicating that model
    parallel is applied across the group, which means that workers in the group
    hold the same copy of the input feature, and demands the same copy of the
    output. FMoE saves computation by slicing the input in the mp group and
    performing all-gather after the MLP computation.
    * `activation` is the activation function to be used in MLP in each expert.
    * `top_k` stands for the number of experts each token is going to.
    '''
89
90
91
92
93
94
    def __init__(
        self,
        num_expert=32,
        d_model=1024,
        d_hidden=4096,
        world_size=1,
Rick Ho's avatar
fmoefy  
Rick Ho committed
95
        mp_group=None,
96
        activation=torch.nn.functional.gelu,
Rick Ho's avatar
Rick Ho committed
97
        gate=NaiveGate,
98
        top_k=2,
Rick Ho's avatar
Rick Ho committed
99
        pre_lnorm=False
100
    ):
Rick Ho's avatar
Rick Ho committed
101
        super().__init__()
Rick Ho's avatar
Rick Ho committed
102
103
104
105
        self.num_expert = num_expert
        self.d_model = d_model
        self.d_hidden = d_hidden
        self.world_size = world_size
Rick Ho's avatar
fmoefy  
Rick Ho committed
106
        self.mp_group = mp_group
Rick Ho's avatar
Rick Ho committed
107
108
109
110
111
112
        if mp_group is None:
            self.mp_size = 1
            self.mp_rank = 0
        else:
            self.mp_size = mp_group.size()
            self.mp_rank = mp_group.rank()
Rick Ho's avatar
Rick Ho committed
113
114
        self.activation = activation
        self.pre_lnorm = pre_lnorm
Rick Ho's avatar
Rick Ho committed
115
        self.top_k = top_k
Rick Ho's avatar
Rick Ho committed
116
117

        self.htoh4 = FMoELinear(num_expert, d_model, d_hidden)
118
        self.h4toh = FMoELinear(num_expert, d_hidden, d_model)
Rick Ho's avatar
Rick Ho committed
119

120
121
122
123
124
125
        if self.world_size > self.mp_size:
            for p in self.htoh4.parameters():
                setattr(p, 'dp_comm', 'none')
            for p in self.h4toh.parameters():
                setattr(p, 'dp_comm', 'none')

Rick Ho's avatar
Rick Ho committed
126
        self.gate = gate(d_model, num_expert, world_size, top_k)
Rick Ho's avatar
Rick Ho committed
127
128
        for p in self.gate.parameters():
            setattr(p, 'dp_comm', 'world')
Rick Ho's avatar
Rick Ho committed
129
130

        self.layer_norm = nn.LayerNorm(d_model)
131
132
133
        self.bias = torch.nn.parameter.Parameter(
            torch.zeros(d_model, dtype=torch.float32)
        )
Rick Ho's avatar
Rick Ho committed
134

Sengxian's avatar
Sengxian committed
135
    def forward(self, inp: torch.Tensor):
Rick Ho's avatar
Rick Ho committed
136
137
138
139
140
        r'''
        The FMoETransformerMLP module automatically performs reshape and layer
        normalization. The score of the selected gate given by the expert is
        multiplied to the experts' output tensors as a weight.
        '''
141
142
143
        original_shape = inp.shape
        inp = inp.reshape(-1, self.d_model)

Rick Ho's avatar
Rick Ho committed
144
        if self.mp_size > 1:
145
            B: int = inp.shape[0]
Rick Ho's avatar
Rick Ho committed
146
147
            local_batch_size = B // self.mp_size
            batch_start = local_batch_size * self.mp_rank
Sengxian's avatar
Sengxian committed
148
            batch_end = min(batch_start + local_batch_size, B)
149
            inp = inp[batch_start:batch_end]
Sengxian's avatar
Sengxian committed
150

Rick Ho's avatar
Rick Ho committed
151
152
153
154
155
156
        residual = inp
        if self.pre_lnorm:
            inp = self.layer_norm(inp)

        gate_top_k_idx, gate_score = self.gate(inp)

157
158
159
        # to: (BxLxtop_k) x d_model
        inp = inp.repeat_interleave(repeats=self.top_k, dim=0)

160
161
162
163
164
165
166
167
168
        x = _fmoe_full_forward(
            inp,
            gate_top_k_idx,
            [self.htoh4, self.h4toh],
            self.activation,
            self.num_expert,
            self.world_size,
        )

169
170
171
172
        # to: (BxL) x top_k x d_model
        core_out = x.view(-1, self.top_k, self.d_model)
        # to: (BxL) x 1 x d_model
        core_out = torch.bmm(gate_score, core_out)
Rick Ho's avatar
Rick Ho committed
173
        output = core_out.reshape(residual.shape) + residual
Rick Ho's avatar
Rick Ho committed
174
175
176

        if not self.pre_lnorm:
            output = self.layer_norm(output)
Sengxian's avatar
Sengxian committed
177

Rick Ho's avatar
Rick Ho committed
178
        if self.mp_size > 1:
Rick Ho's avatar
Rick Ho committed
179
            output = AllGather.apply(output,
Rick Ho's avatar
Rick Ho committed
180
                    self.mp_rank, self.mp_size, self.mp_group)
Sengxian's avatar
Sengxian committed
181

182
        return output.reshape(original_shape), self.bias