test_ddp.py 3.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
import json
import os
import sys
from typing import Dict

import pytest
import torch

from test_numerical import test_fmoe as _test_fmoe
from test_numerical import test_fmoe_linear as _test_fmoe_linear


13
def _run_distributed(func, world_size, args: Dict):
14
15
    if torch.cuda.device_count() < world_size:
        pytest.skip("No enough GPU")
16
17
18
    import subprocess
    import os

19
    ps = []
20
21
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "36666"
22
    os.environ["OMPI_COMM_WORLD_SIZE"] = str(world_size)
23

24
    for i in range(world_size):
25
26
        os.environ["OMPI_COMM_WORLD_RANK"] = str(i)
        p = subprocess.Popen(
27
            [sys.executable, __file__, func, json.dumps(args)], stdout=subprocess.PIPE
28
29
30
31
32
33
34
35
36
37
38
39
40
41
        )
        ps.append(p)

    for p in ps:
        p.wait()
        retc = p.poll()
        assert retc == 0


@pytest.mark.parametrize("num_expert", [4, 8])
@pytest.mark.parametrize("top_k", [2])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("d_model", [16])
@pytest.mark.parametrize("d_hidden", [32])
42
@pytest.mark.parametrize("mp_size", [1, 2])
43
def test_fmoe_linear_distributed(
44
    num_expert, top_k, batch_size, d_model, d_hidden, mp_size
45
46
47
):
    _run_distributed(
        "_test_fmoe_linear",
48
        mp_size * 2,
49
50
51
52
53
54
        {
            "num_expert": num_expert,
            "top_k": top_k,
            "batch_size": batch_size,
            "d_model": d_model,
            "d_hidden": d_hidden,
55
            "mp_size": mp_size,
56
57
58
59
60
61
62
63
64
        },
    )


@pytest.mark.parametrize("num_expert", [4, 8])
@pytest.mark.parametrize("top_k", [2])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("d_model", [16])
@pytest.mark.parametrize("expert", ["NaiveExpert", "LinearExpert"])
65
66
@pytest.mark.parametrize("mp_size", [1, 2])
def test_fmoe_distributed(num_expert, top_k, batch_size, d_model, expert, mp_size):
67
68
    _run_distributed(
        "_test_fmoe",
69
        mp_size * 2,
70
71
72
73
74
75
        {
            "num_expert": num_expert,
            "top_k": top_k,
            "batch_size": batch_size,
            "d_model": d_model,
            "expert": expert,
76
            "mp_size": mp_size,
77
78
79
80
81
82
83
84
85
86
87
88
89
        },
    )


if __name__ == "__main__":
    if len(sys.argv) >= 3:
        args = json.loads(sys.argv[2])
        os.environ["RANK"] = os.environ.get("OMPI_COMM_WORLD_RANK", "0")
        os.environ["WORLD_SIZE"] = os.environ.get("OMPI_COMM_WORLD_SIZE", "1")
        os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["RANK"]
        torch.distributed.init_process_group(backend="nccl")
        args["rank"] = torch.distributed.get_rank()
        args["world_size"] = torch.distributed.get_world_size()
90
91
92
93
94
95
96
97
98
99
100
101
        args["mp_group"] = (
            [
                torch.distributed.new_group(
                    ranks=[j * args["mp_size"] + i for i in range(args["mp_size"])],
                    backend="nccl",
                )
                for j in range(args["world_size"] // args["mp_size"])
            ][args["rank"] // args["mp_size"]]
            if args["mp_size"] > 1
            else None
        )
        del args["mp_size"]
102
        locals()[sys.argv[1]](**args)
103
104
105
106
    else:
        test_fmoe_linear_distributed(
            num_expert=4, top_k=2, batch_size=4, d_model=8, d_hidden=8, mp_size=2
        )