test_numerical.py 12.9 KB
Newer Older
1
import sys
2
from collections import OrderedDict
3
from typing import List, Type, Union
Sengxian's avatar
Sengxian committed
4

5
import pytest
Rick Ho's avatar
Rick Ho committed
6
import torch
7
import torch.nn as nn
8
import numpy as np
9

10
from copy import deepcopy
11
12
13
from fmoe.gates import NaiveGate
from fmoe.layers import FMoE
from fmoe.transformer import _Expert
14
from fmoe.distributed import DistributedGroupedDataParallel as LocalDDP
Rick Ho's avatar
Rick Ho committed
15
from fmoe.megatron.layers import _megatron_init_method
16
from moe import BruteForceMoELinear, BruteForceMoE, NaiveExpert, LinearExpert
Rick Ho's avatar
Rick Ho committed
17
18


19
def _perform_forward(
20
    moe: nn.Module, moe_raw: nn.Module, batch_size, d_model, top_k, rank, mp_group, data_type='torch.FloatTensor'
21
):
22
23
    moe.zero_grad()
    moe_raw.zero_grad()
24
25
26

    inp = torch.rand(batch_size, d_model).type(data_type).cuda()
        
Rick Ho's avatar
Rick Ho committed
27
    if mp_group is not None:
28
29
30
31
32
33
34
35
36
        group_sender = rank // mp_group.size() * mp_group.size()
        torch.distributed.broadcast(inp, group_sender, group=mp_group)
        torch.distributed.broadcast(
            moe.gate.gate.weight.data, group_sender, group=mp_group
        )
        torch.distributed.broadcast(
            moe.gate.gate.bias.data, group_sender, group=mp_group
        )

37
38
    inp_raw = inp.clone()
    inp.requires_grad = True
39

40
    inp_raw.requires_grad = True
Rick Ho's avatar
Rick Ho committed
41
    gate_idx, gate_score = moe.gate(inp_raw)
42
    moe_out = moe(inp)
Rick Ho's avatar
Rick Ho committed
43
    raw_out = moe_raw(inp_raw, gate_idx, gate_score)
44

45
46
47
48
    raw_out.mean().backward()
    moe_out.mean().backward()

    return moe_out, raw_out, inp.grad, inp_raw.grad
49
50


51
def _assert_numerical(names, moe_out_list, raw_out_list, rank, precision=1e-3):
52
    for name, mo, ro in zip(names, moe_out_list, raw_out_list):
53
        err = (mo - ro).abs().max()
54
        print("Rank {} {} abs err {}".format(rank, name, err))
55
        if err > precision:
Sengxian's avatar
Sengxian committed
56
            sys.stderr.write(f"=========== {name} moe out ==============\n")
57
            sys.stderr.write("{}\n".format(mo))
Sengxian's avatar
Sengxian committed
58
            sys.stderr.write(f"=========== {name} raw out ==============\n")
59
            sys.stderr.write("{}\n".format(ro))
60
61
            sys.stderr.write(f"=========== {name} diff ==============\n")
            sys.stderr.write("{}\n{}\n".format((mo - ro).abs(), err))
62
63
64
            assert False


65
class MyMoE(FMoE):
66
67
68
    def __init__(
        self, num_expert, d_model, d_hidden, world_size, mp_group, top_k, activation
    ):
69
70
71
72
73
74
        super().__init__(
            num_expert=num_expert,
            d_model=d_model,
            gate=NaiveGate,
            world_size=world_size,
            mp_group=mp_group,
75
            top_k=top_k,
76
77
78
        )
        self.experts = _Expert(num_expert, d_model, d_hidden, activation)

79
        rng = np.random.default_rng(1234)
Sengxian's avatar
Sengxian committed
80
81
        _megatron_init_method(self.experts.htoh4, rng, 1.0)
        _megatron_init_method(self.experts.h4toh, rng, 1.0)
82

83

84
@pytest.mark.parametrize("num_expert", [4, 8])
Sengxian's avatar
Sengxian committed
85
@pytest.mark.parametrize("top_k", [2, 3])
86
87
88
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("d_model", [16])
@pytest.mark.parametrize("d_hidden", [32])
89
90
@pytest.mark.parametrize("rank", [0])
@pytest.mark.parametrize("world_size", [1])
91
@pytest.mark.parametrize("mp_group", [None])
92
93
@pytest.mark.parametrize("dp_group", [None])
@pytest.mark.parametrize("world_group", [None])
94
@pytest.mark.parametrize("data_type", ['torch.FloatTensor', 'torch.DoubleTensor', 'torch.HalfTensor'])
95
96
97
98
99
100
def test_fmoe_linear(
    num_expert,
    top_k,
    batch_size,
    d_model,
    d_hidden,
101
102
    rank,
    world_size,
103
    mp_group,
104
105
    dp_group,
    world_group,
106
    data_type,
107
108
    activation=torch.nn.functional.gelu,
):
Rick Ho's avatar
Rick Ho committed
109
110
    torch.manual_seed(42 + rank)
    torch.cuda.manual_seed(42 + rank)
Sengxian's avatar
Sengxian committed
111

112
113
    moe = MyMoE(
        num_expert, d_model, d_hidden, world_size, mp_group, top_k, activation
114
    ).type(data_type).cuda()
Rick Ho's avatar
Rick Ho committed
115

Sengxian's avatar
Sengxian committed
116
117
118
119
    moe_raw = BruteForceMoELinear(
        activation=activation,
        num_expert=num_expert,
        d_model=d_model,
120
        d_hidden=d_hidden,
Sengxian's avatar
Sengxian committed
121
        world_size=world_size,
Sengxian's avatar
Sengxian committed
122
        top_k=top_k,
123
    ).type(data_type).cuda()
Rick Ho's avatar
Rick Ho committed
124
125

    if world_size == 1:
126
127
128
129
        moe_raw.weight_htoh4.data = moe.experts.htoh4.weight.data.clone()
        moe_raw.bias_htoh4.data = moe.experts.htoh4.bias.data.clone()
        moe_raw.weight_h4toh.data = moe.experts.h4toh.weight.data.clone()
        moe_raw.bias_h4toh.data = moe.experts.h4toh.bias.data.clone()
Rick Ho's avatar
Rick Ho committed
130
    else:
Sengxian's avatar
Sengxian committed
131
        weight_htoh4_array = [
132
            torch.empty_like(moe.experts.htoh4.weight.data) for _ in range(world_size)
Sengxian's avatar
Sengxian committed
133
        ]
Jiezhong Qiu's avatar
Jiezhong Qiu committed
134
        bias_htoh4_array = [
135
            torch.empty_like(moe.experts.htoh4.bias.data) for _ in range(world_size)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
136
        ]
137
138
        torch.distributed.all_gather(weight_htoh4_array, moe.experts.htoh4.weight.data)
        torch.distributed.all_gather(bias_htoh4_array, moe.experts.htoh4.bias.data)
Sengxian's avatar
Sengxian committed
139
        moe_raw.weight_htoh4.data = torch.cat(weight_htoh4_array, dim=0)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
140
        moe_raw.bias_htoh4.data = torch.cat(bias_htoh4_array, dim=0)
Sengxian's avatar
Sengxian committed
141
142

        weight_h4toh_array = [
143
            torch.empty_like(moe.experts.h4toh.weight.data) for _ in range(world_size)
Sengxian's avatar
Sengxian committed
144
        ]
Jiezhong Qiu's avatar
Jiezhong Qiu committed
145
        bias_h4toh_array = [
146
            torch.empty_like(moe.experts.h4toh.bias.data) for _ in range(world_size)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
147
        ]
148
149
        torch.distributed.all_gather(weight_h4toh_array, moe.experts.h4toh.weight.data)
        torch.distributed.all_gather(bias_h4toh_array, moe.experts.h4toh.bias.data)
Sengxian's avatar
Sengxian committed
150
        moe_raw.weight_h4toh.data = torch.cat(weight_h4toh_array, dim=0)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
151
        moe_raw.bias_h4toh.data = torch.cat(bias_h4toh_array, dim=0)
Sengxian's avatar
Sengxian committed
152

153
    moe_out, raw_out, moe_grad_in, raw_grad_in = _perform_forward(
154
        moe, moe_raw, batch_size, d_model, top_k, rank, mp_group, data_type=data_type
155
    )
Sengxian's avatar
Sengxian committed
156

Sengxian's avatar
Sengxian committed
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
    moe_out_list = (
        moe_out,
        moe_grad_in,
        moe.experts.htoh4.weight.grad,
        moe.experts.h4toh.weight.grad,
        moe.experts.htoh4.bias.grad,
        moe.experts.h4toh.bias.grad,
    )
    raw_out_list = (
        raw_out,
        raw_grad_in,
        moe_raw.weight_htoh4.grad,
        moe_raw.weight_h4toh.grad,
        moe_raw.bias_htoh4.grad,
        moe_raw.bias_h4toh.grad,
    )
Sengxian's avatar
Sengxian committed
173

Rick Ho's avatar
Rick Ho committed
174
    if world_size > 1:
Sengxian's avatar
Sengxian committed
175
        _, __, htoh4_w_grad, h4toh_w_grad, htoh4_b_grad, h4toh_b_grad = raw_out_list
Jiezhong Qiu's avatar
Jiezhong Qiu committed
176
177
178
179
        torch.distributed.all_reduce(htoh4_w_grad)
        torch.distributed.all_reduce(h4toh_w_grad)
        torch.distributed.all_reduce(htoh4_b_grad)
        torch.distributed.all_reduce(h4toh_b_grad)
180
        mp_size = mp_group.size() if mp_group else 1
181
182
183
184
185
186
187
188
189
190
191
192
        htoh4_w_grad = (
            htoh4_w_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
        )
        h4toh_w_grad = (
            h4toh_w_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
        )
        htoh4_b_grad = (
            htoh4_b_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
        )
        h4toh_b_grad = (
            h4toh_b_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
        )
Sengxian's avatar
Sengxian committed
193
        raw_out_list = _, __, htoh4_w_grad, h4toh_w_grad, htoh4_b_grad, h4toh_b_grad
194

Sengxian's avatar
Sengxian committed
195
196
197
198
199
200
201
202
    names = [
        "output",
        "input grad",
        "htoh4 weight grad",
        "h4toh weight grad",
        "htoh4 bias grad",
        "h4toh bias grad",
    ]
Sengxian's avatar
Sengxian committed
203

204
205
206
207

    precision = 5e-1 if data_type == 'torch.HalfTensor' else 1e-3
        
    _assert_numerical(names, moe_out_list, raw_out_list, rank, precision=precision)
208

Sengxian's avatar
Sengxian committed
209

210
211
212
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("num_expert", [4, 8])
@pytest.mark.parametrize("d_model", [16])
Sengxian's avatar
Sengxian committed
213
@pytest.mark.parametrize("top_k", [2, 3])
214
@pytest.mark.parametrize("expert", [NaiveExpert, LinearExpert])
215
216
@pytest.mark.parametrize("rank", [0])
@pytest.mark.parametrize("world_size", [1])
217
@pytest.mark.parametrize("mp_group", [None])
218
219
@pytest.mark.parametrize("dp_group", [None])
@pytest.mark.parametrize("world_group", [None])
220
def test_fmoe(
221
222
223
224
225
226
227
    batch_size,
    num_expert,
    d_model,
    top_k,
    expert: Union[Type[nn.Module], str],
    rank,
    world_size,
228
229
230
    mp_group,
    dp_group,
    world_group,
231
232
233
234
235
236
):
    torch.manual_seed(42 + rank)
    torch.cuda.manual_seed(42 + rank)

    if isinstance(expert, str):
        expert = globals()[expert]
Sengxian's avatar
Sengxian committed
237

238
239
240
241
242
    moe = FMoE(
        num_expert=num_expert,
        d_model=d_model,
        gate=NaiveGate,
        world_size=world_size,
243
        mp_group=mp_group,
244
245
246
247
248
        expert=expert,
        top_k=top_k,
    ).cuda()

    moe_raw = BruteForceMoE(
Sengxian's avatar
Sengxian committed
249
250
251
252
253
        expert=expert,
        num_expert=num_expert,
        d_model=d_model,
        world_size=world_size,
        top_k=top_k,
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
    ).cuda()

    if world_size == 1:
        for expert_moe, expert_raw in zip(moe.experts, moe_raw.experts):
            for para_moe, para_raw in zip(
                expert_moe.parameters(), expert_raw.parameters()
            ):
                para_raw.data = para_moe.data.clone()
    else:
        assert len(moe.experts) >= 1
        for idx, para in enumerate(moe.experts[0].parameters()):
            para_tensor = torch.cat(
                [list(expert.parameters())[idx].unsqueeze(0) for expert in moe.experts]
            )
            para_array = [torch.empty_like(para_tensor) for _ in range(world_size)]
            torch.distributed.all_gather(para_array, para_tensor)
Sengxian's avatar
Sengxian committed
270
271
272
273
274
275
            para_tensor_gathered = torch.cat(para_array, dim=0)
            assert para_tensor_gathered.shape[0] == len(moe_raw.experts)
            for expertID in range(para_tensor_gathered.shape[0]):
                list(moe_raw.experts[expertID].parameters())[
                    idx
                ].data = para_tensor_gathered[expertID]
276

277
    moe_out, raw_out, moe_grad_in, raw_grad_in = _perform_forward(
278
279
        moe, moe_raw, batch_size, d_model, top_k, rank, mp_group
    )
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300

    def get_experts_grad(experts: List[nn.Module]):
        return torch.stack(
            [
                torch.stack(
                    [
                        p.grad.sum() if p.grad is not None else torch.zeros(1).cuda()
                        for p in item.parameters()
                    ]
                ).sum()
                for item in experts
            ]
        )

    moe_grad, raw_grad = (
        get_experts_grad(moe.experts),
        get_experts_grad(moe_raw.experts),
    )

    if world_size > 1:
        torch.distributed.all_reduce(raw_grad)
301
302
        mp_size = mp_group.size() if mp_group else 1
        raw_grad = raw_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
303

304
305
306
    moe_out_list = [moe_out, moe_grad, moe_grad_in]
    raw_out_list = [raw_out, raw_grad, raw_grad_in]
    names = ["forward", "backward", "grad_in"]
307

308
    _assert_numerical(names, moe_out_list, raw_out_list, rank)
Sengxian's avatar
Sengxian committed
309
310


311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
class MyModule(nn.Module):
    def __init__(self, dim=8):
        super(MyModule, self).__init__()
        self.model = nn.Sequential(
            OrderedDict(
                [
                    ("linear1", nn.Linear(dim, dim)),
                    ("relu1", nn.ReLU()),
                    ("linear2", nn.Linear(dim, dim)),
                    ("relu2", nn.ReLU()),
                    ("linear3", nn.Linear(dim, dim)),
                ]
            )
        )

    def set_comm(self):
        for p in self.model._modules["linear1"].parameters():
            setattr(p, "dp_comm", "mp")
        for p in self.model._modules["linear2"].parameters():
            setattr(p, "dp_comm", "dp")
        for p in self.model._modules["linear3"].parameters():
            setattr(p, "dp_comm", "world")

    def forward(self, inp):
        return self.model(inp)


def _test_fmoe_local_ddp(rank, world_size, mp_group, dp_group, world_group):
    batch_size, dim = 4, 8

    torch.manual_seed(42 + rank)
    torch.cuda.manual_seed(42 + rank)

    model = MyModule().cuda()
    model_ddp = LocalDDP(deepcopy(model), mp_group, dp_group, world_group)
    model.set_comm()
    model_ddp.module.set_comm()

    inp = torch.randn(batch_size, dim).cuda()

    raw_out = model(inp).mean()
    ddp_out = model_ddp(inp).mean()

    raw_out.backward()
    ddp_out.backward()

    torch.distributed.all_reduce(
        model.model._modules["linear1"].weight.grad.data, group=mp_group
    )
    model.model._modules["linear1"].weight.grad /= mp_group.size()
    torch.distributed.all_reduce(
        model.model._modules["linear2"].weight.grad.data, group=dp_group
    )
    model.model._modules["linear2"].weight.grad /= dp_group.size()
    torch.distributed.all_reduce(
        model.model._modules["linear3"].weight.grad.data, group=world_group
    )
    model.model._modules["linear3"].weight.grad /= world_group.size()
    model_ddp.allreduce_params(reduce_after=False, fp32_allreduce=False)

    raw_out_list = [
        model.model._modules["linear1"].weight.grad,
        model.model._modules["linear2"].weight.grad,
        model.model._modules["linear3"].weight.grad,
    ]
    ddp_out_list = [
        model_ddp.module.model._modules["linear1"].weight.grad,
        model_ddp.module.model._modules["linear2"].weight.grad,
        model_ddp.module.model._modules["linear3"].weight.grad,
    ]

    names = ["mp grad", "dp grad", "wp grad"]

384
    _assert_numerical(names, ddp_out_list, raw_out_list, rank)
385
386


387
388
if __name__ == "__main__":
    test_fmoe_linear(
389
390
391
        batch_size=2,
        num_expert=2,
        d_model=2,
392
393
394
395
        top_k=2,
        d_hidden=16,
        rank=0,
        world_size=1,
396
        mp_group=None,
397
398
        dp_group=None,
        world_group=None,
399
    )