test_ddp.py 5.84 KB
Newer Older
1
import json
Rick Ho's avatar
Rick Ho committed
2
import random
3
4
5
import os
import sys
from typing import Dict
Rick Ho's avatar
Rick Ho committed
6
import random
Rick Ho's avatar
Rick Ho committed
7
import socket as sock
8
9
10

import pytest
import torch
Rick Ho's avatar
Rick Ho committed
11
import torch.distributed as dist
12
13
14

from test_numerical import test_fmoe as _test_fmoe
from test_numerical import test_fmoe_linear as _test_fmoe_linear
15
from test_numerical import _test_fmoe_local_ddp
16
17


Rick Ho's avatar
Rick Ho committed
18
def _ensure_initialized():
Rick Ho's avatar
Rick Ho committed
19
    if 'RANK' not in os.environ:
Rick Ho's avatar
Rick Ho committed
20
21
22
23
        os.environ["RANK"] = os.environ.get("OMPI_COMM_WORLD_RANK", "0")
        os.environ["WORLD_SIZE"] = os.environ.get("OMPI_COMM_WORLD_SIZE", "1")
        os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["RANK"]
        os.environ["MASTER_ADDR"] = os.environ.get("MASTER_ADDR", "localhost")
Rick Ho's avatar
Rick Ho committed
24
    if not dist.is_initialized():
Rick Ho's avatar
Rick Ho committed
25
26
27
        dist.init_process_group(backend="nccl")


Rick Ho's avatar
Rick Ho committed
28
29
port_count = 0

Rick Ho's avatar
Rick Ho committed
30
31
32
33
def _run_distributed(func, world_size, args: Dict, script=__file__, env=dict()):
    device_count = torch.cuda.device_count()
    if device_count < world_size:
        pytest.skip("No enough GPU, only {} found".format(device_count))
34
35
36
    import subprocess
    import os

37
    ps = []
Rick Ho's avatar
Rick Ho committed
38
    env["MASTER_ADDR"] = "localhost"
Rick Ho's avatar
Rick Ho committed
39
40
41
    global port_count
    env["MASTER_PORT"] = str(9010 + port_count)
    port_count += 1
Rick Ho's avatar
Rick Ho committed
42
    env["OMPI_COMM_WORLD_SIZE"] = str(world_size)
Rick Ho's avatar
Rick Ho committed
43
    env["LD_LIBRARY_PATH"] = os.environ.get("LD_LIBRARY_PATH")
44

45
    for i in range(world_size):
Rick Ho's avatar
Rick Ho committed
46
        env["OMPI_COMM_WORLD_RANK"] = str(i)
zhanggzh's avatar
zhanggzh committed
47
        env["CUDA_VISIBLE_DEVICES"] = str(i % device_count)
48
        p = subprocess.Popen(
Rick Ho's avatar
Rick Ho committed
49
50
51
            [sys.executable, script, func, json.dumps(args)],
            stdout=subprocess.PIPE,
            env=env
52
53
54
55
56
57
58
59
60
61
62
63
64
65
        )
        ps.append(p)

    for p in ps:
        p.wait()
        retc = p.poll()
        assert retc == 0


@pytest.mark.parametrize("num_expert", [4, 8])
@pytest.mark.parametrize("top_k", [2])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("d_model", [16])
@pytest.mark.parametrize("d_hidden", [32])
66
@pytest.mark.parametrize("mp_size", [1, 2])
Rick Ho's avatar
Rick Ho committed
67
@pytest.mark.parametrize("data_type", ['torch.float32', 'torch.bfloat16', 'torch.float16'])
68
def test_fmoe_linear_distributed(
Rich Ho's avatar
Rich Ho committed
69
    num_expert, top_k, batch_size, d_model, d_hidden, mp_size, data_type
70
71
72
):
    _run_distributed(
        "_test_fmoe_linear",
73
        mp_size * 2,
74
75
76
77
78
79
        {
            "num_expert": num_expert,
            "top_k": top_k,
            "batch_size": batch_size,
            "d_model": d_model,
            "d_hidden": d_hidden,
80
            "mp_size": mp_size,
Rich Ho's avatar
Rich Ho committed
81
            "data_type": data_type
82
83
84
85
86
87
88
89
90
        },
    )


@pytest.mark.parametrize("num_expert", [4, 8])
@pytest.mark.parametrize("top_k", [2])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("d_model", [16])
@pytest.mark.parametrize("expert", ["NaiveExpert", "LinearExpert"])
91
@pytest.mark.parametrize("mp_size", [1, 2])
Rick Ho's avatar
Rick Ho committed
92
93
@pytest.mark.parametrize("data_type", ['torch.float32', 'torch.bfloat16', 'torch.float16'])
def test_fmoe_distributed(num_expert, top_k, batch_size, d_model, expert, mp_size, data_type):
94
95
    _run_distributed(
        "_test_fmoe",
96
        mp_size * 2,
97
98
99
100
101
102
        {
            "num_expert": num_expert,
            "top_k": top_k,
            "batch_size": batch_size,
            "d_model": d_model,
            "expert": expert,
103
            "mp_size": mp_size,
Rick Ho's avatar
Rick Ho committed
104
            "data_type": data_type,
105
106
107
108
        },
    )


109
110
111
112
113
114
115
@pytest.mark.parametrize("mp_size", [1, 2])
def test_fmoe_local_ddp(mp_size):
    _run_distributed(
        _test_fmoe_local_ddp.__name__, mp_size * 2, {"mp_size": mp_size},
    )


116
117
118
119
120
121
122
123
124
if __name__ == "__main__":
    if len(sys.argv) >= 3:
        args = json.loads(sys.argv[2])
        os.environ["RANK"] = os.environ.get("OMPI_COMM_WORLD_RANK", "0")
        os.environ["WORLD_SIZE"] = os.environ.get("OMPI_COMM_WORLD_SIZE", "1")
        os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["RANK"]
        torch.distributed.init_process_group(backend="nccl")
        args["rank"] = torch.distributed.get_rank()
        args["world_size"] = torch.distributed.get_world_size()
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
        args["mp_group"] = [
            torch.distributed.new_group(
                ranks=[j * args["mp_size"] + i for i in range(args["mp_size"])],
                backend="nccl",
            )
            for j in range(args["world_size"] // args["mp_size"])
        ][args["rank"] // args["mp_size"]]
        args["dp_group"] = [
            torch.distributed.new_group(
                ranks=[
                    i * args["mp_size"] + j
                    for i in range(args["world_size"] // args["mp_size"])
                ],
                backend="nccl",
            )
            for j in range(args["mp_size"])
        ][args["rank"] % args["mp_size"]]
        args["world_group"] = torch.distributed.new_group(
            ranks=list(range(args["world_size"])), backend="nccl",
144
145
        )
        del args["mp_size"]
146
        locals()[sys.argv[1]](**args)
147
    else:
Rick Ho's avatar
Rick Ho committed
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
        torch.distributed.init_process_group(backend="nccl")
        args = dict(mp_size=1, data_type='torch.float16')
        args["rank"] = torch.distributed.get_rank()
        args["world_size"] = torch.distributed.get_world_size()
        args["mp_group"] = [
            torch.distributed.new_group(
                ranks=[j * args["mp_size"] + i for i in range(args["mp_size"])],
                backend="nccl",
            )
            for j in range(args["world_size"] // args["mp_size"])
        ][args["rank"] // args["mp_size"]]
        args["dp_group"] = [
            torch.distributed.new_group(
                ranks=[
                    i * args["mp_size"] + j
                    for i in range(args["world_size"] // args["mp_size"])
                ],
                backend="nccl",
            )
            for j in range(args["mp_size"])
        ][args["rank"] % args["mp_size"]]
        args["world_group"] = torch.distributed.new_group(
            ranks=list(range(args["world_size"])), backend="nccl",
171
        )
Rick Ho's avatar
Rick Ho committed
172
173
        del args["mp_size"]
        _test_fmoe(4, 2, 16, 2, 'NaiveExpert', **args)