runner.py 3.59 KB
Newer Older
1
2
3
4
5
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

"""SuperBench Runner."""

6
import random
7
8
from pathlib import Path

9
10
from omegaconf import ListConfig, OmegaConf

11
from superbench.common.utils import SuperBenchLogger, logger
12
from superbench.runner.ansible import AnsibleClient
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29


class SuperBenchRunner():
    """SuperBench runner class."""
    def __init__(self, sb_config, docker_config, ansible_config, output_dir):
        """Initilize.

        Args:
            sb_config (DictConfig): SuperBench config object.
            docker_config (DictConfig): Docker config object.
            ansible_config (DictConfig): Ansible config object.
            output_dir (str): Dir for output.
        """
        self._sb_config = sb_config
        self._docker_config = docker_config
        self._ansible_config = ansible_config
        self._output_dir = output_dir
30
        self._ansible_client = AnsibleClient(ansible_config)
31
32
33
34
35

        self.__set_logger('sb-run.log')
        logger.info('Runner uses config: %s.', self._sb_config)
        logger.info('Runner writes to: %s.', self._output_dir)

36
37
38
39
        self._sb_benchmarks = self._sb_config.superbench.benchmarks
        self._sb_enabled_benchmarks = self.__get_enabled_benchmarks()
        logger.info('Runner will run: %s', self._sb_enabled_benchmarks)

40
41
42
43
44
45
46
47
    def __set_logger(self, filename):
        """Set logger and add file handler.

        Args:
            filename (str): Log file name.
        """
        SuperBenchLogger.add_handler(logger.logger, filename=str(Path(self._output_dir) / filename))

48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
    def __get_enabled_benchmarks(self):
        """Get enabled benchmarks list.

        Return:
            list: List of benchmarks which will be executed.
        """
        if self._sb_config.superbench.enable:
            if isinstance(self._sb_config.superbench.enable, str):
                return [self._sb_config.superbench.enable]
            elif isinstance(self._sb_config.superbench.enable, (list, ListConfig)):
                return list(self._sb_config.superbench.enable)
        return [k for k, v in self._sb_benchmarks.items() if v.enable]

    def deploy(self):    # pragma: no cover
        """Deploy SuperBench environment."""
        logger.info('Preparing SuperBench environment.')
        extravars = {
            'ssh_port': random.randint(1 << 14, (1 << 15) - 1),
            'output_dir': self._output_dir,
            'docker_image': self._docker_config.image,
68
            'gpu_vendor': 'nvidia',
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
        }
        if bool(self._docker_config.username) and bool(self._docker_config.password):
            extravars.update(
                {
                    'docker_registry': self._docker_config.registry,
                    'docker_username': self._docker_config.username,
                    'docker_password': self._docker_config.password,
                }
            )
        self._ansible_client.run(self._ansible_client.get_playbook_config('deploy.yaml', extravars=extravars))

    def check_env(self):    # pragma: no cover
        """Check SuperBench environment."""
        logger.info('Checking SuperBench environment.')
        OmegaConf.save(config=self._sb_config, f=str(Path(self._output_dir) / 'sb.config.yaml'))
        self._ansible_client.run(
            self._ansible_client.get_playbook_config('check_env.yaml', extravars={'output_dir': self._output_dir})
        )

88
89
90
91
92
93
94
95
96
    def run(self):
        """Run the SuperBench benchmarks distributedly.

        Raises:
            NotImplementedError: Not implemented yet.
        """
        logger.info(self._sb_config)
        logger.error('Work in progress, not implemented yet.')
        pass