test_numerical.py 6.99 KB
Newer Older
1
2
3
4
import json
import os
import sys
from typing import List, Callable, Dict, Type, Union
Sengxian's avatar
Sengxian committed
5

6
import pytest
Rick Ho's avatar
Rick Ho committed
7
import torch
8
9
10
11
12
13
import torch.nn as nn

from fmoe.gates import NaiveGate
from fmoe.layers import FMoE
from fmoe.transformer import _Expert
from moe import BruteForceMoELinear, BruteForceMoE, NaiveExpert, LinearExpert
Rick Ho's avatar
Rick Ho committed
14
15


16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
def _perform_forward(moe: nn.Module, moe_raw: nn.Module, batch_size, d_model, top_k):
    moe.zero_grad()
    moe_raw.zero_grad()
    inp = torch.rand(batch_size, d_model).cuda()
    gate_idx, gate_score = moe.gate(inp)
    inp_repeated = inp.repeat_interleave(repeats=top_k, dim=0)
    moe_out = moe(inp).mean()
    raw_out = moe_raw(inp_repeated, gate_idx, gate_score).mean()

    moe_out.backward()
    raw_out.backward()

    return moe_out, raw_out


31
def _assert_numercial(names, moe_out_list, raw_out_list, rank):
32
33
34
35
36
37
38
39
40
41
42
43
    for name, mo, ro in zip(names, moe_out_list, raw_out_list):
        err = (mo - ro).abs().sum()
        print("Rank {} {} abs err {}".format(rank, name, err))
        if err > 1e-3:
            sys.stderr.write("=========== moe out ==============\n")
            sys.stderr.write("{}\n".format(mo))
            sys.stderr.write("=========== raw out ==============\n")
            sys.stderr.write("{}\n".format(ro))
            assert False


@pytest.mark.parametrize("num_expert", [4, 8])
Sengxian's avatar
Sengxian committed
44
@pytest.mark.parametrize("top_k", [2, 3])
45
46
47
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("d_model", [16])
@pytest.mark.parametrize("d_hidden", [32])
48
49
@pytest.mark.parametrize("rank", [0])
@pytest.mark.parametrize("world_size", [1])
50
51
52
53
54
55
def test_fmoe_linear(
    num_expert,
    top_k,
    batch_size,
    d_model,
    d_hidden,
56
57
    rank,
    world_size,
58
59
    activation=torch.nn.functional.gelu,
):
Rick Ho's avatar
Rick Ho committed
60
61
    torch.manual_seed(42 + rank)
    torch.cuda.manual_seed(42 + rank)
Sengxian's avatar
Sengxian committed
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76

    experts = _Expert(num_expert, d_model, d_hidden, activation).cuda()

    def expert_fn(inp, gate):
        return experts(inp, gate)

    moe = FMoE(
        num_expert=num_expert,
        d_model=d_model,
        gate=NaiveGate,
        world_size=world_size,
        mp_group=None,
        expert_fn=expert_fn,
        top_k=top_k,
    ).cuda()
Rick Ho's avatar
Rick Ho committed
77

Sengxian's avatar
Sengxian committed
78
79
80
81
    moe_raw = BruteForceMoELinear(
        activation=activation,
        num_expert=num_expert,
        d_model=d_model,
82
        d_hidden=d_hidden,
Sengxian's avatar
Sengxian committed
83
        world_size=world_size,
Sengxian's avatar
Sengxian committed
84
        top_k=top_k,
Sengxian's avatar
Sengxian committed
85
    ).cuda()
Rick Ho's avatar
Rick Ho committed
86
87

    if world_size == 1:
Sengxian's avatar
Sengxian committed
88
89
        moe_raw.weight_htoh4.data = experts.htoh4.weight.data.clone()
        moe_raw.weight_h4toh.data = experts.h4toh.weight.data.clone()
Rick Ho's avatar
Rick Ho committed
90
    else:
Sengxian's avatar
Sengxian committed
91
92
93
94
95
96
97
98
99
100
101
102
        weight_htoh4_array = [
            torch.empty_like(experts.htoh4.weight.data) for _ in range(world_size)
        ]
        torch.distributed.all_gather(weight_htoh4_array, experts.htoh4.weight.data)
        moe_raw.weight_htoh4.data = torch.cat(weight_htoh4_array, dim=0)

        weight_h4toh_array = [
            torch.empty_like(experts.h4toh.weight.data) for _ in range(world_size)
        ]
        torch.distributed.all_gather(weight_h4toh_array, experts.h4toh.weight.data)
        moe_raw.weight_h4toh.data = torch.cat(weight_h4toh_array, dim=0)

103
    moe_out, raw_out = _perform_forward(moe, moe_raw, batch_size, d_model, top_k)
Sengxian's avatar
Sengxian committed
104

105
106
    moe_out_list = moe_out, experts.htoh4.weight.grad, experts.h4toh.weight.grad
    raw_out_list = raw_out, moe_raw.weight_htoh4.grad, moe_raw.weight_h4toh.grad
Sengxian's avatar
Sengxian committed
107

Rick Ho's avatar
Rick Ho committed
108
    if world_size > 1:
109
        _, htoh4_grad, h4toh_grad = raw_out_list
Sengxian's avatar
Sengxian committed
110
111
112
113
        torch.distributed.all_reduce(htoh4_grad)
        torch.distributed.all_reduce(h4toh_grad)
        htoh4_grad = htoh4_grad[rank * num_expert : (rank + 1) * num_expert]
        h4toh_grad = h4toh_grad[rank * num_expert : (rank + 1) * num_expert]
114
115
116
        raw_out_list = _, htoh4_grad, h4toh_grad

    names = ["output", "htoh4 weight grad", "h4toh weight grad"]
117
    _assert_numercial(names, moe_out_list, raw_out_list, rank)
118

Sengxian's avatar
Sengxian committed
119

120
121
122
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("num_expert", [4, 8])
@pytest.mark.parametrize("d_model", [16])
Sengxian's avatar
Sengxian committed
123
@pytest.mark.parametrize("top_k", [2, 3])
124
@pytest.mark.parametrize("expert", ["NaiveExpert", "LinearExpert"])
125
126
@pytest.mark.parametrize("rank", [0])
@pytest.mark.parametrize("world_size", [1])
127
def test_fmoe(
128
129
130
131
132
133
134
    batch_size,
    num_expert,
    d_model,
    top_k,
    expert: Union[Type[nn.Module], str],
    rank,
    world_size,
135
136
137
138
139
140
):
    torch.manual_seed(42 + rank)
    torch.cuda.manual_seed(42 + rank)

    if isinstance(expert, str):
        expert = globals()[expert]
Sengxian's avatar
Sengxian committed
141

142
143
144
145
146
147
148
149
150
151
152
    moe = FMoE(
        num_expert=num_expert,
        d_model=d_model,
        gate=NaiveGate,
        world_size=world_size,
        mp_group=None,
        expert=expert,
        top_k=top_k,
    ).cuda()

    moe_raw = BruteForceMoE(
Sengxian's avatar
Sengxian committed
153
154
155
156
157
        expert=expert,
        num_expert=num_expert,
        d_model=d_model,
        world_size=world_size,
        top_k=top_k,
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
    ).cuda()

    if world_size == 1:
        for expert_moe, expert_raw in zip(moe.experts, moe_raw.experts):
            for para_moe, para_raw in zip(
                expert_moe.parameters(), expert_raw.parameters()
            ):
                para_raw.data = para_moe.data.clone()
    else:
        assert len(moe.experts) >= 1
        for idx, para in enumerate(moe.experts[0].parameters()):
            para_tensor = torch.cat(
                [list(expert.parameters())[idx].unsqueeze(0) for expert in moe.experts]
            )
            para_array = [torch.empty_like(para_tensor) for _ in range(world_size)]
            torch.distributed.all_gather(para_array, para_tensor)
Sengxian's avatar
Sengxian committed
174
175
176
177
178
179
            para_tensor_gathered = torch.cat(para_array, dim=0)
            assert para_tensor_gathered.shape[0] == len(moe_raw.experts)
            for expertID in range(para_tensor_gathered.shape[0]):
                list(moe_raw.experts[expertID].parameters())[
                    idx
                ].data = para_tensor_gathered[expertID]
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208

    moe_out, raw_out = _perform_forward(moe, moe_raw, batch_size, d_model, top_k)

    def get_experts_grad(experts: List[nn.Module]):
        return torch.stack(
            [
                torch.stack(
                    [
                        p.grad.sum() if p.grad is not None else torch.zeros(1).cuda()
                        for p in item.parameters()
                    ]
                ).sum()
                for item in experts
            ]
        )

    moe_grad, raw_grad = (
        get_experts_grad(moe.experts),
        get_experts_grad(moe_raw.experts),
    )

    if world_size > 1:
        torch.distributed.all_reduce(raw_grad)
        raw_grad = raw_grad[rank * num_expert : (rank + 1) * num_expert]

    moe_out_list = [moe_out, moe_grad]
    raw_out_list = [raw_out, raw_grad]
    names = ["forward", "backward"]

209
    _assert_numercial(names, moe_out_list, raw_out_list, rank)
Sengxian's avatar
Sengxian committed
210
211


212
213
214
215
216
217
218
219
220
if __name__ == "__main__":
    test_fmoe_linear(
        batch_size=4,
        num_expert=4,
        d_model=8,
        top_k=2,
        d_hidden=16,
        rank=0,
        world_size=1,
221
    )
222
223
224
225
226
227
228
229
    test_fmoe(
        batch_size=4,
        num_expert=4,
        d_model=8,
        top_k=2,
        expert=NaiveExpert,
        rank=0,
        world_size=1,
230
    )