schedule.py 5.72 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
3
4
5
6
7
8
9
r"""
The smart schedule proposed in FasterMoE.
"""
import torch
from torch.autograd.function import Function

from fmoe.functions import prepare_forward, ensure_comm
from fmoe.functions import _local_scatter, _local_gather 
import fmoe_cuda as fmoe_native
Rick Ho's avatar
Rick Ho committed
10
from fmoe.fastermoe import expert_utils
Rick Ho's avatar
Rick Ho committed
11

Rick Ho's avatar
Rick Ho committed
12
13
from .shadow_policy import get_shadow_policy

Rick Ho's avatar
Rick Ho committed
14
15
16
17
18
19

class MoEForward(Function):
    @staticmethod
    def forward(
            ctx,
            expert_fn,
Rick Ho's avatar
Rick Ho committed
20
            experts,
Rick Ho's avatar
Rick Ho committed
21
22
23
24
25
            inp, # models,
            pos_s, pos_g,
            local_expert_count, global_expert_count,
            stored_models,
            fwd_batch_size, out_batch_size,
26
            num_expert,
Rick Ho's avatar
Rick Ho committed
27
28
29
            world_size):
        local_input_buf = _local_scatter(inp, pos_s)

30
31
32
        ctx.gibs = [None] * (world_size * num_expert * 2)
        ctx.gobs = [None] * (world_size * num_expert * 2)
        def _expert_forward(x, y, expert_idx, store_idx):
Rick Ho's avatar
Rick Ho committed
33
            nothing = lambda a: a
Rick Ho's avatar
Rick Ho committed
34
            x = x.data
Rick Ho's avatar
Rick Ho committed
35
36
            with torch.enable_grad():
                x.requires_grad = True
Rick Ho's avatar
Rick Ho committed
37
38
39
40
41
42
43
                try:
                    # To skip torch autograd's version check.
                    with torch.autograd.graph.saved_tensors_hooks(nothing, nothing):
                        y0 = expert_fn(x, torch.tensor([x.shape[0]], dtype=torch.int64))
                except Exception as e:
                    # Ignore the error and fall back for compatibility to older
                    # versions of PyTorch
44
45
46
                    y0 = expert_fn(x, torch.tensor([x.shape[0]], dtype=torch.int64), expert_idx)
            ctx.gibs[store_idx] = x
            ctx.gobs[store_idx] = y0
Rick Ho's avatar
Rick Ho committed
47
48
            y.copy_(y0)

Rick Ho's avatar
Rick Ho committed
49
50
        ctx.experts = experts
        if stored_models.any():
51
52
53
            ctx.expert_size = expert_utils.get_expert_param_size(experts, 0)
            for i in range(num_expert):
                assert ctx.expert_size == expert_utils.get_expert_param_size(experts, i), "report bug"            
Rick Ho's avatar
Rick Ho committed
54
55
        else:
            ctx.expert_size = 0
56
57
58
59
60
61
        get_param_fn = lambda out, idx: expert_utils.get_expert_params(experts, out, idx)
        pop_fn = lambda idx: expert_utils.pop_expert_params(experts, idx)
        ctx.shadows = [None] * world_size * num_expert
        def stash_fn(params, store_idx, expert_idx):
            expert_utils.stash_expert_params(experts, params, expert_idx)
            ctx.shadows[store_idx] = params
Rick Ho's avatar
Rick Ho committed
62

Rick Ho's avatar
Rick Ho committed
63
        local_output_buf, gib = fmoe_native.smart_sch_forward(
Rick Ho's avatar
Rick Ho committed
64
65
                local_input_buf,
                local_expert_count, global_expert_count, 
Rick Ho's avatar
Rick Ho committed
66
67
                stored_models, fwd_batch_size, ctx.expert_size,
                world_size, _expert_forward, get_param_fn, stash_fn, pop_fn)
Rick Ho's avatar
Rick Ho committed
68
69
70
71

        out = _local_gather(local_output_buf, pos_g, out_batch_size,
                maybe_overlap=False)
        
Rick Ho's avatar
Rick Ho committed
72
73
        # gib and local_input_buf are necessary, because ctx.gibs are created
        # based on their memory
Rick Ho's avatar
Rick Ho committed
74
        variables = (pos_s, pos_g, local_expert_count, global_expert_count,
Rick Ho's avatar
Rick Ho committed
75
                stored_models, gib, local_input_buf)
Rick Ho's avatar
Rick Ho committed
76
        
77
        ctx.moe_args = fwd_batch_size, inp.shape[0], num_expert, world_size
Rick Ho's avatar
Rick Ho committed
78
79
80
81
82
83
84
        ctx.save_for_backward(*variables)

        return out

    @staticmethod
    def backward(ctx, grad_out):
        (pos_s, pos_g, local_expert_count, global_expert_count,
Rick Ho's avatar
Rick Ho committed
85
                stored_models, _1, _2) = ctx.saved_tensors
86
        (fwd_batch_size, inp_batch_size, num_expert, world_size) = ctx.moe_args
Rick Ho's avatar
Rick Ho committed
87

88
89
90
        def _expert_backward(grad_y, grad_x, expert_idx, store_idx):
            y = ctx.gobs[store_idx]
            x = ctx.gibs[store_idx]
Rick Ho's avatar
Rick Ho committed
91
            torch.autograd.backward([y], [grad_y])
Rick Ho's avatar
Rick Ho committed
92
            grad_x.copy_(x.grad)
Rick Ho's avatar
Rick Ho committed
93

Rick Ho's avatar
Rick Ho committed
94
        experts = ctx.experts
95
96
97
98
99
100
        def stash_fn(store_idx, expert_idx):
            expert_utils.stash_expert_params(experts, ctx.shadows[store_idx], expert_idx)
        pop_fn = lambda idx: expert_utils.pop_expert_params(experts, idx)
        def collect_fn(store_idx, root, expert_idx): 
            grad = ctx.shadows[store_idx]
            expert_utils.collect_expert_grads(experts, grad, expert_idx)
Rick Ho's avatar
Rick Ho committed
101
            fmoe_native.reduce_grad(grad, root, ctx.expert_size)
102
        set_grad_fn = lambda store_idx, expert_idx: expert_utils.set_grads(experts, ctx.shadows[store_idx], expert_idx)
Rick Ho's avatar
Rick Ho committed
103

Rick Ho's avatar
Rick Ho committed
104
105
        grad_out_buf = _local_scatter(grad_out.contiguous(), pos_g)
        grad_in_buf = fmoe_native.smart_sch_backward(
Rick Ho's avatar
Rick Ho committed
106
                grad_out_buf,
Rick Ho's avatar
Rick Ho committed
107
108
                local_expert_count, global_expert_count,
                stored_models,
Rick Ho's avatar
Rick Ho committed
109
110
111
                pos_s.shape[0], fwd_batch_size,
                world_size,
                _expert_backward, stash_fn, pop_fn, collect_fn, set_grad_fn)
Rick Ho's avatar
Rick Ho committed
112
113
        grad_in = _local_gather(grad_in_buf, pos_s, inp_batch_size)

114
        return (None, None, grad_in, None, None, None, None, None, None, None, None, None)
Rick Ho's avatar
Rick Ho committed
115
116


Rick Ho's avatar
Rick Ho committed
117
118
119
policy_fn = None


Rick Ho's avatar
Rick Ho committed
120
def _fmoe_general_global_forward(inp, gate, expert_fn, n_expert, world_size, experts=None, stored_models=None):
Rick Ho's avatar
Rick Ho committed
121
122
123
124
125
126
127
128
129
130
    # TODO: Using multiple tensors as input is to be supported.
    assert(isinstance(inp, torch.Tensor))
    (
        pos,
        local_expert_count,
        global_expert_count,
        fwd_expert_count,
        fwd_batch_size,
    ) = prepare_forward(gate, n_expert, world_size)

Rick Ho's avatar
Rick Ho committed
131
132
133
134
    global policy_fn
    if policy_fn is None:
        policy_fn = get_shadow_policy(d_model=inp.shape[-1])

Rick Ho's avatar
Rick Ho committed
135
    if stored_models is None:
Rick Ho's avatar
Rick Ho committed
136
137
        stored_models = policy_fn(local_expert_count, global_expert_count,
                n_expert, world_size)
Rick Ho's avatar
Rick Ho committed
138
139
140
141
142
143

    topk = 1
    if len(gate.shape) == 2:
        topk = gate.shape[1]
    out_batch_size = inp.shape[0] * topk

Rick Ho's avatar
Rick Ho committed
144
    return MoEForward.apply(expert_fn, experts, inp,
Rick Ho's avatar
Rick Ho committed
145
146
            torch.div(pos, topk, rounding_mode='floor'), pos,
            local_expert_count, global_expert_count, stored_models,
147
            fwd_batch_size, out_batch_size, n_expert, world_size)