functions.py 7.36 KB
Newer Older
1
2
3
4
5
6
r"""
The fmoe.functions module contains functions that are directly warped up from
C/CUDA functions to complete distributed communication, computation and gradient
computation.
"""

Rick Ho's avatar
Rick Ho committed
7
8
9
import torch
from torch.autograd import Function
import fmoe_cuda
10
from .utils import get_torch_default_comm
Rick Ho's avatar
Rick Ho committed
11
12


Rick Ho's avatar
Rick Ho committed
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
def count_by_gate(gate, num_expert, world_size, comm):
    # TODO: support -1 in gate, which means ignore this input
    with torch.no_grad():
        _, pos = torch.sort(gate)
        gate_idx, gate_count = torch.unique(gate, return_counts=True)
        local_expert_count = torch.zeros(
            num_expert * world_size, device=gate.device, dtype=torch.long
        )
        local_expert_count.index_put_((gate_idx.long(),), gate_count)

        if world_size > 1:
            (global_expert_count,) = fmoe_cuda.expert_exchange(
                local_expert_count, num_expert, world_size
            )
        else:
            global_expert_count = local_expert_count
    return pos, local_expert_count, global_expert_count



def prepare_forward(gate, num_expert, world_size, comm=None):
34
35
36
37
38
39
40
41
42
43
    r"""
    Prepare necessary information from gate output for MoE computation.

    Args:
        gate: a 1-d Long Tensor representing the target expert of each input
        sample.
        num_expert: number of experts on each worker.
        world_size: number of workers that hold different experts.
        comm: the communicator of all workers in the expert-parallel group.
    """
Rick Ho's avatar
Rick Ho committed
44
    if world_size > 1:
45
46
        if comm is None:
            comm = get_torch_default_comm()
Rick Ho's avatar
Rick Ho committed
47
        fmoe_cuda.ensure_nccl(comm, gate)
Rick Ho's avatar
Rick Ho committed
48

Rick Ho's avatar
Rick Ho committed
49
50
    pos, local_expert_count, global_expert_count = count_by_gate(gate, 
            num_expert, world_size)
Rick Ho's avatar
Rick Ho committed
51
    with torch.no_grad():
Rick Ho's avatar
Rick Ho committed
52
53
        fwd_expert_count = global_expert_count.view(world_size,
                num_expert).sum(dim=0)
Rick Ho's avatar
Rick Ho committed
54
        fwd_batch_size = int(fwd_expert_count.sum().item())
55
56
57
58
59
60
61
    return (
        pos,
        local_expert_count.cpu(),
        global_expert_count.cpu(),
        fwd_expert_count.cpu(),
        fwd_batch_size,
    )
Rick Ho's avatar
Rick Ho committed
62
63
64


class MOEScatter(Function):
65
66
67
68
69
    r"""
    Scatter input samples from [batch x sequences] to contiguous alone experts.
    If `world_size` is greater than 1, the samples will first be locally
    scattered, and then exchanged across workers.
    """
Sengxian's avatar
Sengxian committed
70

Rick Ho's avatar
Rick Ho committed
71
    @staticmethod
72
73
74
75
76
77
78
79
80
81
    def forward(
        ctx,
        inp,
        pos,
        local_expert_count,
        global_expert_count,
        fwd_batch_size,
        world_size,
    ):
        (local_input_buf,) = fmoe_cuda.local_scatter(inp, pos)
Rick Ho's avatar
Rick Ho committed
82
        if world_size > 1:
83
84
85
86
87
88
89
            (global_input_buf,) = fmoe_cuda.global_scatter(
                local_input_buf,
                local_expert_count,
                global_expert_count,
                fwd_batch_size,
                world_size,
            )
Rick Ho's avatar
Rick Ho committed
90
91
        else:
            global_input_buf = local_input_buf
92
        ctx.moe_args = inp.shape[0], world_size
Rick Ho's avatar
Rick Ho committed
93
        variables = (pos, local_expert_count, global_expert_count)
94
        ctx.save_for_backward(*variables)
Rick Ho's avatar
Rick Ho committed
95
96
97
98
99
        return global_input_buf

    @staticmethod
    def backward(ctx, global_grad_in):
        (pos, local_expert_count, global_expert_count) = ctx.saved_tensors
100
        (local_batch_size, world_size) = ctx.moe_args
Rick Ho's avatar
Rick Ho committed
101
102

        if world_size > 1:
103
104
105
106
107
108
109
            (local_grad_in,) = fmoe_cuda.global_gather(
                global_grad_in,
                local_expert_count,
                global_expert_count,
                local_batch_size,
                world_size,
            )
Rick Ho's avatar
Rick Ho committed
110
111
        else:
            local_grad_in = global_grad_in
112
        (grad_in,) = fmoe_cuda.local_gather(local_grad_in, pos)
Rick Ho's avatar
Rick Ho committed
113
114
115
116
        return grad_in, None, None, None, None, None


class MOELinear(Function):
117
118
119
    r"""
    Computes linear operators within one GPU on different experts simutaneously.
    """
Sengxian's avatar
Sengxian committed
120

Rick Ho's avatar
Rick Ho committed
121
122
    @staticmethod
    def forward(ctx, global_input_buf, weight, fwd_expert_count):
Rick Ho's avatar
Rick Ho committed
123
        (global_output_buf,) = fmoe_cuda.linear_forward(
124
125
            global_input_buf, weight, fwd_expert_count
        )
126
        variables = (global_input_buf, weight, fwd_expert_count)
Rick Ho's avatar
Rick Ho committed
127
128
129
130
131
132
        ctx.save_for_backward(*variables)
        return global_output_buf

    @staticmethod
    def backward(ctx, grad_out):
        (input_buf, weight, fwd_expert_count) = ctx.saved_tensors
Rick Ho's avatar
Rick Ho committed
133
        grad_inp_buf, grad_weight = fmoe_cuda.linear_backward(
134
135
            grad_out, input_buf, weight, fwd_expert_count
        )
Rick Ho's avatar
Rick Ho committed
136
137
138
139
        return grad_inp_buf, grad_weight, None


class MOEGather(Function):
140
141
142
143
    r"""
    Gather output samples from contiguous alone experts back to [batch x
    sequences]. Works symmetrically with MOEScatter.
    """
Sengxian's avatar
Sengxian committed
144

Rick Ho's avatar
Rick Ho committed
145
    @staticmethod
146
147
148
149
150
151
152
153
154
    def forward(
        ctx,
        global_output_buf,
        pos,
        local_expert_count,
        global_expert_count,
        local_batch_size,
        world_size,
    ):
Rick Ho's avatar
Rick Ho committed
155
        if world_size > 1:
156
157
158
159
160
161
162
            (local_output_buf,) = fmoe_cuda.global_gather(
                global_output_buf,
                local_expert_count,
                global_expert_count,
                local_batch_size,
                world_size,
            )
Rick Ho's avatar
Rick Ho committed
163
164
        else:
            local_output_buf = global_output_buf
165
        (output,) = fmoe_cuda.local_gather(local_output_buf, pos)
Rick Ho's avatar
Rick Ho committed
166

167
        ctx.moe_args = (global_output_buf.shape[0], world_size)
Rick Ho's avatar
Rick Ho committed
168
169
170
171
172
173
174
        variables = (pos, local_expert_count, global_expert_count)
        ctx.save_for_backward(*variables)
        return output

    @staticmethod
    def backward(ctx, grad_out):
        pos, local_expert_count, global_expert_count = ctx.saved_tensors
175
        fwd_batch_size, world_size = ctx.moe_args
176
        (grad_out_buf,) = fmoe_cuda.local_scatter(grad_out.contiguous(), pos)
Rick Ho's avatar
Rick Ho committed
177
        if world_size > 1:
178
179
180
181
182
183
184
            (global_grad_out_buf,) = fmoe_cuda.global_scatter(
                grad_out_buf,
                local_expert_count,
                global_expert_count,
                fwd_batch_size,
                world_size,
            )
Rick Ho's avatar
Rick Ho committed
185
186
187
        else:
            global_grad_out_buf = grad_out_buf
        return global_grad_out_buf, None, None, None, None, None
Rick Ho's avatar
Rick Ho committed
188
189
190


class AllGather(Function):
Sengxian's avatar
Sengxian committed
191
    r"""
Rick Ho's avatar
Rick Ho committed
192
    A wrapper for the All-Gather function to support auto-differentiation.
Sengxian's avatar
Sengxian committed
193
194
    """

Rick Ho's avatar
Rick Ho committed
195
196
197
198
199
200
201
202
203
204
205
206
    @staticmethod
    def forward(ctx, inp, rank, world_size, group):
        tensor_list = [torch.empty_like(inp) for _ in range(world_size)]
        torch.distributed.all_gather(tensor_list, inp, group=group)
        torch.cuda.synchronize()
        output = torch.cat(tensor_list, dim=0)
        ctx.args = rank, inp.shape[0]
        return output

    @staticmethod
    def backward(ctx, grad_out):
        rank, dim0 = ctx.args
Sengxian's avatar
Sengxian committed
207
        return grad_out[rank * dim0 : (rank + 1) * dim0], None, None, None
Sengxian's avatar
Sengxian committed
208
209
210


class Slice(Function):
Sengxian's avatar
Sengxian committed
211
    r"""
Sengxian's avatar
Sengxian committed
212
    A wrapper for the Slice function to support auto-differentiation.
Sengxian's avatar
Sengxian committed
213
214
    """

Sengxian's avatar
Sengxian committed
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
    @staticmethod
    def forward(ctx, inp, rank, world_size, group):
        B: int = inp.shape[0]
        local_batch_size = B // world_size
        batch_start = local_batch_size * rank
        batch_end = min(batch_start + local_batch_size, B)
        inp = inp[batch_start:batch_end]
        ctx.args = world_size, group
        return inp

    @staticmethod
    def backward(ctx, grad_out):
        world_size, group = ctx.args
        tensor_list = [torch.empty_like(grad_out) for _ in range(world_size)]
        torch.distributed.all_gather(tensor_list, grad_out, group=group)
        torch.cuda.synchronize()
        grad_out = torch.cat(tensor_list, dim=0)
        return grad_out, None, None, None