test_launch.py 8.18 KB
Newer Older
1
2
import json
import os
3
4
5
import tempfile
import unittest

6
from launch import *
7

8

9
10
11
12
13
14
15
16
17
18
19
20
21
22
class TestWrapUdfInTorchDistLauncher(unittest.TestCase):
    """wrap_udf_in_torch_dist_launcher()"""

    def test_simple(self):
        # test that a simple udf_command is correctly wrapped
        udf_command = "python3.7 path/to/some/trainer.py arg1 arg2"
        wrapped_udf_command = wrap_udf_in_torch_dist_launcher(
            udf_command=udf_command,
            num_trainers=2,
            num_nodes=2,
            node_rank=1,
            master_addr="127.0.0.1",
            master_port=1234,
        )
23
24
25
26
27
        expected = (
            "python3.7 -m torch.distributed.launch "
            "--nproc_per_node=2 --nnodes=2 --node_rank=1 --master_addr=127.0.0.1 "
            "--master_port=1234 path/to/some/trainer.py arg1 arg2"
        )
28
29
30
31
        self.assertEqual(wrapped_udf_command, expected)

    def test_chained_udf(self):
        # test that a chained udf_command is properly handled
32
33
34
        udf_command = (
            "cd path/to && python3.7 path/to/some/trainer.py arg1 arg2"
        )
35
36
37
38
39
40
41
42
        wrapped_udf_command = wrap_udf_in_torch_dist_launcher(
            udf_command=udf_command,
            num_trainers=2,
            num_nodes=2,
            node_rank=1,
            master_addr="127.0.0.1",
            master_port=1234,
        )
43
44
45
46
47
        expected = (
            "cd path/to && python3.7 -m torch.distributed.launch "
            "--nproc_per_node=2 --nnodes=2 --node_rank=1 --master_addr=127.0.0.1 "
            "--master_port=1234 path/to/some/trainer.py arg1 arg2"
        )
48
49
50
51
52
        self.assertEqual(wrapped_udf_command, expected)

    def test_py_versions(self):
        # test that this correctly handles different py versions/binaries
        py_binaries = (
53
54
55
56
57
            "python3.7",
            "python3.8",
            "python3.9",
            "python3",
            "python",
58
59
60
61
62
63
64
65
66
67
68
69
        )
        udf_command = "{python_bin} path/to/some/trainer.py arg1 arg2"

        for py_bin in py_binaries:
            wrapped_udf_command = wrap_udf_in_torch_dist_launcher(
                udf_command=udf_command.format(python_bin=py_bin),
                num_trainers=2,
                num_nodes=2,
                node_rank=1,
                master_addr="127.0.0.1",
                master_port=1234,
            )
70
71
72
73
74
75
76
            expected = (
                "{python_bin} -m torch.distributed.launch ".format(
                    python_bin=py_bin
                )
                + "--nproc_per_node=2 --nnodes=2 --node_rank=1 --master_addr=127.0.0.1 "
                "--master_port=1234 path/to/some/trainer.py arg1 arg2"
            )
77
78
79
            self.assertEqual(wrapped_udf_command, expected)


80
81
82
83
84
85
class TestWrapCmdWithLocalEnvvars(unittest.TestCase):
    """wrap_cmd_with_local_envvars()"""

    def test_simple(self):
        self.assertEqual(
            wrap_cmd_with_local_envvars("ls && pwd", "VAR1=value1 VAR2=value2"),
86
            "(export VAR1=value1 VAR2=value2; ls && pwd)",
87
88
89
90
91
        )


class TestConstructDglServerEnvVars(unittest.TestCase):
    """construct_dgl_server_env_vars()"""
92

93
94
95
96
97
98
99
100
101
    def test_simple(self):
        self.assertEqual(
            construct_dgl_server_env_vars(
                num_samplers=2,
                num_server_threads=3,
                tot_num_clients=4,
                part_config="path/to/part.config",
                ip_config="path/to/ip.config",
                num_servers=5,
102
                graph_format="csc",
103
                keep_alive=False,
104
105
106
107
108
109
110
111
112
113
            ),
            (
                "DGL_ROLE=server "
                "DGL_NUM_SAMPLER=2 "
                "OMP_NUM_THREADS=3 "
                "DGL_NUM_CLIENT=4 "
                "DGL_CONF_PATH=path/to/part.config "
                "DGL_IP_CONFIG=path/to/ip.config "
                "DGL_NUM_SERVER=5 "
                "DGL_GRAPH_FORMAT=csc "
114
                "DGL_KEEP_ALIVE=0 "
115
            ),
116
117
118
119
120
        )


class TestConstructDglClientEnvVars(unittest.TestCase):
    """construct_dgl_client_env_vars()"""
121

122
123
124
125
126
127
128
129
130
131
132
    def test_simple(self):
        # with pythonpath
        self.assertEqual(
            construct_dgl_client_env_vars(
                num_samplers=1,
                tot_num_clients=2,
                part_config="path/to/part.config",
                ip_config="path/to/ip.config",
                num_servers=3,
                graph_format="csc",
                num_omp_threads=4,
133
                group_id=0,
134
                pythonpath="some/pythonpath/",
135
136
137
138
139
140
141
142
143
144
145
            ),
            (
                "DGL_DIST_MODE=distributed "
                "DGL_ROLE=client "
                "DGL_NUM_SAMPLER=1 "
                "DGL_NUM_CLIENT=2 "
                "DGL_CONF_PATH=path/to/part.config "
                "DGL_IP_CONFIG=path/to/ip.config "
                "DGL_NUM_SERVER=3 "
                "DGL_GRAPH_FORMAT=csc "
                "OMP_NUM_THREADS=4 "
146
                "DGL_GROUP_ID=0 "
147
                "PYTHONPATH=some/pythonpath/ "
148
            ),
149
150
151
152
153
154
155
156
157
158
159
        )
        # without pythonpath
        self.assertEqual(
            construct_dgl_client_env_vars(
                num_samplers=1,
                tot_num_clients=2,
                part_config="path/to/part.config",
                ip_config="path/to/ip.config",
                num_servers=3,
                graph_format="csc",
                num_omp_threads=4,
160
                group_id=0,
161
162
163
164
165
166
167
168
169
170
171
            ),
            (
                "DGL_DIST_MODE=distributed "
                "DGL_ROLE=client "
                "DGL_NUM_SAMPLER=1 "
                "DGL_NUM_CLIENT=2 "
                "DGL_CONF_PATH=path/to/part.config "
                "DGL_IP_CONFIG=path/to/ip.config "
                "DGL_NUM_SERVER=3 "
                "DGL_GRAPH_FORMAT=csc "
                "OMP_NUM_THREADS=4 "
172
                "DGL_GROUP_ID=0 "
173
            ),
174
175
176
        )


177
def test_submit_jobs():
178
    class Args:
179
        pass
180

181
182
183
184
    args = Args()

    with tempfile.TemporaryDirectory() as test_dir:
        num_machines = 8
185
186
        ip_config = os.path.join(test_dir, "ip_config.txt")
        with open(ip_config, "w") as f:
187
            for i in range(num_machines):
188
189
190
191
                f.write("{} {}\n".format("127.0.0." + str(i), 30050))
        part_config = os.path.join(test_dir, "ogb-products.json")
        with open(part_config, "w") as f:
            json.dump({"num_parts": num_machines}, f)
192
193
194
195
        args.num_trainers = 8
        args.num_samplers = 1
        args.num_servers = 4
        args.workspace = test_dir
196
197
198
        args.part_config = "ogb-products.json"
        args.ip_config = "ip_config.txt"
        args.server_name = "ogb-products"
199
200
        args.keep_alive = False
        args.num_server_threads = 1
201
        args.graph_format = "csc"
202
203
204
205
206
207
        args.extra_envs = ["NCCL_DEBUG=INFO"]
        args.num_omp_threads = 1
        udf_command = "python3 train_dist.py --num_epochs 10"
        clients_cmd, servers_cmd = submit_jobs(args, udf_command, dry_run=True)

        def common_checks():
208
209
210
211
212
213
214
215
216
217
218
219
220
221
            assert "cd " + test_dir in cmd
            assert "export " + args.extra_envs[0] in cmd
            assert f"DGL_NUM_SAMPLER={args.num_samplers}" in cmd
            assert (
                f"DGL_NUM_CLIENT={args.num_trainers*(args.num_samplers+1)*num_machines}"
                in cmd
            )
            assert f"DGL_CONF_PATH={args.part_config}" in cmd
            assert f"DGL_IP_CONFIG={args.ip_config}" in cmd
            assert f"DGL_NUM_SERVER={args.num_servers}" in cmd
            assert f"DGL_GRAPH_FORMAT={args.graph_format}" in cmd
            assert f"OMP_NUM_THREADS={args.num_omp_threads}" in cmd
            assert udf_command[len("python3 ") :] in cmd

222
223
        for cmd in clients_cmd:
            common_checks()
224
225
226
227
228
229
230
231
232
            assert "DGL_DIST_MODE=distributed" in cmd
            assert "DGL_ROLE=client" in cmd
            assert "DGL_GROUP_ID=0" in cmd
            assert (
                f"python3 -m torch.distributed.launch --nproc_per_node={args.num_trainers} --nnodes={num_machines}"
                in cmd
            )
            assert "--master_addr=127.0.0" in cmd
            assert "--master_port=1234" in cmd
233
234
        for cmd in servers_cmd:
            common_checks()
235
236
237
            assert "DGL_ROLE=server" in cmd
            assert "DGL_KEEP_ALIVE=0" in cmd
            assert "DGL_SERVER_ID=" in cmd
238
239


240
if __name__ == "__main__":
241
    unittest.main()