functions.py 4.27 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
4
5
import torch
from torch.autograd import Function
import fmoe_cuda


Rick Ho's avatar
Rick Ho committed
6
7
8
9
10
def moe_prepare_forward(gate, num_expert, world_size, comm=None):
    if comm is None:
        comm = torch.distributed.distributed_c10d._default_pg
    if world_size > 1:
        fmoe_cuda.ensure_nccl(comm, gate)
Rick Ho's avatar
Rick Ho committed
11
12
13
14

    with torch.no_grad():
        _, pos = torch.sort(gate)
        gate_idx, gate_count = torch.unique(gate, return_counts=True)
Rick Ho's avatar
Rick Ho committed
15
16
        local_expert_count = torch.zeros(num_expert * world_size, 
                device=gate.device, dtype=torch.long)
Rick Ho's avatar
Rick Ho committed
17
18
        local_expert_count.index_put_((gate_idx.long(), ), gate_count)

19
20
21
22
23
        if world_size > 1:
            global_expert_count, = fmoe_cuda.expert_exchange(
                    local_expert_count, num_expert, world_size)
        else:
            global_expert_count = local_expert_count
Rick Ho's avatar
Rick Ho committed
24
        fwd_expert_count = global_expert_count.view(world_size, 
25
               num_expert).sum(dim=0)
Rick Ho's avatar
Rick Ho committed
26
27
28
29
30
31
32
33
34
        fwd_batch_size = int(fwd_expert_count.sum().item())
    return (pos, local_expert_count.cpu(), global_expert_count.cpu(), 
            fwd_expert_count.cpu(), fwd_batch_size)


class MOEScatter(Function):
    @staticmethod
    def forward(ctx, inp, pos, local_expert_count, global_expert_count,
            fwd_batch_size, world_size):
35
        local_input_buf, = fmoe_cuda.local_scatter(inp, pos)
Rick Ho's avatar
Rick Ho committed
36
        if world_size > 1:
Rick Ho's avatar
Rick Ho committed
37
            global_input_buf, = fmoe_cuda.global_scatter(local_input_buf, 
Rick Ho's avatar
Rick Ho committed
38
39
40
41
42
43
                    local_expert_count, global_expert_count,
                    fwd_batch_size, world_size)
        else:
            global_input_buf = local_input_buf
        ctx.moe_args = fwd_batch_size, inp.shape[0], world_size
        variables = (pos, local_expert_count, global_expert_count)
44
        ctx.save_for_backward(*variables)
Rick Ho's avatar
Rick Ho committed
45
46
47
48
49
50
51
52
        return global_input_buf

    @staticmethod
    def backward(ctx, global_grad_in):
        (pos, local_expert_count, global_expert_count) = ctx.saved_tensors
        (fwd_batch_size, local_batch_size, world_size) = ctx.moe_args

        if world_size > 1:
53
            local_grad_in, = fmoe_cuda.global_gather(global_grad_in,
Rick Ho's avatar
Rick Ho committed
54
55
56
57
                    local_expert_count, global_expert_count,
                    local_batch_size, world_size)
        else:
            local_grad_in = global_grad_in
58
        grad_in, = fmoe_cuda.local_gather(local_grad_in, pos)
Rick Ho's avatar
Rick Ho committed
59
60
61
62
63
64
        return grad_in, None, None, None, None, None


class MOELinear(Function):
    @staticmethod
    def forward(ctx, global_input_buf, weight, fwd_expert_count):
Rick Ho's avatar
Rick Ho committed
65
        global_output_buf, = fmoe_cuda.forward(global_input_buf, weight,
Rick Ho's avatar
Rick Ho committed
66
                fwd_expert_count)
67
        variables = (global_input_buf, weight, fwd_expert_count)
Rick Ho's avatar
Rick Ho committed
68
69
70
71
72
73
        ctx.save_for_backward(*variables)
        return global_output_buf

    @staticmethod
    def backward(ctx, grad_out):
        (input_buf, weight, fwd_expert_count) = ctx.saved_tensors
74
        grad_inp_buf, grad_weight = fmoe_cuda.backward(
Rick Ho's avatar
Rick Ho committed
75
76
77
78
79
80
81
82
83
                grad_out, input_buf, weight, fwd_expert_count)
        return grad_inp_buf, grad_weight, None


class MOEGather(Function):
    @staticmethod
    def forward(ctx, global_output_buf, pos, local_expert_count, 
            global_expert_count, local_batch_size, world_size):
        if world_size > 1:
Rick Ho's avatar
Rick Ho committed
84
            local_output_buf, = fmoe_cuda.global_gather(global_output_buf, 
Rick Ho's avatar
Rick Ho committed
85
86
87
88
                    local_expert_count, global_expert_count, 
                    local_batch_size, world_size)
        else:
            local_output_buf = global_output_buf
89
        output, = fmoe_cuda.local_gather(local_output_buf, pos)
Rick Ho's avatar
Rick Ho committed
90

91
        ctx.moe_args = local_batch_size, global_output_buf.shape[0], world_size
Rick Ho's avatar
Rick Ho committed
92
93
94
95
96
97
98
        variables = (pos, local_expert_count, global_expert_count)
        ctx.save_for_backward(*variables)
        return output

    @staticmethod
    def backward(ctx, grad_out):
        pos, local_expert_count, global_expert_count = ctx.saved_tensors
99
        local_batch_size, fwd_batch_size, world_size = ctx.moe_args
100
        grad_out_buf, = fmoe_cuda.local_scatter(grad_out.contiguous(), pos)
Rick Ho's avatar
Rick Ho committed
101
        if world_size > 1:
Rick Ho's avatar
Rick Ho committed
102
            global_grad_out_buf, = fmoe_cuda.global_scatter(grad_out_buf,
Rick Ho's avatar
Rick Ho committed
103
104
105
106
107
108
109
                    local_expert_count, global_expert_count,
                    fwd_batch_size, world_size)
        else:
            global_grad_out_buf = grad_out_buf
        return global_grad_out_buf, None, None, None, None, None