functions.py 7.2 KB
Newer Older
1
2
3
4
5
6
r"""
The fmoe.functions module contains functions that are directly warped up from
C/CUDA functions to complete distributed communication, computation and gradient
computation.
"""

Rick Ho's avatar
Rick Ho committed
7
8
9
import torch
from torch.autograd import Function
import fmoe_cuda
10
from .utils import get_torch_default_comm
Rick Ho's avatar
Rick Ho committed
11
12


13
def ensure_comm(t, comm):
Rick Ho's avatar
Rick Ho committed
14
15
16
17
18
    if comm is None:
        comm = get_torch_default_comm()
    fmoe_cuda.ensure_nccl(comm, t)


19
def count_by_gate(gate, num_expert, world_size, require_pos=True):
Rick Ho's avatar
Rick Ho committed
20
21
    with torch.no_grad():
        local_expert_count = torch.zeros(
22
            num_expert * world_size, device=gate.device, dtype=torch.int32
Rick Ho's avatar
Rick Ho committed
23
        )
24
25
        fmoe_cuda.expert_count(gate, local_expert_count)
        local_expert_count = local_expert_count.long()
Rick Ho's avatar
Rick Ho committed
26
27

        if world_size > 1:
28
            global_expert_count = fmoe_cuda.expert_exchange(
Rick Ho's avatar
Rick Ho committed
29
30
31
32
                local_expert_count, num_expert, world_size
            )
        else:
            global_expert_count = local_expert_count
33
34
35
36
37
38
        if not require_pos:
            pos = None
        else:
            lec_cum = torch.cumsum(local_expert_count, dim=0).int()
            pos_size = lec_cum[-1].item()
            pos = torch.empty((pos_size,), device=gate.device, dtype=torch.long)
39
            fmoe_cuda.assign_pos(lec_cum, gate, pos)
Rick Ho's avatar
Rick Ho committed
40
41
42
    return pos, local_expert_count, global_expert_count


43
def prepare_forward(gate, num_expert, world_size):
44
45
46
47
48
49
50
51
52
53
    r"""
    Prepare necessary information from gate output for MoE computation.

    Args:
        gate: a 1-d Long Tensor representing the target expert of each input
        sample.
        num_expert: number of experts on each worker.
        world_size: number of workers that hold different experts.
        comm: the communicator of all workers in the expert-parallel group.
    """
Rick Ho's avatar
Rick Ho committed
54
    pos, local_expert_count, global_expert_count = count_by_gate(gate, 
55
            num_expert, world_size)
Rick Ho's avatar
Rick Ho committed
56
    with torch.no_grad():
Rick Ho's avatar
Rick Ho committed
57
58
        fwd_expert_count = global_expert_count.view(world_size,
                num_expert).sum(dim=0)
Rick Ho's avatar
Rick Ho committed
59
        fwd_batch_size = int(fwd_expert_count.sum().item())
60
61
62
63
64
65
66
    return (
        pos,
        local_expert_count.cpu(),
        global_expert_count.cpu(),
        fwd_expert_count.cpu(),
        fwd_batch_size,
    )
Rick Ho's avatar
Rick Ho committed
67
68


69
70
71
72
73
def _local_scatter(inp, pos):
    inp_buf = torch.index_select(inp, 0, pos)
    return inp_buf


Rick Ho's avatar
Rick Ho committed
74
def _local_gather(inp, pos, out_batch_size, maybe_overlap=True):
75
76
    inp_buf = torch.zeros(out_batch_size, inp.shape[-1],
            dtype=inp.dtype, device=inp.device)
Rick Ho's avatar
Rick Ho committed
77
78
79
80
    if maybe_overlap:
        inp_buf.index_add_(0, pos, inp)
    else:
        inp_buf.index_copy_(0, pos, inp)
81
82
83
    return inp_buf


Rick Ho's avatar
Rick Ho committed
84
class MOEScatter(Function):
85
86
87
88
89
    r"""
    Scatter input samples from [batch x sequences] to contiguous alone experts.
    If `world_size` is greater than 1, the samples will first be locally
    scattered, and then exchanged across workers.
    """
Sengxian's avatar
Sengxian committed
90

Rick Ho's avatar
Rick Ho committed
91
    @staticmethod
92
93
94
95
96
97
98
99
100
    def forward(
        ctx,
        inp,
        pos,
        local_expert_count,
        global_expert_count,
        fwd_batch_size,
        world_size,
    ):
101
        local_input_buf = _local_scatter(inp, pos)
Rick Ho's avatar
Rick Ho committed
102
        if world_size > 1:
103
            global_input_buf = fmoe_cuda.global_scatter(
104
105
106
107
108
109
                local_input_buf,
                local_expert_count,
                global_expert_count,
                fwd_batch_size,
                world_size,
            )
Rick Ho's avatar
Rick Ho committed
110
111
        else:
            global_input_buf = local_input_buf
Rich Ho's avatar
Rich Ho committed
112
        ctx.moe_args = inp.shape[0], pos.shape[0], world_size
Rick Ho's avatar
Rick Ho committed
113
        variables = (pos, local_expert_count, global_expert_count)
114
        ctx.save_for_backward(*variables)
Rick Ho's avatar
Rick Ho committed
115
116
117
118
119
        return global_input_buf

    @staticmethod
    def backward(ctx, global_grad_in):
        (pos, local_expert_count, global_expert_count) = ctx.saved_tensors
Rich Ho's avatar
Rich Ho committed
120
        (inp_batch_size, buf_batch_size, world_size) = ctx.moe_args
Rick Ho's avatar
Rick Ho committed
121
122

        if world_size > 1:
123
            local_grad_in = fmoe_cuda.global_gather(
124
125
126
                global_grad_in,
                local_expert_count,
                global_expert_count,
Rich Ho's avatar
Rich Ho committed
127
                buf_batch_size,
128
129
                world_size,
            )
Rick Ho's avatar
Rick Ho committed
130
131
        else:
            local_grad_in = global_grad_in
Rich Ho's avatar
Rich Ho committed
132
        grad_in = _local_gather(local_grad_in, pos, inp_batch_size)
Rick Ho's avatar
Rick Ho committed
133
134
135
        return grad_in, None, None, None, None, None

class MOEGather(Function):
136
137
138
139
    r"""
    Gather output samples from contiguous alone experts back to [batch x
    sequences]. Works symmetrically with MOEScatter.
    """
Sengxian's avatar
Sengxian committed
140

Rick Ho's avatar
Rick Ho committed
141
    @staticmethod
142
143
144
145
146
147
148
149
150
    def forward(
        ctx,
        global_output_buf,
        pos,
        local_expert_count,
        global_expert_count,
        local_batch_size,
        world_size,
    ):
Rick Ho's avatar
Rick Ho committed
151
        if world_size > 1:
152
            local_output_buf = fmoe_cuda.global_gather(
153
154
155
                global_output_buf,
                local_expert_count,
                global_expert_count,
156
                pos.shape[0],
157
158
                world_size,
            )
Rick Ho's avatar
Rick Ho committed
159
160
        else:
            local_output_buf = global_output_buf
Rick Ho's avatar
Rick Ho committed
161
162
        output = _local_gather(local_output_buf, pos, local_batch_size,
                maybe_overlap=False)
Rick Ho's avatar
Rick Ho committed
163

164
        ctx.moe_args = (global_output_buf.shape[0], world_size)
Rick Ho's avatar
Rick Ho committed
165
166
167
168
169
170
171
        variables = (pos, local_expert_count, global_expert_count)
        ctx.save_for_backward(*variables)
        return output

    @staticmethod
    def backward(ctx, grad_out):
        pos, local_expert_count, global_expert_count = ctx.saved_tensors
172
        fwd_batch_size, world_size = ctx.moe_args
173
        grad_out_buf = _local_scatter(grad_out.contiguous(), pos)
Rick Ho's avatar
Rick Ho committed
174
        if world_size > 1:
175
            global_grad_out_buf = fmoe_cuda.global_scatter(
176
177
178
179
180
181
                grad_out_buf,
                local_expert_count,
                global_expert_count,
                fwd_batch_size,
                world_size,
            )
Rick Ho's avatar
Rick Ho committed
182
183
184
        else:
            global_grad_out_buf = grad_out_buf
        return global_grad_out_buf, None, None, None, None, None
Rick Ho's avatar
Rick Ho committed
185
186
187


class AllGather(Function):
Sengxian's avatar
Sengxian committed
188
    r"""
Rick Ho's avatar
Rick Ho committed
189
    A wrapper for the All-Gather function to support auto-differentiation.
Sengxian's avatar
Sengxian committed
190
191
    """

Rick Ho's avatar
Rick Ho committed
192
193
194
195
196
197
198
199
200
201
202
203
    @staticmethod
    def forward(ctx, inp, rank, world_size, group):
        tensor_list = [torch.empty_like(inp) for _ in range(world_size)]
        torch.distributed.all_gather(tensor_list, inp, group=group)
        torch.cuda.synchronize()
        output = torch.cat(tensor_list, dim=0)
        ctx.args = rank, inp.shape[0]
        return output

    @staticmethod
    def backward(ctx, grad_out):
        rank, dim0 = ctx.args
Sengxian's avatar
Sengxian committed
204
        return grad_out[rank * dim0 : (rank + 1) * dim0], None, None, None
Sengxian's avatar
Sengxian committed
205
206
207


class Slice(Function):
Sengxian's avatar
Sengxian committed
208
    r"""
Sengxian's avatar
Sengxian committed
209
    A wrapper for the Slice function to support auto-differentiation.
Sengxian's avatar
Sengxian committed
210
211
    """

Sengxian's avatar
Sengxian committed
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
    @staticmethod
    def forward(ctx, inp, rank, world_size, group):
        B: int = inp.shape[0]
        local_batch_size = B // world_size
        batch_start = local_batch_size * rank
        batch_end = min(batch_start + local_batch_size, B)
        inp = inp[batch_start:batch_end]
        ctx.args = world_size, group
        return inp

    @staticmethod
    def backward(ctx, grad_out):
        world_size, group = ctx.args
        tensor_list = [torch.empty_like(grad_out) for _ in range(world_size)]
        torch.distributed.all_gather(tensor_list, grad_out, group=group)
        torch.cuda.synchronize()
        grad_out = torch.cat(tensor_list, dim=0)
        return grad_out, None, None, None