runner.py 7.5 KB
Newer Older
1
2
3
4
5
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

"""SuperBench Runner."""

6
import random
7
8
from pathlib import Path

9
from joblib import Parallel, delayed
10
11
from omegaconf import ListConfig, OmegaConf

12
from superbench.common.utils import SuperBenchLogger, logger
13
from superbench.runner.ansible import AnsibleClient
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30


class SuperBenchRunner():
    """SuperBench runner class."""
    def __init__(self, sb_config, docker_config, ansible_config, output_dir):
        """Initilize.

        Args:
            sb_config (DictConfig): SuperBench config object.
            docker_config (DictConfig): Docker config object.
            ansible_config (DictConfig): Ansible config object.
            output_dir (str): Dir for output.
        """
        self._sb_config = sb_config
        self._docker_config = docker_config
        self._ansible_config = ansible_config
        self._output_dir = output_dir
31
        self._ansible_client = AnsibleClient(ansible_config)
32
33
34
35
36

        self.__set_logger('sb-run.log')
        logger.info('Runner uses config: %s.', self._sb_config)
        logger.info('Runner writes to: %s.', self._output_dir)

37
        self._sb_benchmarks = self._sb_config.superbench.benchmarks
38
        self.__validate_sb_config()
39
40
41
        self._sb_enabled_benchmarks = self.__get_enabled_benchmarks()
        logger.info('Runner will run: %s', self._sb_enabled_benchmarks)

42
43
44
45
46
47
48
49
    def __set_logger(self, filename):
        """Set logger and add file handler.

        Args:
            filename (str): Log file name.
        """
        SuperBenchLogger.add_handler(logger.logger, filename=str(Path(self._output_dir) / filename))

50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
    def __validate_sb_config(self):
        """Validate SuperBench config object.

        Raise:
            InvalidConfigError: If input config is invalid.
        """
        # TODO: add validation and defaulting
        for name in self._sb_benchmarks:
            if not self._sb_benchmarks[name].modes:
                self._sb_benchmarks[name].modes = []
            for idx, mode in enumerate(self._sb_benchmarks[name].modes):
                if mode.name == 'local':
                    if not mode.proc_num:
                        self._sb_benchmarks[name].modes[idx].proc_num = 1
                    if not mode.prefix:
                        self._sb_benchmarks[name].modes[idx].prefix = ''
                elif mode.name == 'torch.distributed':
                    if not mode.proc_num:
                        self._sb_benchmarks[name].modes[idx].proc_num = 8

70
71
72
73
74
75
76
77
78
79
80
81
82
    def __get_enabled_benchmarks(self):
        """Get enabled benchmarks list.

        Return:
            list: List of benchmarks which will be executed.
        """
        if self._sb_config.superbench.enable:
            if isinstance(self._sb_config.superbench.enable, str):
                return [self._sb_config.superbench.enable]
            elif isinstance(self._sb_config.superbench.enable, (list, ListConfig)):
                return list(self._sb_config.superbench.enable)
        return [k for k, v in self._sb_benchmarks.items() if v.enable]

83
    def __get_mode_command(self, benchmark_name, mode):
84
85
86
        """Get runner command for given mode.

        Args:
87
            benchmark_name (str): Benchmark name.
88
89
90
91
92
            mode (DictConfig): Runner mode.

        Return:
            str: Runner command.
        """
93
94
95
96
97
98
99
100
        exec_command = ('sb exec -c sb.config.yaml -C superbench.enable={name}').format(name=benchmark_name)
        mode_command = exec_command
        if mode.name == 'local':
            mode_command = '{prefix} {command}'.format(
                prefix=mode.prefix.format(proc_rank=mode.proc_rank, proc_num=mode.proc_num),
                command=exec_command,
            )
        elif mode.name == 'torch.distributed':
101
102
            # TODO: replace with torch.distributed.run in v1.9
            # TODO: only supports node_num=1 and node_num=all currently
103
            mode_command = (
104
105
106
107
                'python3 -m torch.distributed.launch '
                '--use_env --no_python --nproc_per_node={proc_num} '
                '--nnodes={node_num} --node_rank=$NODE_RANK '
                '--master_addr=$MASTER_ADDR --master_port=$MASTER_PORT '
108
                '{command} {torch_distributed_suffix}'
109
            ).format(
110
111
112
113
114
115
116
                proc_num=mode.proc_num,
                node_num=1 if mode.node_num == 1 else '$NNODES',
                command=exec_command,
                torch_distributed_suffix=(
                    'superbench.benchmarks.{name}.parameters.distributed_impl=ddp '
                    'superbench.benchmarks.{name}.parameters.distributed_backend=nccl'
                ).format(name=benchmark_name),
117
            )
118
        return mode_command.strip()
119

120
121
122
123
124
125
126
    def deploy(self):    # pragma: no cover
        """Deploy SuperBench environment."""
        logger.info('Preparing SuperBench environment.')
        extravars = {
            'ssh_port': random.randint(1 << 14, (1 << 15) - 1),
            'output_dir': self._output_dir,
            'docker_image': self._docker_config.image,
127
            'gpu_vendor': 'nvidia',
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
        }
        if bool(self._docker_config.username) and bool(self._docker_config.password):
            extravars.update(
                {
                    'docker_registry': self._docker_config.registry,
                    'docker_username': self._docker_config.username,
                    'docker_password': self._docker_config.password,
                }
            )
        self._ansible_client.run(self._ansible_client.get_playbook_config('deploy.yaml', extravars=extravars))

    def check_env(self):    # pragma: no cover
        """Check SuperBench environment."""
        logger.info('Checking SuperBench environment.')
        OmegaConf.save(config=self._sb_config, f=str(Path(self._output_dir) / 'sb.config.yaml'))
        self._ansible_client.run(
            self._ansible_client.get_playbook_config('check_env.yaml', extravars={'output_dir': self._output_dir})
        )

147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
    def _run_proc(self, benchmark_name, mode, vars):
        """Run the process.

        Args:
            benchmark_name (str): Benchmark name.
            mode (DictConfig): Runner mode.
            vars (dict): Process variables.

        Returns:
            int: Process return code.
        """
        mode.update(vars)
        logger.info('Runner is going to run %s in %s mode, proc rank %d.', benchmark_name, mode.name, mode.proc_rank)
        rc = self._ansible_client.run(
            self._ansible_client.get_shell_config(
                (
                    'docker exec sb-workspace bash -c '
                    '"set -o allexport && source sb.env && set +o allexport && {command}"'
                ).format(command=self.__get_mode_command(benchmark_name, mode), )
            ),
            sudo=True
        )
        return rc

171
    def run(self):
172
173
174
175
176
177
        """Run the SuperBench benchmarks distributedly."""
        self.check_env()
        for benchmark_name in self._sb_benchmarks:
            if benchmark_name not in self._sb_enabled_benchmarks:
                continue
            benchmark_config = self._sb_benchmarks[benchmark_name]
178
179
180
181
182
183
            for mode in benchmark_config.modes:
                if mode.name == 'local':
                    Parallel(n_jobs=mode.proc_num if mode.parallel else 1)(
                        delayed(self._run_proc)(benchmark_name, mode, {
                            'proc_rank': proc_rank
                        }) for proc_rank in range(mode.proc_num)
184
                    )
185
186
                elif mode.name == 'torch.distributed':
                    self._run_proc(benchmark_name, mode, {'proc_rank': 0})