schedule.py 3.44 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
r"""
The smart schedule proposed in FasterMoE.
"""
import torch
from torch.autograd.function import Function

from fmoe.functions import prepare_forward, ensure_comm
from fmoe.functions import _local_scatter, _local_gather 
import fmoe_cuda as fmoe_native


class MoEForward(Function):
    @staticmethod
    def forward(
            ctx,
            expert_fn,
            inp, # models,
            pos_s, pos_g,
            local_expert_count, global_expert_count,
            stored_models,
            fwd_batch_size, out_batch_size,
            world_size):
        local_input_buf = _local_scatter(inp, pos_s)

        # TODO: leave this for furture work of expert shadowing
        # model_params = [[tuple(m.parameters()) for m in node] for node in models]

        ctx.gibs = [None] * world_size
        ctx.gobs = [None] * world_size
        def _expert_forward(x, y, idx):
            x = x.data
            x.requires_grad = True
            y0 = expert_fn(x, [x.shape[0]])
            ctx.gibs[idx] = x
            ctx.gobs[idx] = y0
            y.copy_(y0)

        local_output_buf = fmoe_native.smart_sch_forward(
                local_input_buf,
                local_expert_count, global_expert_count, 
                stored_models, fwd_batch_size,
                world_size, _expert_forward)

        out = _local_gather(local_output_buf, pos_g, out_batch_size,
                maybe_overlap=False)
        
        variables = (pos_s, pos_g, local_expert_count, global_expert_count,
                stored_models)
        
        ctx.moe_args = fwd_batch_size, inp.shape[0], world_size
        ctx.save_for_backward(*variables)

        return out

    @staticmethod
    def backward(ctx, grad_out):
        (pos_s, pos_g, local_expert_count, global_expert_count,
                gib, gmb, gob, stored_models) = ctx.saved_tensors
        (fwd_batch_size, inp_batch_size, world_size) = ctx.moe_args

        def _expert_backward(grad, idx):
            y = ctx.gobs[idx]
            torch.autograd.backward([y], [grad])
            x = ctx.gibs[idx]
            return x.grad

        grad_out_buf = _local_scatter(grad_out.contiguous(), pos_g)
        grad_in_buf = fmoe_native.smart_sch_backward(
                gib, gmb, gob, grad_out_buf,
                local_expert_count, global_expert_count,
                stored_models,
                fwd_batch_size, pos_s.shape[0],
                world_size, _expert_backward)
        grad_in = _local_gather(grad_in_buf, pos_s, inp_batch_size)

        return (None, grad_in, None, None, None, None, None, None, None, None)


def _fmoe_general_global_forward(inp, gate, expert_fn, n_expert, world_size):
    # TODO: Using multiple tensors as input is to be supported.
    assert(isinstance(inp, torch.Tensor))
    # TODO: Support many experts on each process
    assert(n_expert == 1)
    (
        pos,
        local_expert_count,
        global_expert_count,
        fwd_expert_count,
        fwd_batch_size,
    ) = prepare_forward(gate, n_expert, world_size)

    # TODO: Expert shadowing is to be supported. Currently using all 0s
    stored_models = torch.zeros(n_expert * world_size, dtype=torch.bool)

    topk = 1
    if len(gate.shape) == 2:
        topk = gate.shape[1]
    out_batch_size = inp.shape[0] * topk

    return MoEForward.apply(expert_fn, inp,
            torch.div(pos, topk, rounding_mode='floor'), pos,
            local_expert_count, global_expert_count, stored_models,
            fwd_batch_size, out_batch_size, world_size)