layers.py 7.14 KB
Newer Older
Sengxian's avatar
Sengxian committed
1
r"""
Rick Ho's avatar
Rick Ho committed
2
FMoE core layer
Sengxian's avatar
Sengxian committed
3
"""
Rick Ho's avatar
Rick Ho committed
4
import torch
Rick Ho's avatar
Rick Ho committed
5
6
import torch.nn as nn

7
from .functions import prepare_forward, ensure_comm
Rick Ho's avatar
Rick Ho committed
8
from .functions import MOEScatter, MOEGather
Sengxian's avatar
Sengxian committed
9
from .functions import AllGather, Slice
Rick Ho's avatar
Rick Ho committed
10
from .gates import NaiveGate
Rick Ho's avatar
Rick Ho committed
11

Rick Ho's avatar
Rick Ho committed
12
13


Rick Ho's avatar
Rick Ho committed
14
def mark_module_parallel_comm(module, comm):
Sengxian's avatar
Sengxian committed
15
    r"""
Rick Ho's avatar
Rick Ho committed
16
17
    Mark all parameters in `module` as doing data parallel in `comm`, where
    `comm` may be one of `'world', 'dp', 'none'`.
Sengxian's avatar
Sengxian committed
18
    """
Rick Ho's avatar
Rick Ho committed
19
    for p in module.parameters():
Sengxian's avatar
Sengxian committed
20
        setattr(p, "dp_comm", comm)
Rick Ho's avatar
Rick Ho committed
21
22


23
def _fmoe_general_global_forward(inp, gate, expert_fn, num_expert, world_size):
Sengxian's avatar
Sengxian committed
24
    r"""
Rick Ho's avatar
Rick Ho committed
25
26
27
28
29
    A private function that performs the following steps to complete the MoE
    computation.
    * Count the number of tokens from each worker to each expert.
    * Send the features to their target position so that input features to each
    expert are contiguous in memory.
Rick Ho's avatar
Rick Ho committed
30
    * Perform the forward computation of the experts using `expert_fn`
Rick Ho's avatar
Rick Ho committed
31
32
33
    * Gather the output features of experts back, and reorder them as sentences.
    Intermediate results like expert counts are hidden from users by this
    function.
Sengxian's avatar
Sengxian committed
34
    """
35
    (
Sengxian's avatar
Sengxian committed
36
37
38
39
40
        pos,
        local_expert_count,
        global_expert_count,
        fwd_expert_count,
        fwd_batch_size,
41
    ) = prepare_forward(gate, num_expert, world_size)
Rick Ho's avatar
Rick Ho committed
42
43
44
    topk = 1
    if len(gate.shape) == 2:
        topk = gate.shape[1]
45
    x = MOEScatter.apply(
Rick Ho's avatar
Rick Ho committed
46
        inp, pos // topk,
Rick Ho's avatar
Rick Ho committed
47
        local_expert_count, global_expert_count, fwd_batch_size, world_size
48
    )
Rick Ho's avatar
Rick Ho committed
49
    x = expert_fn(x, fwd_expert_count)
50
51
52
53
54

    out_batch_size = inp.shape[0]
    if len(gate.shape) == 2:
        out_batch_size *= gate.shape[1]

55
    x = MOEGather.apply(
56
57
58
        x, pos,
        local_expert_count, global_expert_count,
        out_batch_size, world_size
59
    )
Rick Ho's avatar
Rick Ho committed
60
61
62
    return x


Rick Ho's avatar
Rick Ho committed
63
class FMoE(nn.Module):
Sengxian's avatar
Sengxian committed
64
    r"""
Rick Ho's avatar
Rick Ho committed
65
66
    A general moe implementation that supports an arbitrary module as the
    expert.
Rick Ho's avatar
Rick Ho committed
67
68
69
    * `num_expert` stands for the number of experts on **each** worker.
    * `world_size` stands for the total number of workers that contains
    different experts.
Rick Ho's avatar
Rick Ho committed
70
71
72
73
74
75
    * `slice_group` can be a torch's communication group, indicating that
    specific model parallel is applied across the group, and workers in the
    group hold the same copy of input feature, and requires the same copy of
    the output. For each worker, FMoE only computes the output of a certain
    slice of the input batch, and will all-gather the outputs after
    computation.  
Rick Ho's avatar
Rick Ho committed
76
    * `top_k` stands for the number of experts each token is going to.
Rick Ho's avatar
Rick Ho committed
77
78
79
    * `gate` is a gate class which can found in `fmoe.gates`.
    * `expert` can be specified as a module class, it is used to generate
    `num_expert` expert modules.
Sengxian's avatar
Sengxian committed
80
81
82
83
84
85
86
    """

    def __init__(
        self,
        num_expert=32,
        d_model=1024,
        world_size=1,
Rick Ho's avatar
Rick Ho committed
87
88
        mp_group=None, # being deprecated
        slice_group=None,
Rick Ho's avatar
Rick Ho committed
89
        moe_group=None,
Sengxian's avatar
Sengxian committed
90
91
92
        top_k=2,
        gate=NaiveGate,
        expert=None,
93
        gate_hook=None,
Colin's avatar
Colin committed
94
95
        mask=None,
        mask_dict=None,
Sengxian's avatar
Sengxian committed
96
    ):
Rick Ho's avatar
Rick Ho committed
97
        super().__init__()
Rick Ho's avatar
Rick Ho committed
98
99
100
        self.num_expert = num_expert
        self.d_model = d_model
        self.world_size = world_size
Rick Ho's avatar
Rick Ho committed
101
102
103
104
105
106
107
108

        self.slice_group = slice_group
        if mp_group is not None:
            print('[Warning] mp_group is being deprecated')
            self.slice_group = mp_group
        if self.slice_group is None:
            self.slice_size = 1
            self.slice_rank = 0
Rick Ho's avatar
Rick Ho committed
109
        else:
Rick Ho's avatar
Rick Ho committed
110
111
112
            self.slice_size = slice_group.size()
            self.slice_rank = slice_group.rank()

Rick Ho's avatar
Rick Ho committed
113
        self.top_k = top_k
Colin's avatar
Colin committed
114
115
116
117
118
        if type(expert) is list:
            self.experts = nn.ModuleList([e(d_model) for e in expert])
            self.experts_fused = False
            self.num_expert = num_expert = len(expert)
        elif expert is not None:
Rick Ho's avatar
Rick Ho committed
119
120
            self.experts = nn.ModuleList([expert(d_model)
                for _ in range(num_expert)])
121
122
123
            self.experts_fused = False
        else:
            self.experts_fused = True
Rick Ho's avatar
Rick Ho committed
124

Colin's avatar
Colin committed
125
        self.gate = gate(d_model, num_expert, world_size, top_k)
126
        self.gate_hook = gate_hook
Colin's avatar
Colin committed
127
128
        self.mask = mask
        self.mask_dict = mask_dict
Rick Ho's avatar
Rick Ho committed
129
        self.moe_group = moe_group
Rick Ho's avatar
Rick Ho committed
130
131

    def expert_fn(self, inp, fwd_expert_count):
Sengxian's avatar
Sengxian committed
132
        r"""
Rick Ho's avatar
Rick Ho committed
133
134
        The default expert function which either calls the experts as a whole
        or as separate experts.
Sengxian's avatar
Sengxian committed
135
        """
136
        if self.experts_fused:
Rick Ho's avatar
Rick Ho committed
137
138
139
140
141
            return self.experts(inp, fwd_expert_count)
        outputs = []
        base_idx = 0
        for i in range(self.num_expert):
            batch_size = fwd_expert_count[i].item()
Sengxian's avatar
Sengxian committed
142
            inp_slice = inp[base_idx : base_idx + batch_size]
Rick Ho's avatar
Rick Ho committed
143
144
145
            outputs.append(self.experts[i](inp_slice))
            base_idx += batch_size
        return torch.cat(outputs, dim=0)
Rick Ho's avatar
Rick Ho committed
146

Sengxian's avatar
Sengxian committed
147
148
    def mark_parallel_comm(self, expert_dp_comm="none"):
        r"""
Rick Ho's avatar
Rick Ho committed
149
150
151
        Automatically mark the data parallel comms of the parameters within the
        module. This can be typically called at the end of the __init__ function
        in child classes.
Sengxian's avatar
Sengxian committed
152
        """
Rick Ho's avatar
Rick Ho committed
153
        if self.experts is not None:
154
            comm = expert_dp_comm
Rick Ho's avatar
Rick Ho committed
155
156
157
158
159
            if isinstance(self.experts, list):
                for e in self.experts:
                    mark_module_parallel_comm(e, comm)
            else:
                mark_module_parallel_comm(self.experts, comm)
Rick Ho's avatar
Rick Ho committed
160
        mark_module_parallel_comm(self.gate, "gate")
Rick Ho's avatar
Rick Ho committed
161
162

    def forward(self, inp):
Sengxian's avatar
Sengxian committed
163
        r"""
Rick Ho's avatar
Rick Ho committed
164
165
166
        The FMoE module first computes gate output, and then conduct MoE forward
        according to the gate.  The score of the selected gate given by the
        expert is multiplied to the experts' output tensors as a weight.
Sengxian's avatar
Sengxian committed
167
        """
168
169
        if self.world_size > 1:
            ensure_comm(inp, self.moe_group)
Rick Ho's avatar
Rick Ho committed
170
171
172
        if self.slice_size > 1:
            inp = Slice.apply(inp, self.slice_rank,
                    self.slice_size, self.slice_group)
Sengxian's avatar
Sengxian committed
173

Rick Ho's avatar
Rick Ho committed
174
175
        gate_top_k_idx, gate_score = self.gate(inp)

176
177
178
        if self.gate_hook is not None:
            self.gate_hook(gate_top_k_idx, gate_score, None)

Colin's avatar
Colin committed
179
        # delete masked tensors
Colin's avatar
Colin committed
180
        if self.mask is not None and self.mask_dict is not None:
Colin's avatar
Colin committed
181
182
183
184
185
186
            mask = self.mask.view(-1)
            # to: (BxL') x d_model
            inp = inp[mask == 0, :]
            gate_top_k_idx = gate_top_k_idx[mask == 0, :]

        fwd = _fmoe_general_global_forward(
187
188
            inp, gate_top_k_idx,
            self.expert_fn, self.num_expert, self.world_size
Sengxian's avatar
Sengxian committed
189
        )
190

Colin's avatar
Colin committed
191
        # recover deleted tensors
Colin's avatar
Colin committed
192
        if self.mask is not None and self.mask_dict is not None:
Colin's avatar
Colin committed
193
194
195
            # to: (BxL') x top_k x d_model
            fwd = fwd.view(-1, self.top_k, self.d_model)
            # to: (BxL) x top_k x d_model
Colin's avatar
Colin committed
196
            x = torch.zeros(mask.shape[0], self.top_k, self.d_model, device=fwd.device, dtype=fwd.dtype)
Colin's avatar
Colin committed
197
198
199
200
201
202
203
204
            # recover
            x[mask == 0] = fwd
            for k, v in self.mask_dict.items():
                x[mask == k] = v
        else:
            x = fwd.view(-1, self.top_k, self.d_model)

        gate_score = gate_score.view(x.shape[0], 1, self.top_k)
Rick Ho's avatar
Rick Ho committed
205
        x = torch.bmm(gate_score, x).reshape(-1, self.d_model)
Sengxian's avatar
Sengxian committed
206

Rick Ho's avatar
Rick Ho committed
207
208
209
        if self.slice_size > 1:
            x = AllGather.apply(x, self.slice_rank,
                    self.slice_size, self.slice_group)
Rick Ho's avatar
Rick Ho committed
210
        return x