registry.py 9.78 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

"""Interfaces that provide access to benchmarks."""

from typing import Dict

from superbench.common.utils import logger
from superbench.benchmarks import Platform, Framework, BenchmarkContext
from superbench.benchmarks.base import Benchmark


class BenchmarkRegistry:
    """Class that minatains all benchmarks.

    Provide the following functions:
        Register new benchmark.
        Get the internal benchmark name.
        Check the validation of benchmark parameters.
        Get all configurable settings of benchmark.
        Launch one benchmark and return the result.
    """
    benchmarks: Dict[str, dict] = dict()

    @classmethod
26
    def register_benchmark(cls, name, class_def, parameters='', platform=None):
27
28
29
30
31
32
        """Register new benchmark, key is the benchmark name.

        Args:
            name (str): internal name of benchmark.
            class_def (Benchmark): class object of benchmark.
            parameters (str): predefined parameters of benchmark.
33
            platform (Platform): Platform types like CUDA, ROCM, DTK.
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
        """
        if not name or not isinstance(name, str):
            logger.log_and_raise(
                TypeError,
                'Name of registered benchmark is not string - benchmark: {}, type: {}'.format(name, type(name))
            )

        if not issubclass(class_def, Benchmark):
            logger.log_and_raise(
                TypeError,
                'Registered class is not subclass of Benchmark - benchmark: {}, type: {}'.format(name, type(class_def))
            )

        if name not in cls.benchmarks:
            cls.benchmarks[name] = dict()

        if platform:
            if platform not in Platform:
                platform_list = list(map(str, Platform))
                logger.log_and_raise(
                    TypeError, 'Unknown platform - benchmark: {}, supportted platforms: {}, but got: {}'.format(
                        name, platform_list, platform
                    )
                )
58
59
            if platform in cls.benchmarks[name]:
                logger.warning('Duplicate registration - benchmark: {}, platform: {}'.format(name, platform))
60

61
            cls.benchmarks[name][platform] = (class_def, parameters)
62
        else:
63
            # If not specified the tag, means the benchmark works for all platforms.
64
            for p in Platform:
65
66
67
68
                if p in cls.benchmarks[name]:
                    logger.warning('Duplicate registration - benchmark: {}, platform: {}'.format(name, p))

                cls.benchmarks[name][p] = (class_def, parameters)
69

70
71
72
73
74
75
76
77
78
79
80
81
82
83
        cls.__parse_and_check_args(name, class_def, parameters)

    @classmethod
    def __parse_and_check_args(cls, name, class_def, parameters):
        """Parse and check the predefine parameters.

        If ignore_invalid is True, and 'required' arguments are not set when register the benchmark,
        the arguments should be provided by user in config and skip the arguments checking.

        Args:
            name (str): internal name of benchmark.
            class_def (Benchmark): class object of benchmark.
            parameters (str): predefined parameters of benchmark.
        """
84
85
        benchmark = class_def(name, parameters)
        benchmark.add_parser_arguments()
86
        ret, args, unknown = benchmark.parse_args(ignore_invalid=True)
87
88
89
90
91
        if not ret or len(unknown) >= 1:
            logger.log_and_raise(
                TypeError,
                'Registered benchmark has invalid arguments - benchmark: {}, parameters: {}'.format(name, parameters)
            )
92
        elif args is not None:
93
            cls.benchmarks[name]['predefine_param'] = vars(args)
94
            logger.debug('Benchmark registration - benchmark: {}, predefine_parameters: {}'.format(name, vars(args)))
95
96
97
98
99
100
        else:
            cls.benchmarks[name]['predefine_param'] = dict()
            logger.info(
                'Benchmark registration - benchmark: {}, missing required parameters or invalid parameters, '
                'skip the arguments checking.'.format(name)
            )
101

102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
    @classmethod
    def is_benchmark_context_valid(cls, benchmark_context):
        """Check wether the benchmark context is valid or not.

        Args:
            benchmark_context (BenchmarkContext): the benchmark context.

        Return:
            ret (bool): return True if context is valid.
        """
        if isinstance(benchmark_context, BenchmarkContext) and benchmark_context.name:
            return True
        else:
            logger.error('Benchmark has invalid context')
            return False

    @classmethod
    def __get_benchmark_name(cls, benchmark_context):
        """Return the internal benchmark name.

        Args:
            benchmark_context (BenchmarkContext): the benchmark context.

        Return:
            benchmark_name (str): internal benchmark name, None means context is invalid.
        """
        if not cls.is_benchmark_context_valid(benchmark_context):
            return None

        benchmark_name = benchmark_context.name
        framework = benchmark_context.framework

        if framework != Framework.NONE:
            benchmark_name = framework.value + '-' + benchmark_name

        return benchmark_name

    @classmethod
140
141
    def create_benchmark_context(cls, name, platform=Platform.CPU, parameters='', framework=Framework.NONE):
        """Constructor.
142
143

        Args:
144
            name (str): name of benchmark in config file.
145
            platform (Platform): Platform types like Platform.CPU, Platform.CUDA, Platform.ROCM, Platform.DTK.
146
            parameters (str): predefined parameters of benchmark.
147
            framework (Framework): Framework types like Framework.PYTORCH, Framework.ONNXRUNTIME.
148
149

        Return:
150
            benchmark_context (BenchmarkContext): the benchmark context.
151
        """
152
        return BenchmarkContext(name, platform, parameters, framework)
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177

    @classmethod
    def get_benchmark_configurable_settings(cls, benchmark_context):
        """Get all configurable settings of benchmark.

        Args:
            benchmark_context (BenchmarkContext): the benchmark context.

        Return:
            All configurable settings in raw string, None means context is invalid or no benchmark is found.
        """
        if not cls.is_benchmark_context_valid(benchmark_context):
            return None

        benchmark_name = cls.__get_benchmark_name(benchmark_context)
        platform = benchmark_context.platform

        (benchmark_class, predefine_params) = cls.__select_benchmark(benchmark_name, platform)
        if benchmark_class:
            benchmark = benchmark_class(benchmark_name)
            benchmark.add_parser_arguments()
            return benchmark.get_configurable_settings()
        else:
            return None

178
179
180
181
182
183
184
185
186
187
188
189
190
    @classmethod
    def get_all_benchmark_predefine_settings(cls):
        """Get all registered benchmarks' predefine settings.

        Return:
            benchmark_params (dict[str, dict]): key is benchmark name,
              value is the dict with structure: {'parameter': default_value}.
        """
        benchmark_params = dict()
        for name in cls.benchmarks:
            benchmark_params[name] = cls.benchmarks[name]['predefine_param']
        return benchmark_params

191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
    @classmethod
    def launch_benchmark(cls, benchmark_context):
        """Select and Launch benchmark.

        Args:
            benchmark_context (BenchmarkContext): the benchmark context.

        Return:
            benchmark (Benchmark): the benchmark instance contains all results,
              None means context is invalid or no benchmark is found.
        """
        if not cls.is_benchmark_context_valid(benchmark_context):
            return None

        benchmark_name = cls.__get_benchmark_name(benchmark_context)

        benchmark = None
        if benchmark_name:
            platform = benchmark_context.platform
            parameters = benchmark_context.parameters
            (benchmark_class, predefine_params) = cls.__select_benchmark(benchmark_name, platform)
            if benchmark_class:
                if predefine_params:
                    parameters = predefine_params + ' ' + parameters

                benchmark = benchmark_class(benchmark_name, parameters)
guoshzhao's avatar
guoshzhao committed
217
                benchmark.run()
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265

        return benchmark

    @classmethod
    def is_benchmark_registered(cls, benchmark_context):
        """Check wether the benchmark is registered or not.

        Args:
            benchmark_context (BenchmarkContext): the benchmark context.

        Return:
            ret (bool): return True if context is valid and benchmark is registered.
        """
        if not cls.is_benchmark_context_valid(benchmark_context):
            return False

        benchmark_name = cls.__get_benchmark_name(benchmark_context)
        platform = benchmark_context.platform

        if cls.benchmarks.get(benchmark_name, {}).get(platform) is None:
            return False

        return True

    @classmethod
    def __select_benchmark(cls, name, platform):
        """Select benchmark by name and platform.

        Args:
            name (str): internal name of benchmark.
            platform (Platform): Platform type of benchmark.

        Return:
            benchmark_class (Benchmark): class object of benchmark.
            predefine_params (str): predefined parameters which is set when register the benchmark.
        """
        if name not in cls.benchmarks or platform not in cls.benchmarks[name]:
            logger.warning('Benchmark has no implementation, name: {}, platform: {}'.format(name, platform))
            return (None, None)

        (benchmark_class, predefine_params) = cls.benchmarks[name][platform]

        return (benchmark_class, predefine_params)

    @classmethod
    def clean_benchmarks(cls):
        """Clean up the benchmark registry."""
        cls.benchmarks.clear()