test_numerical.py 12.3 KB
Newer Older
1
import sys
2
from collections import OrderedDict
3
from typing import List, Type, Union
Sengxian's avatar
Sengxian committed
4

5
import pytest
Rick Ho's avatar
Rick Ho committed
6
import torch
7
8
import torch.nn as nn

9
from copy import deepcopy
10
11
12
from fmoe.gates import NaiveGate
from fmoe.layers import FMoE
from fmoe.transformer import _Expert
13
from fmoe.distributed import DistributedGroupedDataParallel as LocalDDP
14
from moe import BruteForceMoELinear, BruteForceMoE, NaiveExpert, LinearExpert
Rick Ho's avatar
Rick Ho committed
15
16


17
18
19
def _perform_forward(
    moe: nn.Module, moe_raw: nn.Module, batch_size, d_model, top_k, rank, mp_group
):
20
21
    moe.zero_grad()
    moe_raw.zero_grad()
22
23
24
25
26
27
28
29
30
31
32
33
34
    if not mp_group:
        inp = torch.rand(batch_size, d_model).cuda()
    else:
        group_sender = rank // mp_group.size() * mp_group.size()
        inp = torch.rand(batch_size, d_model).cuda()
        torch.distributed.broadcast(inp, group_sender, group=mp_group)
        torch.distributed.broadcast(
            moe.gate.gate.weight.data, group_sender, group=mp_group
        )
        torch.distributed.broadcast(
            moe.gate.gate.bias.data, group_sender, group=mp_group
        )

35
36
37
38
39
40
41
42
43
44
45
    gate_idx, gate_score = moe.gate(inp)
    inp_repeated = inp.repeat_interleave(repeats=top_k, dim=0)
    moe_out = moe(inp).mean()
    raw_out = moe_raw(inp_repeated, gate_idx, gate_score).mean()

    moe_out.backward()
    raw_out.backward()

    return moe_out, raw_out


46
def _assert_numercial(names, moe_out_list, raw_out_list, rank):
47
48
49
50
51
52
53
54
55
56
57
    for name, mo, ro in zip(names, moe_out_list, raw_out_list):
        err = (mo - ro).abs().sum()
        print("Rank {} {} abs err {}".format(rank, name, err))
        if err > 1e-3:
            sys.stderr.write("=========== moe out ==============\n")
            sys.stderr.write("{}\n".format(mo))
            sys.stderr.write("=========== raw out ==============\n")
            sys.stderr.write("{}\n".format(ro))
            assert False


58
class MyMoE(FMoE):
59
60
61
    def __init__(
        self, num_expert, d_model, d_hidden, world_size, mp_group, top_k, activation
    ):
62
63
64
65
66
67
        super().__init__(
            num_expert=num_expert,
            d_model=d_model,
            gate=NaiveGate,
            world_size=world_size,
            mp_group=mp_group,
68
            top_k=top_k,
69
70
71
72
        )
        self.experts = _Expert(num_expert, d_model, d_hidden, activation)


73
@pytest.mark.parametrize("num_expert", [4, 8])
Sengxian's avatar
Sengxian committed
74
@pytest.mark.parametrize("top_k", [2, 3])
75
76
77
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("d_model", [16])
@pytest.mark.parametrize("d_hidden", [32])
78
79
@pytest.mark.parametrize("rank", [0])
@pytest.mark.parametrize("world_size", [1])
80
@pytest.mark.parametrize("mp_group", [None])
81
82
@pytest.mark.parametrize("dp_group", [None])
@pytest.mark.parametrize("world_group", [None])
83
84
85
86
87
88
def test_fmoe_linear(
    num_expert,
    top_k,
    batch_size,
    d_model,
    d_hidden,
89
90
    rank,
    world_size,
91
    mp_group,
92
93
    dp_group,
    world_group,
94
95
    activation=torch.nn.functional.gelu,
):
Rick Ho's avatar
Rick Ho committed
96
97
    torch.manual_seed(42 + rank)
    torch.cuda.manual_seed(42 + rank)
Sengxian's avatar
Sengxian committed
98

99
100
101
    moe = MyMoE(
        num_expert, d_model, d_hidden, world_size, mp_group, top_k, activation
    ).cuda()
Rick Ho's avatar
Rick Ho committed
102

Sengxian's avatar
Sengxian committed
103
104
105
106
    moe_raw = BruteForceMoELinear(
        activation=activation,
        num_expert=num_expert,
        d_model=d_model,
107
        d_hidden=d_hidden,
Sengxian's avatar
Sengxian committed
108
        world_size=world_size,
Sengxian's avatar
Sengxian committed
109
        top_k=top_k,
Sengxian's avatar
Sengxian committed
110
    ).cuda()
Rick Ho's avatar
Rick Ho committed
111
112

    if world_size == 1:
113
114
115
116
        moe_raw.weight_htoh4.data = moe.experts.htoh4.weight.data.clone()
        moe_raw.bias_htoh4.data = moe.experts.htoh4.bias.data.clone()
        moe_raw.weight_h4toh.data = moe.experts.h4toh.weight.data.clone()
        moe_raw.bias_h4toh.data = moe.experts.h4toh.bias.data.clone()
Rick Ho's avatar
Rick Ho committed
117
    else:
Sengxian's avatar
Sengxian committed
118
        weight_htoh4_array = [
119
            torch.empty_like(moe.experts.htoh4.weight.data) for _ in range(world_size)
Sengxian's avatar
Sengxian committed
120
        ]
Jiezhong Qiu's avatar
Jiezhong Qiu committed
121
        bias_htoh4_array = [
122
            torch.empty_like(moe.experts.htoh4.bias.data) for _ in range(world_size)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
123
        ]
124
125
        torch.distributed.all_gather(weight_htoh4_array, moe.experts.htoh4.weight.data)
        torch.distributed.all_gather(bias_htoh4_array, moe.experts.htoh4.bias.data)
Sengxian's avatar
Sengxian committed
126
        moe_raw.weight_htoh4.data = torch.cat(weight_htoh4_array, dim=0)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
127
        moe_raw.bias_htoh4.data = torch.cat(bias_htoh4_array, dim=0)
Sengxian's avatar
Sengxian committed
128
129

        weight_h4toh_array = [
130
            torch.empty_like(moe.experts.h4toh.weight.data) for _ in range(world_size)
Sengxian's avatar
Sengxian committed
131
        ]
Jiezhong Qiu's avatar
Jiezhong Qiu committed
132
        bias_h4toh_array = [
133
            torch.empty_like(moe.experts.h4toh.bias.data) for _ in range(world_size)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
134
        ]
135
136
        torch.distributed.all_gather(weight_h4toh_array, moe.experts.h4toh.weight.data)
        torch.distributed.all_gather(bias_h4toh_array, moe.experts.h4toh.bias.data)
Sengxian's avatar
Sengxian committed
137
        moe_raw.weight_h4toh.data = torch.cat(weight_h4toh_array, dim=0)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
138
        moe_raw.bias_h4toh.data = torch.cat(bias_h4toh_array, dim=0)
Sengxian's avatar
Sengxian committed
139

140
141
142
    moe_out, raw_out = _perform_forward(
        moe, moe_raw, batch_size, d_model, top_k, rank, mp_group
    )
Sengxian's avatar
Sengxian committed
143

144
145
146
147
148
149
150
151
152
153
154
155
156
157
    moe_out_list = (
        moe_out,
        moe.experts.htoh4.weight.grad,
        moe.experts.h4toh.weight.grad,
        moe.experts.htoh4.bias.grad,
        moe.experts.h4toh.bias.grad,
    )
    raw_out_list = (
        raw_out,
        moe_raw.weight_htoh4.grad,
        moe_raw.weight_h4toh.grad,
        moe_raw.bias_htoh4.grad,
        moe_raw.bias_h4toh.grad,
    )
Sengxian's avatar
Sengxian committed
158

Rick Ho's avatar
Rick Ho committed
159
    if world_size > 1:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
160
161
162
163
164
        _, htoh4_w_grad, h4toh_w_grad, htoh4_b_grad, h4toh_b_grad = raw_out_list
        torch.distributed.all_reduce(htoh4_w_grad)
        torch.distributed.all_reduce(h4toh_w_grad)
        torch.distributed.all_reduce(htoh4_b_grad)
        torch.distributed.all_reduce(h4toh_b_grad)
165
        mp_size = mp_group.size() if mp_group else 1
166
167
168
169
170
171
172
173
174
175
176
177
        htoh4_w_grad = (
            htoh4_w_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
        )
        h4toh_w_grad = (
            h4toh_w_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
        )
        htoh4_b_grad = (
            htoh4_b_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
        )
        h4toh_b_grad = (
            h4toh_b_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
        )
Jiezhong Qiu's avatar
Jiezhong Qiu committed
178
        raw_out_list = _, htoh4_w_grad, h4toh_w_grad, htoh4_b_grad, h4toh_b_grad
179

180
181
182
183
184
185
186
    names = [
        "output",
        "htoh4 weight grad",
        "h4toh weight grad",
        "htoh4 bias grad",
        "h4toh bias grad",
    ]
187
    _assert_numercial(names, moe_out_list, raw_out_list, rank)
188

Sengxian's avatar
Sengxian committed
189

190
191
192
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("num_expert", [4, 8])
@pytest.mark.parametrize("d_model", [16])
Sengxian's avatar
Sengxian committed
193
@pytest.mark.parametrize("top_k", [2, 3])
194
@pytest.mark.parametrize("expert", [NaiveExpert, LinearExpert])
195
196
@pytest.mark.parametrize("rank", [0])
@pytest.mark.parametrize("world_size", [1])
197
@pytest.mark.parametrize("mp_group", [None])
198
199
@pytest.mark.parametrize("dp_group", [None])
@pytest.mark.parametrize("world_group", [None])
200
def test_fmoe(
201
202
203
204
205
206
207
    batch_size,
    num_expert,
    d_model,
    top_k,
    expert: Union[Type[nn.Module], str],
    rank,
    world_size,
208
209
210
    mp_group,
    dp_group,
    world_group,
211
212
213
214
215
216
):
    torch.manual_seed(42 + rank)
    torch.cuda.manual_seed(42 + rank)

    if isinstance(expert, str):
        expert = globals()[expert]
Sengxian's avatar
Sengxian committed
217

218
219
220
221
222
    moe = FMoE(
        num_expert=num_expert,
        d_model=d_model,
        gate=NaiveGate,
        world_size=world_size,
223
        mp_group=mp_group,
224
225
226
227
228
        expert=expert,
        top_k=top_k,
    ).cuda()

    moe_raw = BruteForceMoE(
Sengxian's avatar
Sengxian committed
229
230
231
232
233
        expert=expert,
        num_expert=num_expert,
        d_model=d_model,
        world_size=world_size,
        top_k=top_k,
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
    ).cuda()

    if world_size == 1:
        for expert_moe, expert_raw in zip(moe.experts, moe_raw.experts):
            for para_moe, para_raw in zip(
                expert_moe.parameters(), expert_raw.parameters()
            ):
                para_raw.data = para_moe.data.clone()
    else:
        assert len(moe.experts) >= 1
        for idx, para in enumerate(moe.experts[0].parameters()):
            para_tensor = torch.cat(
                [list(expert.parameters())[idx].unsqueeze(0) for expert in moe.experts]
            )
            para_array = [torch.empty_like(para_tensor) for _ in range(world_size)]
            torch.distributed.all_gather(para_array, para_tensor)
Sengxian's avatar
Sengxian committed
250
251
252
253
254
255
            para_tensor_gathered = torch.cat(para_array, dim=0)
            assert para_tensor_gathered.shape[0] == len(moe_raw.experts)
            for expertID in range(para_tensor_gathered.shape[0]):
                list(moe_raw.experts[expertID].parameters())[
                    idx
                ].data = para_tensor_gathered[expertID]
256

257
258
259
    moe_out, raw_out = _perform_forward(
        moe, moe_raw, batch_size, d_model, top_k, rank, mp_group
    )
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280

    def get_experts_grad(experts: List[nn.Module]):
        return torch.stack(
            [
                torch.stack(
                    [
                        p.grad.sum() if p.grad is not None else torch.zeros(1).cuda()
                        for p in item.parameters()
                    ]
                ).sum()
                for item in experts
            ]
        )

    moe_grad, raw_grad = (
        get_experts_grad(moe.experts),
        get_experts_grad(moe_raw.experts),
    )

    if world_size > 1:
        torch.distributed.all_reduce(raw_grad)
281
282
        mp_size = mp_group.size() if mp_group else 1
        raw_grad = raw_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
283
284
285
286
287

    moe_out_list = [moe_out, moe_grad]
    raw_out_list = [raw_out, raw_grad]
    names = ["forward", "backward"]

288
    _assert_numercial(names, moe_out_list, raw_out_list, rank)
Sengxian's avatar
Sengxian committed
289
290


291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
class MyModule(nn.Module):
    def __init__(self, dim=8):
        super(MyModule, self).__init__()
        self.model = nn.Sequential(
            OrderedDict(
                [
                    ("linear1", nn.Linear(dim, dim)),
                    ("relu1", nn.ReLU()),
                    ("linear2", nn.Linear(dim, dim)),
                    ("relu2", nn.ReLU()),
                    ("linear3", nn.Linear(dim, dim)),
                ]
            )
        )

    def set_comm(self):
        for p in self.model._modules["linear1"].parameters():
            setattr(p, "dp_comm", "mp")
        for p in self.model._modules["linear2"].parameters():
            setattr(p, "dp_comm", "dp")
        for p in self.model._modules["linear3"].parameters():
            setattr(p, "dp_comm", "world")

    def forward(self, inp):
        return self.model(inp)


def _test_fmoe_local_ddp(rank, world_size, mp_group, dp_group, world_group):
    batch_size, dim = 4, 8

    torch.manual_seed(42 + rank)
    torch.cuda.manual_seed(42 + rank)

    model = MyModule().cuda()
    model_ddp = LocalDDP(deepcopy(model), mp_group, dp_group, world_group)
    model.set_comm()
    model_ddp.module.set_comm()

    inp = torch.randn(batch_size, dim).cuda()

    raw_out = model(inp).mean()
    ddp_out = model_ddp(inp).mean()

    raw_out.backward()
    ddp_out.backward()

    torch.distributed.all_reduce(
        model.model._modules["linear1"].weight.grad.data, group=mp_group
    )
    model.model._modules["linear1"].weight.grad /= mp_group.size()
    torch.distributed.all_reduce(
        model.model._modules["linear2"].weight.grad.data, group=dp_group
    )
    model.model._modules["linear2"].weight.grad /= dp_group.size()
    torch.distributed.all_reduce(
        model.model._modules["linear3"].weight.grad.data, group=world_group
    )
    model.model._modules["linear3"].weight.grad /= world_group.size()
    model_ddp.allreduce_params(reduce_after=False, fp32_allreduce=False)

    raw_out_list = [
        model.model._modules["linear1"].weight.grad,
        model.model._modules["linear2"].weight.grad,
        model.model._modules["linear3"].weight.grad,
    ]
    ddp_out_list = [
        model_ddp.module.model._modules["linear1"].weight.grad,
        model_ddp.module.model._modules["linear2"].weight.grad,
        model_ddp.module.model._modules["linear3"].weight.grad,
    ]

    names = ["mp grad", "dp grad", "wp grad"]

    _assert_numercial(names, ddp_out_list, raw_out_list, rank)


367
368
369
370
371
372
373
374
375
if __name__ == "__main__":
    test_fmoe_linear(
        batch_size=4,
        num_expert=4,
        d_model=8,
        top_k=2,
        d_hidden=16,
        rank=0,
        world_size=1,
376
        mp_group=None,
377
378
        dp_group=None,
        world_group=None,
379
    )
380
381
382
383
384
385
386
387
    test_fmoe(
        batch_size=4,
        num_expert=4,
        d_model=8,
        top_k=2,
        expert=NaiveExpert,
        rank=0,
        world_size=1,
388
        mp_group=None,
389
390
        dp_group=None,
        world_group=None,
391
    )