schedule.py 3.49 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
r"""
The smart schedule proposed in FasterMoE.
"""
import torch
from torch.autograd.function import Function

from fmoe.functions import prepare_forward, ensure_comm
from fmoe.functions import _local_scatter, _local_gather 
import fmoe_cuda as fmoe_native


class MoEForward(Function):
    @staticmethod
    def forward(
            ctx,
            expert_fn,
            inp, # models,
            pos_s, pos_g,
            local_expert_count, global_expert_count,
            stored_models,
            fwd_batch_size, out_batch_size,
            world_size):
        local_input_buf = _local_scatter(inp, pos_s)

        # TODO: leave this for furture work of expert shadowing
        # model_params = [[tuple(m.parameters()) for m in node] for node in models]

        ctx.gibs = [None] * world_size
        ctx.gobs = [None] * world_size
        def _expert_forward(x, y, idx):
            x = x.data
Rick Ho's avatar
Rick Ho committed
32
33
34
            with torch.enable_grad():
                x.requires_grad = True
                y0 = expert_fn(x, [x.shape[0]])
Rick Ho's avatar
Rick Ho committed
35
36
37
38
            ctx.gibs[idx] = x
            ctx.gobs[idx] = y0
            y.copy_(y0)

Rick Ho's avatar
Rick Ho committed
39
        local_output_buf, gib = fmoe_native.smart_sch_forward(
Rick Ho's avatar
Rick Ho committed
40
41
42
43
44
45
46
47
48
                local_input_buf,
                local_expert_count, global_expert_count, 
                stored_models, fwd_batch_size,
                world_size, _expert_forward)

        out = _local_gather(local_output_buf, pos_g, out_batch_size,
                maybe_overlap=False)
        
        variables = (pos_s, pos_g, local_expert_count, global_expert_count,
Rick Ho's avatar
Rick Ho committed
49
                stored_models, gib)
Rick Ho's avatar
Rick Ho committed
50
51
52
53
54
55
56
57
58
        
        ctx.moe_args = fwd_batch_size, inp.shape[0], world_size
        ctx.save_for_backward(*variables)

        return out

    @staticmethod
    def backward(ctx, grad_out):
        (pos_s, pos_g, local_expert_count, global_expert_count,
Rick Ho's avatar
Rick Ho committed
59
                stored_models, _) = ctx.saved_tensors
Rick Ho's avatar
Rick Ho committed
60
61
        (fwd_batch_size, inp_batch_size, world_size) = ctx.moe_args

Rick Ho's avatar
Rick Ho committed
62
        def _expert_backward(grad_y, grad_x, idx):
Rick Ho's avatar
Rick Ho committed
63
            y = ctx.gobs[idx]
Rick Ho's avatar
Rick Ho committed
64
            torch.autograd.backward([y], [grad_y])
Rick Ho's avatar
Rick Ho committed
65
            x = ctx.gibs[idx]
Rick Ho's avatar
Rick Ho committed
66
            grad_x.copy_(x.grad)
Rick Ho's avatar
Rick Ho committed
67
68
69

        grad_out_buf = _local_scatter(grad_out.contiguous(), pos_g)
        grad_in_buf = fmoe_native.smart_sch_backward(
Rick Ho's avatar
Rick Ho committed
70
                grad_out_buf,
Rick Ho's avatar
Rick Ho committed
71
72
                local_expert_count, global_expert_count,
                stored_models,
Rick Ho's avatar
Rick Ho committed
73
                pos_s.shape[0], fwd_batch_size,
Rick Ho's avatar
Rick Ho committed
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
                world_size, _expert_backward)
        grad_in = _local_gather(grad_in_buf, pos_s, inp_batch_size)

        return (None, grad_in, None, None, None, None, None, None, None, None)


def _fmoe_general_global_forward(inp, gate, expert_fn, n_expert, world_size):
    # TODO: Using multiple tensors as input is to be supported.
    assert(isinstance(inp, torch.Tensor))
    # TODO: Support many experts on each process
    assert(n_expert == 1)
    (
        pos,
        local_expert_count,
        global_expert_count,
        fwd_expert_count,
        fwd_batch_size,
    ) = prepare_forward(gate, n_expert, world_size)

    # TODO: Expert shadowing is to be supported. Currently using all 0s
    stored_models = torch.zeros(n_expert * world_size, dtype=torch.bool)

    topk = 1
    if len(gate.shape) == 2:
        topk = gate.shape[1]
    out_batch_size = inp.shape[0] * topk

    return MoEForward.apply(expert_fn, inp,
            torch.div(pos, topk, rounding_mode='floor'), pos,
            local_expert_count, global_expert_count, stored_models,
            fwd_batch_size, out_batch_size, world_size)