runner.py 7.77 KB
Newer Older
1
2
3
4
5
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

"""SuperBench Runner."""

6
import random
7
8
from pathlib import Path

9
from joblib import Parallel, delayed
10
11
from omegaconf import ListConfig, OmegaConf

12
from superbench.common.utils import SuperBenchLogger, logger
13
from superbench.runner.ansible import AnsibleClient
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30


class SuperBenchRunner():
    """SuperBench runner class."""
    def __init__(self, sb_config, docker_config, ansible_config, output_dir):
        """Initilize.

        Args:
            sb_config (DictConfig): SuperBench config object.
            docker_config (DictConfig): Docker config object.
            ansible_config (DictConfig): Ansible config object.
            output_dir (str): Dir for output.
        """
        self._sb_config = sb_config
        self._docker_config = docker_config
        self._ansible_config = ansible_config
        self._output_dir = output_dir
31
        self._ansible_client = AnsibleClient(ansible_config)
32
33
34
35
36

        self.__set_logger('sb-run.log')
        logger.info('Runner uses config: %s.', self._sb_config)
        logger.info('Runner writes to: %s.', self._output_dir)

37
        self._sb_benchmarks = self._sb_config.superbench.benchmarks
38
        self.__validate_sb_config()
39
40
41
        self._sb_enabled_benchmarks = self.__get_enabled_benchmarks()
        logger.info('Runner will run: %s', self._sb_enabled_benchmarks)

42
43
44
45
46
47
48
49
    def __set_logger(self, filename):
        """Set logger and add file handler.

        Args:
            filename (str): Log file name.
        """
        SuperBenchLogger.add_handler(logger.logger, filename=str(Path(self._output_dir) / filename))

50
51
52
53
54
55
56
    def __validate_sb_config(self):
        """Validate SuperBench config object.

        Raise:
            InvalidConfigError: If input config is invalid.
        """
        # TODO: add validation and defaulting
57
58
        if not self._sb_config.superbench.env:
            self._sb_config.superbench.env = {}
59
60
61
62
63
64
65
66
67
68
69
70
71
        for name in self._sb_benchmarks:
            if not self._sb_benchmarks[name].modes:
                self._sb_benchmarks[name].modes = []
            for idx, mode in enumerate(self._sb_benchmarks[name].modes):
                if mode.name == 'local':
                    if not mode.proc_num:
                        self._sb_benchmarks[name].modes[idx].proc_num = 1
                    if not mode.prefix:
                        self._sb_benchmarks[name].modes[idx].prefix = ''
                elif mode.name == 'torch.distributed':
                    if not mode.proc_num:
                        self._sb_benchmarks[name].modes[idx].proc_num = 8

72
73
74
75
76
77
78
79
80
81
82
83
84
    def __get_enabled_benchmarks(self):
        """Get enabled benchmarks list.

        Return:
            list: List of benchmarks which will be executed.
        """
        if self._sb_config.superbench.enable:
            if isinstance(self._sb_config.superbench.enable, str):
                return [self._sb_config.superbench.enable]
            elif isinstance(self._sb_config.superbench.enable, (list, ListConfig)):
                return list(self._sb_config.superbench.enable)
        return [k for k, v in self._sb_benchmarks.items() if v.enable]

85
    def __get_mode_command(self, benchmark_name, mode):
86
87
88
        """Get runner command for given mode.

        Args:
89
            benchmark_name (str): Benchmark name.
90
91
92
93
94
            mode (DictConfig): Runner mode.

        Return:
            str: Runner command.
        """
95
96
97
98
99
100
101
102
        exec_command = ('sb exec -c sb.config.yaml -C superbench.enable={name}').format(name=benchmark_name)
        mode_command = exec_command
        if mode.name == 'local':
            mode_command = '{prefix} {command}'.format(
                prefix=mode.prefix.format(proc_rank=mode.proc_rank, proc_num=mode.proc_num),
                command=exec_command,
            )
        elif mode.name == 'torch.distributed':
103
104
            # TODO: replace with torch.distributed.run in v1.9
            # TODO: only supports node_num=1 and node_num=all currently
105
            mode_command = (
106
107
108
109
                'python3 -m torch.distributed.launch '
                '--use_env --no_python --nproc_per_node={proc_num} '
                '--nnodes={node_num} --node_rank=$NODE_RANK '
                '--master_addr=$MASTER_ADDR --master_port=$MASTER_PORT '
110
                '{command} {torch_distributed_suffix}'
111
            ).format(
112
113
114
115
116
117
118
                proc_num=mode.proc_num,
                node_num=1 if mode.node_num == 1 else '$NNODES',
                command=exec_command,
                torch_distributed_suffix=(
                    'superbench.benchmarks.{name}.parameters.distributed_impl=ddp '
                    'superbench.benchmarks.{name}.parameters.distributed_backend=nccl'
                ).format(name=benchmark_name),
119
            )
120
        return mode_command.strip()
121

122
123
124
125
126
127
128
    def deploy(self):    # pragma: no cover
        """Deploy SuperBench environment."""
        logger.info('Preparing SuperBench environment.')
        extravars = {
            'ssh_port': random.randint(1 << 14, (1 << 15) - 1),
            'output_dir': self._output_dir,
            'docker_image': self._docker_config.image,
129
            'gpu_vendor': 'nvidia',
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
        }
        if bool(self._docker_config.username) and bool(self._docker_config.password):
            extravars.update(
                {
                    'docker_registry': self._docker_config.registry,
                    'docker_username': self._docker_config.username,
                    'docker_password': self._docker_config.password,
                }
            )
        self._ansible_client.run(self._ansible_client.get_playbook_config('deploy.yaml', extravars=extravars))

    def check_env(self):    # pragma: no cover
        """Check SuperBench environment."""
        logger.info('Checking SuperBench environment.')
        OmegaConf.save(config=self._sb_config, f=str(Path(self._output_dir) / 'sb.config.yaml'))
        self._ansible_client.run(
146
147
148
149
150
151
152
            self._ansible_client.get_playbook_config(
                'check_env.yaml',
                extravars={
                    'output_dir': self._output_dir,
                    'env': '\n'.join(f'{k}={v}' for k, v in self._sb_config.superbench.env.items()),
                }
            )
153
154
        )

155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
    def _run_proc(self, benchmark_name, mode, vars):
        """Run the process.

        Args:
            benchmark_name (str): Benchmark name.
            mode (DictConfig): Runner mode.
            vars (dict): Process variables.

        Returns:
            int: Process return code.
        """
        mode.update(vars)
        logger.info('Runner is going to run %s in %s mode, proc rank %d.', benchmark_name, mode.name, mode.proc_rank)
        rc = self._ansible_client.run(
            self._ansible_client.get_shell_config(
                (
                    'docker exec sb-workspace bash -c '
172
                    "'set -o allexport && source sb.env && set +o allexport && {command}'"
173
174
175
176
177
178
                ).format(command=self.__get_mode_command(benchmark_name, mode), )
            ),
            sudo=True
        )
        return rc

179
    def run(self):
180
181
182
183
184
185
        """Run the SuperBench benchmarks distributedly."""
        self.check_env()
        for benchmark_name in self._sb_benchmarks:
            if benchmark_name not in self._sb_enabled_benchmarks:
                continue
            benchmark_config = self._sb_benchmarks[benchmark_name]
186
187
188
189
190
191
            for mode in benchmark_config.modes:
                if mode.name == 'local':
                    Parallel(n_jobs=mode.proc_num if mode.parallel else 1)(
                        delayed(self._run_proc)(benchmark_name, mode, {
                            'proc_rank': proc_rank
                        }) for proc_rank in range(mode.proc_num)
192
                    )
193
194
                elif mode.name == 'torch.distributed':
                    self._run_proc(benchmark_name, mode, {'proc_rank': 0})