functions.py 7.44 KB
Newer Older
1
2
3
4
5
6
r"""
The fmoe.functions module contains functions that are directly warped up from
C/CUDA functions to complete distributed communication, computation and gradient
computation.
"""

Rick Ho's avatar
Rick Ho committed
7
8
9
import torch
from torch.autograd import Function
import fmoe_cuda
10
from .utils import get_torch_default_comm
Rick Ho's avatar
Rick Ho committed
11
12


Rick Ho's avatar
Rick Ho committed
13
14
15
16
17
18
def _ensure_nccl(t, comm=None):
    if comm is None:
        comm = get_torch_default_comm()
    fmoe_cuda.ensure_nccl(comm, t)


Rick Ho's avatar
Rick Ho committed
19
def count_by_gate(gate, num_expert, world_size):
Rick Ho's avatar
Rick Ho committed
20
21
22
23
24
25
26
27
28
29
    # TODO: support -1 in gate, which means ignore this input
    with torch.no_grad():
        _, pos = torch.sort(gate)
        gate_idx, gate_count = torch.unique(gate, return_counts=True)
        local_expert_count = torch.zeros(
            num_expert * world_size, device=gate.device, dtype=torch.long
        )
        local_expert_count.index_put_((gate_idx.long(),), gate_count)

        if world_size > 1:
Rick Ho's avatar
Rick Ho committed
30
            _ensure_nccl(gate)
Rick Ho's avatar
Rick Ho committed
31
32
33
34
35
36
37
38
39
            (global_expert_count,) = fmoe_cuda.expert_exchange(
                local_expert_count, num_expert, world_size
            )
        else:
            global_expert_count = local_expert_count
    return pos, local_expert_count, global_expert_count


def prepare_forward(gate, num_expert, world_size, comm=None):
40
41
42
43
44
45
46
47
48
49
    r"""
    Prepare necessary information from gate output for MoE computation.

    Args:
        gate: a 1-d Long Tensor representing the target expert of each input
        sample.
        num_expert: number of experts on each worker.
        world_size: number of workers that hold different experts.
        comm: the communicator of all workers in the expert-parallel group.
    """
Rick Ho's avatar
Rick Ho committed
50
    if world_size > 1:
Rick Ho's avatar
Rick Ho committed
51
        _ensure_nccl(gate, comm=comm)
Rick Ho's avatar
Rick Ho committed
52

Rick Ho's avatar
Rick Ho committed
53
54
    pos, local_expert_count, global_expert_count = count_by_gate(gate, 
            num_expert, world_size)
Rick Ho's avatar
Rick Ho committed
55
    with torch.no_grad():
Rick Ho's avatar
Rick Ho committed
56
57
        fwd_expert_count = global_expert_count.view(world_size,
                num_expert).sum(dim=0)
Rick Ho's avatar
Rick Ho committed
58
        fwd_batch_size = int(fwd_expert_count.sum().item())
59
60
61
62
63
64
65
    return (
        pos,
        local_expert_count.cpu(),
        global_expert_count.cpu(),
        fwd_expert_count.cpu(),
        fwd_batch_size,
    )
Rick Ho's avatar
Rick Ho committed
66
67
68


class MOEScatter(Function):
69
70
71
72
73
    r"""
    Scatter input samples from [batch x sequences] to contiguous alone experts.
    If `world_size` is greater than 1, the samples will first be locally
    scattered, and then exchanged across workers.
    """
Sengxian's avatar
Sengxian committed
74

Rick Ho's avatar
Rick Ho committed
75
    @staticmethod
76
77
78
79
80
81
82
83
84
85
    def forward(
        ctx,
        inp,
        pos,
        local_expert_count,
        global_expert_count,
        fwd_batch_size,
        world_size,
    ):
        (local_input_buf,) = fmoe_cuda.local_scatter(inp, pos)
Rick Ho's avatar
Rick Ho committed
86
        if world_size > 1:
87
88
89
90
91
92
93
            (global_input_buf,) = fmoe_cuda.global_scatter(
                local_input_buf,
                local_expert_count,
                global_expert_count,
                fwd_batch_size,
                world_size,
            )
Rick Ho's avatar
Rick Ho committed
94
95
        else:
            global_input_buf = local_input_buf
96
        ctx.moe_args = inp.shape[0], world_size
Rick Ho's avatar
Rick Ho committed
97
        variables = (pos, local_expert_count, global_expert_count)
98
        ctx.save_for_backward(*variables)
Rick Ho's avatar
Rick Ho committed
99
100
101
102
103
        return global_input_buf

    @staticmethod
    def backward(ctx, global_grad_in):
        (pos, local_expert_count, global_expert_count) = ctx.saved_tensors
104
        (local_batch_size, world_size) = ctx.moe_args
Rick Ho's avatar
Rick Ho committed
105
106

        if world_size > 1:
107
108
109
110
111
112
113
            (local_grad_in,) = fmoe_cuda.global_gather(
                global_grad_in,
                local_expert_count,
                global_expert_count,
                local_batch_size,
                world_size,
            )
Rick Ho's avatar
Rick Ho committed
114
115
        else:
            local_grad_in = global_grad_in
116
        (grad_in,) = fmoe_cuda.local_gather(local_grad_in, pos)
Rick Ho's avatar
Rick Ho committed
117
118
119
120
        return grad_in, None, None, None, None, None


class MOELinear(Function):
121
122
123
    r"""
    Computes linear operators within one GPU on different experts simutaneously.
    """
Sengxian's avatar
Sengxian committed
124

Rick Ho's avatar
Rick Ho committed
125
126
    @staticmethod
    def forward(ctx, global_input_buf, weight, fwd_expert_count):
Rick Ho's avatar
Rick Ho committed
127
        (global_output_buf,) = fmoe_cuda.linear_forward(
128
129
            global_input_buf, weight, fwd_expert_count
        )
130
        variables = (global_input_buf, weight, fwd_expert_count)
Rick Ho's avatar
Rick Ho committed
131
132
133
134
135
136
        ctx.save_for_backward(*variables)
        return global_output_buf

    @staticmethod
    def backward(ctx, grad_out):
        (input_buf, weight, fwd_expert_count) = ctx.saved_tensors
Rick Ho's avatar
Rick Ho committed
137
        grad_inp_buf, grad_weight = fmoe_cuda.linear_backward(
138
139
            grad_out, input_buf, weight, fwd_expert_count
        )
Rick Ho's avatar
Rick Ho committed
140
141
142
143
        return grad_inp_buf, grad_weight, None


class MOEGather(Function):
144
145
146
147
    r"""
    Gather output samples from contiguous alone experts back to [batch x
    sequences]. Works symmetrically with MOEScatter.
    """
Sengxian's avatar
Sengxian committed
148

Rick Ho's avatar
Rick Ho committed
149
    @staticmethod
150
151
152
153
154
155
156
157
158
    def forward(
        ctx,
        global_output_buf,
        pos,
        local_expert_count,
        global_expert_count,
        local_batch_size,
        world_size,
    ):
Rick Ho's avatar
Rick Ho committed
159
        if world_size > 1:
160
161
162
163
164
165
166
            (local_output_buf,) = fmoe_cuda.global_gather(
                global_output_buf,
                local_expert_count,
                global_expert_count,
                local_batch_size,
                world_size,
            )
Rick Ho's avatar
Rick Ho committed
167
168
        else:
            local_output_buf = global_output_buf
169
        (output,) = fmoe_cuda.local_gather(local_output_buf, pos)
Rick Ho's avatar
Rick Ho committed
170

171
        ctx.moe_args = (global_output_buf.shape[0], world_size)
Rick Ho's avatar
Rick Ho committed
172
173
174
175
176
177
178
        variables = (pos, local_expert_count, global_expert_count)
        ctx.save_for_backward(*variables)
        return output

    @staticmethod
    def backward(ctx, grad_out):
        pos, local_expert_count, global_expert_count = ctx.saved_tensors
179
        fwd_batch_size, world_size = ctx.moe_args
180
        (grad_out_buf,) = fmoe_cuda.local_scatter(grad_out.contiguous(), pos)
Rick Ho's avatar
Rick Ho committed
181
        if world_size > 1:
182
183
184
185
186
187
188
            (global_grad_out_buf,) = fmoe_cuda.global_scatter(
                grad_out_buf,
                local_expert_count,
                global_expert_count,
                fwd_batch_size,
                world_size,
            )
Rick Ho's avatar
Rick Ho committed
189
190
191
        else:
            global_grad_out_buf = grad_out_buf
        return global_grad_out_buf, None, None, None, None, None
Rick Ho's avatar
Rick Ho committed
192
193
194


class AllGather(Function):
Sengxian's avatar
Sengxian committed
195
    r"""
Rick Ho's avatar
Rick Ho committed
196
    A wrapper for the All-Gather function to support auto-differentiation.
Sengxian's avatar
Sengxian committed
197
198
    """

Rick Ho's avatar
Rick Ho committed
199
200
201
202
203
204
205
206
207
208
209
210
    @staticmethod
    def forward(ctx, inp, rank, world_size, group):
        tensor_list = [torch.empty_like(inp) for _ in range(world_size)]
        torch.distributed.all_gather(tensor_list, inp, group=group)
        torch.cuda.synchronize()
        output = torch.cat(tensor_list, dim=0)
        ctx.args = rank, inp.shape[0]
        return output

    @staticmethod
    def backward(ctx, grad_out):
        rank, dim0 = ctx.args
Sengxian's avatar
Sengxian committed
211
        return grad_out[rank * dim0 : (rank + 1) * dim0], None, None, None
Sengxian's avatar
Sengxian committed
212
213
214


class Slice(Function):
Sengxian's avatar
Sengxian committed
215
    r"""
Sengxian's avatar
Sengxian committed
216
    A wrapper for the Slice function to support auto-differentiation.
Sengxian's avatar
Sengxian committed
217
218
    """

Sengxian's avatar
Sengxian committed
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
    @staticmethod
    def forward(ctx, inp, rank, world_size, group):
        B: int = inp.shape[0]
        local_batch_size = B // world_size
        batch_start = local_batch_size * rank
        batch_end = min(batch_start + local_batch_size, B)
        inp = inp[batch_start:batch_end]
        ctx.args = world_size, group
        return inp

    @staticmethod
    def backward(ctx, grad_out):
        world_size, group = ctx.args
        tensor_list = [torch.empty_like(grad_out) for _ in range(world_size)]
        torch.distributed.all_gather(tensor_list, grad_out, group=group)
        torch.cuda.synchronize()
        grad_out = torch.cat(tensor_list, dim=0)
        return grad_out, None, None, None