test_ddp.py 4.04 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
import json
import os
import sys
from typing import Dict

import pytest
import torch

from test_numerical import test_fmoe as _test_fmoe
from test_numerical import test_fmoe_linear as _test_fmoe_linear
11
from test_numerical import _test_fmoe_local_ddp
12
13


14
def _run_distributed(func, world_size, args: Dict, script=__file__):
15
16
    if torch.cuda.device_count() < world_size:
        pytest.skip("No enough GPU")
17
18
19
    import subprocess
    import os

20
    ps = []
21
22
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "36666"
23
    os.environ["OMPI_COMM_WORLD_SIZE"] = str(world_size)
24

25
    for i in range(world_size):
26
27
        os.environ["OMPI_COMM_WORLD_RANK"] = str(i)
        p = subprocess.Popen(
28
            [sys.executable, script, func, json.dumps(args)], stdout=subprocess.PIPE
29
30
31
32
33
34
35
36
37
38
39
40
41
42
        )
        ps.append(p)

    for p in ps:
        p.wait()
        retc = p.poll()
        assert retc == 0


@pytest.mark.parametrize("num_expert", [4, 8])
@pytest.mark.parametrize("top_k", [2])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("d_model", [16])
@pytest.mark.parametrize("d_hidden", [32])
43
@pytest.mark.parametrize("mp_size", [1, 2])
Rich Ho's avatar
Rich Ho committed
44
@pytest.mark.parametrize("data_type", ['torch.FloatTensor', 'torch.DoubleTensor', 'torch.HalfTensor'])
45
def test_fmoe_linear_distributed(
Rich Ho's avatar
Rich Ho committed
46
    num_expert, top_k, batch_size, d_model, d_hidden, mp_size, data_type
47
48
49
):
    _run_distributed(
        "_test_fmoe_linear",
50
        mp_size * 2,
51
52
53
54
55
56
        {
            "num_expert": num_expert,
            "top_k": top_k,
            "batch_size": batch_size,
            "d_model": d_model,
            "d_hidden": d_hidden,
57
            "mp_size": mp_size,
Rich Ho's avatar
Rich Ho committed
58
            "data_type": data_type
59
60
61
62
63
64
65
66
67
        },
    )


@pytest.mark.parametrize("num_expert", [4, 8])
@pytest.mark.parametrize("top_k", [2])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("d_model", [16])
@pytest.mark.parametrize("expert", ["NaiveExpert", "LinearExpert"])
68
69
@pytest.mark.parametrize("mp_size", [1, 2])
def test_fmoe_distributed(num_expert, top_k, batch_size, d_model, expert, mp_size):
70
71
    _run_distributed(
        "_test_fmoe",
72
        mp_size * 2,
73
74
75
76
77
78
        {
            "num_expert": num_expert,
            "top_k": top_k,
            "batch_size": batch_size,
            "d_model": d_model,
            "expert": expert,
79
            "mp_size": mp_size,
80
81
82
83
        },
    )


84
85
86
87
88
89
90
@pytest.mark.parametrize("mp_size", [1, 2])
def test_fmoe_local_ddp(mp_size):
    _run_distributed(
        _test_fmoe_local_ddp.__name__, mp_size * 2, {"mp_size": mp_size},
    )


91
92
93
94
95
96
97
98
99
if __name__ == "__main__":
    if len(sys.argv) >= 3:
        args = json.loads(sys.argv[2])
        os.environ["RANK"] = os.environ.get("OMPI_COMM_WORLD_RANK", "0")
        os.environ["WORLD_SIZE"] = os.environ.get("OMPI_COMM_WORLD_SIZE", "1")
        os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["RANK"]
        torch.distributed.init_process_group(backend="nccl")
        args["rank"] = torch.distributed.get_rank()
        args["world_size"] = torch.distributed.get_world_size()
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
        args["mp_group"] = [
            torch.distributed.new_group(
                ranks=[j * args["mp_size"] + i for i in range(args["mp_size"])],
                backend="nccl",
            )
            for j in range(args["world_size"] // args["mp_size"])
        ][args["rank"] // args["mp_size"]]
        args["dp_group"] = [
            torch.distributed.new_group(
                ranks=[
                    i * args["mp_size"] + j
                    for i in range(args["world_size"] // args["mp_size"])
                ],
                backend="nccl",
            )
            for j in range(args["mp_size"])
        ][args["rank"] % args["mp_size"]]
        args["world_group"] = torch.distributed.new_group(
            ranks=list(range(args["world_size"])), backend="nccl",
119
120
        )
        del args["mp_size"]
121
        locals()[sys.argv[1]](**args)
122
    else:
123
        test_fmoe_local_ddp(mp_size=2)
124
        test_fmoe_linear_distributed(
Rich Ho's avatar
Rich Ho committed
125
126
            num_expert=4, top_k=2, batch_size=4, d_model=8, d_hidden=8, mp_size=2,
            data_type="torch.HalfTensor"
127
        )