Unverified Commit 31b6f085 authored by guoshzhao's avatar guoshzhao Committed by GitHub
Browse files

Benchmarks: Code Revision - Support benchmark re-registration, keep the latest one. (#23)



* support benchmark re-registration.

* address comments
Co-authored-by: default avatarGuoshuai Zhao <guzhao@microsoft.com>
parent 5b9b5cc8
...@@ -6,7 +6,6 @@ ...@@ -6,7 +6,6 @@
from typing import Dict from typing import Dict
from superbench.common.utils import logger from superbench.common.utils import logger
from superbench.common.errors import DuplicateBenchmarkRegistrationError
from superbench.benchmarks import Platform, Framework, BenchmarkContext from superbench.benchmarks import Platform, Framework, BenchmarkContext
from superbench.benchmarks.base import Benchmark from superbench.benchmarks.base import Benchmark
...@@ -56,24 +55,17 @@ def register_benchmark(cls, name, class_def, parameters=None, platform=None): ...@@ -56,24 +55,17 @@ def register_benchmark(cls, name, class_def, parameters=None, platform=None):
name, platform_list, platform name, platform_list, platform
) )
) )
if platform in cls.benchmarks[name]:
logger.warning('Duplicate registration - benchmark: {}, platform: {}'.format(name, platform))
if platform not in cls.benchmarks[name]: cls.benchmarks[name][platform] = (class_def, parameters)
cls.benchmarks[name][platform] = (class_def, parameters)
else:
logger.log_and_raise(
DuplicateBenchmarkRegistrationError,
'Duplicate registration - benchmark: {}, platform: {}'.format(name, platform)
)
else: else:
# If not specified the tag, means the # If not specified the tag, means the benchmark works for all platforms.
# benchmark works for all platforms.
for p in Platform: for p in Platform:
if p not in cls.benchmarks[name]: if p in cls.benchmarks[name]:
cls.benchmarks[name][p] = (class_def, parameters) logger.warning('Duplicate registration - benchmark: {}, platform: {}'.format(name, p))
else:
logger.log_and_raise( cls.benchmarks[name][p] = (class_def, parameters)
DuplicateBenchmarkRegistrationError, 'Duplicate registration - benchmark: {}'.format(name)
)
@classmethod @classmethod
def is_benchmark_context_valid(cls, benchmark_context): def is_benchmark_context_valid(cls, benchmark_context):
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Exception types for SuperBench errors."""
class DuplicateBenchmarkRegistrationError(Exception):
"""An error is raised for duplicate benchmark registration."""
pass
...@@ -116,7 +116,6 @@ def create_benchmark(params='--num_steps=8'): ...@@ -116,7 +116,6 @@ def create_benchmark(params='--num_steps=8'):
assert (name) assert (name)
(benchmark_class, predefine_params) = BenchmarkRegistry._BenchmarkRegistry__select_benchmark(name, context.platform) (benchmark_class, predefine_params) = BenchmarkRegistry._BenchmarkRegistry__select_benchmark(name, context.platform)
assert (benchmark_class) assert (benchmark_class)
BenchmarkRegistry.clean_benchmarks()
return benchmark_class(name, predefine_params + ' ' + context.parameters) return benchmark_class(name, predefine_params + ' ' + context.parameters)
......
...@@ -63,8 +63,6 @@ def test_register_benchmark(): ...@@ -63,8 +63,6 @@ def test_register_benchmark():
context = BenchmarkContext('accumulation', platform) context = BenchmarkContext('accumulation', platform)
assert (BenchmarkRegistry.is_benchmark_registered(context)) assert (BenchmarkRegistry.is_benchmark_registered(context))
BenchmarkRegistry.clean_benchmarks()
# Register the benchmark for CUDA platform if use platform=Platform.CUDA. # Register the benchmark for CUDA platform if use platform=Platform.CUDA.
BenchmarkRegistry.register_benchmark('accumulation-cuda', AccumulationBenchmark, platform=Platform.CUDA) BenchmarkRegistry.register_benchmark('accumulation-cuda', AccumulationBenchmark, platform=Platform.CUDA)
context = BenchmarkContext('accumulation-cuda', Platform.CUDA) context = BenchmarkContext('accumulation-cuda', Platform.CUDA)
...@@ -72,8 +70,6 @@ def test_register_benchmark(): ...@@ -72,8 +70,6 @@ def test_register_benchmark():
context = BenchmarkContext('accumulation-cuda', Platform.ROCM) context = BenchmarkContext('accumulation-cuda', Platform.ROCM)
assert (BenchmarkRegistry.is_benchmark_registered(context) is False) assert (BenchmarkRegistry.is_benchmark_registered(context) is False)
BenchmarkRegistry.clean_benchmarks()
def test_is_benchmark_context_valid(): def test_is_benchmark_context_valid():
"""Test interface BenchmarkRegistry.is_benchmark_context_valid().""" """Test interface BenchmarkRegistry.is_benchmark_context_valid()."""
...@@ -102,8 +98,6 @@ def test_get_benchmark_name(): ...@@ -102,8 +98,6 @@ def test_get_benchmark_name():
name = BenchmarkRegistry._BenchmarkRegistry__get_benchmark_name(context) name = BenchmarkRegistry._BenchmarkRegistry__get_benchmark_name(context)
assert (name == benchmark_names[i]) assert (name == benchmark_names[i])
BenchmarkRegistry.clean_benchmarks()
def test_check_parameters(): def test_check_parameters():
"""Test interface BenchmarkRegistry.check_parameters().""" """Test interface BenchmarkRegistry.check_parameters()."""
...@@ -118,8 +112,6 @@ def test_check_parameters(): ...@@ -118,8 +112,6 @@ def test_check_parameters():
context = BenchmarkContext('accumulation', Platform.CPU, parameters='--lower=1') context = BenchmarkContext('accumulation', Platform.CPU, parameters='--lower=1')
assert (BenchmarkRegistry.check_parameters(context) is False) assert (BenchmarkRegistry.check_parameters(context) is False)
BenchmarkRegistry.clean_benchmarks()
def test_get_benchmark_configurable_settings(): def test_get_benchmark_configurable_settings():
"""Test BenchmarkRegistry interface. """Test BenchmarkRegistry interface.
...@@ -139,8 +131,6 @@ def test_get_benchmark_configurable_settings(): ...@@ -139,8 +131,6 @@ def test_get_benchmark_configurable_settings():
--upper_bound int The upper bound for accumulation.""" --upper_bound int The upper bound for accumulation."""
assert (settings == expected) assert (settings == expected)
BenchmarkRegistry.clean_benchmarks()
def test_launch_benchmark(): def test_launch_benchmark():
"""Test interface BenchmarkRegistry.launch_benchmark().""" """Test interface BenchmarkRegistry.launch_benchmark()."""
...@@ -194,10 +184,8 @@ def test_launch_benchmark(): ...@@ -194,10 +184,8 @@ def test_launch_benchmark():
) )
assert (result == expected) assert (result == expected)
# Failed to launch benchmark. # Failed to launch benchmark due to 'benchmark not found'.
context = BenchmarkContext( context = BenchmarkContext(
'accumulation', Platform.CPU, parameters='--lower_bound=1 --upper_bound=4', framework=Framework.PYTORCH 'accumulation-fail', Platform.CPU, parameters='--lower_bound=1 --upper_bound=4', framework=Framework.PYTORCH
) )
assert (BenchmarkRegistry.check_parameters(context) is False) assert (BenchmarkRegistry.check_parameters(context) is False)
BenchmarkRegistry.clean_benchmarks()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment