test_launch.py 2.55 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import unittest

from tools.launch import wrap_udf_in_torch_dist_launcher


class TestWrapUdfInTorchDistLauncher(unittest.TestCase):
    """wrap_udf_in_torch_dist_launcher()"""

    def test_simple(self):
        # test that a simple udf_command is correctly wrapped
        udf_command = "python3.7 path/to/some/trainer.py arg1 arg2"
        wrapped_udf_command = wrap_udf_in_torch_dist_launcher(
            udf_command=udf_command,
            num_trainers=2,
            num_nodes=2,
            node_rank=1,
            master_addr="127.0.0.1",
            master_port=1234,
        )
        expected = "python3.7 -m torch.distributed.launch " \
                   "--nproc_per_node=2 --nnodes=2 --node_rank=1 --master_addr=127.0.0.1 " \
                   "--master_port=1234 path/to/some/trainer.py arg1 arg2"
        self.assertEqual(wrapped_udf_command, expected)

    def test_chained_udf(self):
        # test that a chained udf_command is properly handled
        udf_command = "cd path/to && python3.7 path/to/some/trainer.py arg1 arg2"
        wrapped_udf_command = wrap_udf_in_torch_dist_launcher(
            udf_command=udf_command,
            num_trainers=2,
            num_nodes=2,
            node_rank=1,
            master_addr="127.0.0.1",
            master_port=1234,
        )
        expected = "cd path/to && python3.7 -m torch.distributed.launch " \
                   "--nproc_per_node=2 --nnodes=2 --node_rank=1 --master_addr=127.0.0.1 " \
                   "--master_port=1234 path/to/some/trainer.py arg1 arg2"
        self.assertEqual(wrapped_udf_command, expected)

    def test_py_versions(self):
        # test that this correctly handles different py versions/binaries
        py_binaries = (
            "python3.7", "python3.8", "python3.9", "python3", "python"
        )
        udf_command = "{python_bin} path/to/some/trainer.py arg1 arg2"

        for py_bin in py_binaries:
            wrapped_udf_command = wrap_udf_in_torch_dist_launcher(
                udf_command=udf_command.format(python_bin=py_bin),
                num_trainers=2,
                num_nodes=2,
                node_rank=1,
                master_addr="127.0.0.1",
                master_port=1234,
            )
            expected = "{python_bin} -m torch.distributed.launch ".format(python_bin=py_bin) + \
                       "--nproc_per_node=2 --nnodes=2 --node_rank=1 --master_addr=127.0.0.1 " \
                       "--master_port=1234 path/to/some/trainer.py arg1 arg2"
            self.assertEqual(wrapped_udf_command, expected)


if __name__ == '__main__':
    unittest.main()