test_gates.py 3.87 KB
Newer Older
Rich Ho's avatar
Rich Ho committed
1
import pytest
2

Rick Ho's avatar
Rick Ho committed
3
import os
4
5
import sys
import json
Rich Ho's avatar
Rich Ho committed
6
import math
7

Rick Ho's avatar
Rick Ho committed
8
9
import torch
import torch.distributed as dist
GODVIX's avatar
GODVIX committed
10
import torch.nn.functional as F
Rich Ho's avatar
Rich Ho committed
11
from fmoe.gates import GShardGate, SwitchGate
12
from test_ddp import _run_distributed
Rick Ho's avatar
Rick Ho committed
13
14


Rich Ho's avatar
Rich Ho committed
15
16
17
18
19
20
21
22
23
24
def _ensure_initialized():
    if not dist.is_initialized():
        os.environ["RANK"] = os.environ.get("OMPI_COMM_WORLD_RANK", "0")
        os.environ["WORLD_SIZE"] = os.environ.get("OMPI_COMM_WORLD_SIZE", "1")
        os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["RANK"]
        os.environ["MASTER_ADDR"] = os.environ.get("MASTER_ADDR", "localhost")
        os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "12211")
        dist.init_process_group(backend="nccl")


25
26
27
28
@pytest.mark.parametrize("d_model", [1024])
@pytest.mark.parametrize("batch_size", [16])
@pytest.mark.parametrize("n_expert", [1, 4])
@pytest.mark.parametrize("cap", [.1, 1.1])
Rich Ho's avatar
Rich Ho committed
29
def test_gshard_gate(d_model, batch_size, n_expert, cap):
30
    if 1 * n_expert < 2:
Rich Ho's avatar
Rich Ho committed
31
        pytest.skip("No enough experts")
32
33
34
35
36
37
38
39
40
41
42
43
44
45
    _run_distributed('_test_gshard_gate',
            1,
            {
                'd_model': d_model,
                'batch_size': batch_size,
                'n_expert': n_expert,
                'cap': cap
            },
            script=__file__
    )


def _test_gshard_gate(d_model, batch_size, n_expert, cap):
    _ensure_initialized()
Rich Ho's avatar
Rich Ho committed
46
47
    gate = GShardGate(d_model, n_expert, dist.get_world_size(),
            capacity=(cap, cap)).cuda()
Rick Ho's avatar
Rick Ho committed
48
49
    x = torch.rand(batch_size, d_model).cuda()
    topk_idx, topk_val = gate(x)
Rick Ho's avatar
Rick Ho committed
50
    counts = [0 for _ in range(n_expert * dist.get_world_size())]
Rich Ho's avatar
Rich Ho committed
51
52
53
54
55
56
    for v in topk_idx.cpu().view(-1).numpy():
        if v != -1:
            counts[v] += 1
    real_cap = math.ceil(cap * batch_size)
    for i in counts:
        assert(i <= real_cap)
Rick Ho's avatar
Rick Ho committed
57

GODVIX's avatar
GODVIX committed
58
59
60
61
62
63
64
    gate_score = gate.gate(x)
    for i in range(batch_size):
        for j in range(gate.top_k):
            v = topk_idx[i, j]
            if v != -1:
                assert topk_val[i, j] == gate_score[i, v]

Rick Ho's avatar
Rick Ho committed
65

66
67
68
69
@pytest.mark.parametrize("d_model", [1024])
@pytest.mark.parametrize("batch_size", [4096])
@pytest.mark.parametrize("n_expert", [1, 16])
@pytest.mark.parametrize("cap", [.1, .8])
Rich Ho's avatar
Rich Ho committed
70
def test_switch_gate(d_model, batch_size, n_expert, cap):
71
72
73
74
75
76
77
78
79
80
81
82
83
    _run_distributed('_test_switch_gate',
            1,
            {
                'd_model': d_model,
                'batch_size': batch_size,
                'n_expert': n_expert,
                'cap': cap
            },
            script=__file__
    )


def _test_switch_gate(d_model, batch_size, n_expert, cap):
Rich Ho's avatar
Rich Ho committed
84
85
86
87
    _ensure_initialized()
    gate = SwitchGate(d_model, n_expert, dist.get_world_size(),
            capacity=(cap, cap)).cuda()
    x = torch.rand(batch_size, d_model).cuda()
GODVIX's avatar
GODVIX committed
88
    rng = torch.cuda.get_rng_state() # save rng state
Rich Ho's avatar
Rich Ho committed
89
    topk_idx, topk_val = gate(x)
Rick Ho's avatar
Rick Ho committed
90
    counts = [0 for _ in range(n_expert * dist.get_world_size())]
Rich Ho's avatar
Rich Ho committed
91
92
93
94
95
96
97
    for v in topk_idx.cpu().view(-1).numpy():
        if v != -1:
            counts[v] += 1
    real_cap = math.ceil(cap * batch_size)
    for i in counts:
        assert(i <= real_cap)

GODVIX's avatar
GODVIX committed
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
    score = gate.gate(x)

    if gate.training:
        # reset rng state to make sure noise is the same as in gate.forward()
        torch.cuda.set_rng_state(rng)
        # random uniform number from [1-eps, 1+eps]
        noise = torch.rand_like(score)
        noise = noise * 2 * gate.switch_eps + 1.0 - gate.switch_eps
        score += noise

    # fp32 softmax for numerical stability
    score = F.softmax(score.float(), dim=-1)

    for i in range(batch_size):
        v = topk_idx[i]
        if v != -1:
            assert topk_val[i] == score[i, topk_idx[i]]

Rich Ho's avatar
Rich Ho committed
116

Rick Ho's avatar
Rick Ho committed
117
if __name__ == '__main__':
118
119
120
121
122
123
    if len(sys.argv) >= 3:
        args = json.loads(sys.argv[2])
        locals()[sys.argv[1]](**args)
    else:
        _ensure_initialized()
        # test_gshard_gate(4096, 1024, 4, .2)
GODVIX's avatar
GODVIX committed
124
        test_switch_gate(8, 16, 4, .1)
125
        # test_switch_gate(4096, 1024, 4, .2)