functions.py 8.18 KB
Newer Older
1
2
3
4
5
6
r"""
The fmoe.functions module contains functions that are directly warped up from
C/CUDA functions to complete distributed communication, computation and gradient
computation.
"""

Rick Ho's avatar
Rick Ho committed
7
8
9
import torch
from torch.autograd import Function
import fmoe_cuda
10
from .utils import get_torch_default_comm
Rick Ho's avatar
Rick Ho committed
11
12


Rick Ho's avatar
Rick Ho committed
13
14
15
16
17
18
def _ensure_nccl(t, comm=None):
    if comm is None:
        comm = get_torch_default_comm()
    fmoe_cuda.ensure_nccl(comm, t)


19
def count_by_gate(gate, num_expert, world_size, require_pos=True):
Rick Ho's avatar
Rick Ho committed
20
21
    with torch.no_grad():
        local_expert_count = torch.zeros(
22
            num_expert * world_size, device=gate.device, dtype=torch.int32
Rick Ho's avatar
Rick Ho committed
23
        )
24
25
        fmoe_cuda.expert_count(gate, local_expert_count)
        local_expert_count = local_expert_count.long()
Rick Ho's avatar
Rick Ho committed
26
27

        if world_size > 1:
Rick Ho's avatar
Rick Ho committed
28
            _ensure_nccl(gate)
29
            global_expert_count = fmoe_cuda.expert_exchange(
Rick Ho's avatar
Rick Ho committed
30
31
32
33
                local_expert_count, num_expert, world_size
            )
        else:
            global_expert_count = local_expert_count
34
35
36
37
38
39
        if not require_pos:
            pos = None
        else:
            lec_cum = torch.cumsum(local_expert_count, dim=0).int()
            pos_size = lec_cum[-1].item()
            pos = torch.empty((pos_size,), device=gate.device, dtype=torch.long)
40
            fmoe_cuda.assign_pos(lec_cum, gate, pos)
Rick Ho's avatar
Rick Ho committed
41
42
43
44
    return pos, local_expert_count, global_expert_count


def prepare_forward(gate, num_expert, world_size, comm=None):
45
46
47
48
49
50
51
52
53
54
    r"""
    Prepare necessary information from gate output for MoE computation.

    Args:
        gate: a 1-d Long Tensor representing the target expert of each input
        sample.
        num_expert: number of experts on each worker.
        world_size: number of workers that hold different experts.
        comm: the communicator of all workers in the expert-parallel group.
    """
Rick Ho's avatar
Rick Ho committed
55
    if world_size > 1:
Rick Ho's avatar
Rick Ho committed
56
        _ensure_nccl(gate, comm=comm)
Rick Ho's avatar
Rick Ho committed
57

Rick Ho's avatar
Rick Ho committed
58
59
    pos, local_expert_count, global_expert_count = count_by_gate(gate, 
            num_expert, world_size)
Rick Ho's avatar
Rick Ho committed
60
    with torch.no_grad():
Rick Ho's avatar
Rick Ho committed
61
62
        fwd_expert_count = global_expert_count.view(world_size,
                num_expert).sum(dim=0)
Rick Ho's avatar
Rick Ho committed
63
        fwd_batch_size = int(fwd_expert_count.sum().item())
64
65
66
67
68
69
70
    return (
        pos,
        local_expert_count.cpu(),
        global_expert_count.cpu(),
        fwd_expert_count.cpu(),
        fwd_batch_size,
    )
Rick Ho's avatar
Rick Ho committed
71
72


73
74
75
76
77
def _local_scatter(inp, pos):
    inp_buf = torch.index_select(inp, 0, pos)
    return inp_buf


Rick Ho's avatar
Rick Ho committed
78
def _local_gather(inp, pos, out_batch_size, maybe_overlap=True):
79
80
    inp_buf = torch.zeros(out_batch_size, inp.shape[-1],
            dtype=inp.dtype, device=inp.device)
Rick Ho's avatar
Rick Ho committed
81
82
83
84
    if maybe_overlap:
        inp_buf.index_add_(0, pos, inp)
    else:
        inp_buf.index_copy_(0, pos, inp)
85
86
87
    return inp_buf


Rick Ho's avatar
Rick Ho committed
88
class MOEScatter(Function):
89
90
91
92
93
    r"""
    Scatter input samples from [batch x sequences] to contiguous alone experts.
    If `world_size` is greater than 1, the samples will first be locally
    scattered, and then exchanged across workers.
    """
Sengxian's avatar
Sengxian committed
94

Rick Ho's avatar
Rick Ho committed
95
    @staticmethod
96
97
98
99
100
101
102
103
104
    def forward(
        ctx,
        inp,
        pos,
        local_expert_count,
        global_expert_count,
        fwd_batch_size,
        world_size,
    ):
105
        local_input_buf = _local_scatter(inp, pos)
Rick Ho's avatar
Rick Ho committed
106
        if world_size > 1:
107
            global_input_buf = fmoe_cuda.global_scatter(
108
109
110
111
112
113
                local_input_buf,
                local_expert_count,
                global_expert_count,
                fwd_batch_size,
                world_size,
            )
Rick Ho's avatar
Rick Ho committed
114
115
        else:
            global_input_buf = local_input_buf
Rich Ho's avatar
Rich Ho committed
116
        ctx.moe_args = inp.shape[0], pos.shape[0], world_size
Rick Ho's avatar
Rick Ho committed
117
        variables = (pos, local_expert_count, global_expert_count)
118
        ctx.save_for_backward(*variables)
Rick Ho's avatar
Rick Ho committed
119
120
121
122
123
        return global_input_buf

    @staticmethod
    def backward(ctx, global_grad_in):
        (pos, local_expert_count, global_expert_count) = ctx.saved_tensors
Rich Ho's avatar
Rich Ho committed
124
        (inp_batch_size, buf_batch_size, world_size) = ctx.moe_args
Rick Ho's avatar
Rick Ho committed
125
126

        if world_size > 1:
127
            local_grad_in = fmoe_cuda.global_gather(
128
129
130
                global_grad_in,
                local_expert_count,
                global_expert_count,
Rich Ho's avatar
Rich Ho committed
131
                buf_batch_size,
132
133
                world_size,
            )
Rick Ho's avatar
Rick Ho committed
134
135
        else:
            local_grad_in = global_grad_in
Rich Ho's avatar
Rich Ho committed
136
        grad_in = _local_gather(local_grad_in, pos, inp_batch_size)
Rick Ho's avatar
Rick Ho committed
137
138
139
140
        return grad_in, None, None, None, None, None


class MOELinear(Function):
141
142
143
    r"""
    Computes linear operators within one GPU on different experts simutaneously.
    """
Sengxian's avatar
Sengxian committed
144

Rick Ho's avatar
Rick Ho committed
145
    @staticmethod
146
    def forward(ctx, global_input_buf, fwd_expert_count, weight, bias=None):
147
        global_output_buf = fmoe_cuda.linear_forward(
148
            global_input_buf, fwd_expert_count, weight, bias
149
        )
150
        variables = (global_input_buf, fwd_expert_count, weight, bias)
Rick Ho's avatar
Rick Ho committed
151
152
153
154
155
        ctx.save_for_backward(*variables)
        return global_output_buf

    @staticmethod
    def backward(ctx, grad_out):
156
        (input_buf, fwd_expert_count, weight, bias) = ctx.saved_tensors
157
        grad_inp_buf, grad_weight, grad_bias = fmoe_cuda.linear_backward(
158
            grad_out, input_buf, fwd_expert_count, weight, bias
159
        )
160
161
162
163
164

        if not torch.is_tensor(bias):
            grad_bias = None

        return grad_inp_buf, None, grad_weight, grad_bias
Rick Ho's avatar
Rick Ho committed
165
166
167


class MOEGather(Function):
168
169
170
171
    r"""
    Gather output samples from contiguous alone experts back to [batch x
    sequences]. Works symmetrically with MOEScatter.
    """
Sengxian's avatar
Sengxian committed
172

Rick Ho's avatar
Rick Ho committed
173
    @staticmethod
174
175
176
177
178
179
180
181
182
    def forward(
        ctx,
        global_output_buf,
        pos,
        local_expert_count,
        global_expert_count,
        local_batch_size,
        world_size,
    ):
Rick Ho's avatar
Rick Ho committed
183
        if world_size > 1:
184
            local_output_buf = fmoe_cuda.global_gather(
185
186
187
                global_output_buf,
                local_expert_count,
                global_expert_count,
188
                pos.shape[0],
189
190
                world_size,
            )
Rick Ho's avatar
Rick Ho committed
191
192
        else:
            local_output_buf = global_output_buf
Rick Ho's avatar
Rick Ho committed
193
194
        output = _local_gather(local_output_buf, pos, local_batch_size,
                maybe_overlap=False)
Rick Ho's avatar
Rick Ho committed
195

196
        ctx.moe_args = (global_output_buf.shape[0], world_size)
Rick Ho's avatar
Rick Ho committed
197
198
199
200
201
202
203
        variables = (pos, local_expert_count, global_expert_count)
        ctx.save_for_backward(*variables)
        return output

    @staticmethod
    def backward(ctx, grad_out):
        pos, local_expert_count, global_expert_count = ctx.saved_tensors
204
        fwd_batch_size, world_size = ctx.moe_args
205
        grad_out_buf = _local_scatter(grad_out.contiguous(), pos)
Rick Ho's avatar
Rick Ho committed
206
        if world_size > 1:
207
            global_grad_out_buf = fmoe_cuda.global_scatter(
208
209
210
211
212
213
                grad_out_buf,
                local_expert_count,
                global_expert_count,
                fwd_batch_size,
                world_size,
            )
Rick Ho's avatar
Rick Ho committed
214
215
216
        else:
            global_grad_out_buf = grad_out_buf
        return global_grad_out_buf, None, None, None, None, None
Rick Ho's avatar
Rick Ho committed
217
218
219


class AllGather(Function):
Sengxian's avatar
Sengxian committed
220
    r"""
Rick Ho's avatar
Rick Ho committed
221
    A wrapper for the All-Gather function to support auto-differentiation.
Sengxian's avatar
Sengxian committed
222
223
    """

Rick Ho's avatar
Rick Ho committed
224
225
226
227
228
229
230
231
232
233
234
235
    @staticmethod
    def forward(ctx, inp, rank, world_size, group):
        tensor_list = [torch.empty_like(inp) for _ in range(world_size)]
        torch.distributed.all_gather(tensor_list, inp, group=group)
        torch.cuda.synchronize()
        output = torch.cat(tensor_list, dim=0)
        ctx.args = rank, inp.shape[0]
        return output

    @staticmethod
    def backward(ctx, grad_out):
        rank, dim0 = ctx.args
Sengxian's avatar
Sengxian committed
236
        return grad_out[rank * dim0 : (rank + 1) * dim0], None, None, None
Sengxian's avatar
Sengxian committed
237
238
239


class Slice(Function):
Sengxian's avatar
Sengxian committed
240
    r"""
Sengxian's avatar
Sengxian committed
241
    A wrapper for the Slice function to support auto-differentiation.
Sengxian's avatar
Sengxian committed
242
243
    """

Sengxian's avatar
Sengxian committed
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
    @staticmethod
    def forward(ctx, inp, rank, world_size, group):
        B: int = inp.shape[0]
        local_batch_size = B // world_size
        batch_start = local_batch_size * rank
        batch_end = min(batch_start + local_batch_size, B)
        inp = inp[batch_start:batch_end]
        ctx.args = world_size, group
        return inp

    @staticmethod
    def backward(ctx, grad_out):
        world_size, group = ctx.args
        tensor_list = [torch.empty_like(grad_out) for _ in range(world_size)]
        torch.distributed.all_gather(tensor_list, grad_out, group=group)
        torch.cuda.synchronize()
        grad_out = torch.cat(tensor_list, dim=0)
        return grad_out, None, None, None