moe_function.py 3.14 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import torch
from torch.autograd import Function
import moe_cuda


class MOELocal(Function):
    @staticmethod
    def forward(ctx, inp, gate, weight):
        expert_count, pos = moe_cuda.expert_count(gate, weight.shape[0])
        input_buf, = moe_cuda.local_scatter(inp, pos)
        output_buf, = moe_cuda.forward(input_buf, weight, expert_count)
        output = moe_cuda.local_gather(output_buf, pos)

        variables = [input_buf, gate, weight, expert_count, pos]
        ctx.save_for_backward(*variables)

        return output[0]

    @staticmethod
    def backward(ctx, grad_out):
        input_buf, gate, weight, expert_count, pos = ctx.saved_tensors

        grad_out_buf, = moe_cuda.local_scatter(grad_out.contiguous(), pos)
        grad_inp_buf, grad_weight = moe_cuda.backward(
                grad_out_buf, input_buf, weight, expert_count)
        grad_inp, = moe_cuda.local_gather(grad_inp_buf, pos)

        return grad_inp, None, grad_weight


class MOEGlobal(Function):
    @staticmethod
    def forward(ctx, inp, gate, weight, world_size):
        num_expert = weight.shape[0]

        local_expert_count, pos = moe_cuda.expert_count(gate, 
                world_size * num_expert)
Rick Ho's avatar
Rick Ho committed
38
39
40
        global_expert_count, fwd_expert_count = moe_cuda.expert_exchange(
                local_expert_count, num_expert, world_size)
        fwd_batch_size = int(fwd_expert_count.sum().item())
Rick Ho's avatar
Rick Ho committed
41
42
43

        local_input_buf, = moe_cuda.local_scatter(inp, pos)

Rick Ho's avatar
Rick Ho committed
44
45
        local_output_buf, global_input_buf = moe_cuda.global_fused_forward(
                local_input_buf, weight,
Rick Ho's avatar
Rick Ho committed
46
                local_expert_count, global_expert_count,
Rick Ho's avatar
Rick Ho committed
47
48
                fwd_batch_size, inp.shape[0], world_size)

Rick Ho's avatar
Rick Ho committed
49
        output, = moe_cuda.local_gather(local_output_buf, pos)
Rick Ho's avatar
Rick Ho committed
50

Rick Ho's avatar
Rick Ho committed
51
52
53
54
        variables = (global_input_buf, gate, weight, 
                local_expert_count, global_expert_count, fwd_expert_count,
                pos)
        ctx.moe_args = (num_expert, inp.shape[0], fwd_batch_size, world_size)
Rick Ho's avatar
Rick Ho committed
55
56
        ctx.save_for_backward(*variables)

Rick Ho's avatar
Rick Ho committed
57
        return output
Rick Ho's avatar
Rick Ho committed
58
59
60

    @staticmethod
    def backward(ctx, grad_out):
Rick Ho's avatar
Rick Ho committed
61
62
63
64
        (input_buf, gate, weight, 
                local_expert_count, global_expert_count, fwd_expert_count, 
                pos) = ctx.saved_tensors
        num_expert, local_batch_size, fwd_batch_size, world_size = ctx.moe_args
Rick Ho's avatar
Rick Ho committed
65
66
67
68

        grad_out_buf, = moe_cuda.local_scatter(grad_out.contiguous(), pos)
        global_grad_out_buf, = moe_cuda.global_scatter(grad_out_buf,
                local_expert_count, global_expert_count,
Rick Ho's avatar
Rick Ho committed
69
                fwd_batch_size, world_size)
Rick Ho's avatar
Rick Ho committed
70
71

        grad_inp_buf, grad_weight = moe_cuda.backward(
Rick Ho's avatar
Rick Ho committed
72
                global_grad_out_buf, input_buf, weight, fwd_expert_count)
Rick Ho's avatar
Rick Ho committed
73

Rick Ho's avatar
Rick Ho committed
74
        local_grad_inp_buf, = moe_cuda.global_gather(grad_inp_buf,
Rick Ho's avatar
Rick Ho committed
75
                local_expert_count, global_expert_count,
Rick Ho's avatar
Rick Ho committed
76
                local_batch_size, world_size)
Rick Ho's avatar
Rick Ho committed
77
78
        grad_inp, = moe_cuda.local_gather(local_grad_inp_buf, pos)

Rick Ho's avatar
Rick Ho committed
79
        return grad_inp, None, grad_weight, None
Rick Ho's avatar
Rick Ho committed
80
81
82


def moe(inp, gate, weight, world_size):
Rick Ho's avatar
Rick Ho committed
83
    if world_size is not None and world_size > 1:
Rick Ho's avatar
Rick Ho committed
84
        return MOEGlobal.apply(inp, gate, weight, world_size)
Rick Ho's avatar
Rick Ho committed
85
86
    else:
        return MOELocal.apply(inp, gate, weight)