functions.py 8.33 KB
Newer Older
1
2
3
4
5
6
r"""
The fmoe.functions module contains functions that are directly warped up from
C/CUDA functions to complete distributed communication, computation and gradient
computation.
"""

Rick Ho's avatar
Rick Ho committed
7
8
9
import torch
from torch.autograd import Function
import fmoe_cuda
10
from .utils import get_torch_default_comm
Rick Ho's avatar
Rick Ho committed
11
12


Rick Ho's avatar
Rick Ho committed
13
14
15
16
17
18
def _ensure_nccl(t, comm=None):
    if comm is None:
        comm = get_torch_default_comm()
    fmoe_cuda.ensure_nccl(comm, t)


19
def count_by_gate(gate, num_expert, world_size, require_pos=True):
Rick Ho's avatar
Rick Ho committed
20
    with torch.no_grad():
21
22
23
        flatten_gate = gate.view(-1)
        eff_gate = flatten_gate[flatten_gate != -1]

Rick Ho's avatar
Rick Ho committed
24
25
26
        local_expert_count = torch.zeros(
            num_expert * world_size, device=gate.device, dtype=torch.long
        )
27
28
29
        ones = torch.ones(eff_gate.numel(),
                device=gate.device, dtype=torch.long)
        local_expert_count.index_add_(0, eff_gate, ones)
Rick Ho's avatar
Rick Ho committed
30
31

        if world_size > 1:
Rick Ho's avatar
Rick Ho committed
32
            _ensure_nccl(gate)
Rick Ho's avatar
Rick Ho committed
33
34
35
36
37
            (global_expert_count,) = fmoe_cuda.expert_exchange(
                local_expert_count, num_expert, world_size
            )
        else:
            global_expert_count = local_expert_count
38
39
40
41
42
43
44
        if not require_pos:
            pos = None
        else:
            lec_cum = torch.cumsum(local_expert_count, dim=0).int()
            pos_size = lec_cum[-1].item()
            pos = torch.empty((pos_size,), device=gate.device, dtype=torch.long)
            fmoe_cuda.assign_pos_(lec_cum, gate, pos)
Rick Ho's avatar
Rick Ho committed
45
46
47
48
    return pos, local_expert_count, global_expert_count


def prepare_forward(gate, num_expert, world_size, comm=None):
49
50
51
52
53
54
55
56
57
58
    r"""
    Prepare necessary information from gate output for MoE computation.

    Args:
        gate: a 1-d Long Tensor representing the target expert of each input
        sample.
        num_expert: number of experts on each worker.
        world_size: number of workers that hold different experts.
        comm: the communicator of all workers in the expert-parallel group.
    """
Rick Ho's avatar
Rick Ho committed
59
    if world_size > 1:
Rick Ho's avatar
Rick Ho committed
60
        _ensure_nccl(gate, comm=comm)
Rick Ho's avatar
Rick Ho committed
61

Rick Ho's avatar
Rick Ho committed
62
63
    pos, local_expert_count, global_expert_count = count_by_gate(gate, 
            num_expert, world_size)
Rick Ho's avatar
Rick Ho committed
64
    with torch.no_grad():
Rick Ho's avatar
Rick Ho committed
65
66
        fwd_expert_count = global_expert_count.view(world_size,
                num_expert).sum(dim=0)
Rick Ho's avatar
Rick Ho committed
67
        fwd_batch_size = int(fwd_expert_count.sum().item())
68
69
70
71
72
73
74
    return (
        pos,
        local_expert_count.cpu(),
        global_expert_count.cpu(),
        fwd_expert_count.cpu(),
        fwd_batch_size,
    )
Rick Ho's avatar
Rick Ho committed
75
76


77
78
79
80
81
def _local_scatter(inp, pos):
    inp_buf = torch.index_select(inp, 0, pos)
    return inp_buf


Rick Ho's avatar
Rick Ho committed
82
def _local_gather(inp, pos, out_batch_size, maybe_overlap=True):
83
84
    inp_buf = torch.zeros(out_batch_size, inp.shape[-1],
            dtype=inp.dtype, device=inp.device)
Rick Ho's avatar
Rick Ho committed
85
86
87
88
    if maybe_overlap:
        inp_buf.index_add_(0, pos, inp)
    else:
        inp_buf.index_copy_(0, pos, inp)
89
90
91
    return inp_buf


Rick Ho's avatar
Rick Ho committed
92
class MOEScatter(Function):
93
94
95
96
97
    r"""
    Scatter input samples from [batch x sequences] to contiguous alone experts.
    If `world_size` is greater than 1, the samples will first be locally
    scattered, and then exchanged across workers.
    """
Sengxian's avatar
Sengxian committed
98

Rick Ho's avatar
Rick Ho committed
99
    @staticmethod
100
101
102
103
104
105
106
107
108
    def forward(
        ctx,
        inp,
        pos,
        local_expert_count,
        global_expert_count,
        fwd_batch_size,
        world_size,
    ):
109
        local_input_buf = _local_scatter(inp, pos)
Rick Ho's avatar
Rick Ho committed
110
        if world_size > 1:
111
112
113
114
115
116
117
            (global_input_buf,) = fmoe_cuda.global_scatter(
                local_input_buf,
                local_expert_count,
                global_expert_count,
                fwd_batch_size,
                world_size,
            )
Rick Ho's avatar
Rick Ho committed
118
119
        else:
            global_input_buf = local_input_buf
Rich Ho's avatar
Rich Ho committed
120
        ctx.moe_args = inp.shape[0], pos.shape[0], world_size
Rick Ho's avatar
Rick Ho committed
121
        variables = (pos, local_expert_count, global_expert_count)
122
        ctx.save_for_backward(*variables)
Rick Ho's avatar
Rick Ho committed
123
124
125
126
127
        return global_input_buf

    @staticmethod
    def backward(ctx, global_grad_in):
        (pos, local_expert_count, global_expert_count) = ctx.saved_tensors
Rich Ho's avatar
Rich Ho committed
128
        (inp_batch_size, buf_batch_size, world_size) = ctx.moe_args
Rick Ho's avatar
Rick Ho committed
129
130

        if world_size > 1:
131
132
133
134
            (local_grad_in,) = fmoe_cuda.global_gather(
                global_grad_in,
                local_expert_count,
                global_expert_count,
Rich Ho's avatar
Rich Ho committed
135
                buf_batch_size,
136
137
                world_size,
            )
Rick Ho's avatar
Rick Ho committed
138
139
        else:
            local_grad_in = global_grad_in
Rich Ho's avatar
Rich Ho committed
140
        grad_in = _local_gather(local_grad_in, pos, inp_batch_size)
Rick Ho's avatar
Rick Ho committed
141
142
143
144
        return grad_in, None, None, None, None, None


class MOELinear(Function):
145
146
147
    r"""
    Computes linear operators within one GPU on different experts simutaneously.
    """
Sengxian's avatar
Sengxian committed
148

Rick Ho's avatar
Rick Ho committed
149
    @staticmethod
150
    def forward(ctx, global_input_buf, fwd_expert_count, weight, bias=None):
Rick Ho's avatar
Rick Ho committed
151
        (global_output_buf,) = fmoe_cuda.linear_forward(
152
            global_input_buf, fwd_expert_count, weight, bias
153
        )
154
        variables = (global_input_buf, fwd_expert_count, weight, bias)
Rick Ho's avatar
Rick Ho committed
155
156
157
158
159
        ctx.save_for_backward(*variables)
        return global_output_buf

    @staticmethod
    def backward(ctx, grad_out):
160
        (input_buf, fwd_expert_count, weight, bias) = ctx.saved_tensors
161
        grad_inp_buf, grad_weight, grad_bias = fmoe_cuda.linear_backward(
162
            grad_out, input_buf, fwd_expert_count, weight, bias
163
        )
164
165
166
167
168

        if not torch.is_tensor(bias):
            grad_bias = None

        return grad_inp_buf, None, grad_weight, grad_bias
Rick Ho's avatar
Rick Ho committed
169
170
171


class MOEGather(Function):
172
173
174
175
    r"""
    Gather output samples from contiguous alone experts back to [batch x
    sequences]. Works symmetrically with MOEScatter.
    """
Sengxian's avatar
Sengxian committed
176

Rick Ho's avatar
Rick Ho committed
177
    @staticmethod
178
179
180
181
182
183
184
185
186
    def forward(
        ctx,
        global_output_buf,
        pos,
        local_expert_count,
        global_expert_count,
        local_batch_size,
        world_size,
    ):
Rick Ho's avatar
Rick Ho committed
187
        if world_size > 1:
188
189
190
191
192
193
194
            (local_output_buf,) = fmoe_cuda.global_gather(
                global_output_buf,
                local_expert_count,
                global_expert_count,
                local_batch_size,
                world_size,
            )
Rick Ho's avatar
Rick Ho committed
195
196
        else:
            local_output_buf = global_output_buf
Rick Ho's avatar
Rick Ho committed
197
198
        output = _local_gather(local_output_buf, pos, local_batch_size,
                maybe_overlap=False)
Rick Ho's avatar
Rick Ho committed
199

200
        ctx.moe_args = (global_output_buf.shape[0], world_size)
Rick Ho's avatar
Rick Ho committed
201
202
203
204
205
206
207
        variables = (pos, local_expert_count, global_expert_count)
        ctx.save_for_backward(*variables)
        return output

    @staticmethod
    def backward(ctx, grad_out):
        pos, local_expert_count, global_expert_count = ctx.saved_tensors
208
        fwd_batch_size, world_size = ctx.moe_args
209
        grad_out_buf = _local_scatter(grad_out.contiguous(), pos)
Rick Ho's avatar
Rick Ho committed
210
        if world_size > 1:
211
212
213
214
215
216
217
            (global_grad_out_buf,) = fmoe_cuda.global_scatter(
                grad_out_buf,
                local_expert_count,
                global_expert_count,
                fwd_batch_size,
                world_size,
            )
Rick Ho's avatar
Rick Ho committed
218
219
220
        else:
            global_grad_out_buf = grad_out_buf
        return global_grad_out_buf, None, None, None, None, None
Rick Ho's avatar
Rick Ho committed
221
222
223


class AllGather(Function):
Sengxian's avatar
Sengxian committed
224
    r"""
Rick Ho's avatar
Rick Ho committed
225
    A wrapper for the All-Gather function to support auto-differentiation.
Sengxian's avatar
Sengxian committed
226
227
    """

Rick Ho's avatar
Rick Ho committed
228
229
230
231
232
233
234
235
236
237
238
239
    @staticmethod
    def forward(ctx, inp, rank, world_size, group):
        tensor_list = [torch.empty_like(inp) for _ in range(world_size)]
        torch.distributed.all_gather(tensor_list, inp, group=group)
        torch.cuda.synchronize()
        output = torch.cat(tensor_list, dim=0)
        ctx.args = rank, inp.shape[0]
        return output

    @staticmethod
    def backward(ctx, grad_out):
        rank, dim0 = ctx.args
Sengxian's avatar
Sengxian committed
240
        return grad_out[rank * dim0 : (rank + 1) * dim0], None, None, None
Sengxian's avatar
Sengxian committed
241
242
243


class Slice(Function):
Sengxian's avatar
Sengxian committed
244
    r"""
Sengxian's avatar
Sengxian committed
245
    A wrapper for the Slice function to support auto-differentiation.
Sengxian's avatar
Sengxian committed
246
247
    """

Sengxian's avatar
Sengxian committed
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
    @staticmethod
    def forward(ctx, inp, rank, world_size, group):
        B: int = inp.shape[0]
        local_batch_size = B // world_size
        batch_start = local_batch_size * rank
        batch_end = min(batch_start + local_batch_size, B)
        inp = inp[batch_start:batch_end]
        ctx.args = world_size, group
        return inp

    @staticmethod
    def backward(ctx, grad_out):
        world_size, group = ctx.args
        tensor_list = [torch.empty_like(grad_out) for _ in range(world_size)]
        torch.distributed.all_gather(tensor_list, grad_out, group=group)
        torch.cuda.synchronize()
        grad_out = torch.cat(tensor_list, dim=0)
        return grad_out, None, None, None