".github/scripts/unittest-linux/install.sh" did not exist on "b8402baa258e1d755e1ff1256db96f63def2a878"
test_gates.py 3.75 KB
Newer Older
Rich Ho's avatar
Rich Ho committed
1
import pytest
2

Rick Ho's avatar
Rick Ho committed
3
import os
4
5
import sys
import json
Rich Ho's avatar
Rich Ho committed
6
import math
7

Rick Ho's avatar
Rick Ho committed
8
9
import torch
import torch.distributed as dist
GODVIX's avatar
GODVIX committed
10
import torch.nn.functional as F
Rich Ho's avatar
Rich Ho committed
11
from fmoe.gates import GShardGate, SwitchGate
Rick Ho's avatar
Rick Ho committed
12
from fmoe.functions import ensure_comm
Rick Ho's avatar
Rick Ho committed
13
from test_ddp import _ensure_initialized, _run_distributed
Rich Ho's avatar
Rich Ho committed
14
15


16
17
18
19
@pytest.mark.parametrize("d_model", [1024])
@pytest.mark.parametrize("batch_size", [16])
@pytest.mark.parametrize("n_expert", [1, 4])
@pytest.mark.parametrize("cap", [.1, 1.1])
Rich Ho's avatar
Rich Ho committed
20
def test_gshard_gate(d_model, batch_size, n_expert, cap):
21
    if 1 * n_expert < 2:
Rich Ho's avatar
Rich Ho committed
22
        pytest.skip("No enough experts")
23
24
25
26
27
28
29
30
31
32
33
34
35
36
    _run_distributed('_test_gshard_gate',
            1,
            {
                'd_model': d_model,
                'batch_size': batch_size,
                'n_expert': n_expert,
                'cap': cap
            },
            script=__file__
    )


def _test_gshard_gate(d_model, batch_size, n_expert, cap):
    _ensure_initialized()
Rick Ho's avatar
Rick Ho committed
37
    rank = torch.distributed.get_rank()
Rich Ho's avatar
Rich Ho committed
38
39
    gate = GShardGate(d_model, n_expert, dist.get_world_size(),
            capacity=(cap, cap)).cuda()
Rick Ho's avatar
Rick Ho committed
40
    x = torch.rand(batch_size, d_model).cuda()
Rick Ho's avatar
Rick Ho committed
41
    ensure_comm(x, None)
Rick Ho's avatar
Rick Ho committed
42
    topk_idx, topk_val = gate(x)
Rick Ho's avatar
Rick Ho committed
43
    counts = [0 for _ in range(n_expert * dist.get_world_size())]
Rich Ho's avatar
Rich Ho committed
44
45
46
47
48
49
    for v in topk_idx.cpu().view(-1).numpy():
        if v != -1:
            counts[v] += 1
    real_cap = math.ceil(cap * batch_size)
    for i in counts:
        assert(i <= real_cap)
Rick Ho's avatar
Rick Ho committed
50

GODVIX's avatar
GODVIX committed
51
    gate_score = gate.gate(x)
Rick Ho's avatar
Rick Ho committed
52
53
54
55
56
57
    gate_top_k_val, gate_top_k_idx = torch.topk(
        gate_score, k=gate.top_k, dim=-1, largest=True, sorted=False
    )
    gate_top_k_val = gate_top_k_val.view(-1, gate.top_k)
    gate_score = F.softmax(gate_top_k_val, dim=-1)

GODVIX's avatar
GODVIX committed
58
59
60
61
    for i in range(batch_size):
        for j in range(gate.top_k):
            v = topk_idx[i, j]
            if v != -1:
Rick Ho's avatar
Rick Ho committed
62
                assert topk_val[i, j] == gate_score[i, j]
GODVIX's avatar
GODVIX committed
63

Rick Ho's avatar
Rick Ho committed
64

65
66
67
68
@pytest.mark.parametrize("d_model", [1024])
@pytest.mark.parametrize("batch_size", [4096])
@pytest.mark.parametrize("n_expert", [1, 16])
@pytest.mark.parametrize("cap", [.1, .8])
Rich Ho's avatar
Rich Ho committed
69
def test_switch_gate(d_model, batch_size, n_expert, cap):
70
71
72
73
74
75
76
77
78
79
80
81
82
    _run_distributed('_test_switch_gate',
            1,
            {
                'd_model': d_model,
                'batch_size': batch_size,
                'n_expert': n_expert,
                'cap': cap
            },
            script=__file__
    )


def _test_switch_gate(d_model, batch_size, n_expert, cap):
Rich Ho's avatar
Rich Ho committed
83
84
85
86
    _ensure_initialized()
    gate = SwitchGate(d_model, n_expert, dist.get_world_size(),
            capacity=(cap, cap)).cuda()
    x = torch.rand(batch_size, d_model).cuda()
GODVIX's avatar
GODVIX committed
87
    rng = torch.cuda.get_rng_state() # save rng state
Rich Ho's avatar
Rich Ho committed
88
    topk_idx, topk_val = gate(x)
Rick Ho's avatar
Rick Ho committed
89
    counts = [0 for _ in range(n_expert * dist.get_world_size())]
Rich Ho's avatar
Rich Ho committed
90
91
92
93
94
95
96
    for v in topk_idx.cpu().view(-1).numpy():
        if v != -1:
            counts[v] += 1
    real_cap = math.ceil(cap * batch_size)
    for i in counts:
        assert(i <= real_cap)

GODVIX's avatar
GODVIX committed
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
    score = gate.gate(x)

    if gate.training:
        # reset rng state to make sure noise is the same as in gate.forward()
        torch.cuda.set_rng_state(rng)
        # random uniform number from [1-eps, 1+eps]
        noise = torch.rand_like(score)
        noise = noise * 2 * gate.switch_eps + 1.0 - gate.switch_eps
        score += noise

    # fp32 softmax for numerical stability
    score = F.softmax(score.float(), dim=-1)

    for i in range(batch_size):
        v = topk_idx[i]
        if v != -1:
            assert topk_val[i] == score[i, topk_idx[i]]

Rich Ho's avatar
Rich Ho committed
115

Rick Ho's avatar
Rick Ho committed
116
if __name__ == '__main__':
117
118
119
120
    if len(sys.argv) >= 3:
        args = json.loads(sys.argv[2])
        locals()[sys.argv[1]](**args)
    else:
Rick Ho's avatar
Rick Ho committed
121
        # _ensure_initialized()
122
        # test_gshard_gate(4096, 1024, 4, .2)
Rick Ho's avatar
Rick Ho committed
123
        _test_gshard_gate(8, 16, 4, .1)
124
        # test_switch_gate(4096, 1024, 4, .2)