test_numerical.py 9.44 KB
Newer Older
1
import sys
2
from typing import List, Type, Union
Sengxian's avatar
Sengxian committed
3

4
import pytest
Rick Ho's avatar
Rick Ho committed
5
import torch
6
7
8
9
10
11
import torch.nn as nn

from fmoe.gates import NaiveGate
from fmoe.layers import FMoE
from fmoe.transformer import _Expert
from moe import BruteForceMoELinear, BruteForceMoE, NaiveExpert, LinearExpert
Rick Ho's avatar
Rick Ho committed
12
13


14
15
16
def _perform_forward(
    moe: nn.Module, moe_raw: nn.Module, batch_size, d_model, top_k, rank, mp_group
):
17
18
    moe.zero_grad()
    moe_raw.zero_grad()
19
20
21
22
23
24
25
26
27
28
29
30
31
    if not mp_group:
        inp = torch.rand(batch_size, d_model).cuda()
    else:
        group_sender = rank // mp_group.size() * mp_group.size()
        inp = torch.rand(batch_size, d_model).cuda()
        torch.distributed.broadcast(inp, group_sender, group=mp_group)
        torch.distributed.broadcast(
            moe.gate.gate.weight.data, group_sender, group=mp_group
        )
        torch.distributed.broadcast(
            moe.gate.gate.bias.data, group_sender, group=mp_group
        )

32
33
    inp_raw = inp.clone()
    inp.requires_grad = True
34

35
36
37
38
39
    inp_raw.requires_grad = True
    gate_idx, gate_score = moe.gate(inp_raw)
    inp_repeated = inp_raw.repeat_interleave(repeats=top_k, dim=0)
    moe_out = moe(inp)
    raw_out = moe_raw(inp_repeated, gate_idx, gate_score)
40

41
42
43
44
    raw_out.mean().backward()
    moe_out.mean().backward()

    return moe_out, raw_out, inp.grad, inp_raw.grad
45
46


47
def _assert_numercial(names, moe_out_list, raw_out_list, rank):
48
49
50
51
52
53
54
55
56
57
58
    for name, mo, ro in zip(names, moe_out_list, raw_out_list):
        err = (mo - ro).abs().sum()
        print("Rank {} {} abs err {}".format(rank, name, err))
        if err > 1e-3:
            sys.stderr.write("=========== moe out ==============\n")
            sys.stderr.write("{}\n".format(mo))
            sys.stderr.write("=========== raw out ==============\n")
            sys.stderr.write("{}\n".format(ro))
            assert False


59
60
61
62
63
64
65
66
67
68
69
70
71
72
class MyMoE(FMoE):
    def __init__(self, num_expert, d_model, d_hidden, world_size, mp_group,
            top_k, activation):
        super().__init__(
            num_expert=num_expert,
            d_model=d_model,
            gate=NaiveGate,
            world_size=world_size,
            mp_group=mp_group,
            top_k=top_k
        )
        self.experts = _Expert(num_expert, d_model, d_hidden, activation)


73
@pytest.mark.parametrize("num_expert", [4, 8])
Sengxian's avatar
Sengxian committed
74
@pytest.mark.parametrize("top_k", [2, 3])
75
76
77
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("d_model", [16])
@pytest.mark.parametrize("d_hidden", [32])
78
79
@pytest.mark.parametrize("rank", [0])
@pytest.mark.parametrize("world_size", [1])
80
@pytest.mark.parametrize("mp_group", [None])
81
82
83
84
85
86
def test_fmoe_linear(
    num_expert,
    top_k,
    batch_size,
    d_model,
    d_hidden,
87
88
    rank,
    world_size,
89
    mp_group,
90
91
    activation=torch.nn.functional.gelu,
):
Rick Ho's avatar
Rick Ho committed
92
93
    torch.manual_seed(42 + rank)
    torch.cuda.manual_seed(42 + rank)
Sengxian's avatar
Sengxian committed
94

95
96
    moe = MyMoE(num_expert, d_model, d_hidden, world_size, mp_group, top_k,
            activation).cuda()
Rick Ho's avatar
Rick Ho committed
97

Sengxian's avatar
Sengxian committed
98
99
100
101
    moe_raw = BruteForceMoELinear(
        activation=activation,
        num_expert=num_expert,
        d_model=d_model,
102
        d_hidden=d_hidden,
Sengxian's avatar
Sengxian committed
103
        world_size=world_size,
Sengxian's avatar
Sengxian committed
104
        top_k=top_k,
Sengxian's avatar
Sengxian committed
105
    ).cuda()
Rick Ho's avatar
Rick Ho committed
106
107

    if world_size == 1:
108
109
110
111
        moe_raw.weight_htoh4.data = moe.experts.htoh4.weight.data.clone()
        moe_raw.bias_htoh4.data = moe.experts.htoh4.bias.data.clone()
        moe_raw.weight_h4toh.data = moe.experts.h4toh.weight.data.clone()
        moe_raw.bias_h4toh.data = moe.experts.h4toh.bias.data.clone()
Rick Ho's avatar
Rick Ho committed
112
    else:
Sengxian's avatar
Sengxian committed
113
        weight_htoh4_array = [
114
            torch.empty_like(moe.experts.htoh4.weight.data) for _ in range(world_size)
Sengxian's avatar
Sengxian committed
115
        ]
Jiezhong Qiu's avatar
Jiezhong Qiu committed
116
        bias_htoh4_array = [
117
            torch.empty_like(moe.experts.htoh4.bias.data) for _ in range(world_size)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
118
        ]
119
120
        torch.distributed.all_gather(weight_htoh4_array, moe.experts.htoh4.weight.data)
        torch.distributed.all_gather(bias_htoh4_array, moe.experts.htoh4.bias.data)
Sengxian's avatar
Sengxian committed
121
        moe_raw.weight_htoh4.data = torch.cat(weight_htoh4_array, dim=0)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
122
        moe_raw.bias_htoh4.data = torch.cat(bias_htoh4_array, dim=0)
Sengxian's avatar
Sengxian committed
123
124

        weight_h4toh_array = [
125
            torch.empty_like(moe.experts.h4toh.weight.data) for _ in range(world_size)
Sengxian's avatar
Sengxian committed
126
        ]
Jiezhong Qiu's avatar
Jiezhong Qiu committed
127
        bias_h4toh_array = [
128
            torch.empty_like(moe.experts.h4toh.bias.data) for _ in range(world_size)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
129
        ]
130
131
        torch.distributed.all_gather(weight_h4toh_array, moe.experts.h4toh.weight.data)
        torch.distributed.all_gather(bias_h4toh_array, moe.experts.h4toh.bias.data)
Sengxian's avatar
Sengxian committed
132
        moe_raw.weight_h4toh.data = torch.cat(weight_h4toh_array, dim=0)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
133
        moe_raw.bias_h4toh.data = torch.cat(bias_h4toh_array, dim=0)
Sengxian's avatar
Sengxian committed
134

135
    moe_out, raw_out, moe_grad_in, raw_grad_in = _perform_forward(
136
137
        moe, moe_raw, batch_size, d_model, top_k, rank, mp_group
    )
Sengxian's avatar
Sengxian committed
138

139
140
    moe_out_list = moe_out, moe_grad_in, moe.experts.htoh4.weight.grad, moe.experts.h4toh.weight.grad, moe.experts.htoh4.bias.grad, moe.experts.h4toh.bias.grad
    raw_out_list = raw_out, raw_grad_in, moe_raw.weight_htoh4.grad, moe_raw.weight_h4toh.grad, moe_raw.bias_htoh4.grad, moe_raw.bias_h4toh.grad
Sengxian's avatar
Sengxian committed
141

Rick Ho's avatar
Rick Ho committed
142
    if world_size > 1:
Jiezhong Qiu's avatar
Jiezhong Qiu committed
143
144
145
146
147
        _, htoh4_w_grad, h4toh_w_grad, htoh4_b_grad, h4toh_b_grad = raw_out_list
        torch.distributed.all_reduce(htoh4_w_grad)
        torch.distributed.all_reduce(h4toh_w_grad)
        torch.distributed.all_reduce(htoh4_b_grad)
        torch.distributed.all_reduce(h4toh_b_grad)
148
        mp_size = mp_group.size() if mp_group else 1
Jiezhong Qiu's avatar
Jiezhong Qiu committed
149
150
151
152
153
        htoh4_w_grad = htoh4_w_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
        h4toh_w_grad = h4toh_w_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
        htoh4_b_grad = htoh4_b_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
        h4toh_b_grad = h4toh_b_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
        raw_out_list = _, htoh4_w_grad, h4toh_w_grad, htoh4_b_grad, h4toh_b_grad
154

155
    names = ["output", "input grad", "htoh4 weight grad", "h4toh weight grad", "htoh4 bias grad", "h4toh bias grad"]
156
    _assert_numercial(names, moe_out_list, raw_out_list, rank)
157

Sengxian's avatar
Sengxian committed
158

159
160
161
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("num_expert", [4, 8])
@pytest.mark.parametrize("d_model", [16])
Sengxian's avatar
Sengxian committed
162
@pytest.mark.parametrize("top_k", [2, 3])
163
@pytest.mark.parametrize("expert", [NaiveExpert, LinearExpert])
164
165
@pytest.mark.parametrize("rank", [0])
@pytest.mark.parametrize("world_size", [1])
166
@pytest.mark.parametrize("mp_group", [None])
167
def test_fmoe(
168
169
170
171
172
173
    batch_size,
    num_expert,
    d_model,
    top_k,
    expert: Union[Type[nn.Module], str],
    rank,
174
    mp_group,
175
    world_size,
176
177
178
179
180
181
):
    torch.manual_seed(42 + rank)
    torch.cuda.manual_seed(42 + rank)

    if isinstance(expert, str):
        expert = globals()[expert]
Sengxian's avatar
Sengxian committed
182

183
184
185
186
187
    moe = FMoE(
        num_expert=num_expert,
        d_model=d_model,
        gate=NaiveGate,
        world_size=world_size,
188
        mp_group=mp_group,
189
190
191
192
193
        expert=expert,
        top_k=top_k,
    ).cuda()

    moe_raw = BruteForceMoE(
Sengxian's avatar
Sengxian committed
194
195
196
197
198
        expert=expert,
        num_expert=num_expert,
        d_model=d_model,
        world_size=world_size,
        top_k=top_k,
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
    ).cuda()

    if world_size == 1:
        for expert_moe, expert_raw in zip(moe.experts, moe_raw.experts):
            for para_moe, para_raw in zip(
                expert_moe.parameters(), expert_raw.parameters()
            ):
                para_raw.data = para_moe.data.clone()
    else:
        assert len(moe.experts) >= 1
        for idx, para in enumerate(moe.experts[0].parameters()):
            para_tensor = torch.cat(
                [list(expert.parameters())[idx].unsqueeze(0) for expert in moe.experts]
            )
            para_array = [torch.empty_like(para_tensor) for _ in range(world_size)]
            torch.distributed.all_gather(para_array, para_tensor)
Sengxian's avatar
Sengxian committed
215
216
217
218
219
220
            para_tensor_gathered = torch.cat(para_array, dim=0)
            assert para_tensor_gathered.shape[0] == len(moe_raw.experts)
            for expertID in range(para_tensor_gathered.shape[0]):
                list(moe_raw.experts[expertID].parameters())[
                    idx
                ].data = para_tensor_gathered[expertID]
221

222
    moe_out, raw_out, moe_grad_in, raw_grad_in = _perform_forward(
223
224
        moe, moe_raw, batch_size, d_model, top_k, rank, mp_group
    )
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245

    def get_experts_grad(experts: List[nn.Module]):
        return torch.stack(
            [
                torch.stack(
                    [
                        p.grad.sum() if p.grad is not None else torch.zeros(1).cuda()
                        for p in item.parameters()
                    ]
                ).sum()
                for item in experts
            ]
        )

    moe_grad, raw_grad = (
        get_experts_grad(moe.experts),
        get_experts_grad(moe_raw.experts),
    )

    if world_size > 1:
        torch.distributed.all_reduce(raw_grad)
246
247
        mp_size = mp_group.size() if mp_group else 1
        raw_grad = raw_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
248

249
250
251
    moe_out_list = [moe_out, moe_grad, moe_grad_in]
    raw_out_list = [raw_out, raw_grad, raw_grad_in]
    names = ["forward", "backward", "grad_in"]
252

253
    _assert_numercial(names, moe_out_list, raw_out_list, rank)
Sengxian's avatar
Sengxian committed
254
255


256
257
if __name__ == "__main__":
    test_fmoe_linear(
258
259
260
261
262
263
264
265
266
           batch_size=4,
           num_expert=4,
           d_model=8,
           top_k=2,
           d_hidden=16,
           rank=0,
           world_size=1,
           mp_group=None,
       )
267
    test_fmoe(
268
269
270
271
272
273
274
275
276
           batch_size=4,
           num_expert=4,
           d_model=8,
           top_k=2,
           expert=NaiveExpert,
           rank=0,
           world_size=1,
           mp_group=None,
       )