layers.py 6.73 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
4
r'''
Layers that FMoE provides to users
'''
import torch
Rick Ho's avatar
Rick Ho committed
5
6
import torch.nn as nn

Rick Ho's avatar
Rick Ho committed
7
8
9
from .functions import moe_prepare_forward
from .functions import MOEScatter, MOEGather, MOELinear
from .functions import AllGather
Rick Ho's avatar
Rick Ho committed
10
from .gates import NaiveGate
Rick Ho's avatar
Rick Ho committed
11

Rick Ho's avatar
Rick Ho committed
12
13

class FMoELinear(nn.Module):
Rick Ho's avatar
Rick Ho committed
14
15
16
17
18
19
    r'''
    A linear layer that contains multiple experts.
    As multiple experts can be placed on the same worker, the computation can be
    performed in parallel to increase the performance.
    The FMoELinear module provides such function.
    '''
Rick Ho's avatar
Rick Ho committed
20
    def __init__(self, num_expert=32, in_feat=1024, out_feat=1024):
Rick Ho's avatar
Rick Ho committed
21
        super().__init__()
Rick Ho's avatar
Rick Ho committed
22
23
24
        self.num_expert = num_expert
        self.in_feat = in_feat
        self.out_feat = out_feat
25
        self.weight = nn.Parameter(torch.Tensor(num_expert, out_feat, in_feat))
Rick Ho's avatar
Rick Ho committed
26
27
28
        self.reset_parameters()

    def reset_parameters(self):
Rick Ho's avatar
Rick Ho committed
29
30
31
        r'''
        Initialize the weight as linear layers
        '''
Rick Ho's avatar
Rick Ho committed
32
        for i in range(self.num_expert):
Rick Ho's avatar
Rick Ho committed
33
34
            linear = nn.Linear(in_features=self.in_feat,
                    out_features=self.out_feat)
Rick Ho's avatar
Rick Ho committed
35
36
37
            self.weight.data[i] = linear.weight.data

    def forward(self, inp, fwd_expert_count):
Rick Ho's avatar
Rick Ho committed
38
39
40
        r'''
        Call MOE function
        '''
Rick Ho's avatar
Rick Ho committed
41
42
43
        return MOELinear.apply(inp, self.weight, fwd_expert_count)


Rick Ho's avatar
Rick Ho committed
44
45
46
47
48
49
50
51
52
53
def mark_module_parallel_comm(module, comm):
    r'''
    Mark all parameters in `module` as doing data parallel in `comm`, where
    `comm` may be one of `'world', 'dp', 'none'`.
    '''
    for p in module.parameters():
        setattr(p, 'dp_comm', comm)


def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size):
Rick Ho's avatar
Rick Ho committed
54
55
56
57
58
59
    r'''
    A private function that performs the following steps to complete the MoE
    computation.
    * Count the number of tokens from each worker to each expert.
    * Send the features to their target position so that input features to each
    expert are contiguous in memory.
Rick Ho's avatar
Rick Ho committed
60
    * Perform the forward computation of the experts using `expert_fn`
Rick Ho's avatar
Rick Ho committed
61
62
63
64
    * Gather the output features of experts back, and reorder them as sentences.
    Intermediate results like expert counts are hidden from users by this
    function.
    '''
65
    (
Rick Ho's avatar
Rick Ho committed
66
67
        pos, local_expert_count, global_expert_count, fwd_expert_count,
        fwd_batch_size
68
69
    ) = moe_prepare_forward(gate, num_expert, world_size)
    x = MOEScatter.apply(
Rick Ho's avatar
Rick Ho committed
70
71
        inp, pos, local_expert_count, global_expert_count, fwd_batch_size,
        world_size
72
    )
Rick Ho's avatar
Rick Ho committed
73
    x = expert_fn(x, fwd_expert_count)
74
75
76
    x = MOEGather.apply(
        x, pos, local_expert_count, global_expert_count, inp.shape[0], world_size
    )
Rick Ho's avatar
Rick Ho committed
77
78
79
    return x


Rick Ho's avatar
Rick Ho committed
80
81

class FMoE(nn.Module):
Rick Ho's avatar
Rick Ho committed
82
    r'''
Rick Ho's avatar
Rick Ho committed
83
84
    A general moe implementation that supports an arbitrary module as the expert
    Either `expert` or `expert_fn` is required.
Rick Ho's avatar
Rick Ho committed
85
86
87
88
89
90
91
92
93
    * `num_expert` stands for the number of experts on **each** worker.
    * `world_size` stands for the total number of workers that contains
    different experts.
    * `mp_group` can be a torch's communication group, indicating that model
    parallel is applied across the group, which means that workers in the group
    hold the same copy of the input feature, and demands the same copy of the
    output. FMoE saves computation by slicing the input in the mp group and
    performing all-gather after the MLP computation.
    * `top_k` stands for the number of experts each token is going to.
Rick Ho's avatar
Rick Ho committed
94
95
96
97
98
99
    * `gate` is a gate class which can found in `fmoe.gates`.
    * `expert` can be specified as a module class, it is used to generate
    `num_expert` expert modules.
    * `expert_fn` is specified as a callable object or a function, it will be
    called during forward, giving the input tensor (contiguous) and the array of
    the number of input feature to each expert as input.
Rick Ho's avatar
Rick Ho committed
100
    '''
Rick Ho's avatar
Rick Ho committed
101
102
    def __init__(self, num_expert=32, d_model=1024, world_size=1, mp_group=None,
            top_k=2, gate=NaiveGate, expert=None, expert_fn=None):
Rick Ho's avatar
Rick Ho committed
103
        super().__init__()
Rick Ho's avatar
Rick Ho committed
104
105
106
        self.num_expert = num_expert
        self.d_model = d_model
        self.world_size = world_size
Rick Ho's avatar
fmoefy  
Rick Ho committed
107
        self.mp_group = mp_group
Rick Ho's avatar
Rick Ho committed
108
109
110
111
112
113
        if mp_group is None:
            self.mp_size = 1
            self.mp_rank = 0
        else:
            self.mp_size = mp_group.size()
            self.mp_rank = mp_group.rank()
Rick Ho's avatar
Rick Ho committed
114
        self.top_k = top_k
Rick Ho's avatar
Rick Ho committed
115
        self.gate = gate(d_model, num_expert, world_size, top_k)
Rick Ho's avatar
Rick Ho committed
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
        if expert_fn is None:
            assert expert is not None, 'Either expert or expert_fn should be set'
            self.experts = [expert(d_model) for _ in range(num_expert)]
            def expert_fn(self, inp, fwd_expert_count):
                outputs = []
                base_idx = 0
                for i in range(self.num_expert):
                    batch_size = fwd_expert_count[i].item()
                    inp_slice = inp[base_idx:base_idx + batch_size]
                    outputs.append(self.experts[i](inp_slice))
                    base_idx += batch_size
                return torch.cat(outputs, dim=0)
        self.expert_fn = expert_fn

    def mark_parallel_comm(self):
Rick Ho's avatar
Rick Ho committed
131
        r'''
Rick Ho's avatar
Rick Ho committed
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
        Automatically mark the data parallel comms of the parameters within the
        module. This can be typically called at the end of the __init__ function
        in child classes.
        '''
        if self.experts is not None:
            if self.world_size > self.mp_size:
                comm = 'none'
            else:
                comm = 'dp'
            if isinstance(self.experts, list):
                for e in self.experts:
                    mark_module_parallel_comm(e, comm)
            else:
                mark_module_parallel_comm(self.experts, comm)
        mark_module_parallel_comm(self.gate, 'world')

    def forward(self, inp):
        r'''
        The FMoE module first computes gate output, and then conduct MoE forward
        according to the gate.  The score of the selected gate given by the
        expert is multiplied to the experts' output tensors as a weight.
Rick Ho's avatar
Rick Ho committed
153
        '''
Rick Ho's avatar
Rick Ho committed
154
        if self.mp_size > 1:
155
            B: int = inp.shape[0]
Rick Ho's avatar
Rick Ho committed
156
157
            local_batch_size = B // self.mp_size
            batch_start = local_batch_size * self.mp_rank
Sengxian's avatar
Sengxian committed
158
            batch_end = min(batch_start + local_batch_size, B)
159
            inp = inp[batch_start:batch_end]
Sengxian's avatar
Sengxian committed
160

Rick Ho's avatar
Rick Ho committed
161
        gate_top_k_idx, gate_score = self.gate(inp)
162
163
        # to: (BxLxtop_k) x d_model
        inp = inp.repeat_interleave(repeats=self.top_k, dim=0)
Rick Ho's avatar
Rick Ho committed
164
165
        x = _fmoe_general_global_forward(inp, gate_top_k_idx, self.expert_fn,
                self.num_expert, self.world_size)
166
        # to: (BxL) x top_k x d_model
Rick Ho's avatar
Rick Ho committed
167
168
169
        x = x.view(-1, self.top_k, self.d_model)
        # to: (BxL) x d_model
        x = torch.bmm(gate_score, x).reshape(-1, self.d_model)
Sengxian's avatar
Sengxian committed
170

Rick Ho's avatar
Rick Ho committed
171
        if self.mp_size > 1:
Rick Ho's avatar
Rick Ho committed
172
            x = AllGather.apply(x,
Rick Ho's avatar
Rick Ho committed
173
                    self.mp_rank, self.mp_size, self.mp_group)
Rick Ho's avatar
Rick Ho committed
174
        return x