functions.py 7.04 KB
Newer Older
1
2
3
4
5
6
r"""
The fmoe.functions module contains functions that are directly warped up from
C/CUDA functions to complete distributed communication, computation and gradient
computation.
"""

Rick Ho's avatar
Rick Ho committed
7
8
9
import torch
from torch.autograd import Function
import fmoe_cuda
10
from .utils import get_torch_default_comm
Rick Ho's avatar
Rick Ho committed
11
12


Rick Ho's avatar
Rick Ho committed
13
def moe_prepare_forward(gate, num_expert, world_size, comm=None):
14
15
16
17
18
19
20
21
22
23
    r"""
    Prepare necessary information from gate output for MoE computation.

    Args:
        gate: a 1-d Long Tensor representing the target expert of each input
        sample.
        num_expert: number of experts on each worker.
        world_size: number of workers that hold different experts.
        comm: the communicator of all workers in the expert-parallel group.
    """
Rick Ho's avatar
Rick Ho committed
24
    if world_size > 1:
25
26
        if comm is None:
            comm = get_torch_default_comm()
Rick Ho's avatar
Rick Ho committed
27
        fmoe_cuda.ensure_nccl(comm, gate)
Rick Ho's avatar
Rick Ho committed
28
29
30
31

    with torch.no_grad():
        _, pos = torch.sort(gate)
        gate_idx, gate_count = torch.unique(gate, return_counts=True)
32
33
34
35
        local_expert_count = torch.zeros(
            num_expert * world_size, device=gate.device, dtype=torch.long
        )
        local_expert_count.index_put_((gate_idx.long(),), gate_count)
Rick Ho's avatar
Rick Ho committed
36

37
        if world_size > 1:
38
39
40
            (global_expert_count,) = fmoe_cuda.expert_exchange(
                local_expert_count, num_expert, world_size
            )
41
42
        else:
            global_expert_count = local_expert_count
Rick Ho's avatar
Rick Ho committed
43
44
        fwd_expert_count = global_expert_count.view(world_size,
                num_expert).sum(dim=0)
Rick Ho's avatar
Rick Ho committed
45
        fwd_batch_size = int(fwd_expert_count.sum().item())
46
47
48
49
50
51
52
    return (
        pos,
        local_expert_count.cpu(),
        global_expert_count.cpu(),
        fwd_expert_count.cpu(),
        fwd_batch_size,
    )
Rick Ho's avatar
Rick Ho committed
53
54
55


class MOEScatter(Function):
56
57
58
59
60
    r"""
    Scatter input samples from [batch x sequences] to contiguous alone experts.
    If `world_size` is greater than 1, the samples will first be locally
    scattered, and then exchanged across workers.
    """
Rick Ho's avatar
Rick Ho committed
61
    @staticmethod
62
63
64
65
66
67
68
69
70
71
    def forward(
        ctx,
        inp,
        pos,
        local_expert_count,
        global_expert_count,
        fwd_batch_size,
        world_size,
    ):
        (local_input_buf,) = fmoe_cuda.local_scatter(inp, pos)
Rick Ho's avatar
Rick Ho committed
72
        if world_size > 1:
73
74
75
76
77
78
79
            (global_input_buf,) = fmoe_cuda.global_scatter(
                local_input_buf,
                local_expert_count,
                global_expert_count,
                fwd_batch_size,
                world_size,
            )
Rick Ho's avatar
Rick Ho committed
80
81
        else:
            global_input_buf = local_input_buf
82
        ctx.moe_args = inp.shape[0], world_size
Rick Ho's avatar
Rick Ho committed
83
        variables = (pos, local_expert_count, global_expert_count)
84
        ctx.save_for_backward(*variables)
Rick Ho's avatar
Rick Ho committed
85
86
87
88
89
        return global_input_buf

    @staticmethod
    def backward(ctx, global_grad_in):
        (pos, local_expert_count, global_expert_count) = ctx.saved_tensors
90
        (local_batch_size, world_size) = ctx.moe_args
Rick Ho's avatar
Rick Ho committed
91
92

        if world_size > 1:
93
94
95
96
97
98
99
            (local_grad_in,) = fmoe_cuda.global_gather(
                global_grad_in,
                local_expert_count,
                global_expert_count,
                local_batch_size,
                world_size,
            )
Rick Ho's avatar
Rick Ho committed
100
101
        else:
            local_grad_in = global_grad_in
102
        (grad_in,) = fmoe_cuda.local_gather(local_grad_in, pos)
Rick Ho's avatar
Rick Ho committed
103
104
105
106
        return grad_in, None, None, None, None, None


class MOELinear(Function):
107
108
109
    r"""
    Computes linear operators within one GPU on different experts simutaneously.
    """
Rick Ho's avatar
Rick Ho committed
110
111
    @staticmethod
    def forward(ctx, global_input_buf, weight, fwd_expert_count):
112
113
114
        (global_output_buf,) = fmoe_cuda.forward(
            global_input_buf, weight, fwd_expert_count
        )
115
        variables = (global_input_buf, weight, fwd_expert_count)
Rick Ho's avatar
Rick Ho committed
116
117
118
119
120
121
        ctx.save_for_backward(*variables)
        return global_output_buf

    @staticmethod
    def backward(ctx, grad_out):
        (input_buf, weight, fwd_expert_count) = ctx.saved_tensors
122
        grad_inp_buf, grad_weight = fmoe_cuda.backward(
123
124
            grad_out, input_buf, weight, fwd_expert_count
        )
Rick Ho's avatar
Rick Ho committed
125
126
127
128
        return grad_inp_buf, grad_weight, None


class MOEGather(Function):
129
130
131
132
    r"""
    Gather output samples from contiguous alone experts back to [batch x
    sequences]. Works symmetrically with MOEScatter.
    """
Rick Ho's avatar
Rick Ho committed
133
    @staticmethod
134
135
136
137
138
139
140
141
142
    def forward(
        ctx,
        global_output_buf,
        pos,
        local_expert_count,
        global_expert_count,
        local_batch_size,
        world_size,
    ):
Rick Ho's avatar
Rick Ho committed
143
        if world_size > 1:
144
145
146
147
148
149
150
            (local_output_buf,) = fmoe_cuda.global_gather(
                global_output_buf,
                local_expert_count,
                global_expert_count,
                local_batch_size,
                world_size,
            )
Rick Ho's avatar
Rick Ho committed
151
152
        else:
            local_output_buf = global_output_buf
153
        (output,) = fmoe_cuda.local_gather(local_output_buf, pos)
Rick Ho's avatar
Rick Ho committed
154

155
        ctx.moe_args = (global_output_buf.shape[0], world_size)
Rick Ho's avatar
Rick Ho committed
156
157
158
159
160
161
162
        variables = (pos, local_expert_count, global_expert_count)
        ctx.save_for_backward(*variables)
        return output

    @staticmethod
    def backward(ctx, grad_out):
        pos, local_expert_count, global_expert_count = ctx.saved_tensors
163
        fwd_batch_size, world_size = ctx.moe_args
164
        (grad_out_buf,) = fmoe_cuda.local_scatter(grad_out.contiguous(), pos)
Rick Ho's avatar
Rick Ho committed
165
        if world_size > 1:
166
167
168
169
170
171
172
            (global_grad_out_buf,) = fmoe_cuda.global_scatter(
                grad_out_buf,
                local_expert_count,
                global_expert_count,
                fwd_batch_size,
                world_size,
            )
Rick Ho's avatar
Rick Ho committed
173
174
175
        else:
            global_grad_out_buf = grad_out_buf
        return global_grad_out_buf, None, None, None, None, None
Rick Ho's avatar
Rick Ho committed
176
177
178


class AllGather(Function):
Rick Ho's avatar
Rick Ho committed
179
180
181
    r'''
    A wrapper for the All-Gather function to support auto-differentiation.
    '''
Rick Ho's avatar
Rick Ho committed
182
183
184
185
186
187
188
189
190
191
192
193
194
    @staticmethod
    def forward(ctx, inp, rank, world_size, group):
        tensor_list = [torch.empty_like(inp) for _ in range(world_size)]
        torch.distributed.all_gather(tensor_list, inp, group=group)
        torch.cuda.synchronize()
        output = torch.cat(tensor_list, dim=0)
        ctx.args = rank, inp.shape[0]
        return output

    @staticmethod
    def backward(ctx, grad_out):
        rank, dim0 = ctx.args
        return grad_out[rank * dim0:(rank + 1) * dim0], None, None, None
Sengxian's avatar
Sengxian committed
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218


class Slice(Function):
    r'''
    A wrapper for the Slice function to support auto-differentiation.
    '''
    @staticmethod
    def forward(ctx, inp, rank, world_size, group):
        B: int = inp.shape[0]
        local_batch_size = B // world_size
        batch_start = local_batch_size * rank
        batch_end = min(batch_start + local_batch_size, B)
        inp = inp[batch_start:batch_end]
        ctx.args = world_size, group
        return inp

    @staticmethod
    def backward(ctx, grad_out):
        world_size, group = ctx.args
        tensor_list = [torch.empty_like(grad_out) for _ in range(world_size)]
        torch.distributed.all_gather(tensor_list, grad_out, group=group)
        torch.cuda.synchronize()
        grad_out = torch.cat(tensor_list, dim=0)
        return grad_out, None, None, None