moe_function.py 3.44 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
import torch
from torch.autograd import Function
Rick Ho's avatar
Rick Ho committed
3
import fmoe_cuda
Rick Ho's avatar
Rick Ho committed
4
5
6
7
8


class MOELocal(Function):
    @staticmethod
    def forward(ctx, inp, gate, weight):
9
10
11
12
13
14
15
16
        _, pos = torch.sort(gate)
        gate_idx, gate_count = torch.unique(gate, return_counts=True)
        expert_count = torch.zeros(weight.shape[0], device=weight.device, 
                dtype=torch.long)
        expert_count.index_put_((gate_idx.long(), ), gate_count)

        # expert_count, pos = fmoe_cuda.expert_count(gate, weight.shape[0])
        ecc = expert_count.cpu()
Rick Ho's avatar
Rick Ho committed
17
        input_buf, = fmoe_cuda.local_scatter(inp, pos)
18
        output_buf, = fmoe_cuda.forward(input_buf, weight, ecc)
Rick Ho's avatar
Rick Ho committed
19
        output = fmoe_cuda.local_gather(output_buf, pos)
Rick Ho's avatar
Rick Ho committed
20

21
        variables = [input_buf, gate, weight, ecc, pos]
Rick Ho's avatar
Rick Ho committed
22
23
24
25
26
27
28
29
        ctx.save_for_backward(*variables)

        return output[0]

    @staticmethod
    def backward(ctx, grad_out):
        input_buf, gate, weight, expert_count, pos = ctx.saved_tensors

Rick Ho's avatar
Rick Ho committed
30
31
        grad_out_buf, = fmoe_cuda.local_scatter(grad_out.contiguous(), pos)
        grad_inp_buf, grad_weight = fmoe_cuda.backward(
Rick Ho's avatar
Rick Ho committed
32
                grad_out_buf, input_buf, weight, expert_count)
Rick Ho's avatar
Rick Ho committed
33
        grad_inp, = fmoe_cuda.local_gather(grad_inp_buf, pos)
Rick Ho's avatar
Rick Ho committed
34
35
36
37
38
39
40
41
42

        return grad_inp, None, grad_weight


class MOEGlobal(Function):
    @staticmethod
    def forward(ctx, inp, gate, weight, world_size):
        num_expert = weight.shape[0]

Rick Ho's avatar
Rick Ho committed
43
        local_expert_count, pos = fmoe_cuda.expert_count(gate, 
Rick Ho's avatar
Rick Ho committed
44
                world_size * num_expert)
Rick Ho's avatar
Rick Ho committed
45
        global_expert_count, fwd_expert_count = fmoe_cuda.expert_exchange(
Rick Ho's avatar
Rick Ho committed
46
47
                local_expert_count, num_expert, world_size)
        fwd_batch_size = int(fwd_expert_count.sum().item())
Rick Ho's avatar
Rick Ho committed
48

Rick Ho's avatar
Rick Ho committed
49
        local_input_buf, = fmoe_cuda.local_scatter(inp, pos)
Rick Ho's avatar
Rick Ho committed
50

Rick Ho's avatar
Rick Ho committed
51
        local_output_buf, global_input_buf = fmoe_cuda.global_fused_forward(
Rick Ho's avatar
Rick Ho committed
52
                local_input_buf, weight,
Rick Ho's avatar
Rick Ho committed
53
                local_expert_count, global_expert_count,
Rick Ho's avatar
Rick Ho committed
54
55
                fwd_batch_size, inp.shape[0], world_size)

Rick Ho's avatar
Rick Ho committed
56
        output, = fmoe_cuda.local_gather(local_output_buf, pos)
Rick Ho's avatar
Rick Ho committed
57

Rick Ho's avatar
Rick Ho committed
58
59
60
61
        variables = (global_input_buf, gate, weight, 
                local_expert_count, global_expert_count, fwd_expert_count,
                pos)
        ctx.moe_args = (num_expert, inp.shape[0], fwd_batch_size, world_size)
Rick Ho's avatar
Rick Ho committed
62
63
        ctx.save_for_backward(*variables)

Rick Ho's avatar
Rick Ho committed
64
        return output
Rick Ho's avatar
Rick Ho committed
65
66
67

    @staticmethod
    def backward(ctx, grad_out):
Rick Ho's avatar
Rick Ho committed
68
69
70
71
        (input_buf, gate, weight, 
                local_expert_count, global_expert_count, fwd_expert_count, 
                pos) = ctx.saved_tensors
        num_expert, local_batch_size, fwd_batch_size, world_size = ctx.moe_args
Rick Ho's avatar
Rick Ho committed
72

Rick Ho's avatar
Rick Ho committed
73
74
        grad_out_buf, = fmoe_cuda.local_scatter(grad_out.contiguous(), pos)
        global_grad_out_buf, = fmoe_cuda.global_scatter(grad_out_buf,
Rick Ho's avatar
Rick Ho committed
75
                local_expert_count, global_expert_count,
Rick Ho's avatar
Rick Ho committed
76
                fwd_batch_size, world_size)
Rick Ho's avatar
Rick Ho committed
77

Rick Ho's avatar
Rick Ho committed
78
        grad_inp_buf, grad_weight = fmoe_cuda.backward(
Rick Ho's avatar
Rick Ho committed
79
                global_grad_out_buf, input_buf, weight, fwd_expert_count)
Rick Ho's avatar
Rick Ho committed
80

Rick Ho's avatar
Rick Ho committed
81
        local_grad_inp_buf, = fmoe_cuda.global_gather(grad_inp_buf,
Rick Ho's avatar
Rick Ho committed
82
                local_expert_count, global_expert_count,
Rick Ho's avatar
Rick Ho committed
83
                local_batch_size, world_size)
Rick Ho's avatar
Rick Ho committed
84
        grad_inp, = fmoe_cuda.local_gather(local_grad_inp_buf, pos)
Rick Ho's avatar
Rick Ho committed
85

Rick Ho's avatar
Rick Ho committed
86
        return grad_inp, None, grad_weight, None
Rick Ho's avatar
Rick Ho committed
87
88
89


def moe(inp, gate, weight, world_size):
Rick Ho's avatar
Rick Ho committed
90
    if world_size is not None and world_size > 1:
Rick Ho's avatar
Rick Ho committed
91
        return MOEGlobal.apply(inp, gate, weight, world_size)
Rick Ho's avatar
Rick Ho committed
92
93
    else:
        return MOELocal.apply(inp, gate, weight)