test_numerical.py 16.7 KB
Newer Older
1
import sys
2
from collections import OrderedDict
3
from typing import List, Type, Union
Sengxian's avatar
Sengxian committed
4

5
import pytest
Rick Ho's avatar
Rick Ho committed
6
import torch
7
import torch.nn as nn
8
import numpy as np
9

10
from copy import deepcopy
11
12
13
from fmoe.gates import NaiveGate
from fmoe.layers import FMoE
from fmoe.transformer import _Expert
14
from fmoe.distributed import DistributedGroupedDataParallel as LocalDDP
Rick Ho's avatar
Rick Ho committed
15
from fmoe.megatron.layers import _megatron_init_method
16
from moe import BruteForceMoELinear, BruteForceMoE, NaiveExpert, LinearExpert
Rick Ho's avatar
Rick Ho committed
17
18


19
def _perform_forward(
20
    moe: nn.Module, moe_raw: nn.Module, batch_size, d_model, top_k, rank, mp_group, data_type='torch.FloatTensor'
21
):
22
23
    moe.zero_grad()
    moe_raw.zero_grad()
24
25
26

    inp = torch.rand(batch_size, d_model).type(data_type).cuda()
        
Rick Ho's avatar
Rick Ho committed
27
    if mp_group is not None:
28
29
30
31
32
33
34
35
36
        group_sender = rank // mp_group.size() * mp_group.size()
        torch.distributed.broadcast(inp, group_sender, group=mp_group)
        torch.distributed.broadcast(
            moe.gate.gate.weight.data, group_sender, group=mp_group
        )
        torch.distributed.broadcast(
            moe.gate.gate.bias.data, group_sender, group=mp_group
        )

37
38
    inp_raw = inp.clone()
    inp.requires_grad = True
39

40
    inp_raw.requires_grad = True
Rick Ho's avatar
Rick Ho committed
41
    gate_idx, gate_score = moe.gate(inp_raw)
42
    moe_out = moe(inp)
Rick Ho's avatar
Rick Ho committed
43
    raw_out = moe_raw(inp_raw, gate_idx, gate_score)
44

45
46
47
48
    raw_out.mean().backward()
    moe_out.mean().backward()

    return moe_out, raw_out, inp.grad, inp_raw.grad
49
50


51
def _assert_numerical(names, moe_out_list, raw_out_list, rank, precision=1e-3):
52
    for name, mo, ro in zip(names, moe_out_list, raw_out_list):
53
        err = (mo - ro).abs().max()
Rick Ho's avatar
Rick Ho committed
54
55
        if err.dtype == torch.bfloat16:
            precision *= 100
56
        print("Rank {} {} abs err {}".format(rank, name, err))
57
        if err > precision:
Sengxian's avatar
Sengxian committed
58
            sys.stderr.write(f"=========== {name} moe out ==============\n")
59
            sys.stderr.write("{}\n".format(mo))
Sengxian's avatar
Sengxian committed
60
            sys.stderr.write(f"=========== {name} raw out ==============\n")
61
            sys.stderr.write("{}\n".format(ro))
62
63
            sys.stderr.write(f"=========== {name} diff ==============\n")
            sys.stderr.write("{}\n{}\n".format((mo - ro).abs(), err))
64
65
66
            assert False


67
class MyMoE(FMoE):
68
69
70
    def __init__(
        self, num_expert, d_model, d_hidden, world_size, mp_group, top_k, activation
    ):
71
72
73
74
75
        super().__init__(
            num_expert=num_expert,
            d_model=d_model,
            gate=NaiveGate,
            world_size=world_size,
76
            slice_group=mp_group,
77
            top_k=top_k,
78
79
80
        )
        self.experts = _Expert(num_expert, d_model, d_hidden, activation)

81
        rng = np.random.default_rng(1234)
Sengxian's avatar
Sengxian committed
82
83
        _megatron_init_method(self.experts.htoh4, rng, 1.0)
        _megatron_init_method(self.experts.h4toh, rng, 1.0)
84

85

86
@pytest.mark.parametrize("num_expert", [4, 8])
Sengxian's avatar
Sengxian committed
87
@pytest.mark.parametrize("top_k", [2, 3])
88
89
90
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("d_model", [16])
@pytest.mark.parametrize("d_hidden", [32])
91
92
@pytest.mark.parametrize("rank", [0])
@pytest.mark.parametrize("world_size", [1])
93
@pytest.mark.parametrize("mp_group", [None])
94
95
@pytest.mark.parametrize("dp_group", [None])
@pytest.mark.parametrize("world_group", [None])
96
@pytest.mark.parametrize("data_type", ['torch.FloatTensor', 'torch.DoubleTensor', 'torch.HalfTensor'])
97
98
99
100
101
102
def test_fmoe_linear(
    num_expert,
    top_k,
    batch_size,
    d_model,
    d_hidden,
103
104
    rank,
    world_size,
105
    mp_group,
106
107
    dp_group,
    world_group,
108
    data_type,
109
110
    activation=torch.nn.functional.gelu,
):
Rick Ho's avatar
Rick Ho committed
111
112
    torch.manual_seed(42 + rank)
    torch.cuda.manual_seed(42 + rank)
Sengxian's avatar
Sengxian committed
113

114
115
    moe = MyMoE(
        num_expert, d_model, d_hidden, world_size, mp_group, top_k, activation
116
    ).type(data_type).cuda()
Rick Ho's avatar
Rick Ho committed
117

Sengxian's avatar
Sengxian committed
118
119
120
121
    moe_raw = BruteForceMoELinear(
        activation=activation,
        num_expert=num_expert,
        d_model=d_model,
122
        d_hidden=d_hidden,
Sengxian's avatar
Sengxian committed
123
        world_size=world_size,
Sengxian's avatar
Sengxian committed
124
        top_k=top_k,
125
    ).type(data_type).cuda()
Rick Ho's avatar
Rick Ho committed
126
127

    if world_size == 1:
128
129
130
131
        moe_raw.weight_htoh4.data = moe.experts.htoh4.weight.data.clone()
        moe_raw.bias_htoh4.data = moe.experts.htoh4.bias.data.clone()
        moe_raw.weight_h4toh.data = moe.experts.h4toh.weight.data.clone()
        moe_raw.bias_h4toh.data = moe.experts.h4toh.bias.data.clone()
Rick Ho's avatar
Rick Ho committed
132
    else:
Sengxian's avatar
Sengxian committed
133
        weight_htoh4_array = [
134
            torch.empty_like(moe.experts.htoh4.weight.data) for _ in range(world_size)
Sengxian's avatar
Sengxian committed
135
        ]
Jiezhong Qiu's avatar
Jiezhong Qiu committed
136
        bias_htoh4_array = [
137
            torch.empty_like(moe.experts.htoh4.bias.data) for _ in range(world_size)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
138
        ]
139
140
        torch.distributed.all_gather(weight_htoh4_array, moe.experts.htoh4.weight.data)
        torch.distributed.all_gather(bias_htoh4_array, moe.experts.htoh4.bias.data)
Sengxian's avatar
Sengxian committed
141
        moe_raw.weight_htoh4.data = torch.cat(weight_htoh4_array, dim=0)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
142
        moe_raw.bias_htoh4.data = torch.cat(bias_htoh4_array, dim=0)
Sengxian's avatar
Sengxian committed
143
144

        weight_h4toh_array = [
145
            torch.empty_like(moe.experts.h4toh.weight.data) for _ in range(world_size)
Sengxian's avatar
Sengxian committed
146
        ]
Jiezhong Qiu's avatar
Jiezhong Qiu committed
147
        bias_h4toh_array = [
148
            torch.empty_like(moe.experts.h4toh.bias.data) for _ in range(world_size)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
149
        ]
150
151
        torch.distributed.all_gather(weight_h4toh_array, moe.experts.h4toh.weight.data)
        torch.distributed.all_gather(bias_h4toh_array, moe.experts.h4toh.bias.data)
Sengxian's avatar
Sengxian committed
152
        moe_raw.weight_h4toh.data = torch.cat(weight_h4toh_array, dim=0)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
153
        moe_raw.bias_h4toh.data = torch.cat(bias_h4toh_array, dim=0)
Sengxian's avatar
Sengxian committed
154

155
    moe_out, raw_out, moe_grad_in, raw_grad_in = _perform_forward(
156
        moe, moe_raw, batch_size, d_model, top_k, rank, mp_group, data_type=data_type
157
    )
Sengxian's avatar
Sengxian committed
158

Sengxian's avatar
Sengxian committed
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
    moe_out_list = (
        moe_out,
        moe_grad_in,
        moe.experts.htoh4.weight.grad,
        moe.experts.h4toh.weight.grad,
        moe.experts.htoh4.bias.grad,
        moe.experts.h4toh.bias.grad,
    )
    raw_out_list = (
        raw_out,
        raw_grad_in,
        moe_raw.weight_htoh4.grad,
        moe_raw.weight_h4toh.grad,
        moe_raw.bias_htoh4.grad,
        moe_raw.bias_h4toh.grad,
    )
Sengxian's avatar
Sengxian committed
175

Rick Ho's avatar
Rick Ho committed
176
    if world_size > 1:
Sengxian's avatar
Sengxian committed
177
        _, __, htoh4_w_grad, h4toh_w_grad, htoh4_b_grad, h4toh_b_grad = raw_out_list
Jiezhong Qiu's avatar
Jiezhong Qiu committed
178
179
180
181
        torch.distributed.all_reduce(htoh4_w_grad)
        torch.distributed.all_reduce(h4toh_w_grad)
        torch.distributed.all_reduce(htoh4_b_grad)
        torch.distributed.all_reduce(h4toh_b_grad)
182
        mp_size = mp_group.size() if mp_group else 1
183
184
185
186
187
188
189
190
191
192
193
194
        htoh4_w_grad = (
            htoh4_w_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
        )
        h4toh_w_grad = (
            h4toh_w_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
        )
        htoh4_b_grad = (
            htoh4_b_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
        )
        h4toh_b_grad = (
            h4toh_b_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
        )
Sengxian's avatar
Sengxian committed
195
        raw_out_list = _, __, htoh4_w_grad, h4toh_w_grad, htoh4_b_grad, h4toh_b_grad
196

Sengxian's avatar
Sengxian committed
197
198
199
200
201
202
203
204
    names = [
        "output",
        "input grad",
        "htoh4 weight grad",
        "h4toh weight grad",
        "htoh4 bias grad",
        "h4toh bias grad",
    ]
Sengxian's avatar
Sengxian committed
205

206
207
208
209

    precision = 5e-1 if data_type == 'torch.HalfTensor' else 1e-3
        
    _assert_numerical(names, moe_out_list, raw_out_list, rank, precision=precision)
210

Sengxian's avatar
Sengxian committed
211

212
213
214
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("num_expert", [4, 8])
@pytest.mark.parametrize("d_model", [16])
Sengxian's avatar
Sengxian committed
215
@pytest.mark.parametrize("top_k", [2, 3])
216
@pytest.mark.parametrize("expert", [NaiveExpert, LinearExpert])
217
218
@pytest.mark.parametrize("rank", [0])
@pytest.mark.parametrize("world_size", [1])
219
@pytest.mark.parametrize("mp_group", [None])
220
221
@pytest.mark.parametrize("dp_group", [None])
@pytest.mark.parametrize("world_group", [None])
Rick Ho's avatar
Rick Ho committed
222
@pytest.mark.parametrize("data_type", [torch.float32, torch.float16, torch.bfloat16])
223
def test_fmoe(
224
225
226
227
228
229
230
    batch_size,
    num_expert,
    d_model,
    top_k,
    expert: Union[Type[nn.Module], str],
    rank,
    world_size,
231
232
233
    mp_group,
    dp_group,
    world_group,
Rick Ho's avatar
Rick Ho committed
234
    data_type
235
236
237
238
239
240
):
    torch.manual_seed(42 + rank)
    torch.cuda.manual_seed(42 + rank)

    if isinstance(expert, str):
        expert = globals()[expert]
Sengxian's avatar
Sengxian committed
241

242
243
244
245
246
    moe = FMoE(
        num_expert=num_expert,
        d_model=d_model,
        gate=NaiveGate,
        world_size=world_size,
247
        mp_group=mp_group,
248
249
        expert=expert,
        top_k=top_k,
Rick Ho's avatar
Rick Ho committed
250
    ).cuda().to(data_type)
251
252

    moe_raw = BruteForceMoE(
Sengxian's avatar
Sengxian committed
253
254
255
256
257
        expert=expert,
        num_expert=num_expert,
        d_model=d_model,
        world_size=world_size,
        top_k=top_k,
Rick Ho's avatar
Rick Ho committed
258
    ).cuda().to(data_type)
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273

    if world_size == 1:
        for expert_moe, expert_raw in zip(moe.experts, moe_raw.experts):
            for para_moe, para_raw in zip(
                expert_moe.parameters(), expert_raw.parameters()
            ):
                para_raw.data = para_moe.data.clone()
    else:
        assert len(moe.experts) >= 1
        for idx, para in enumerate(moe.experts[0].parameters()):
            para_tensor = torch.cat(
                [list(expert.parameters())[idx].unsqueeze(0) for expert in moe.experts]
            )
            para_array = [torch.empty_like(para_tensor) for _ in range(world_size)]
            torch.distributed.all_gather(para_array, para_tensor)
Sengxian's avatar
Sengxian committed
274
275
276
277
278
279
            para_tensor_gathered = torch.cat(para_array, dim=0)
            assert para_tensor_gathered.shape[0] == len(moe_raw.experts)
            for expertID in range(para_tensor_gathered.shape[0]):
                list(moe_raw.experts[expertID].parameters())[
                    idx
                ].data = para_tensor_gathered[expertID]
280

281
    moe_out, raw_out, moe_grad_in, raw_grad_in = _perform_forward(
Rick Ho's avatar
Rick Ho committed
282
        moe, moe_raw, batch_size, d_model, top_k, rank, mp_group, data_type
283
    )
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304

    def get_experts_grad(experts: List[nn.Module]):
        return torch.stack(
            [
                torch.stack(
                    [
                        p.grad.sum() if p.grad is not None else torch.zeros(1).cuda()
                        for p in item.parameters()
                    ]
                ).sum()
                for item in experts
            ]
        )

    moe_grad, raw_grad = (
        get_experts_grad(moe.experts),
        get_experts_grad(moe_raw.experts),
    )

    if world_size > 1:
        torch.distributed.all_reduce(raw_grad)
305
306
        mp_size = mp_group.size() if mp_group else 1
        raw_grad = raw_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
307

308
309
310
    moe_out_list = [moe_out, moe_grad, moe_grad_in]
    raw_out_list = [raw_out, raw_grad, raw_grad_in]
    names = ["forward", "backward", "grad_in"]
311

312
    _assert_numerical(names, moe_out_list, raw_out_list, rank)
Sengxian's avatar
Sengxian committed
313
314


315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
class MyModule(nn.Module):
    def __init__(self, dim=8):
        super(MyModule, self).__init__()
        self.model = nn.Sequential(
            OrderedDict(
                [
                    ("linear1", nn.Linear(dim, dim)),
                    ("relu1", nn.ReLU()),
                    ("linear2", nn.Linear(dim, dim)),
                    ("relu2", nn.ReLU()),
                    ("linear3", nn.Linear(dim, dim)),
                ]
            )
        )

    def set_comm(self):
        for p in self.model._modules["linear1"].parameters():
            setattr(p, "dp_comm", "mp")
        for p in self.model._modules["linear2"].parameters():
            setattr(p, "dp_comm", "dp")
        for p in self.model._modules["linear3"].parameters():
            setattr(p, "dp_comm", "world")

    def forward(self, inp):
        return self.model(inp)


def _test_fmoe_local_ddp(rank, world_size, mp_group, dp_group, world_group):
    batch_size, dim = 4, 8

    torch.manual_seed(42 + rank)
    torch.cuda.manual_seed(42 + rank)

    model = MyModule().cuda()
Rick Ho's avatar
Rick Ho committed
349
350
    model_ddp = LocalDDP(deepcopy(model),
            mp_group=mp_group, dp_group=dp_group, world_group=world_group)
351
    model = deepcopy(model_ddp.module)
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
    model.set_comm()
    model_ddp.module.set_comm()

    inp = torch.randn(batch_size, dim).cuda()

    raw_out = model(inp).mean()
    ddp_out = model_ddp(inp).mean()

    raw_out.backward()
    ddp_out.backward()

    torch.distributed.all_reduce(
        model.model._modules["linear1"].weight.grad.data, group=mp_group
    )
    model.model._modules["linear1"].weight.grad /= mp_group.size()
    torch.distributed.all_reduce(
        model.model._modules["linear2"].weight.grad.data, group=dp_group
    )
    model.model._modules["linear2"].weight.grad /= dp_group.size()
    torch.distributed.all_reduce(
        model.model._modules["linear3"].weight.grad.data, group=world_group
    )
    model.model._modules["linear3"].weight.grad /= world_group.size()
    model_ddp.allreduce_params(reduce_after=False, fp32_allreduce=False)

    raw_out_list = [
        model.model._modules["linear1"].weight.grad,
        model.model._modules["linear2"].weight.grad,
        model.model._modules["linear3"].weight.grad,
    ]
    ddp_out_list = [
        model_ddp.module.model._modules["linear1"].weight.grad,
        model_ddp.module.model._modules["linear2"].weight.grad,
        model_ddp.module.model._modules["linear3"].weight.grad,
    ]

    names = ["mp grad", "dp grad", "wp grad"]

390
    _assert_numerical(names, ddp_out_list, raw_out_list, rank)
391
392


Colin's avatar
Colin committed
393
394
395
396
397
398
399
400
401
402
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("num_expert", [None])
@pytest.mark.parametrize("d_model", [16])
@pytest.mark.parametrize("top_k", [2, 3])
@pytest.mark.parametrize("expert", [ [NaiveExpert for _ in range(4)], [LinearExpert, NaiveExpert, LinearExpert, NaiveExpert, LinearExpert, NaiveExpert, LinearExpert, NaiveExpert] ])
@pytest.mark.parametrize("rank", [0])
@pytest.mark.parametrize("world_size", [1])
@pytest.mark.parametrize("mp_group", [None])
@pytest.mark.parametrize("dp_group", [None])
@pytest.mark.parametrize("world_group", [None])
Rick Ho's avatar
Rick Ho committed
403
@pytest.mark.parametrize("data_type", [torch.float32])
Colin's avatar
Colin committed
404
405
406
407
408
409
410
411
412
413
414
def test_fmoe_experts(
    batch_size,
    num_expert,
    d_model,
    top_k,
    expert: Union[Type[nn.Module], str],
    rank,
    world_size,
    mp_group,
    dp_group,
    world_group,
Rick Ho's avatar
Rick Ho committed
415
    data_type
Colin's avatar
Colin committed
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
):
    torch.manual_seed(42 + rank)
    torch.cuda.manual_seed(42 + rank)

    if isinstance(expert, str):
        expert = globals()[expert]

    moe = FMoE(
        num_expert=num_expert,
        d_model=d_model,
        gate=NaiveGate,
        world_size=world_size,
        mp_group=mp_group,
        expert=expert,
        top_k=top_k,
Rick Ho's avatar
Rick Ho committed
431
    ).cuda().to(data_type)
Colin's avatar
Colin committed
432
433
434
435
436
437
438

    moe_raw = BruteForceMoE(
        expert=expert,
        num_expert=num_expert,
        d_model=d_model,
        world_size=world_size,
        top_k=top_k,
Rick Ho's avatar
Rick Ho committed
439
    ).cuda().to(data_type)
Colin's avatar
Colin committed
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462

    if world_size == 1:
        for expert_moe, expert_raw in zip(moe.experts, moe_raw.experts):
            for para_moe, para_raw in zip(
                expert_moe.parameters(), expert_raw.parameters()
            ):
                para_raw.data = para_moe.data.clone()
    else:
        assert len(moe.experts) >= 1
        for idx, para in enumerate(moe.experts[0].parameters()):
            para_tensor = torch.cat(
                [list(expert.parameters())[idx].unsqueeze(0) for expert in moe.experts]
            )
            para_array = [torch.empty_like(para_tensor) for _ in range(world_size)]
            torch.distributed.all_gather(para_array, para_tensor)
            para_tensor_gathered = torch.cat(para_array, dim=0)
            assert para_tensor_gathered.shape[0] == len(moe_raw.experts)
            for expertID in range(para_tensor_gathered.shape[0]):
                list(moe_raw.experts[expertID].parameters())[
                    idx
                ].data = para_tensor_gathered[expertID]

    moe_out, raw_out, moe_grad_in, raw_grad_in = _perform_forward(
Rick Ho's avatar
Rick Ho committed
463
        moe, moe_raw, batch_size, d_model, top_k, rank, mp_group, data_type
Colin's avatar
Colin committed
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
    )

    def get_experts_grad(experts: List[nn.Module]):
        return torch.stack(
            [
                torch.stack(
                    [
                        p.grad.sum() if p.grad is not None else torch.zeros(1).cuda()
                        for p in item.parameters()
                    ]
                ).sum()
                for item in experts
            ]
        )

    moe_grad, raw_grad = (
        get_experts_grad(moe.experts),
        get_experts_grad(moe_raw.experts),
    )

    if world_size > 1:
        torch.distributed.all_reduce(raw_grad)
        mp_size = mp_group.size() if mp_group else 1
        raw_grad = raw_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size

    moe_out_list = [moe_out, moe_grad, moe_grad_in]
    raw_out_list = [raw_out, raw_grad, raw_grad_in]
    names = ["forward", "backward", "grad_in"]

    _assert_numerical(names, moe_out_list, raw_out_list, rank)


496
if __name__ == "__main__":
Rick Ho's avatar
Rick Ho committed
497
    test_fmoe(
498
499
500
        batch_size=2,
        num_expert=2,
        d_model=2,
501
        top_k=2,
Rick Ho's avatar
Rick Ho committed
502
        expert=[NaiveExpert for _ in range(4)],
503
504
        rank=0,
        world_size=1,
505
        mp_group=None,
506
507
        dp_group=None,
        world_group=None,
Rick Ho's avatar
Rick Ho committed
508
        data_type=torch.bfloat16
509
    )