test_runner.py 3.54 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

"""SuperBench Runner test."""

import unittest
import shutil
import tempfile
from pathlib import Path

from omegaconf import OmegaConf

from superbench.runner import SuperBenchRunner


class RunnerTestCase(unittest.TestCase):
    """A class for runner test cases."""
    def setUp(self):
        """Hook method for setting up the test fixture before exercising it."""
        default_config_file = Path(__file__).parent / '../../superbench/config/default.yaml'
        self.default_config = OmegaConf.load(str(default_config_file))
        self.output_dir = tempfile.mkdtemp()

        self.runner = SuperBenchRunner(self.default_config, None, None, self.output_dir)

    def tearDown(self):
        """Hook method for deconstructing the test fixture after testing it."""
        shutil.rmtree(self.output_dir)

    def test_set_logger(self):
        """Test log file exists."""
        expected_log_file = Path(self.runner._output_dir) / 'sb-run.log'
        self.assertTrue(expected_log_file.is_file())

35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
    def test_get_mode_command(self):
        """Test __get_mode_command."""
        test_cases = [
            {
                'mode': {
                    'name': 'non_exist',
                },
                'exec_command': 'sb exec',
                'expected_command': 'sb exec',
            },
            {
                'mode': {
                    'name': 'torch.distributed',
                },
                'exec_command':
                'sb exec',
                'expected_command': (
                    'python3 -m torch.distributed.launch '
                    '--use_env --no_python --nproc_per_node=8 '
                    '--nnodes=$NNODES --node_rank=$NODE_RANK '
                    '--master_addr=$MASTER_ADDR --master_port=$MASTER_PORT '
                    'sb exec'
                ),
            },
            {
                'mode': {
                    'name': 'torch.distributed',
                    'proc_num': 1,
                    'node_num': 'all',
                },
                'exec_command':
                'sb exec',
                'expected_command': (
                    'python3 -m torch.distributed.launch '
                    '--use_env --no_python --nproc_per_node=1 '
                    '--nnodes=$NNODES --node_rank=$NODE_RANK '
                    '--master_addr=$MASTER_ADDR --master_port=$MASTER_PORT '
                    'sb exec'
                ),
            },
            {
                'mode': {
                    'name': 'torch.distributed',
                    'proc_num': 8,
                    'node_num': 1,
                },
                'exec_command':
                'sb exec',
                'expected_command': (
                    'python3 -m torch.distributed.launch '
                    '--use_env --no_python --nproc_per_node=8 '
                    '--nnodes=1 --node_rank=$NODE_RANK '
                    '--master_addr=$MASTER_ADDR --master_port=$MASTER_PORT '
                    'sb exec'
                ),
            },
        ]
        for test_case in test_cases:
            with self.subTest(msg='Testing with case', test_case=test_case):
                self.assertEqual(
                    self.runner._SuperBenchRunner__get_mode_command(
                        OmegaConf.create(test_case['mode']), test_case['exec_command']
                    ), test_case['expected_command']
                )

100
101
    def test_run(self):
        """Test run."""
102
        self.runner._sb_enabled_benchmarks = []
103
        self.runner.run()