moe_function.py 3.19 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import torch
from torch.autograd import Function
import moe_cuda


class MOELocal(Function):
    @staticmethod
    def forward(ctx, inp, gate, weight):
        expert_count, pos = moe_cuda.expert_count(gate, weight.shape[0])
        input_buf, = moe_cuda.local_scatter(inp, pos)
        output_buf, = moe_cuda.forward(input_buf, weight, expert_count)
        output = moe_cuda.local_gather(output_buf, pos)

        variables = [input_buf, gate, weight, expert_count, pos]
        ctx.save_for_backward(*variables)

        return output[0]

    @staticmethod
    def backward(ctx, grad_out):
        input_buf, gate, weight, expert_count, pos = ctx.saved_tensors

        grad_out_buf, = moe_cuda.local_scatter(grad_out.contiguous(), pos)
        grad_inp_buf, grad_weight = moe_cuda.backward(
                grad_out_buf, input_buf, weight, expert_count)
        grad_inp, = moe_cuda.local_gather(grad_inp_buf, pos)

        return grad_inp, None, grad_weight


class MOEGlobal(Function):
    @staticmethod
    def forward(ctx, inp, gate, weight, world_size):
        num_expert = weight.shape[0]

        local_expert_count, pos = moe_cuda.expert_count(gate, 
                world_size * num_expert)
        global_expert_count = torch.empty_like(world_size, num_expert)
        torch.distributed.all_to_all(global_expert_count,
                local_expert_count.reshape(world_size, num_expert))
        batch_size = int(global_expert_count.sum().item())

        local_input_buf, = moe_cuda.local_scatter(inp, pos)
        global_input_buf, = moe_cuda.global_scatter(local_input_buf, 
                local_expert_count, global_expert_count,
                batch_size, world_size)

        global_output_buf, = moe_cuda.forward(input_buf, weight, expert_count)

        local_output_buf, = moe_cuda.global_gather(global_output_buf,
                local_expert_count, global_expert_count,
                inp.shape[0], world_size)
        output = moe_cuda.local_gather(local_output_buf, pos)

        variables = [input_buf, gate, weight, 
                local_expert_count, global_expert_count, 
                pos, num_expert, batch_size, world_size]
        ctx.save_for_backward(*variables)

        return output[0]

    @staticmethod
    def backward(ctx, grad_out):
        (input_buf, gate, weight, local_expert_count, global_expert_count, 
                pos, num_expert, batch_size, world_size) = ctx.saved_tensors

        grad_out_buf, = moe_cuda.local_scatter(grad_out.contiguous(), pos)
        global_grad_out_buf, = moe_cuda.global_scatter(grad_out_buf,
                local_expert_count, global_expert_count,
                batch_size, world_size)

        grad_inp_buf, grad_weight = moe_cuda.backward(
                global_grad_out_buf, input_buf, weight, expert_count)

        local_grad_inp_buf = moe_cuda.global_gather(grad_inp_buf,
                local_expert_count, global_expert_count,
                batch_size, world_size)
        grad_inp, = moe_cuda.local_gather(local_grad_inp_buf, pos)

        return grad_inp, None, grad_weight


def moe(inp, gate, weight, world_size):
    if world_size is not None:
        return MOEGlobal.apply(inp, gate, weight)
    else:
        return MOELocal.apply(inp, gate, weight)