test_mimo.py 4.88 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import sys

import pytest
import torch
import torch.nn as nn
import numpy as np

from fmoe.gates import NaiveGate
from fmoe.layers import FMoE
from fmoe.linear import FMoELinear
from fmoe.megatron.layers import _megatron_init_method


def _assert_numerical(names, moe_out_list, raw_out_list, rank, precision=1e-3):
    for name, mo, ro in zip(names, moe_out_list, raw_out_list):
        err = (mo - ro).abs().max()
        print("Rank {} {} abs err {}".format(rank, name, err))
        if err > precision:
            sys.stderr.write(f"=========== {name} moe out ==============\n")
            sys.stderr.write("{}\n".format(mo))
            sys.stderr.write(f"=========== {name} raw out ==============\n")
            sys.stderr.write("{}\n".format(ro))
            sys.stderr.write(f"=========== {name} diff ==============\n")
            sys.stderr.write("{}\n{}\n".format((mo - ro).abs(), err))
            assert False


class MyExpert(nn.Module):
    r"""
    An expert using 2 FMoELinear modules to speed up the computation of experts
    within one worker.
    """

    def __init__(self, num_expert, d_model, d_hidden, activation, rank=0):
        super().__init__()
        self.htoh4 = FMoELinear(num_expert, d_model, d_hidden, bias=True, rank=rank)
        self.h4toh = FMoELinear(num_expert, d_hidden, d_model, bias=True, rank=rank)
        self.activation = activation

    def forward(self, inp, fwd_expert_count):
        r"""
        First expand input to 4h (the hidden size is variable, but is called h4
        for convenience). Then perform activation. Finally shirink back to h.
        """
        if type(inp) == dict:
            x = inp["x"]
            y = inp["y"]
        elif type(inp) == list:
            x = inp[0]
            y = inp[1]
        else:
            raise NotImplementedError
        x = self.htoh4(x, fwd_expert_count)
        x = self.activation(x)
        x = self.h4toh(x, fwd_expert_count)
        y = self.htoh4(y, fwd_expert_count)
        y = self.activation(y)
        y = self.h4toh(y, fwd_expert_count)
        if type(inp) == dict:
            ret = {"x": x, "y": y}
        elif type(inp) == list:
            ret = [x, y]

        return ret


class MyGate(NaiveGate):
    def __init__(self, d_model, num_expert, world_size, top_k=2):
        super().__init__(d_model, num_expert, world_size, top_k)

    def forward(self, inp, return_all_scores=False):
        if type(inp) == dict:
            x = inp["x"]
        elif type(inp) == list:
            x = inp[0]
        else:
            raise NotImplementedError
        return super().forward(x, return_all_scores)


class MyMoE(FMoE):
    def __init__(
        self, num_expert, d_model, d_hidden, world_size, mp_group, top_k, activation
    ):
        super().__init__(
            num_expert=num_expert,
            d_model=d_model,
            gate=MyGate,
            world_size=world_size,
            mp_group=mp_group,
            top_k=top_k,
        )
        self.experts = MyExpert(num_expert, d_model, d_hidden, activation)

        rng = np.random.default_rng(1234)
        _megatron_init_method(self.experts.htoh4, rng, 1.0)
        _megatron_init_method(self.experts.h4toh, rng, 1.0)


@pytest.mark.parametrize("num_expert", [4, 8])
@pytest.mark.parametrize("top_k", [2, 3])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("d_model", [16])
@pytest.mark.parametrize("d_hidden", [32])
@pytest.mark.parametrize("rank", [0])
@pytest.mark.parametrize("world_size", [1])
@pytest.mark.parametrize("mp_group", [None])
@pytest.mark.parametrize("dp_group", [None])
@pytest.mark.parametrize("world_group", [None])
@pytest.mark.parametrize(
    "data_type", ["torch.FloatTensor", "torch.DoubleTensor", "torch.HalfTensor"]
)
@pytest.mark.parametrize("list_input", [False, True])
def test_fmoe_mimo_linear(
    num_expert,
    top_k,
    batch_size,
    d_model,
    d_hidden,
    rank,
    world_size,
    mp_group,
    dp_group,
    world_group,
    data_type,
    list_input,
    activation=torch.nn.functional.gelu,
):

    torch.manual_seed(42 + rank)
    torch.cuda.manual_seed(42 + rank)

    moe = MyMoE(
        num_expert=num_expert,
        d_model=d_model,
        d_hidden=4 * d_model,
        world_size=world_size,
        mp_group=mp_group,
        top_k=top_k,
        activation=activation,
    ).cuda()

    x = torch.rand(batch_size, d_model).cuda()
    inp = [x, x.clone()] if list_input else {"x": x, "y": x.clone()}
    moe_out = moe(inp)

    if list_input:
        _assert_numerical(["x"], [moe_out[0]], [moe_out[1]], rank)
    else:
        _assert_numerical(["x"], [moe_out["x"]], [moe_out["y"]], rank)


if __name__ == "__main__":
    test_fmoe_mimo_linear(
        batch_size=2,
        num_expert=2,
        d_model=2,
        top_k=2,
        d_hidden=16,
        rank=0,
        world_size=1,
        mp_group=None,
        dp_group=None,
        world_group=None,
        data_type=torch.float32,
166
        list_input=True
167
    )