functions.py 7.05 KB
Newer Older
1
2
3
4
5
6
r"""
The fmoe.functions module contains functions that are directly warped up from
C/CUDA functions to complete distributed communication, computation and gradient
computation.
"""

Rick Ho's avatar
Rick Ho committed
7
8
9
import torch
from torch.autograd import Function
import fmoe_cuda
10
from .utils import get_torch_default_comm
Rick Ho's avatar
Rick Ho committed
11
12


Rick Ho's avatar
Rick Ho committed
13
def moe_prepare_forward(gate, num_expert, world_size, comm=None):
14
15
16
17
18
19
20
21
22
23
    r"""
    Prepare necessary information from gate output for MoE computation.

    Args:
        gate: a 1-d Long Tensor representing the target expert of each input
        sample.
        num_expert: number of experts on each worker.
        world_size: number of workers that hold different experts.
        comm: the communicator of all workers in the expert-parallel group.
    """
Rick Ho's avatar
Rick Ho committed
24
    if world_size > 1:
25
26
        if comm is None:
            comm = get_torch_default_comm()
Rick Ho's avatar
Rick Ho committed
27
        fmoe_cuda.ensure_nccl(comm, gate)
Rick Ho's avatar
Rick Ho committed
28
29
30
31

    with torch.no_grad():
        _, pos = torch.sort(gate)
        gate_idx, gate_count = torch.unique(gate, return_counts=True)
32
33
34
35
        local_expert_count = torch.zeros(
            num_expert * world_size, device=gate.device, dtype=torch.long
        )
        local_expert_count.index_put_((gate_idx.long(),), gate_count)
Rick Ho's avatar
Rick Ho committed
36

37
        if world_size > 1:
38
39
40
            (global_expert_count,) = fmoe_cuda.expert_exchange(
                local_expert_count, num_expert, world_size
            )
41
42
        else:
            global_expert_count = local_expert_count
Rick Ho's avatar
Rick Ho committed
43
44
        fwd_expert_count = global_expert_count.view(world_size,
                num_expert).sum(dim=0)
Rick Ho's avatar
Rick Ho committed
45
        fwd_batch_size = int(fwd_expert_count.sum().item())
46
47
48
49
50
51
52
    return (
        pos,
        local_expert_count.cpu(),
        global_expert_count.cpu(),
        fwd_expert_count.cpu(),
        fwd_batch_size,
    )
Rick Ho's avatar
Rick Ho committed
53
54
55


class MOEScatter(Function):
56
57
58
59
60
    r"""
    Scatter input samples from [batch x sequences] to contiguous alone experts.
    If `world_size` is greater than 1, the samples will first be locally
    scattered, and then exchanged across workers.
    """
Sengxian's avatar
Sengxian committed
61

Rick Ho's avatar
Rick Ho committed
62
    @staticmethod
63
64
65
66
67
68
69
70
71
72
    def forward(
        ctx,
        inp,
        pos,
        local_expert_count,
        global_expert_count,
        fwd_batch_size,
        world_size,
    ):
        (local_input_buf,) = fmoe_cuda.local_scatter(inp, pos)
Rick Ho's avatar
Rick Ho committed
73
        if world_size > 1:
74
75
76
77
78
79
80
            (global_input_buf,) = fmoe_cuda.global_scatter(
                local_input_buf,
                local_expert_count,
                global_expert_count,
                fwd_batch_size,
                world_size,
            )
Rick Ho's avatar
Rick Ho committed
81
82
        else:
            global_input_buf = local_input_buf
83
        ctx.moe_args = inp.shape[0], world_size
Rick Ho's avatar
Rick Ho committed
84
        variables = (pos, local_expert_count, global_expert_count)
85
        ctx.save_for_backward(*variables)
Rick Ho's avatar
Rick Ho committed
86
87
88
89
90
        return global_input_buf

    @staticmethod
    def backward(ctx, global_grad_in):
        (pos, local_expert_count, global_expert_count) = ctx.saved_tensors
91
        (local_batch_size, world_size) = ctx.moe_args
Rick Ho's avatar
Rick Ho committed
92
93

        if world_size > 1:
94
95
96
97
98
99
100
            (local_grad_in,) = fmoe_cuda.global_gather(
                global_grad_in,
                local_expert_count,
                global_expert_count,
                local_batch_size,
                world_size,
            )
Rick Ho's avatar
Rick Ho committed
101
102
        else:
            local_grad_in = global_grad_in
103
        (grad_in,) = fmoe_cuda.local_gather(local_grad_in, pos)
Rick Ho's avatar
Rick Ho committed
104
105
106
107
        return grad_in, None, None, None, None, None


class MOELinear(Function):
108
109
110
    r"""
    Computes linear operators within one GPU on different experts simutaneously.
    """
Sengxian's avatar
Sengxian committed
111

Rick Ho's avatar
Rick Ho committed
112
113
    @staticmethod
    def forward(ctx, global_input_buf, weight, fwd_expert_count):
114
115
116
        (global_output_buf,) = fmoe_cuda.forward(
            global_input_buf, weight, fwd_expert_count
        )
117
        variables = (global_input_buf, weight, fwd_expert_count)
Rick Ho's avatar
Rick Ho committed
118
119
120
121
122
123
        ctx.save_for_backward(*variables)
        return global_output_buf

    @staticmethod
    def backward(ctx, grad_out):
        (input_buf, weight, fwd_expert_count) = ctx.saved_tensors
124
        grad_inp_buf, grad_weight = fmoe_cuda.backward(
125
126
            grad_out, input_buf, weight, fwd_expert_count
        )
Rick Ho's avatar
Rick Ho committed
127
128
129
130
        return grad_inp_buf, grad_weight, None


class MOEGather(Function):
131
132
133
134
    r"""
    Gather output samples from contiguous alone experts back to [batch x
    sequences]. Works symmetrically with MOEScatter.
    """
Sengxian's avatar
Sengxian committed
135

Rick Ho's avatar
Rick Ho committed
136
    @staticmethod
137
138
139
140
141
142
143
144
145
    def forward(
        ctx,
        global_output_buf,
        pos,
        local_expert_count,
        global_expert_count,
        local_batch_size,
        world_size,
    ):
Rick Ho's avatar
Rick Ho committed
146
        if world_size > 1:
147
148
149
150
151
152
153
            (local_output_buf,) = fmoe_cuda.global_gather(
                global_output_buf,
                local_expert_count,
                global_expert_count,
                local_batch_size,
                world_size,
            )
Rick Ho's avatar
Rick Ho committed
154
155
        else:
            local_output_buf = global_output_buf
156
        (output,) = fmoe_cuda.local_gather(local_output_buf, pos)
Rick Ho's avatar
Rick Ho committed
157

158
        ctx.moe_args = (global_output_buf.shape[0], world_size)
Rick Ho's avatar
Rick Ho committed
159
160
161
162
163
164
165
        variables = (pos, local_expert_count, global_expert_count)
        ctx.save_for_backward(*variables)
        return output

    @staticmethod
    def backward(ctx, grad_out):
        pos, local_expert_count, global_expert_count = ctx.saved_tensors
166
        fwd_batch_size, world_size = ctx.moe_args
167
        (grad_out_buf,) = fmoe_cuda.local_scatter(grad_out.contiguous(), pos)
Rick Ho's avatar
Rick Ho committed
168
        if world_size > 1:
169
170
171
172
173
174
175
            (global_grad_out_buf,) = fmoe_cuda.global_scatter(
                grad_out_buf,
                local_expert_count,
                global_expert_count,
                fwd_batch_size,
                world_size,
            )
Rick Ho's avatar
Rick Ho committed
176
177
178
        else:
            global_grad_out_buf = grad_out_buf
        return global_grad_out_buf, None, None, None, None, None
Rick Ho's avatar
Rick Ho committed
179
180
181


class AllGather(Function):
Sengxian's avatar
Sengxian committed
182
    r"""
Rick Ho's avatar
Rick Ho committed
183
    A wrapper for the All-Gather function to support auto-differentiation.
Sengxian's avatar
Sengxian committed
184
185
    """

Rick Ho's avatar
Rick Ho committed
186
187
188
189
190
191
192
193
194
195
196
197
    @staticmethod
    def forward(ctx, inp, rank, world_size, group):
        tensor_list = [torch.empty_like(inp) for _ in range(world_size)]
        torch.distributed.all_gather(tensor_list, inp, group=group)
        torch.cuda.synchronize()
        output = torch.cat(tensor_list, dim=0)
        ctx.args = rank, inp.shape[0]
        return output

    @staticmethod
    def backward(ctx, grad_out):
        rank, dim0 = ctx.args
Sengxian's avatar
Sengxian committed
198
        return grad_out[rank * dim0 : (rank + 1) * dim0], None, None, None
Sengxian's avatar
Sengxian committed
199
200
201


class Slice(Function):
Sengxian's avatar
Sengxian committed
202
    r"""
Sengxian's avatar
Sengxian committed
203
    A wrapper for the Slice function to support auto-differentiation.
Sengxian's avatar
Sengxian committed
204
205
    """

Sengxian's avatar
Sengxian committed
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
    @staticmethod
    def forward(ctx, inp, rank, world_size, group):
        B: int = inp.shape[0]
        local_batch_size = B // world_size
        batch_start = local_batch_size * rank
        batch_end = min(batch_start + local_batch_size, B)
        inp = inp[batch_start:batch_end]
        ctx.args = world_size, group
        return inp

    @staticmethod
    def backward(ctx, grad_out):
        world_size, group = ctx.args
        tensor_list = [torch.empty_like(grad_out) for _ in range(world_size)]
        torch.distributed.all_gather(tensor_list, grad_out, group=group)
        torch.cuda.synchronize()
        grad_out = torch.cat(tensor_list, dim=0)
        return grad_out, None, None, None