test_gates.py 3.02 KB
Newer Older
Rich Ho's avatar
Rich Ho committed
1
import pytest
2

Rick Ho's avatar
Rick Ho committed
3
import os
4
5
import sys
import json
Rich Ho's avatar
Rich Ho committed
6
import math
7

Rick Ho's avatar
Rick Ho committed
8
9
import torch
import torch.distributed as dist
Rich Ho's avatar
Rich Ho committed
10
from fmoe.gates import GShardGate, SwitchGate
11
from test_ddp import _run_distributed
Rick Ho's avatar
Rick Ho committed
12
13


Rich Ho's avatar
Rich Ho committed
14
15
16
17
18
19
20
21
22
23
def _ensure_initialized():
    if not dist.is_initialized():
        os.environ["RANK"] = os.environ.get("OMPI_COMM_WORLD_RANK", "0")
        os.environ["WORLD_SIZE"] = os.environ.get("OMPI_COMM_WORLD_SIZE", "1")
        os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["RANK"]
        os.environ["MASTER_ADDR"] = os.environ.get("MASTER_ADDR", "localhost")
        os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "12211")
        dist.init_process_group(backend="nccl")


24
25
26
27
@pytest.mark.parametrize("d_model", [1024])
@pytest.mark.parametrize("batch_size", [16])
@pytest.mark.parametrize("n_expert", [1, 4])
@pytest.mark.parametrize("cap", [.1, 1.1])
Rich Ho's avatar
Rich Ho committed
28
def test_gshard_gate(d_model, batch_size, n_expert, cap):
29
    if 1 * n_expert < 2:
Rich Ho's avatar
Rich Ho committed
30
        pytest.skip("No enough experts")
31
32
33
34
35
36
37
38
39
40
41
42
43
44
    _run_distributed('_test_gshard_gate',
            1,
            {
                'd_model': d_model,
                'batch_size': batch_size,
                'n_expert': n_expert,
                'cap': cap
            },
            script=__file__
    )


def _test_gshard_gate(d_model, batch_size, n_expert, cap):
    _ensure_initialized()
Rich Ho's avatar
Rich Ho committed
45
46
    gate = GShardGate(d_model, n_expert, dist.get_world_size(),
            capacity=(cap, cap)).cuda()
Rick Ho's avatar
Rick Ho committed
47
48
    x = torch.rand(batch_size, d_model).cuda()
    topk_idx, topk_val = gate(x)
Rick Ho's avatar
Rick Ho committed
49
    counts = [0 for _ in range(n_expert * dist.get_world_size())]
Rich Ho's avatar
Rich Ho committed
50
51
52
53
54
55
    for v in topk_idx.cpu().view(-1).numpy():
        if v != -1:
            counts[v] += 1
    real_cap = math.ceil(cap * batch_size)
    for i in counts:
        assert(i <= real_cap)
Rick Ho's avatar
Rick Ho committed
56
57


58
59
60
61
@pytest.mark.parametrize("d_model", [1024])
@pytest.mark.parametrize("batch_size", [4096])
@pytest.mark.parametrize("n_expert", [1, 16])
@pytest.mark.parametrize("cap", [.1, .8])
Rich Ho's avatar
Rich Ho committed
62
def test_switch_gate(d_model, batch_size, n_expert, cap):
63
64
65
66
67
68
69
70
71
72
73
74
75
    _run_distributed('_test_switch_gate',
            1,
            {
                'd_model': d_model,
                'batch_size': batch_size,
                'n_expert': n_expert,
                'cap': cap
            },
            script=__file__
    )


def _test_switch_gate(d_model, batch_size, n_expert, cap):
Rich Ho's avatar
Rich Ho committed
76
77
78
79
80
    _ensure_initialized()
    gate = SwitchGate(d_model, n_expert, dist.get_world_size(),
            capacity=(cap, cap)).cuda()
    x = torch.rand(batch_size, d_model).cuda()
    topk_idx, topk_val = gate(x)
Rick Ho's avatar
Rick Ho committed
81
    counts = [0 for _ in range(n_expert * dist.get_world_size())]
Rich Ho's avatar
Rich Ho committed
82
83
84
85
86
87
88
89
    for v in topk_idx.cpu().view(-1).numpy():
        if v != -1:
            counts[v] += 1
    real_cap = math.ceil(cap * batch_size)
    for i in counts:
        assert(i <= real_cap)


Rick Ho's avatar
Rick Ho committed
90
if __name__ == '__main__':
91
92
93
94
95
96
97
98
    if len(sys.argv) >= 3:
        args = json.loads(sys.argv[2])
        locals()[sys.argv[1]](**args)
    else:
        _ensure_initialized()
        # test_gshard_gate(4096, 1024, 4, .2)
        test_gshard_gate(8, 16, 1, .1)
        # test_switch_gate(4096, 1024, 4, .2)