moe_function.py 3.33 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import torch
from torch.autograd import Function
import moe_cuda


class MOELocal(Function):
    @staticmethod
    def forward(ctx, inp, gate, weight):
        expert_count, pos = moe_cuda.expert_count(gate, weight.shape[0])
        input_buf, = moe_cuda.local_scatter(inp, pos)
        output_buf, = moe_cuda.forward(input_buf, weight, expert_count)
        output = moe_cuda.local_gather(output_buf, pos)

        variables = [input_buf, gate, weight, expert_count, pos]
        ctx.save_for_backward(*variables)

        return output[0]

    @staticmethod
    def backward(ctx, grad_out):
        input_buf, gate, weight, expert_count, pos = ctx.saved_tensors

        grad_out_buf, = moe_cuda.local_scatter(grad_out.contiguous(), pos)
        grad_inp_buf, grad_weight = moe_cuda.backward(
                grad_out_buf, input_buf, weight, expert_count)
        grad_inp, = moe_cuda.local_gather(grad_inp_buf, pos)

        return grad_inp, None, grad_weight


class MOEGlobal(Function):
    @staticmethod
    def forward(ctx, inp, gate, weight, world_size):
        num_expert = weight.shape[0]

        local_expert_count, pos = moe_cuda.expert_count(gate, 
                world_size * num_expert)
Rick Ho's avatar
Rick Ho committed
38
39
40
        global_expert_count, fwd_expert_count = moe_cuda.expert_exchange(
                local_expert_count, num_expert, world_size)
        fwd_batch_size = int(fwd_expert_count.sum().item())
Rick Ho's avatar
Rick Ho committed
41
42
43
44

        local_input_buf, = moe_cuda.local_scatter(inp, pos)
        global_input_buf, = moe_cuda.global_scatter(local_input_buf, 
                local_expert_count, global_expert_count,
Rick Ho's avatar
Rick Ho committed
45
                fwd_batch_size, world_size)
Rick Ho's avatar
Rick Ho committed
46

Rick Ho's avatar
Rick Ho committed
47
48
        global_output_buf, = moe_cuda.forward(global_input_buf, weight, 
                fwd_expert_count)
Rick Ho's avatar
Rick Ho committed
49
50
51
52

        local_output_buf, = moe_cuda.global_gather(global_output_buf,
                local_expert_count, global_expert_count,
                inp.shape[0], world_size)
Rick Ho's avatar
Rick Ho committed
53
        output, = moe_cuda.local_gather(local_output_buf, pos)
Rick Ho's avatar
Rick Ho committed
54

Rick Ho's avatar
Rick Ho committed
55
56
57
58
        variables = (global_input_buf, gate, weight, 
                local_expert_count, global_expert_count, fwd_expert_count,
                pos)
        ctx.moe_args = (num_expert, inp.shape[0], fwd_batch_size, world_size)
Rick Ho's avatar
Rick Ho committed
59
60
        ctx.save_for_backward(*variables)

Rick Ho's avatar
Rick Ho committed
61
        return output
Rick Ho's avatar
Rick Ho committed
62
63
64

    @staticmethod
    def backward(ctx, grad_out):
Rick Ho's avatar
Rick Ho committed
65
66
67
68
        (input_buf, gate, weight, 
                local_expert_count, global_expert_count, fwd_expert_count, 
                pos) = ctx.saved_tensors
        num_expert, local_batch_size, fwd_batch_size, world_size = ctx.moe_args
Rick Ho's avatar
Rick Ho committed
69
70
71
72

        grad_out_buf, = moe_cuda.local_scatter(grad_out.contiguous(), pos)
        global_grad_out_buf, = moe_cuda.global_scatter(grad_out_buf,
                local_expert_count, global_expert_count,
Rick Ho's avatar
Rick Ho committed
73
                fwd_batch_size, world_size)
Rick Ho's avatar
Rick Ho committed
74
75

        grad_inp_buf, grad_weight = moe_cuda.backward(
Rick Ho's avatar
Rick Ho committed
76
                global_grad_out_buf, input_buf, weight, fwd_expert_count)
Rick Ho's avatar
Rick Ho committed
77

Rick Ho's avatar
Rick Ho committed
78
        local_grad_inp_buf, = moe_cuda.global_gather(grad_inp_buf,
Rick Ho's avatar
Rick Ho committed
79
                local_expert_count, global_expert_count,
Rick Ho's avatar
Rick Ho committed
80
                local_batch_size, world_size)
Rick Ho's avatar
Rick Ho committed
81
82
        grad_inp, = moe_cuda.local_gather(local_grad_inp_buf, pos)

Rick Ho's avatar
Rick Ho committed
83
        return grad_inp, None, grad_weight, None
Rick Ho's avatar
Rick Ho committed
84
85
86
87


def moe(inp, gate, weight, world_size):
    if world_size is not None:
Rick Ho's avatar
Rick Ho committed
88
        return MOEGlobal.apply(inp, gate, weight, world_size)
Rick Ho's avatar
Rick Ho committed
89
90
    else:
        return MOELocal.apply(inp, gate, weight)