"docs/source/en/optimization/fp16.md" did not exist on "8dfff7c01529a1a476696691626b261f92fd19e3"
test_gates.py 1.44 KB
Newer Older
Rich Ho's avatar
Rich Ho committed
1
import pytest
Rick Ho's avatar
Rick Ho committed
2
import os
Rich Ho's avatar
Rich Ho committed
3
import math
Rick Ho's avatar
Rick Ho committed
4
5
6
7
8
import torch
import torch.distributed as dist
from fmoe.gates import GShardGate


Rich Ho's avatar
Rich Ho committed
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def _ensure_initialized():
    if not dist.is_initialized():
        os.environ["RANK"] = os.environ.get("OMPI_COMM_WORLD_RANK", "0")
        os.environ["WORLD_SIZE"] = os.environ.get("OMPI_COMM_WORLD_SIZE", "1")
        os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["RANK"]
        os.environ["MASTER_ADDR"] = os.environ.get("MASTER_ADDR", "localhost")
        os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "12211")
        dist.init_process_group(backend="nccl")


@pytest.mark.parametrize("d_model", [8, 1024])
@pytest.mark.parametrize("batch_size", [16, 4096])
@pytest.mark.parametrize("n_expert", [1, 4, 16])
@pytest.mark.parametrize("cap", [.1, .5, 1.1])
def test_gshard_gate(d_model, batch_size, n_expert, cap):
    _ensure_initialized()
    if dist.get_world_size() * n_expert < 2:
        pytest.skip("No enough experts")
    gate = GShardGate(d_model, n_expert, dist.get_world_size(),
            capacity=(cap, cap)).cuda()
Rick Ho's avatar
Rick Ho committed
29
30
    x = torch.rand(batch_size, d_model).cuda()
    topk_idx, topk_val = gate(x)
Rich Ho's avatar
Rich Ho committed
31
32
33
34
35
36
37
    counts = [0 for _ in range(n_expert)]
    for v in topk_idx.cpu().view(-1).numpy():
        if v != -1:
            counts[v] += 1
    real_cap = math.ceil(cap * batch_size)
    for i in counts:
        assert(i <= real_cap)
Rick Ho's avatar
Rick Ho committed
38
39
40


if __name__ == '__main__':
Rich Ho's avatar
Rich Ho committed
41
42
    _ensure_initialized()
    test_gshard_gate(4096, 1024, 4, .2)