test_numerical.py 8.42 KB
Newer Older
1
import sys
2
from typing import List, Type, Union
Sengxian's avatar
Sengxian committed
3

4
import pytest
Rick Ho's avatar
Rick Ho committed
5
import torch
6
7
8
9
10
11
import torch.nn as nn

from fmoe.gates import NaiveGate
from fmoe.layers import FMoE
from fmoe.transformer import _Expert
from moe import BruteForceMoELinear, BruteForceMoE, NaiveExpert, LinearExpert
Rick Ho's avatar
Rick Ho committed
12
13


14
15
16
def _perform_forward(
    moe: nn.Module, moe_raw: nn.Module, batch_size, d_model, top_k, rank, mp_group
):
17
18
    moe.zero_grad()
    moe_raw.zero_grad()
19
20
21
22
23
24
25
26
27
28
29
30
31
    if not mp_group:
        inp = torch.rand(batch_size, d_model).cuda()
    else:
        group_sender = rank // mp_group.size() * mp_group.size()
        inp = torch.rand(batch_size, d_model).cuda()
        torch.distributed.broadcast(inp, group_sender, group=mp_group)
        torch.distributed.broadcast(
            moe.gate.gate.weight.data, group_sender, group=mp_group
        )
        torch.distributed.broadcast(
            moe.gate.gate.bias.data, group_sender, group=mp_group
        )

32
33
34
35
36
37
38
39
40
41
42
    gate_idx, gate_score = moe.gate(inp)
    inp_repeated = inp.repeat_interleave(repeats=top_k, dim=0)
    moe_out = moe(inp).mean()
    raw_out = moe_raw(inp_repeated, gate_idx, gate_score).mean()

    moe_out.backward()
    raw_out.backward()

    return moe_out, raw_out


43
def _assert_numercial(names, moe_out_list, raw_out_list, rank):
44
45
46
47
48
49
50
51
52
53
54
55
    for name, mo, ro in zip(names, moe_out_list, raw_out_list):
        err = (mo - ro).abs().sum()
        print("Rank {} {} abs err {}".format(rank, name, err))
        if err > 1e-3:
            sys.stderr.write("=========== moe out ==============\n")
            sys.stderr.write("{}\n".format(mo))
            sys.stderr.write("=========== raw out ==============\n")
            sys.stderr.write("{}\n".format(ro))
            assert False


@pytest.mark.parametrize("num_expert", [4, 8])
Sengxian's avatar
Sengxian committed
56
@pytest.mark.parametrize("top_k", [2, 3])
57
58
59
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("d_model", [16])
@pytest.mark.parametrize("d_hidden", [32])
60
61
@pytest.mark.parametrize("rank", [0])
@pytest.mark.parametrize("world_size", [1])
62
@pytest.mark.parametrize("mp_group", [None])
63
64
65
66
67
68
def test_fmoe_linear(
    num_expert,
    top_k,
    batch_size,
    d_model,
    d_hidden,
69
70
    rank,
    world_size,
71
    mp_group,
72
73
    activation=torch.nn.functional.gelu,
):
Rick Ho's avatar
Rick Ho committed
74
75
    torch.manual_seed(42 + rank)
    torch.cuda.manual_seed(42 + rank)
Sengxian's avatar
Sengxian committed
76
77
78
79
80
81
82
83
84
85
86

    experts = _Expert(num_expert, d_model, d_hidden, activation).cuda()

    def expert_fn(inp, gate):
        return experts(inp, gate)

    moe = FMoE(
        num_expert=num_expert,
        d_model=d_model,
        gate=NaiveGate,
        world_size=world_size,
87
        mp_group=mp_group,
Sengxian's avatar
Sengxian committed
88
89
90
        expert_fn=expert_fn,
        top_k=top_k,
    ).cuda()
Rick Ho's avatar
Rick Ho committed
91

Sengxian's avatar
Sengxian committed
92
93
94
95
    moe_raw = BruteForceMoELinear(
        activation=activation,
        num_expert=num_expert,
        d_model=d_model,
96
        d_hidden=d_hidden,
Sengxian's avatar
Sengxian committed
97
        world_size=world_size,
Sengxian's avatar
Sengxian committed
98
        top_k=top_k,
Sengxian's avatar
Sengxian committed
99
    ).cuda()
Rick Ho's avatar
Rick Ho committed
100
101

    if world_size == 1:
Sengxian's avatar
Sengxian committed
102
        moe_raw.weight_htoh4.data = experts.htoh4.weight.data.clone()
Jiezhong Qiu's avatar
Jiezhong Qiu committed
103
        moe_raw.bias_htoh4.data = experts.htoh4.bias.data.clone()
Sengxian's avatar
Sengxian committed
104
        moe_raw.weight_h4toh.data = experts.h4toh.weight.data.clone()
Jiezhong Qiu's avatar
Jiezhong Qiu committed
105
        moe_raw.bias_h4toh.data = experts.h4toh.bias.data.clone()
Rick Ho's avatar
Rick Ho committed
106
    else:
Sengxian's avatar
Sengxian committed
107
108
109
        weight_htoh4_array = [
            torch.empty_like(experts.htoh4.weight.data) for _ in range(world_size)
        ]
Jiezhong Qiu's avatar
Jiezhong Qiu committed
110
111
112
        bias_htoh4_array = [
            torch.empty_like(experts.htoh4.bias.data) for _ in range(world_size)
        ]
Sengxian's avatar
Sengxian committed
113
        torch.distributed.all_gather(weight_htoh4_array, experts.htoh4.weight.data)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
114
        torch.distributed.all_gather(bias_htoh4_array, experts.htoh4.bias.data)
Sengxian's avatar
Sengxian committed
115
        moe_raw.weight_htoh4.data = torch.cat(weight_htoh4_array, dim=0)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
116
        moe_raw.bias_htoh4.data = torch.cat(bias_htoh4_array, dim=0)
Sengxian's avatar
Sengxian committed
117
118
119
120

        weight_h4toh_array = [
            torch.empty_like(experts.h4toh.weight.data) for _ in range(world_size)
        ]
Jiezhong Qiu's avatar
Jiezhong Qiu committed
121
122
123
        bias_h4toh_array = [
            torch.empty_like(experts.h4toh.bias.data) for _ in range(world_size)
        ]
Sengxian's avatar
Sengxian committed
124
        torch.distributed.all_gather(weight_h4toh_array, experts.h4toh.weight.data)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
125
        torch.distributed.all_gather(bias_h4toh_array, experts.h4toh.bias.data)
Sengxian's avatar
Sengxian committed
126
        moe_raw.weight_h4toh.data = torch.cat(weight_h4toh_array, dim=0)
Jiezhong Qiu's avatar
Jiezhong Qiu committed
127
        moe_raw.bias_h4toh.data = torch.cat(bias_h4toh_array, dim=0)
Sengxian's avatar
Sengxian committed
128

129
130
131
    moe_out, raw_out = _perform_forward(
        moe, moe_raw, batch_size, d_model, top_k, rank, mp_group
    )
Sengxian's avatar
Sengxian committed
132

133
134
    moe_out_list = moe_out, experts.htoh4.weight.grad, experts.h4toh.weight.grad
    raw_out_list = raw_out, moe_raw.weight_htoh4.grad, moe_raw.weight_h4toh.grad
Sengxian's avatar
Sengxian committed
135

Rick Ho's avatar
Rick Ho committed
136
    if world_size > 1:
137
        _, htoh4_grad, h4toh_grad = raw_out_list
Sengxian's avatar
Sengxian committed
138
139
        torch.distributed.all_reduce(htoh4_grad)
        torch.distributed.all_reduce(h4toh_grad)
140
141
142
        mp_size = mp_group.size() if mp_group else 1
        htoh4_grad = htoh4_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
        h4toh_grad = h4toh_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
143
144
145
        raw_out_list = _, htoh4_grad, h4toh_grad

    names = ["output", "htoh4 weight grad", "h4toh weight grad"]
146
    _assert_numercial(names, moe_out_list, raw_out_list, rank)
147

Sengxian's avatar
Sengxian committed
148

149
150
151
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("num_expert", [4, 8])
@pytest.mark.parametrize("d_model", [16])
Sengxian's avatar
Sengxian committed
152
@pytest.mark.parametrize("top_k", [2, 3])
153
@pytest.mark.parametrize("expert", [NaiveExpert, LinearExpert])
154
155
@pytest.mark.parametrize("rank", [0])
@pytest.mark.parametrize("world_size", [1])
156
@pytest.mark.parametrize("mp_group", [None])
157
def test_fmoe(
158
159
160
161
162
163
    batch_size,
    num_expert,
    d_model,
    top_k,
    expert: Union[Type[nn.Module], str],
    rank,
164
    mp_group,
165
    world_size,
166
167
168
169
170
171
):
    torch.manual_seed(42 + rank)
    torch.cuda.manual_seed(42 + rank)

    if isinstance(expert, str):
        expert = globals()[expert]
Sengxian's avatar
Sengxian committed
172

173
174
175
176
177
    moe = FMoE(
        num_expert=num_expert,
        d_model=d_model,
        gate=NaiveGate,
        world_size=world_size,
178
        mp_group=mp_group,
179
180
181
182
183
        expert=expert,
        top_k=top_k,
    ).cuda()

    moe_raw = BruteForceMoE(
Sengxian's avatar
Sengxian committed
184
185
186
187
188
        expert=expert,
        num_expert=num_expert,
        d_model=d_model,
        world_size=world_size,
        top_k=top_k,
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
    ).cuda()

    if world_size == 1:
        for expert_moe, expert_raw in zip(moe.experts, moe_raw.experts):
            for para_moe, para_raw in zip(
                expert_moe.parameters(), expert_raw.parameters()
            ):
                para_raw.data = para_moe.data.clone()
    else:
        assert len(moe.experts) >= 1
        for idx, para in enumerate(moe.experts[0].parameters()):
            para_tensor = torch.cat(
                [list(expert.parameters())[idx].unsqueeze(0) for expert in moe.experts]
            )
            para_array = [torch.empty_like(para_tensor) for _ in range(world_size)]
            torch.distributed.all_gather(para_array, para_tensor)
Sengxian's avatar
Sengxian committed
205
206
207
208
209
210
            para_tensor_gathered = torch.cat(para_array, dim=0)
            assert para_tensor_gathered.shape[0] == len(moe_raw.experts)
            for expertID in range(para_tensor_gathered.shape[0]):
                list(moe_raw.experts[expertID].parameters())[
                    idx
                ].data = para_tensor_gathered[expertID]
211

212
213
214
    moe_out, raw_out = _perform_forward(
        moe, moe_raw, batch_size, d_model, top_k, rank, mp_group
    )
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235

    def get_experts_grad(experts: List[nn.Module]):
        return torch.stack(
            [
                torch.stack(
                    [
                        p.grad.sum() if p.grad is not None else torch.zeros(1).cuda()
                        for p in item.parameters()
                    ]
                ).sum()
                for item in experts
            ]
        )

    moe_grad, raw_grad = (
        get_experts_grad(moe.experts),
        get_experts_grad(moe_raw.experts),
    )

    if world_size > 1:
        torch.distributed.all_reduce(raw_grad)
236
237
        mp_size = mp_group.size() if mp_group else 1
        raw_grad = raw_grad[rank * num_expert : (rank + 1) * num_expert] / mp_size
238
239
240
241
242

    moe_out_list = [moe_out, moe_grad]
    raw_out_list = [raw_out, raw_grad]
    names = ["forward", "backward"]

243
    _assert_numercial(names, moe_out_list, raw_out_list, rank)
Sengxian's avatar
Sengxian committed
244
245


246
247
248
249
250
251
252
253
254
if __name__ == "__main__":
    test_fmoe_linear(
        batch_size=4,
        num_expert=4,
        d_model=8,
        top_k=2,
        d_hidden=16,
        rank=0,
        world_size=1,
255
        mp_group=None,
256
    )
257
258
259
260
261
262
263
264
    test_fmoe(
        batch_size=4,
        num_expert=4,
        d_model=8,
        top_k=2,
        expert=NaiveExpert,
        rank=0,
        world_size=1,
265
        mp_group=None,
266
    )