test_gates.py 2.17 KB
Newer Older
Rich Ho's avatar
Rich Ho committed
1
import pytest
Rick Ho's avatar
Rick Ho committed
2
import os
Rich Ho's avatar
Rich Ho committed
3
import math
Rick Ho's avatar
Rick Ho committed
4
5
import torch
import torch.distributed as dist
Rich Ho's avatar
Rich Ho committed
6
from fmoe.gates import GShardGate, SwitchGate
Rick Ho's avatar
Rick Ho committed
7
8


Rich Ho's avatar
Rich Ho committed
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def _ensure_initialized():
    if not dist.is_initialized():
        os.environ["RANK"] = os.environ.get("OMPI_COMM_WORLD_RANK", "0")
        os.environ["WORLD_SIZE"] = os.environ.get("OMPI_COMM_WORLD_SIZE", "1")
        os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["RANK"]
        os.environ["MASTER_ADDR"] = os.environ.get("MASTER_ADDR", "localhost")
        os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "12211")
        dist.init_process_group(backend="nccl")


@pytest.mark.parametrize("d_model", [8, 1024])
@pytest.mark.parametrize("batch_size", [16, 4096])
@pytest.mark.parametrize("n_expert", [1, 4, 16])
@pytest.mark.parametrize("cap", [.1, .5, 1.1])
def test_gshard_gate(d_model, batch_size, n_expert, cap):
    _ensure_initialized()
    if dist.get_world_size() * n_expert < 2:
        pytest.skip("No enough experts")
    gate = GShardGate(d_model, n_expert, dist.get_world_size(),
            capacity=(cap, cap)).cuda()
Rick Ho's avatar
Rick Ho committed
29
30
    x = torch.rand(batch_size, d_model).cuda()
    topk_idx, topk_val = gate(x)
Rich Ho's avatar
Rich Ho committed
31
32
33
34
35
36
37
    counts = [0 for _ in range(n_expert)]
    for v in topk_idx.cpu().view(-1).numpy():
        if v != -1:
            counts[v] += 1
    real_cap = math.ceil(cap * batch_size)
    for i in counts:
        assert(i <= real_cap)
Rick Ho's avatar
Rick Ho committed
38
39


Rich Ho's avatar
Rich Ho committed
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
@pytest.mark.parametrize("d_model", [8, 1024])
@pytest.mark.parametrize("batch_size", [16, 4096])
@pytest.mark.parametrize("n_expert", [1, 4, 16])
@pytest.mark.parametrize("cap", [.1, .5, 1.1])
def test_switch_gate(d_model, batch_size, n_expert, cap):
    _ensure_initialized()
    gate = SwitchGate(d_model, n_expert, dist.get_world_size(),
            capacity=(cap, cap)).cuda()
    x = torch.rand(batch_size, d_model).cuda()
    topk_idx, topk_val = gate(x)
    counts = [0 for _ in range(n_expert)]
    for v in topk_idx.cpu().view(-1).numpy():
        if v != -1:
            counts[v] += 1
    real_cap = math.ceil(cap * batch_size)
    for i in counts:
        assert(i <= real_cap)


Rick Ho's avatar
Rick Ho committed
59
if __name__ == '__main__':
Rich Ho's avatar
Rich Ho committed
60
    _ensure_initialized()
Rich Ho's avatar
Rich Ho committed
61
62
    # test_gshard_gate(4096, 1024, 4, .2)
    test_switch_gate(4096, 1024, 4, .2)