test_numerical.py 8.73 KB
Newer Older
1
2
3
4
import json
import os
import sys
from typing import List, Callable, Dict, Type, Union
Sengxian's avatar
Sengxian committed
5

6
import pytest
Rick Ho's avatar
Rick Ho committed
7
import torch
8
9
10
11
12
13
import torch.nn as nn

from fmoe.gates import NaiveGate
from fmoe.layers import FMoE
from fmoe.transformer import _Expert
from moe import BruteForceMoELinear, BruteForceMoE, NaiveExpert, LinearExpert
Rick Ho's avatar
Rick Ho committed
14

Sengxian's avatar
Sengxian committed
15
16
rank = 0
world_size = 1
Rick Ho's avatar
Rick Ho committed
17
18


19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
def _perform_forward(moe: nn.Module, moe_raw: nn.Module, batch_size, d_model, top_k):
    moe.zero_grad()
    moe_raw.zero_grad()
    inp = torch.rand(batch_size, d_model).cuda()
    gate_idx, gate_score = moe.gate(inp)
    inp_repeated = inp.repeat_interleave(repeats=top_k, dim=0)
    moe_out = moe(inp).mean()
    raw_out = moe_raw(inp_repeated, gate_idx, gate_score).mean()

    moe_out.backward()
    raw_out.backward()

    return moe_out, raw_out


def _assert_numercial(names, moe_out_list, raw_out_list):
    for name, mo, ro in zip(names, moe_out_list, raw_out_list):
        err = (mo - ro).abs().sum()
        print("Rank {} {} abs err {}".format(rank, name, err))
        if err > 1e-3:
            sys.stderr.write("=========== moe out ==============\n")
            sys.stderr.write("{}\n".format(mo))
            sys.stderr.write("=========== raw out ==============\n")
            sys.stderr.write("{}\n".format(ro))
            assert False


@pytest.mark.parametrize("num_expert", [4, 8])
Sengxian's avatar
Sengxian committed
47
@pytest.mark.parametrize("top_k", [2, 3])
48
49
50
51
52
53
54
55
56
57
58
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("d_model", [16])
@pytest.mark.parametrize("d_hidden", [32])
def test_fmoe_linear(
    num_expert,
    top_k,
    batch_size,
    d_model,
    d_hidden,
    activation=torch.nn.functional.gelu,
):
Rick Ho's avatar
Rick Ho committed
59
60
    torch.manual_seed(42 + rank)
    torch.cuda.manual_seed(42 + rank)
Sengxian's avatar
Sengxian committed
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75

    experts = _Expert(num_expert, d_model, d_hidden, activation).cuda()

    def expert_fn(inp, gate):
        return experts(inp, gate)

    moe = FMoE(
        num_expert=num_expert,
        d_model=d_model,
        gate=NaiveGate,
        world_size=world_size,
        mp_group=None,
        expert_fn=expert_fn,
        top_k=top_k,
    ).cuda()
Rick Ho's avatar
Rick Ho committed
76

Sengxian's avatar
Sengxian committed
77
78
79
80
    moe_raw = BruteForceMoELinear(
        activation=activation,
        num_expert=num_expert,
        d_model=d_model,
81
        d_hidden=d_hidden,
Sengxian's avatar
Sengxian committed
82
        world_size=world_size,
Sengxian's avatar
Sengxian committed
83
        top_k=top_k,
Sengxian's avatar
Sengxian committed
84
    ).cuda()
Rick Ho's avatar
Rick Ho committed
85
86

    if world_size == 1:
Sengxian's avatar
Sengxian committed
87
88
        moe_raw.weight_htoh4.data = experts.htoh4.weight.data.clone()
        moe_raw.weight_h4toh.data = experts.h4toh.weight.data.clone()
Rick Ho's avatar
Rick Ho committed
89
    else:
Sengxian's avatar
Sengxian committed
90
91
92
93
94
95
96
97
98
99
100
101
        weight_htoh4_array = [
            torch.empty_like(experts.htoh4.weight.data) for _ in range(world_size)
        ]
        torch.distributed.all_gather(weight_htoh4_array, experts.htoh4.weight.data)
        moe_raw.weight_htoh4.data = torch.cat(weight_htoh4_array, dim=0)

        weight_h4toh_array = [
            torch.empty_like(experts.h4toh.weight.data) for _ in range(world_size)
        ]
        torch.distributed.all_gather(weight_h4toh_array, experts.h4toh.weight.data)
        moe_raw.weight_h4toh.data = torch.cat(weight_h4toh_array, dim=0)

102
    moe_out, raw_out = _perform_forward(moe, moe_raw, batch_size, d_model, top_k)
Sengxian's avatar
Sengxian committed
103

104
105
    moe_out_list = moe_out, experts.htoh4.weight.grad, experts.h4toh.weight.grad
    raw_out_list = raw_out, moe_raw.weight_htoh4.grad, moe_raw.weight_h4toh.grad
Sengxian's avatar
Sengxian committed
106

Rick Ho's avatar
Rick Ho committed
107
    if world_size > 1:
108
        _, htoh4_grad, h4toh_grad = raw_out_list
Sengxian's avatar
Sengxian committed
109
110
111
112
        torch.distributed.all_reduce(htoh4_grad)
        torch.distributed.all_reduce(h4toh_grad)
        htoh4_grad = htoh4_grad[rank * num_expert : (rank + 1) * num_expert]
        h4toh_grad = h4toh_grad[rank * num_expert : (rank + 1) * num_expert]
113
114
115
116
117
        raw_out_list = _, htoh4_grad, h4toh_grad

    names = ["output", "htoh4 weight grad", "h4toh weight grad"]
    _assert_numercial(names, moe_out_list, raw_out_list)

Sengxian's avatar
Sengxian committed
118

119
120
121
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("num_expert", [4, 8])
@pytest.mark.parametrize("d_model", [16])
Sengxian's avatar
Sengxian committed
122
@pytest.mark.parametrize("top_k", [2, 3])
123
124
125
126
127
128
129
130
131
@pytest.mark.parametrize("expert", ["NaiveExpert", "LinearExpert"])
def test_fmoe(
    batch_size, num_expert, d_model, top_k, expert: Union[Type[nn.Module], str]
):
    torch.manual_seed(42 + rank)
    torch.cuda.manual_seed(42 + rank)

    if isinstance(expert, str):
        expert = globals()[expert]
Sengxian's avatar
Sengxian committed
132

133
134
135
136
137
138
139
140
141
142
143
    moe = FMoE(
        num_expert=num_expert,
        d_model=d_model,
        gate=NaiveGate,
        world_size=world_size,
        mp_group=None,
        expert=expert,
        top_k=top_k,
    ).cuda()

    moe_raw = BruteForceMoE(
Sengxian's avatar
Sengxian committed
144
145
146
147
148
        expert=expert,
        num_expert=num_expert,
        d_model=d_model,
        world_size=world_size,
        top_k=top_k,
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
    ).cuda()

    if world_size == 1:
        for expert_moe, expert_raw in zip(moe.experts, moe_raw.experts):
            for para_moe, para_raw in zip(
                expert_moe.parameters(), expert_raw.parameters()
            ):
                para_raw.data = para_moe.data.clone()
    else:
        assert len(moe.experts) >= 1
        for idx, para in enumerate(moe.experts[0].parameters()):
            para_tensor = torch.cat(
                [list(expert.parameters())[idx].unsqueeze(0) for expert in moe.experts]
            )
            para_array = [torch.empty_like(para_tensor) for _ in range(world_size)]
            torch.distributed.all_gather(para_array, para_tensor)
Sengxian's avatar
Sengxian committed
165
166
167
168
169
170
            para_tensor_gathered = torch.cat(para_array, dim=0)
            assert para_tensor_gathered.shape[0] == len(moe_raw.experts)
            for expertID in range(para_tensor_gathered.shape[0]):
                list(moe_raw.experts[expertID].parameters())[
                    idx
                ].data = para_tensor_gathered[expertID]
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203

    moe_out, raw_out = _perform_forward(moe, moe_raw, batch_size, d_model, top_k)

    def get_experts_grad(experts: List[nn.Module]):
        return torch.stack(
            [
                torch.stack(
                    [
                        p.grad.sum() if p.grad is not None else torch.zeros(1).cuda()
                        for p in item.parameters()
                    ]
                ).sum()
                for item in experts
            ]
        )

    moe_grad, raw_grad = (
        get_experts_grad(moe.experts),
        get_experts_grad(moe_raw.experts),
    )

    if world_size > 1:
        torch.distributed.all_reduce(raw_grad)
        raw_grad = raw_grad[rank * num_expert : (rank + 1) * num_expert]

    moe_out_list = [moe_out, moe_grad]
    raw_out_list = [raw_out, raw_grad]
    names = ["forward", "backward"]

    _assert_numercial(names, moe_out_list, raw_out_list)


def _run_distributed(func: Callable, args: Dict):
Sengxian's avatar
Sengxian committed
204
205
206
    import subprocess
    import os

207
    ps, n = [], 2
Sengxian's avatar
Sengxian committed
208
209
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "36666"
Sengxian's avatar
Sengxian committed
210
    os.environ["OMPI_COMM_WORLD_SIZE"] = str(n)
Sengxian's avatar
Sengxian committed
211
212

    for i in range(n):
Sengxian's avatar
Sengxian committed
213
        os.environ["OMPI_COMM_WORLD_RANK"] = str(i)
Sengxian's avatar
Sengxian committed
214
        os.environ["CUDA_VISIBLE_DEVICES"] = str(i)
215
216
217
218
        p = subprocess.Popen(
            [sys.executable, __file__, func.__name__, json.dumps(args)],
            stdout=subprocess.PIPE,
        )
Sengxian's avatar
Sengxian committed
219
220
221
222
223
224
225
226
        ps.append(p)

    for p in ps:
        p.wait()
        retc = p.poll()
        assert retc == 0


227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
@pytest.mark.parametrize("num_expert", [4, 8])
@pytest.mark.parametrize("top_k", [2])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("d_model", [16])
@pytest.mark.parametrize("d_hidden", [32])
def test_fmoe_linear_distributed(
    num_expert, top_k, batch_size, d_model, d_hidden,
):
    _run_distributed(
        test_fmoe_linear,
        {
            "num_expert": num_expert,
            "top_k": top_k,
            "batch_size": batch_size,
            "d_model": d_model,
            "d_hidden": d_hidden,
        },
    )


@pytest.mark.parametrize("num_expert", [4, 8])
@pytest.mark.parametrize("top_k", [2])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("d_model", [16])
@pytest.mark.parametrize("expert", ["NaiveExpert", "LinearExpert"])
def test_fmoe_distributed(
    num_expert, top_k, batch_size, d_model, expert,
):
    _run_distributed(
        test_fmoe,
        {
            "num_expert": num_expert,
            "top_k": top_k,
            "batch_size": batch_size,
            "d_model": d_model,
            "expert": expert,
        },
    )


Sengxian's avatar
Sengxian committed
267
if __name__ == "__main__":
268
269
    os.environ["RANK"] = os.environ.get("OMPI_COMM_WORLD_RANK", "0")
    os.environ["WORLD_SIZE"] = os.environ.get("OMPI_COMM_WORLD_SIZE", "1")
Sengxian's avatar
Sengxian committed
270
271
    if int(os.environ["WORLD_SIZE"]) > 1:
        torch.distributed.init_process_group(backend="nccl")
Rick Ho's avatar
Rick Ho committed
272
273
        rank = torch.distributed.get_rank()
        world_size = torch.distributed.get_world_size()
274
275
276
277
278
    if len(sys.argv) >= 3:
        locals()[sys.argv[1]](**json.loads(sys.argv[2]))
    else:
        test_fmoe_linear(batch_size=4, num_expert=4, d_model=8, top_k=2, d_hidden=16)
        test_fmoe(batch_size=4, num_expert=4, d_model=8, top_k=2, expert=NaiveExpert)