"examples/vscode:/vscode.git/clone" did not exist on "8f061987404d5b935ca26e0953a1e84a7131cf33"
functions.py 5.17 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
4
5
import torch
from torch.autograd import Function
import fmoe_cuda


Rick Ho's avatar
Rick Ho committed
6
7
8
9
10
def moe_prepare_forward(gate, num_expert, world_size, comm=None):
    if comm is None:
        comm = torch.distributed.distributed_c10d._default_pg
    if world_size > 1:
        fmoe_cuda.ensure_nccl(comm, gate)
Rick Ho's avatar
Rick Ho committed
11
12
13
14

    with torch.no_grad():
        _, pos = torch.sort(gate)
        gate_idx, gate_count = torch.unique(gate, return_counts=True)
15
16
17
18
        local_expert_count = torch.zeros(
            num_expert * world_size, device=gate.device, dtype=torch.long
        )
        local_expert_count.index_put_((gate_idx.long(),), gate_count)
Rick Ho's avatar
Rick Ho committed
19

20
        if world_size > 1:
21
22
23
            (global_expert_count,) = fmoe_cuda.expert_exchange(
                local_expert_count, num_expert, world_size
            )
24
25
        else:
            global_expert_count = local_expert_count
26
        fwd_expert_count = global_expert_count.view(world_size, num_expert).sum(dim=0)
Rick Ho's avatar
Rick Ho committed
27
        fwd_batch_size = int(fwd_expert_count.sum().item())
28
29
30
31
32
33
34
    return (
        pos,
        local_expert_count.cpu(),
        global_expert_count.cpu(),
        fwd_expert_count.cpu(),
        fwd_batch_size,
    )
Rick Ho's avatar
Rick Ho committed
35
36
37
38


class MOEScatter(Function):
    @staticmethod
39
40
41
42
43
44
45
46
47
48
    def forward(
        ctx,
        inp,
        pos,
        local_expert_count,
        global_expert_count,
        fwd_batch_size,
        world_size,
    ):
        (local_input_buf,) = fmoe_cuda.local_scatter(inp, pos)
Rick Ho's avatar
Rick Ho committed
49
        if world_size > 1:
50
51
52
53
54
55
56
            (global_input_buf,) = fmoe_cuda.global_scatter(
                local_input_buf,
                local_expert_count,
                global_expert_count,
                fwd_batch_size,
                world_size,
            )
Rick Ho's avatar
Rick Ho committed
57
58
59
60
        else:
            global_input_buf = local_input_buf
        ctx.moe_args = fwd_batch_size, inp.shape[0], world_size
        variables = (pos, local_expert_count, global_expert_count)
61
        ctx.save_for_backward(*variables)
Rick Ho's avatar
Rick Ho committed
62
63
64
65
66
67
68
69
        return global_input_buf

    @staticmethod
    def backward(ctx, global_grad_in):
        (pos, local_expert_count, global_expert_count) = ctx.saved_tensors
        (fwd_batch_size, local_batch_size, world_size) = ctx.moe_args

        if world_size > 1:
70
71
72
73
74
75
76
            (local_grad_in,) = fmoe_cuda.global_gather(
                global_grad_in,
                local_expert_count,
                global_expert_count,
                local_batch_size,
                world_size,
            )
Rick Ho's avatar
Rick Ho committed
77
78
        else:
            local_grad_in = global_grad_in
79
        (grad_in,) = fmoe_cuda.local_gather(local_grad_in, pos)
Rick Ho's avatar
Rick Ho committed
80
81
82
83
84
85
        return grad_in, None, None, None, None, None


class MOELinear(Function):
    @staticmethod
    def forward(ctx, global_input_buf, weight, fwd_expert_count):
86
87
88
        (global_output_buf,) = fmoe_cuda.forward(
            global_input_buf, weight, fwd_expert_count
        )
89
        variables = (global_input_buf, weight, fwd_expert_count)
Rick Ho's avatar
Rick Ho committed
90
91
92
93
94
95
        ctx.save_for_backward(*variables)
        return global_output_buf

    @staticmethod
    def backward(ctx, grad_out):
        (input_buf, weight, fwd_expert_count) = ctx.saved_tensors
96
        grad_inp_buf, grad_weight = fmoe_cuda.backward(
97
98
            grad_out, input_buf, weight, fwd_expert_count
        )
Rick Ho's avatar
Rick Ho committed
99
100
101
102
103
        return grad_inp_buf, grad_weight, None


class MOEGather(Function):
    @staticmethod
104
105
106
107
108
109
110
111
112
    def forward(
        ctx,
        global_output_buf,
        pos,
        local_expert_count,
        global_expert_count,
        local_batch_size,
        world_size,
    ):
Rick Ho's avatar
Rick Ho committed
113
        if world_size > 1:
114
115
116
117
118
119
120
            (local_output_buf,) = fmoe_cuda.global_gather(
                global_output_buf,
                local_expert_count,
                global_expert_count,
                local_batch_size,
                world_size,
            )
Rick Ho's avatar
Rick Ho committed
121
122
        else:
            local_output_buf = global_output_buf
123
        (output,) = fmoe_cuda.local_gather(local_output_buf, pos)
Rick Ho's avatar
Rick Ho committed
124

125
        ctx.moe_args = local_batch_size, global_output_buf.shape[0], world_size
Rick Ho's avatar
Rick Ho committed
126
127
128
129
130
131
132
        variables = (pos, local_expert_count, global_expert_count)
        ctx.save_for_backward(*variables)
        return output

    @staticmethod
    def backward(ctx, grad_out):
        pos, local_expert_count, global_expert_count = ctx.saved_tensors
133
        local_batch_size, fwd_batch_size, world_size = ctx.moe_args
134
        (grad_out_buf,) = fmoe_cuda.local_scatter(grad_out.contiguous(), pos)
Rick Ho's avatar
Rick Ho committed
135
        if world_size > 1:
136
137
138
139
140
141
142
            (global_grad_out_buf,) = fmoe_cuda.global_scatter(
                grad_out_buf,
                local_expert_count,
                global_expert_count,
                fwd_batch_size,
                world_size,
            )
Rick Ho's avatar
Rick Ho committed
143
144
145
        else:
            global_grad_out_buf = grad_out_buf
        return global_grad_out_buf, None, None, None, None, None
Rick Ho's avatar
Rick Ho committed
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161


class AllGather(Function):
    @staticmethod
    def forward(ctx, inp, rank, world_size, group):
        tensor_list = [torch.empty_like(inp) for _ in range(world_size)]
        torch.distributed.all_gather(tensor_list, inp, group=group)
        torch.cuda.synchronize()
        output = torch.cat(tensor_list, dim=0)
        ctx.args = rank, inp.shape[0]
        return output

    @staticmethod
    def backward(ctx, grad_out):
        rank, dim0 = ctx.args
        return grad_out[rank * dim0:(rank + 1) * dim0], None, None, None