test_numerical.py 8.66 KB
Newer Older
1
2
3
4
import json
import os
import sys
from typing import List, Callable, Dict, Type, Union
Sengxian's avatar
Sengxian committed
5

6
import pytest
Rick Ho's avatar
Rick Ho committed
7
import torch
8
9
10
11
12
13
import torch.nn as nn

from fmoe.gates import NaiveGate
from fmoe.layers import FMoE
from fmoe.transformer import _Expert
from moe import BruteForceMoELinear, BruteForceMoE, NaiveExpert, LinearExpert
Rick Ho's avatar
Rick Ho committed
14

Sengxian's avatar
Sengxian committed
15
16
rank = 0
world_size = 1
Rick Ho's avatar
Rick Ho committed
17
18


19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
def _perform_forward(moe: nn.Module, moe_raw: nn.Module, batch_size, d_model, top_k):
    moe.zero_grad()
    moe_raw.zero_grad()
    inp = torch.rand(batch_size, d_model).cuda()
    gate_idx, gate_score = moe.gate(inp)
    inp_repeated = inp.repeat_interleave(repeats=top_k, dim=0)
    moe_out = moe(inp).mean()
    raw_out = moe_raw(inp_repeated, gate_idx, gate_score).mean()

    moe_out.backward()
    raw_out.backward()

    return moe_out, raw_out


def _assert_numercial(names, moe_out_list, raw_out_list):
    for name, mo, ro in zip(names, moe_out_list, raw_out_list):
        err = (mo - ro).abs().sum()
        print("Rank {} {} abs err {}".format(rank, name, err))
        if err > 1e-3:
            sys.stderr.write("=========== moe out ==============\n")
            sys.stderr.write("{}\n".format(mo))
            sys.stderr.write("=========== raw out ==============\n")
            sys.stderr.write("{}\n".format(ro))
            assert False


@pytest.mark.parametrize("num_expert", [4, 8])
@pytest.mark.parametrize("top_k", [2])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("d_model", [16])
@pytest.mark.parametrize("d_hidden", [32])
def test_fmoe_linear(
    num_expert,
    top_k,
    batch_size,
    d_model,
    d_hidden,
    activation=torch.nn.functional.gelu,
):
Rick Ho's avatar
Rick Ho committed
59
60
    torch.manual_seed(42 + rank)
    torch.cuda.manual_seed(42 + rank)
Sengxian's avatar
Sengxian committed
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75

    experts = _Expert(num_expert, d_model, d_hidden, activation).cuda()

    def expert_fn(inp, gate):
        return experts(inp, gate)

    moe = FMoE(
        num_expert=num_expert,
        d_model=d_model,
        gate=NaiveGate,
        world_size=world_size,
        mp_group=None,
        expert_fn=expert_fn,
        top_k=top_k,
    ).cuda()
Rick Ho's avatar
Rick Ho committed
76

Sengxian's avatar
Sengxian committed
77
78
79
80
    moe_raw = BruteForceMoELinear(
        activation=activation,
        num_expert=num_expert,
        d_model=d_model,
81
        d_hidden=d_hidden,
Sengxian's avatar
Sengxian committed
82
83
        world_size=world_size,
    ).cuda()
Rick Ho's avatar
Rick Ho committed
84
85

    if world_size == 1:
Sengxian's avatar
Sengxian committed
86
87
        moe_raw.weight_htoh4.data = experts.htoh4.weight.data.clone()
        moe_raw.weight_h4toh.data = experts.h4toh.weight.data.clone()
Rick Ho's avatar
Rick Ho committed
88
    else:
Sengxian's avatar
Sengxian committed
89
90
91
92
93
94
95
96
97
98
99
100
        weight_htoh4_array = [
            torch.empty_like(experts.htoh4.weight.data) for _ in range(world_size)
        ]
        torch.distributed.all_gather(weight_htoh4_array, experts.htoh4.weight.data)
        moe_raw.weight_htoh4.data = torch.cat(weight_htoh4_array, dim=0)

        weight_h4toh_array = [
            torch.empty_like(experts.h4toh.weight.data) for _ in range(world_size)
        ]
        torch.distributed.all_gather(weight_h4toh_array, experts.h4toh.weight.data)
        moe_raw.weight_h4toh.data = torch.cat(weight_h4toh_array, dim=0)

101
    moe_out, raw_out = _perform_forward(moe, moe_raw, batch_size, d_model, top_k)
Sengxian's avatar
Sengxian committed
102

103
104
    moe_out_list = moe_out, experts.htoh4.weight.grad, experts.h4toh.weight.grad
    raw_out_list = raw_out, moe_raw.weight_htoh4.grad, moe_raw.weight_h4toh.grad
Sengxian's avatar
Sengxian committed
105

Rick Ho's avatar
Rick Ho committed
106
    if world_size > 1:
107
        _, htoh4_grad, h4toh_grad = raw_out_list
Sengxian's avatar
Sengxian committed
108
109
110
111
        torch.distributed.all_reduce(htoh4_grad)
        torch.distributed.all_reduce(h4toh_grad)
        htoh4_grad = htoh4_grad[rank * num_expert : (rank + 1) * num_expert]
        h4toh_grad = h4toh_grad[rank * num_expert : (rank + 1) * num_expert]
112
113
114
115
116
        raw_out_list = _, htoh4_grad, h4toh_grad

    names = ["output", "htoh4 weight grad", "h4toh weight grad"]
    _assert_numercial(names, moe_out_list, raw_out_list)

Sengxian's avatar
Sengxian committed
117

118
119
120
121
122
123
124
125
126
127
128
129
130
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("num_expert", [4, 8])
@pytest.mark.parametrize("d_model", [16])
@pytest.mark.parametrize("top_k", [2])
@pytest.mark.parametrize("expert", ["NaiveExpert", "LinearExpert"])
def test_fmoe(
    batch_size, num_expert, d_model, top_k, expert: Union[Type[nn.Module], str]
):
    torch.manual_seed(42 + rank)
    torch.cuda.manual_seed(42 + rank)

    if isinstance(expert, str):
        expert = globals()[expert]
Sengxian's avatar
Sengxian committed
131

132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
    moe = FMoE(
        num_expert=num_expert,
        d_model=d_model,
        gate=NaiveGate,
        world_size=world_size,
        mp_group=None,
        expert=expert,
        top_k=top_k,
    ).cuda()

    moe_raw = BruteForceMoE(
        expert=expert, num_expert=num_expert, d_model=d_model, world_size=world_size,
    ).cuda()

    if world_size == 1:
        for expert_moe, expert_raw in zip(moe.experts, moe_raw.experts):
            for para_moe, para_raw in zip(
                expert_moe.parameters(), expert_raw.parameters()
            ):
                para_raw.data = para_moe.data.clone()
    else:
        assert len(moe.experts) >= 1
        for idx, para in enumerate(moe.experts[0].parameters()):
            para_tensor = torch.cat(
                [list(expert.parameters())[idx].unsqueeze(0) for expert in moe.experts]
            )
            para_array = [torch.empty_like(para_tensor) for _ in range(world_size)]
            torch.distributed.all_gather(para_array, para_tensor)
Sengxian's avatar
Sengxian committed
160
161
162
163
164
165
            para_tensor_gathered = torch.cat(para_array, dim=0)
            assert para_tensor_gathered.shape[0] == len(moe_raw.experts)
            for expertID in range(para_tensor_gathered.shape[0]):
                list(moe_raw.experts[expertID].parameters())[
                    idx
                ].data = para_tensor_gathered[expertID]
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198

    moe_out, raw_out = _perform_forward(moe, moe_raw, batch_size, d_model, top_k)

    def get_experts_grad(experts: List[nn.Module]):
        return torch.stack(
            [
                torch.stack(
                    [
                        p.grad.sum() if p.grad is not None else torch.zeros(1).cuda()
                        for p in item.parameters()
                    ]
                ).sum()
                for item in experts
            ]
        )

    moe_grad, raw_grad = (
        get_experts_grad(moe.experts),
        get_experts_grad(moe_raw.experts),
    )

    if world_size > 1:
        torch.distributed.all_reduce(raw_grad)
        raw_grad = raw_grad[rank * num_expert : (rank + 1) * num_expert]

    moe_out_list = [moe_out, moe_grad]
    raw_out_list = [raw_out, raw_grad]
    names = ["forward", "backward"]

    _assert_numercial(names, moe_out_list, raw_out_list)


def _run_distributed(func: Callable, args: Dict):
Sengxian's avatar
Sengxian committed
199
200
201
    import subprocess
    import os

202
    ps, n = [], 2
Sengxian's avatar
Sengxian committed
203
204
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "36666"
Sengxian's avatar
Sengxian committed
205
    os.environ["OMPI_COMM_WORLD_SIZE"] = str(n)
Sengxian's avatar
Sengxian committed
206
207

    for i in range(n):
Sengxian's avatar
Sengxian committed
208
        os.environ["OMPI_COMM_WORLD_RANK"] = str(i)
Sengxian's avatar
Sengxian committed
209
        os.environ["CUDA_VISIBLE_DEVICES"] = str(i)
210
211
212
213
        p = subprocess.Popen(
            [sys.executable, __file__, func.__name__, json.dumps(args)],
            stdout=subprocess.PIPE,
        )
Sengxian's avatar
Sengxian committed
214
215
216
217
218
219
220
221
        ps.append(p)

    for p in ps:
        p.wait()
        retc = p.poll()
        assert retc == 0


222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
@pytest.mark.parametrize("num_expert", [4, 8])
@pytest.mark.parametrize("top_k", [2])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("d_model", [16])
@pytest.mark.parametrize("d_hidden", [32])
def test_fmoe_linear_distributed(
    num_expert, top_k, batch_size, d_model, d_hidden,
):
    _run_distributed(
        test_fmoe_linear,
        {
            "num_expert": num_expert,
            "top_k": top_k,
            "batch_size": batch_size,
            "d_model": d_model,
            "d_hidden": d_hidden,
        },
    )


@pytest.mark.parametrize("num_expert", [4, 8])
@pytest.mark.parametrize("top_k", [2])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("d_model", [16])
@pytest.mark.parametrize("expert", ["NaiveExpert", "LinearExpert"])
def test_fmoe_distributed(
    num_expert, top_k, batch_size, d_model, expert,
):
    _run_distributed(
        test_fmoe,
        {
            "num_expert": num_expert,
            "top_k": top_k,
            "batch_size": batch_size,
            "d_model": d_model,
            "expert": expert,
        },
    )


Sengxian's avatar
Sengxian committed
262
if __name__ == "__main__":
263
264
    os.environ["RANK"] = os.environ.get("OMPI_COMM_WORLD_RANK", "0")
    os.environ["WORLD_SIZE"] = os.environ.get("OMPI_COMM_WORLD_SIZE", "1")
Sengxian's avatar
Sengxian committed
265
266
    if int(os.environ["WORLD_SIZE"]) > 1:
        torch.distributed.init_process_group(backend="nccl")
Rick Ho's avatar
Rick Ho committed
267
268
        rank = torch.distributed.get_rank()
        world_size = torch.distributed.get_world_size()
269
270
271
272
273
    if len(sys.argv) >= 3:
        locals()[sys.argv[1]](**json.loads(sys.argv[2]))
    else:
        test_fmoe_linear(batch_size=4, num_expert=4, d_model=8, top_k=2, d_hidden=16)
        test_fmoe(batch_size=4, num_expert=4, d_model=8, top_k=2, expert=NaiveExpert)