registry.py 8.86 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

"""Interfaces that provide access to benchmarks."""

from typing import Dict

from superbench.common.utils import logger
from superbench.benchmarks import Platform, Framework, BenchmarkContext
from superbench.benchmarks.base import Benchmark


class BenchmarkRegistry:
    """Class that minatains all benchmarks.

    Provide the following functions:
        Register new benchmark.
        Get the internal benchmark name.
        Check the validation of benchmark parameters.
        Get all configurable settings of benchmark.
        Launch one benchmark and return the result.
    """
    benchmarks: Dict[str, dict] = dict()

    @classmethod
26
    def register_benchmark(cls, name, class_def, parameters='', platform=None):
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
        """Register new benchmark, key is the benchmark name.

        Args:
            name (str): internal name of benchmark.
            class_def (Benchmark): class object of benchmark.
            parameters (str): predefined parameters of benchmark.
            platform (Platform): Platform types like CUDA, ROCM.
        """
        if not name or not isinstance(name, str):
            logger.log_and_raise(
                TypeError,
                'Name of registered benchmark is not string - benchmark: {}, type: {}'.format(name, type(name))
            )

        if not issubclass(class_def, Benchmark):
            logger.log_and_raise(
                TypeError,
                'Registered class is not subclass of Benchmark - benchmark: {}, type: {}'.format(name, type(class_def))
            )

        if name not in cls.benchmarks:
            cls.benchmarks[name] = dict()

        if platform:
            if platform not in Platform:
                platform_list = list(map(str, Platform))
                logger.log_and_raise(
                    TypeError, 'Unknown platform - benchmark: {}, supportted platforms: {}, but got: {}'.format(
                        name, platform_list, platform
                    )
                )
58
59
            if platform in cls.benchmarks[name]:
                logger.warning('Duplicate registration - benchmark: {}, platform: {}'.format(name, platform))
60

61
            cls.benchmarks[name][platform] = (class_def, parameters)
62
        else:
63
            # If not specified the tag, means the benchmark works for all platforms.
64
            for p in Platform:
65
66
67
68
                if p in cls.benchmarks[name]:
                    logger.warning('Duplicate registration - benchmark: {}, platform: {}'.format(name, p))

                cls.benchmarks[name][p] = (class_def, parameters)
69

70
71
72
73
74
75
76
77
78
79
        benchmark = class_def(name, parameters)
        benchmark.add_parser_arguments()
        ret, args, unknown = benchmark.parse_args()
        if not ret or len(unknown) >= 1:
            logger.log_and_raise(
                TypeError,
                'Registered benchmark has invalid arguments - benchmark: {}, parameters: {}'.format(name, parameters)
            )
        else:
            cls.benchmarks[name]['predefine_param'] = vars(args)
80
            logger.debug('Benchmark registration - benchmark: {}, predefine_parameters: {}'.format(name, vars(args)))
81

82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
    @classmethod
    def is_benchmark_context_valid(cls, benchmark_context):
        """Check wether the benchmark context is valid or not.

        Args:
            benchmark_context (BenchmarkContext): the benchmark context.

        Return:
            ret (bool): return True if context is valid.
        """
        if isinstance(benchmark_context, BenchmarkContext) and benchmark_context.name:
            return True
        else:
            logger.error('Benchmark has invalid context')
            return False

    @classmethod
    def __get_benchmark_name(cls, benchmark_context):
        """Return the internal benchmark name.

        Args:
            benchmark_context (BenchmarkContext): the benchmark context.

        Return:
            benchmark_name (str): internal benchmark name, None means context is invalid.
        """
        if not cls.is_benchmark_context_valid(benchmark_context):
            return None

        benchmark_name = benchmark_context.name
        framework = benchmark_context.framework

        if framework != Framework.NONE:
            benchmark_name = framework.value + '-' + benchmark_name

        return benchmark_name

    @classmethod
120
121
    def create_benchmark_context(cls, name, platform=Platform.CPU, parameters='', framework=Framework.NONE):
        """Constructor.
122
123

        Args:
124
125
126
127
            name (str): name of benchmark in config file.
            platform (Platform): Platform types like Platform.CPU, Platform.CUDA, Platform.ROCM.
            parameters (str): predefined parameters of benchmark.
            framework (Framework): Framework types like Framework.PYTORCH, Framework.ONNX.
128
129

        Return:
130
            benchmark_context (BenchmarkContext): the benchmark context.
131
        """
132
        return BenchmarkContext(name, platform, parameters, framework)
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157

    @classmethod
    def get_benchmark_configurable_settings(cls, benchmark_context):
        """Get all configurable settings of benchmark.

        Args:
            benchmark_context (BenchmarkContext): the benchmark context.

        Return:
            All configurable settings in raw string, None means context is invalid or no benchmark is found.
        """
        if not cls.is_benchmark_context_valid(benchmark_context):
            return None

        benchmark_name = cls.__get_benchmark_name(benchmark_context)
        platform = benchmark_context.platform

        (benchmark_class, predefine_params) = cls.__select_benchmark(benchmark_name, platform)
        if benchmark_class:
            benchmark = benchmark_class(benchmark_name)
            benchmark.add_parser_arguments()
            return benchmark.get_configurable_settings()
        else:
            return None

158
159
160
161
162
163
164
165
166
167
168
169
170
    @classmethod
    def get_all_benchmark_predefine_settings(cls):
        """Get all registered benchmarks' predefine settings.

        Return:
            benchmark_params (dict[str, dict]): key is benchmark name,
              value is the dict with structure: {'parameter': default_value}.
        """
        benchmark_params = dict()
        for name in cls.benchmarks:
            benchmark_params[name] = cls.benchmarks[name]['predefine_param']
        return benchmark_params

171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
    @classmethod
    def launch_benchmark(cls, benchmark_context):
        """Select and Launch benchmark.

        Args:
            benchmark_context (BenchmarkContext): the benchmark context.

        Return:
            benchmark (Benchmark): the benchmark instance contains all results,
              None means context is invalid or no benchmark is found.
        """
        if not cls.is_benchmark_context_valid(benchmark_context):
            return None

        benchmark_name = cls.__get_benchmark_name(benchmark_context)

        benchmark = None
        if benchmark_name:
            platform = benchmark_context.platform
            parameters = benchmark_context.parameters
            (benchmark_class, predefine_params) = cls.__select_benchmark(benchmark_name, platform)
            if benchmark_class:
                if predefine_params:
                    parameters = predefine_params + ' ' + parameters

                benchmark = benchmark_class(benchmark_name, parameters)
guoshzhao's avatar
guoshzhao committed
197
                benchmark.run()
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245

        return benchmark

    @classmethod
    def is_benchmark_registered(cls, benchmark_context):
        """Check wether the benchmark is registered or not.

        Args:
            benchmark_context (BenchmarkContext): the benchmark context.

        Return:
            ret (bool): return True if context is valid and benchmark is registered.
        """
        if not cls.is_benchmark_context_valid(benchmark_context):
            return False

        benchmark_name = cls.__get_benchmark_name(benchmark_context)
        platform = benchmark_context.platform

        if cls.benchmarks.get(benchmark_name, {}).get(platform) is None:
            return False

        return True

    @classmethod
    def __select_benchmark(cls, name, platform):
        """Select benchmark by name and platform.

        Args:
            name (str): internal name of benchmark.
            platform (Platform): Platform type of benchmark.

        Return:
            benchmark_class (Benchmark): class object of benchmark.
            predefine_params (str): predefined parameters which is set when register the benchmark.
        """
        if name not in cls.benchmarks or platform not in cls.benchmarks[name]:
            logger.warning('Benchmark has no implementation, name: {}, platform: {}'.format(name, platform))
            return (None, None)

        (benchmark_class, predefine_params) = cls.benchmarks[name][platform]

        return (benchmark_class, predefine_params)

    @classmethod
    def clean_benchmarks(cls):
        """Clean up the benchmark registry."""
        cls.benchmarks.clear()