moe_function.py 4.04 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
import torch
from torch.autograd import Function
Rick Ho's avatar
Rick Ho committed
3
import fmoe_cuda
Rick Ho's avatar
Rick Ho committed
4
5
6
7
8


class MOELocal(Function):
    @staticmethod
    def forward(ctx, inp, gate, weight):
9
10
11
12
13
14
15
16
        _, pos = torch.sort(gate)
        gate_idx, gate_count = torch.unique(gate, return_counts=True)
        expert_count = torch.zeros(weight.shape[0], device=weight.device, 
                dtype=torch.long)
        expert_count.index_put_((gate_idx.long(), ), gate_count)

        # expert_count, pos = fmoe_cuda.expert_count(gate, weight.shape[0])
        ecc = expert_count.cpu()
17
        input_buf, = fmoe_cuda.local_gather(inp, pos)
18
        output_buf, = fmoe_cuda.forward(input_buf, weight, ecc)
Rick Ho's avatar
Rick Ho committed
19
        output = fmoe_cuda.local_gather(output_buf, pos)
Rick Ho's avatar
Rick Ho committed
20

21
        variables = [input_buf, gate, weight, ecc, pos]
Rick Ho's avatar
Rick Ho committed
22
23
24
25
26
27
28
29
        ctx.save_for_backward(*variables)

        return output[0]

    @staticmethod
    def backward(ctx, grad_out):
        input_buf, gate, weight, expert_count, pos = ctx.saved_tensors

Rick Ho's avatar
Rick Ho committed
30
31
        grad_out_buf, = fmoe_cuda.local_scatter(grad_out.contiguous(), pos)
        grad_inp_buf, grad_weight = fmoe_cuda.backward(
Rick Ho's avatar
Rick Ho committed
32
                grad_out_buf, input_buf, weight, expert_count)
Rick Ho's avatar
Rick Ho committed
33
        grad_inp, = fmoe_cuda.local_gather(grad_inp_buf, pos)
Rick Ho's avatar
Rick Ho committed
34
35
36
37
38
39
40

        return grad_inp, None, grad_weight


class MOEGlobal(Function):
    @staticmethod
    def forward(ctx, inp, gate, weight, world_size):
41
42
        fmoe_cuda.ensure_nccl(
            torch.distributed.distributed_c10d._default_pg, inp)
Rick Ho's avatar
Rick Ho committed
43
44
        num_expert = weight.shape[0]

45
46
47
48
49
50
51
52
53
        # local_expert_count, pos = fmoe_cuda.expert_count(gate, 
                # world_size * num_expert)
        _, pos = torch.sort(gate)
        gate_idx, gate_count = torch.unique(gate, return_counts=True)
        local_expert_count = torch.zeros(weight.shape[0] * world_size, 
                device=weight.device, dtype=torch.long)
        local_expert_count.index_put_((gate_idx.long(), ), gate_count)

        global_expert_count, = fmoe_cuda.expert_exchange(
Rick Ho's avatar
Rick Ho committed
54
                local_expert_count, num_expert, world_size)
55
56
        fwd_expert_count = global_expert_count.view(world_size, 
               num_expert).sum(dim=0).cpu()
57

Rick Ho's avatar
Rick Ho committed
58
        fwd_batch_size = int(fwd_expert_count.sum().item())
Rick Ho's avatar
Rick Ho committed
59

60
        local_input_buf, = fmoe_cuda.local_gather(inp, pos)
Rick Ho's avatar
Rick Ho committed
61

62
63
        local_expert_count = local_expert_count.cpu()
        global_expert_count = global_expert_count.cpu()
Rick Ho's avatar
Rick Ho committed
64
        local_output_buf, global_input_buf = fmoe_cuda.global_fused_forward(
Rick Ho's avatar
Rick Ho committed
65
                local_input_buf, weight,
Rick Ho's avatar
Rick Ho committed
66
                local_expert_count, global_expert_count,
Rick Ho's avatar
Rick Ho committed
67
68
                fwd_batch_size, inp.shape[0], world_size)

69
        output, = fmoe_cuda.local_scatter(local_output_buf, pos)
Rick Ho's avatar
Rick Ho committed
70

Rick Ho's avatar
Rick Ho committed
71
72
73
74
        variables = (global_input_buf, gate, weight, 
                local_expert_count, global_expert_count, fwd_expert_count,
                pos)
        ctx.moe_args = (num_expert, inp.shape[0], fwd_batch_size, world_size)
Rick Ho's avatar
Rick Ho committed
75
76
        ctx.save_for_backward(*variables)

Rick Ho's avatar
Rick Ho committed
77
        return output
Rick Ho's avatar
Rick Ho committed
78
79
80

    @staticmethod
    def backward(ctx, grad_out):
Rick Ho's avatar
Rick Ho committed
81
82
83
84
        (input_buf, gate, weight, 
                local_expert_count, global_expert_count, fwd_expert_count, 
                pos) = ctx.saved_tensors
        num_expert, local_batch_size, fwd_batch_size, world_size = ctx.moe_args
Rick Ho's avatar
Rick Ho committed
85

Rick Ho's avatar
Rick Ho committed
86
87
        grad_out_buf, = fmoe_cuda.local_scatter(grad_out.contiguous(), pos)
        global_grad_out_buf, = fmoe_cuda.global_scatter(grad_out_buf,
Rick Ho's avatar
Rick Ho committed
88
                local_expert_count, global_expert_count,
Rick Ho's avatar
Rick Ho committed
89
                fwd_batch_size, world_size)
Rick Ho's avatar
Rick Ho committed
90

Rick Ho's avatar
Rick Ho committed
91
        grad_inp_buf, grad_weight = fmoe_cuda.backward(
Rick Ho's avatar
Rick Ho committed
92
                global_grad_out_buf, input_buf, weight, fwd_expert_count)
Rick Ho's avatar
Rick Ho committed
93

Rick Ho's avatar
Rick Ho committed
94
        local_grad_inp_buf, = fmoe_cuda.global_gather(grad_inp_buf,
Rick Ho's avatar
Rick Ho committed
95
                local_expert_count, global_expert_count,
Rick Ho's avatar
Rick Ho committed
96
                local_batch_size, world_size)
Rick Ho's avatar
Rick Ho committed
97
        grad_inp, = fmoe_cuda.local_gather(local_grad_inp_buf, pos)
Rick Ho's avatar
Rick Ho committed
98

Rick Ho's avatar
Rick Ho committed
99
        return grad_inp, None, grad_weight, None
Rick Ho's avatar
Rick Ho committed
100
101
102


def moe(inp, gate, weight, world_size):
Rick Ho's avatar
Rick Ho committed
103
    if world_size is not None and world_size > 1:
Rick Ho's avatar
Rick Ho committed
104
        return MOEGlobal.apply(inp, gate, weight, world_size)
Rick Ho's avatar
Rick Ho committed
105
106
    else:
        return MOELocal.apply(inp, gate, weight)